From 8da914ad595d151eb9ac09b50a6eb003341f812f Mon Sep 17 00:00:00 2001
From: JonasS <jonass@dev.jsje.de>
Date: Mon, 17 Jul 2023 23:03:08 +0200
Subject: [PATCH] bugfix: Do not crash on invalid lookups (closes #109, thanks
 @tniebergall) - check return value of lookup methods for nil, as documented
 in case something cannot be found)

---
 driver/cleanup.go       |  2 +-
 driver/driver.go        | 25 +++++++--------------
 driver/hetzner_query.go | 48 ++++++++++++++++++++++++++++++++++-------
 3 files changed, 49 insertions(+), 26 deletions(-)

diff --git a/driver/cleanup.go b/driver/cleanup.go
index 1f1ac48..6bb4a9d 100644
--- a/driver/cleanup.go
+++ b/driver/cleanup.go
@@ -42,7 +42,7 @@ func (d *Driver) destroyServer() error {
 		return nil
 	}
 
-	srv, err := d.getServerHandle()
+	srv, err := d.getServerHandleNullable()
 	if err != nil {
 		return errors.Wrap(err, "could not get server handle")
 	}
diff --git a/driver/driver.go b/driver/driver.go
index 14180fd..ae2181a 100644
--- a/driver/driver.go
+++ b/driver/driver.go
@@ -55,8 +55,8 @@ type Driver struct {
 	AdditionalKeyIDs     []int
 	cachedAdditionalKeys []*hcloud.SSHKey
 
-	WaitOnError int
-	WaitOnPolling int
+	WaitOnError           int
+	WaitOnPolling         int
 	WaitForRunningTimeout int
 
 	// internal housekeeping
@@ -99,12 +99,12 @@ const (
 	defaultSSHPort = 22
 	defaultSSHUser = "root"
 
-	flagWaitOnError                = "hetzner-wait-on-error"
-	defaultWaitOnError             = 0
-	flagWaitOnPolling              = "hetzner-wait-on-polling"
-	defaultWaitOnPolling           = 1
-	flagWaitForRunningTimeout      = "hetzner-wait-for-running-timeout"
-	defaultWaitForRunningTimeout   = 0
+	flagWaitOnError              = "hetzner-wait-on-error"
+	defaultWaitOnError           = 0
+	flagWaitOnPolling            = "hetzner-wait-on-polling"
+	defaultWaitOnPolling         = 1
+	flagWaitForRunningTimeout    = "hetzner-wait-for-running-timeout"
+	defaultWaitForRunningTimeout = 0
 
 	legacyFlagUserDataFromFile = "hetzner-user-data-from-file"
 	legacyFlagDisablePublic4   = "hetzner-disable-public-4"
@@ -608,9 +608,6 @@ func (d *Driver) Start() error {
 	if err != nil {
 		return errors.Wrap(err, "could not get server handle")
 	}
-	if srv == nil {
-		return errors.New("server not found")
-	}
 
 	act, _, err := d.getClient().Server.Poweron(context.Background(), srv)
 	if err != nil {
@@ -628,9 +625,6 @@ func (d *Driver) Stop() error {
 	if err != nil {
 		return errors.Wrap(err, "could not get server handle")
 	}
-	if srv == nil {
-		return errors.New("server not found")
-	}
 
 	act, _, err := d.getClient().Server.Shutdown(context.Background(), srv)
 	if err != nil {
@@ -648,9 +642,6 @@ func (d *Driver) Kill() error {
 	if err != nil {
 		return errors.Wrap(err, "could not get server handle")
 	}
-	if srv == nil {
-		return errors.New("server not found")
-	}
 
 	act, _, err := d.getClient().Server.Poweroff(context.Background(), srv)
 	if err != nil {
diff --git a/driver/hetzner_query.go b/driver/hetzner_query.go
index 4b2b87b..4a50f79 100644
--- a/driver/hetzner_query.go
+++ b/driver/hetzner_query.go
@@ -21,7 +21,10 @@ func (d *Driver) getLocation() (*hcloud.Location, error) {
 
 	location, _, err := d.getClient().Location.GetByName(context.Background(), d.Location)
 	if err != nil {
-		return location, errors.Wrap(err, "could not get location by name")
+		return nil, errors.Wrap(err, "could not get location by name")
+	}
+	if location == nil {
+		return nil, fmt.Errorf("unknown location: %v", d.Location)
 	}
 	d.cachedLocation = location
 	return location, nil
@@ -34,7 +37,10 @@ func (d *Driver) getType() (*hcloud.ServerType, error) {
 
 	stype, _, err := d.getClient().ServerType.GetByName(context.Background(), d.Type)
 	if err != nil {
-		return stype, errors.Wrap(err, "could not get type by name")
+		return nil, errors.Wrap(err, "could not get type by name")
+	}
+	if stype == nil {
+		return nil, fmt.Errorf("unknown server type: %v", d.Type)
 	}
 	d.cachedType = stype
 	return instrumented(stype), nil
@@ -51,7 +57,10 @@ func (d *Driver) getImage() (*hcloud.Image, error) {
 	if d.ImageID != 0 {
 		image, _, err = d.getClient().Image.GetByID(context.Background(), d.ImageID)
 		if err != nil {
-			return image, errors.Wrap(err, fmt.Sprintf("could not get image by id %v", d.ImageID))
+			return nil, errors.Wrap(err, fmt.Sprintf("could not get image by id %v", d.ImageID))
+		}
+		if image == nil {
+			return nil, fmt.Errorf("image id not found: %v", d.ImageID)
 		}
 	} else {
 		arch, err := d.getImageArchitectureForLookup()
@@ -61,7 +70,10 @@ func (d *Driver) getImage() (*hcloud.Image, error) {
 
 		image, _, err = d.getClient().Image.GetByNameAndArchitecture(context.Background(), d.Image, arch)
 		if err != nil {
-			return image, errors.Wrap(err, fmt.Sprintf("could not get image by name %v", d.Image))
+			return nil, errors.Wrap(err, fmt.Sprintf("could not get image by name %v", d.Image))
+		}
+		if image == nil {
+			return nil, fmt.Errorf("image not found: %v[%v]", d.Image, arch)
 		}
 	}
 
@@ -87,12 +99,15 @@ func (d *Driver) getKey() (*hcloud.SSHKey, error) {
 		return d.cachedKey, nil
 	}
 
-	stype, _, err := d.getClient().SSHKey.GetByID(context.Background(), d.KeyID)
+	key, _, err := d.getClient().SSHKey.GetByID(context.Background(), d.KeyID)
 	if err != nil {
-		return stype, errors.Wrap(err, "could not get sshkey by ID")
+		return nil, errors.Wrap(err, "could not get sshkey by ID")
 	}
-	d.cachedKey = stype
-	return instrumented(stype), nil
+	if key == nil {
+		return nil, fmt.Errorf("key not found: %v", d.KeyID)
+	}
+	d.cachedKey = key
+	return instrumented(key), nil
 }
 
 func (d *Driver) getRemoteKeyWithSameFingerprint(publicKeyBytes []byte) (*hcloud.SSHKey, error) {
@@ -107,10 +122,24 @@ func (d *Driver) getRemoteKeyWithSameFingerprint(publicKeyBytes []byte) (*hcloud
 	if err != nil {
 		return remoteKey, errors.Wrap(err, "could not get sshkey by fingerprint")
 	}
+	if remoteKey == nil {
+		return nil, fmt.Errorf("key not found by fingerprint: %v", fp)
+	}
 	return instrumented(remoteKey), nil
 }
 
 func (d *Driver) getServerHandle() (*hcloud.Server, error) {
+	srv, err := d.getServerHandleNullable()
+	if err != nil {
+		return nil, err
+	}
+	if srv == nil {
+		return nil, fmt.Errorf("server does not exist: %v", d.ServerID)
+	}
+	return srv, nil
+}
+
+func (d *Driver) getServerHandleNullable() (*hcloud.Server, error) {
 	if d.cachedServer != nil {
 		return d.cachedServer, nil
 	}
@@ -134,6 +163,9 @@ func (d *Driver) waitForAction(a *hcloud.Action) error {
 		if err != nil {
 			return errors.Wrap(err, "could not get client by ID")
 		}
+		if act == nil {
+			return fmt.Errorf("action not found: %v", a.ID)
+		}
 
 		if act.Status == hcloud.ActionStatusSuccess {
 			log.Debugf(" -> finished %s[%d]", act.Command, act.ID)
-- 
GitLab