From 8da914ad595d151eb9ac09b50a6eb003341f812f Mon Sep 17 00:00:00 2001 From: JonasS <jonass@dev.jsje.de> Date: Mon, 17 Jul 2023 23:03:08 +0200 Subject: [PATCH] bugfix: Do not crash on invalid lookups (closes #109, thanks @tniebergall) - check return value of lookup methods for nil, as documented in case something cannot be found) --- driver/cleanup.go | 2 +- driver/driver.go | 25 +++++++-------------- driver/hetzner_query.go | 48 ++++++++++++++++++++++++++++++++++------- 3 files changed, 49 insertions(+), 26 deletions(-) diff --git a/driver/cleanup.go b/driver/cleanup.go index 1f1ac48..6bb4a9d 100644 --- a/driver/cleanup.go +++ b/driver/cleanup.go @@ -42,7 +42,7 @@ func (d *Driver) destroyServer() error { return nil } - srv, err := d.getServerHandle() + srv, err := d.getServerHandleNullable() if err != nil { return errors.Wrap(err, "could not get server handle") } diff --git a/driver/driver.go b/driver/driver.go index 14180fd..ae2181a 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -55,8 +55,8 @@ type Driver struct { AdditionalKeyIDs []int cachedAdditionalKeys []*hcloud.SSHKey - WaitOnError int - WaitOnPolling int + WaitOnError int + WaitOnPolling int WaitForRunningTimeout int // internal housekeeping @@ -99,12 +99,12 @@ const ( defaultSSHPort = 22 defaultSSHUser = "root" - flagWaitOnError = "hetzner-wait-on-error" - defaultWaitOnError = 0 - flagWaitOnPolling = "hetzner-wait-on-polling" - defaultWaitOnPolling = 1 - flagWaitForRunningTimeout = "hetzner-wait-for-running-timeout" - defaultWaitForRunningTimeout = 0 + flagWaitOnError = "hetzner-wait-on-error" + defaultWaitOnError = 0 + flagWaitOnPolling = "hetzner-wait-on-polling" + defaultWaitOnPolling = 1 + flagWaitForRunningTimeout = "hetzner-wait-for-running-timeout" + defaultWaitForRunningTimeout = 0 legacyFlagUserDataFromFile = "hetzner-user-data-from-file" legacyFlagDisablePublic4 = "hetzner-disable-public-4" @@ -608,9 +608,6 @@ func (d *Driver) Start() error { if err != nil { return errors.Wrap(err, "could not get server handle") } - if srv == nil { - return errors.New("server not found") - } act, _, err := d.getClient().Server.Poweron(context.Background(), srv) if err != nil { @@ -628,9 +625,6 @@ func (d *Driver) Stop() error { if err != nil { return errors.Wrap(err, "could not get server handle") } - if srv == nil { - return errors.New("server not found") - } act, _, err := d.getClient().Server.Shutdown(context.Background(), srv) if err != nil { @@ -648,9 +642,6 @@ func (d *Driver) Kill() error { if err != nil { return errors.Wrap(err, "could not get server handle") } - if srv == nil { - return errors.New("server not found") - } act, _, err := d.getClient().Server.Poweroff(context.Background(), srv) if err != nil { diff --git a/driver/hetzner_query.go b/driver/hetzner_query.go index 4b2b87b..4a50f79 100644 --- a/driver/hetzner_query.go +++ b/driver/hetzner_query.go @@ -21,7 +21,10 @@ func (d *Driver) getLocation() (*hcloud.Location, error) { location, _, err := d.getClient().Location.GetByName(context.Background(), d.Location) if err != nil { - return location, errors.Wrap(err, "could not get location by name") + return nil, errors.Wrap(err, "could not get location by name") + } + if location == nil { + return nil, fmt.Errorf("unknown location: %v", d.Location) } d.cachedLocation = location return location, nil @@ -34,7 +37,10 @@ func (d *Driver) getType() (*hcloud.ServerType, error) { stype, _, err := d.getClient().ServerType.GetByName(context.Background(), d.Type) if err != nil { - return stype, errors.Wrap(err, "could not get type by name") + return nil, errors.Wrap(err, "could not get type by name") + } + if stype == nil { + return nil, fmt.Errorf("unknown server type: %v", d.Type) } d.cachedType = stype return instrumented(stype), nil @@ -51,7 +57,10 @@ func (d *Driver) getImage() (*hcloud.Image, error) { if d.ImageID != 0 { image, _, err = d.getClient().Image.GetByID(context.Background(), d.ImageID) if err != nil { - return image, errors.Wrap(err, fmt.Sprintf("could not get image by id %v", d.ImageID)) + return nil, errors.Wrap(err, fmt.Sprintf("could not get image by id %v", d.ImageID)) + } + if image == nil { + return nil, fmt.Errorf("image id not found: %v", d.ImageID) } } else { arch, err := d.getImageArchitectureForLookup() @@ -61,7 +70,10 @@ func (d *Driver) getImage() (*hcloud.Image, error) { image, _, err = d.getClient().Image.GetByNameAndArchitecture(context.Background(), d.Image, arch) if err != nil { - return image, errors.Wrap(err, fmt.Sprintf("could not get image by name %v", d.Image)) + return nil, errors.Wrap(err, fmt.Sprintf("could not get image by name %v", d.Image)) + } + if image == nil { + return nil, fmt.Errorf("image not found: %v[%v]", d.Image, arch) } } @@ -87,12 +99,15 @@ func (d *Driver) getKey() (*hcloud.SSHKey, error) { return d.cachedKey, nil } - stype, _, err := d.getClient().SSHKey.GetByID(context.Background(), d.KeyID) + key, _, err := d.getClient().SSHKey.GetByID(context.Background(), d.KeyID) if err != nil { - return stype, errors.Wrap(err, "could not get sshkey by ID") + return nil, errors.Wrap(err, "could not get sshkey by ID") } - d.cachedKey = stype - return instrumented(stype), nil + if key == nil { + return nil, fmt.Errorf("key not found: %v", d.KeyID) + } + d.cachedKey = key + return instrumented(key), nil } func (d *Driver) getRemoteKeyWithSameFingerprint(publicKeyBytes []byte) (*hcloud.SSHKey, error) { @@ -107,10 +122,24 @@ func (d *Driver) getRemoteKeyWithSameFingerprint(publicKeyBytes []byte) (*hcloud if err != nil { return remoteKey, errors.Wrap(err, "could not get sshkey by fingerprint") } + if remoteKey == nil { + return nil, fmt.Errorf("key not found by fingerprint: %v", fp) + } return instrumented(remoteKey), nil } func (d *Driver) getServerHandle() (*hcloud.Server, error) { + srv, err := d.getServerHandleNullable() + if err != nil { + return nil, err + } + if srv == nil { + return nil, fmt.Errorf("server does not exist: %v", d.ServerID) + } + return srv, nil +} + +func (d *Driver) getServerHandleNullable() (*hcloud.Server, error) { if d.cachedServer != nil { return d.cachedServer, nil } @@ -134,6 +163,9 @@ func (d *Driver) waitForAction(a *hcloud.Action) error { if err != nil { return errors.Wrap(err, "could not get client by ID") } + if act == nil { + return fmt.Errorf("action not found: %v", a.ID) + } if act.Status == hcloud.ActionStatusSuccess { log.Debugf(" -> finished %s[%d]", act.Command, act.ID) -- GitLab