From 7a47197d8b61a134ed95aeddedf18a044e2e0891 Mon Sep 17 00:00:00 2001
From: ycabrer <43866176+ycabrer@users.noreply.github.com>
Date: Wed, 27 Oct 2021 04:32:10 -0600
Subject: [PATCH] Allow for usage of env var `K8S_HOST` in psql (#1026)

* Allow for usage of env var `K8S_HOST` in psql

* small typo fix

* typo fix
---
 cmd/database.go | 109 ++++++++++++++++++++++++++++++++++++------------
 1 file changed, 83 insertions(+), 26 deletions(-)

diff --git a/cmd/database.go b/cmd/database.go
index 0c04082..6732ede 100644
--- a/cmd/database.go
+++ b/cmd/database.go
@@ -11,36 +11,98 @@ import (
 	"gorm.io/gorm"
 )
 
-func savePgsql(jsonInfo string) {
-	envVars := map[string]string{
-		"PGSQL_HOST":     viper.GetString("PGSQL_HOST"),
-		"PGSQL_USER":     viper.GetString("PGSQL_USER"),
-		"PGSQL_DBNAME":   viper.GetString("PGSQL_DBNAME"),
-		"PGSQL_SSLMODE":  viper.GetString("PGSQL_SSLMODE"),
-		"PGSQL_PASSWORD": viper.GetString("PGSQL_PASSWORD"),
+type PsqlConnInfo struct {
+	Host     string
+	User     string
+	DbName   string
+	SslMode  string
+	Password string
+}
+
+func getPsqlConnInfo() (PsqlConnInfo, error) {
+	var host string
+	if value := viper.GetString("PGSQL_HOST"); value != "" {
+		host = value
+	} else {
+		return PsqlConnInfo{}, fmt.Errorf("%s_PGSQL_HOST env var is required", envVarsPrefix)
 	}
 
-	for k, v := range envVars {
-		if v == "" {
-			exitWithError(fmt.Errorf("environment variable %s is missing", envVarsPrefix+"_"+k))
-		}
+	var user string
+	if value := viper.GetString("PGSQL_USER"); value != "" {
+		user = value
+	} else {
+		return PsqlConnInfo{}, fmt.Errorf("%s_PGSQL_USER env var is required", envVarsPrefix)
+	}
+
+	var dbName string
+	if value := viper.GetString("PGSQL_DBNAME"); value != "" {
+		dbName = value
+	} else {
+		return PsqlConnInfo{}, fmt.Errorf("%s_PGSQL_USER env var is required", envVarsPrefix)
+	}
+
+	var sslMode string
+	if value := viper.GetString("PGSQL_SSLMODE"); value != "" {
+		sslMode = value
+	} else {
+		return PsqlConnInfo{}, fmt.Errorf("%s_PGSQL_SSLMODE env var is required", envVarsPrefix)
 	}
 
-	connInfo := fmt.Sprintf("host=%s user=%s dbname=%s sslmode=%s password=%s",
-		envVars["PGSQL_HOST"],
-		envVars["PGSQL_USER"],
-		envVars["PGSQL_DBNAME"],
-		envVars["PGSQL_SSLMODE"],
-		envVars["PGSQL_PASSWORD"],
+	var password string
+	if value := viper.GetString("PGSQL_PASSWORD"); value != "" {
+		password = value
+	} else {
+		return PsqlConnInfo{}, fmt.Errorf("%s_PGSQL_PASSWORD env var is required", envVarsPrefix)
+	}
+
+	return PsqlConnInfo{
+		Host:     host,
+		User:     user,
+		DbName:   dbName,
+		SslMode:  sslMode,
+		Password: password,
+	}, nil
+}
+
+func (c *PsqlConnInfo) toString() string {
+	return fmt.Sprintf("host=%s user=%s dbname=%s sslmode=%s password=%s",
+		c.Host,
+		c.User,
+		c.DbName,
+		c.SslMode,
+		c.Password,
 	)
+}
 
-	hostname, err := os.Hostname()
+func savePgsql(jsonInfo string) {
+	var hostname string
+	if value := viper.GetString("K8S_HOST"); value != "" {
+		// Adhere to the ScanHost column definition below
+		if len(value) > 63 {
+			exitWithError(fmt.Errorf("%s_K8S_HOST value's length must be less than 63 chars", envVarsPrefix))
+		}
+
+		hostname = value
+	} else {
+		host, err := os.Hostname()
+		if err != nil {
+			exitWithError(fmt.Errorf("received error looking up hostname: %s", err))
+		}
+
+		hostname = host
+	}
+
+	PsqlConnInfo, err := getPsqlConnInfo()
 	if err != nil {
-		exitWithError(fmt.Errorf("received error looking up hostname: %s", err))
+		exitWithError(err)
 	}
 
-	timestamp := time.Now()
+	db, err := gorm.Open(postgres.Open(PsqlConnInfo.toString()), &gorm.Config{})
+	if err != nil {
+		exitWithError(fmt.Errorf("received error connecting to database: %s", err))
+	}
 
+	timestamp := time.Now()
 	type ScanResult struct {
 		gorm.Model
 		ScanHost string    `gorm:"type:varchar(63) not null"` // https://www.ietf.org/rfc/rfc1035.txt
@@ -48,12 +110,7 @@ func savePgsql(jsonInfo string) {
 		ScanInfo string    `gorm:"type:jsonb not null"`
 	}
 
-	db, err := gorm.Open(postgres.Open(connInfo), &gorm.Config{})
-	if err != nil {
-		exitWithError(fmt.Errorf("received error connecting to database: %s", err))
-	}
-
 	db.Debug().AutoMigrate(&ScanResult{})
 	db.Save(&ScanResult{ScanHost: hostname, ScanTime: timestamp, ScanInfo: jsonInfo})
-	glog.V(2).Info(fmt.Sprintf("successfully stored result to: %s", envVars["PGSQL_HOST"]))
+	glog.V(2).Info(fmt.Sprintf("successfully stored result to: %s", PsqlConnInfo.Host))
 }
-- 
GitLab