From b4237ccb737104dbe01c4f026c2f07b2ff795fb1 Mon Sep 17 00:00:00 2001
From: Liz Rice <liz@lizrice.com>
Date: Fri, 23 Jun 2017 12:04:46 +0100
Subject: [PATCH] Better error handling when reading YAML files

---
 check/controls.go  | 14 +++++---------
 check/test_test.go |  5 ++++-
 cmd/common.go      |  8 ++++++--
 3 files changed, 15 insertions(+), 12 deletions(-)

diff --git a/check/controls.go b/check/controls.go
index 4e7cde5..dfea006 100644
--- a/check/controls.go
+++ b/check/controls.go
@@ -17,7 +17,6 @@ package check
 import (
 	"encoding/json"
 	"fmt"
-	"os"
 
 	yaml "gopkg.in/yaml.v2"
 )
@@ -46,19 +45,16 @@ type Summary struct {
 }
 
 // NewControls instantiates a new master Controls object.
-func NewControls(t NodeType, in []byte) *Controls {
-	var err error
+func NewControls(t NodeType, in []byte) (*Controls, error) {
 	c := new(Controls)
 
-	err = yaml.Unmarshal(in, c)
+	err := yaml.Unmarshal(in, c)
 	if err != nil {
-		fmt.Fprintf(os.Stderr, "%s\n", err)
-		os.Exit(1)
+		return nil, fmt.Errorf("failed to unmarshal YAML: %s", err)
 	}
 
 	if t != c.Type {
-		fmt.Fprintf(os.Stderr, "non-%s controls file specified\n", t)
-		os.Exit(1)
+		return nil, fmt.Errorf("non-%s controls file specified", t)
 	}
 
 	// Prepare audit commands
@@ -68,7 +64,7 @@ func NewControls(t NodeType, in []byte) *Controls {
 		}
 	}
 
-	return c
+	return c, nil
 }
 
 // RunGroup runs all checks in a group.
diff --git a/check/test_test.go b/check/test_test.go
index 8e495cc..76774f9 100644
--- a/check/test_test.go
+++ b/check/test_test.go
@@ -30,7 +30,10 @@ func init() {
 	if err != nil {
 		panic("Failed reading test data: " + err.Error())
 	}
-	controls = NewControls(MASTER, in)
+	controls, err = NewControls(MASTER, in)
+	if err != nil {
+		panic("Failed creating test controls: " + err.Error())
+	}
 }
 
 func TestTestExecute(t *testing.T) {
diff --git a/cmd/common.go b/cmd/common.go
index 63940d4..8362ea2 100644
--- a/cmd/common.go
+++ b/cmd/common.go
@@ -88,7 +88,7 @@ func runChecks(t check.NodeType) {
 
 	in, err := ioutil.ReadFile(file)
 	if err != nil {
-		fmt.Fprintf(os.Stderr, "error opening %s controls file: %s\n", t, err)
+		fmt.Fprintf(os.Stderr, "error opening %s controls file: %v\n", t, err)
 		os.Exit(1)
 	}
 
@@ -97,7 +97,11 @@ func runChecks(t check.NodeType) {
 	s = strings.Replace(s, "$etcdConfDir", viper.Get("etcdConfDir").(string), -1)
 	s = strings.Replace(s, "$flanneldConfDir", viper.Get("flanneldConfDir").(string), -1)
 
-	controls := check.NewControls(t, []byte(s))
+	controls, err := check.NewControls(t, []byte(s))
+	if err != nil {
+		fmt.Fprintf(os.Stderr, "error setting up %s controls: %v\n", t, err)
+		os.Exit(1)
+	}
 
 	if groupList != "" && checkList == "" {
 		ids := cleanIDs(groupList)
-- 
GitLab