From 5fb133cd02e22bfcf65bd790df1371a6e4b781a2 Mon Sep 17 00:00:00 2001
From: Daniel Pacak <pacak.daniel@gmail.com>
Date: Wed, 1 May 2019 21:43:06 +0200
Subject: [PATCH] Adjust the semantics of scored and unscored flags

---
 check/check_test.go    | 14 ++++++++++++++
 check/controls.go      |  6 ++++++
 check/controls_test.go | 41 +++++++++++++++++++++++++++++++++++------
 cmd/common.go          | 39 ++++++++++++++-------------------------
 cmd/common_test.go     | 15 ++++++++++++---
 cmd/root.go            |  4 ++--
 6 files changed, 83 insertions(+), 36 deletions(-)

diff --git a/check/check_test.go b/check/check_test.go
index 27c3c64..46515a9 100644
--- a/check/check_test.go
+++ b/check/check_test.go
@@ -1,3 +1,17 @@
+// Copyright © 2017 Aqua Security Software Ltd. <info@aquasec.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
 package check
 
 import (
diff --git a/check/controls.go b/check/controls.go
index 84635a1..0a6183c 100644
--- a/check/controls.go
+++ b/check/controls.go
@@ -17,6 +17,7 @@ package check
 import (
 	"encoding/json"
 	"fmt"
+	"github.com/golang/glog"
 	"gopkg.in/yaml.v2"
 )
 
@@ -49,6 +50,7 @@ type Summary struct {
 	Info int `json:"total_info"`
 }
 
+// Predicate a predicate on the given Group and Check arguments.
 type Predicate func(group *Group, check *Check) bool
 
 // NewControls instantiates a new master Controls object.
@@ -134,6 +136,8 @@ func summarize(controls *Controls, state State) {
 		controls.Summary.Warn++
 	case INFO:
 		controls.Summary.Info++
+	default:
+		glog.Warningf("Unrecognized state %s", state)
 	}
 }
 
@@ -147,5 +151,7 @@ func summarizeGroup(group *Group, state State) {
 		group.Warn++
 	case INFO:
 		group.Info++
+	default:
+		glog.Warningf("Unrecognized state %s", state)
 	}
 }
diff --git a/check/controls_test.go b/check/controls_test.go
index d0d480c..4c6a6aa 100644
--- a/check/controls_test.go
+++ b/check/controls_test.go
@@ -1,3 +1,17 @@
+// Copyright © 2017 Aqua Security Software Ltd. <info@aquasec.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
 package check
 
 import (
@@ -79,7 +93,7 @@ groups:
 
 func TestControls_RunChecks(t *testing.T) {
 
-	t.Run("Should run all checks", func(t *testing.T) {
+	t.Run("Should run checks matching the filter and update summaries", func(t *testing.T) {
 		// given
 		runner := new(mockRunner)
 		// and
@@ -108,15 +122,30 @@ groups:
 		// then
 		assert.Equal(t, 2, len(controls.Groups))
 		// and
-		assert.Equal(t, "G1", controls.Groups[0].ID)
-		assert.Equal(t, "G1/C1", controls.Groups[0].Checks[0].ID)
+		G1 := controls.Groups[0]
+		assert.Equal(t, "G1", G1.ID)
+		assert.Equal(t, "G1/C1", G1.Checks[0].ID)
+		assertEqualGroupSummary(t, 1, 0, 0, 0, G1)
 		// and
-		assert.Equal(t, "G2", controls.Groups[1].ID)
-		assert.Equal(t, "G2/C1", controls.Groups[1].Checks[0].ID)
+		G2 := controls.Groups[1]
+		assert.Equal(t, "G2", G2.ID)
+		assert.Equal(t, "G2/C1", G2.Checks[0].ID)
+		assertEqualGroupSummary(t, 0, 1, 0, 0, G2)
 		// and
-		// TODO We can assert that group and controls summaries are updated.
+		assert.Equal(t, 1, controls.Summary.Pass)
+		assert.Equal(t, 1, controls.Summary.Fail)
+		assert.Equal(t, 0, controls.Summary.Info)
+		assert.Equal(t, 0, controls.Summary.Warn)
 		// and
 		runner.AssertExpectations(t)
 	})
 
 }
+
+func assertEqualGroupSummary(t *testing.T, pass, fail, info, warn int, actual *Group) {
+	t.Helper()
+	assert.Equal(t, pass, actual.Pass)
+	assert.Equal(t, fail, actual.Fail)
+	assert.Equal(t, info, actual.Info)
+	assert.Equal(t, warn, actual.Warn)
+}
diff --git a/cmd/common.go b/cmd/common.go
index cbc8f91..ef09d82 100644
--- a/cmd/common.go
+++ b/cmd/common.go
@@ -25,15 +25,11 @@ import (
 	"github.com/spf13/viper"
 )
 
-var (
-	errmsgs string
-)
-
-// NewRunFilter constructs a Predicate based on FilterOptions which determines whether tested Checks should be run or not.
-func NewRunFilter(opts FilterOpts) check.Predicate {
+// NewRunFilter constructs a Predicate based on FilterOpts which determines whether tested Checks should be run or not.
+func NewRunFilter(opts FilterOpts) (check.Predicate, error) {
 
 	if opts.CheckList != "" && opts.GroupList != "" {
-		exitWithError(fmt.Errorf("group option and check option can't be used together"))
+		return nil, fmt.Errorf("group option and check option can't be used together")
 	}
 
 	var groupIDs map[string]bool
@@ -47,31 +43,21 @@ func NewRunFilter(opts FilterOpts) check.Predicate {
 	}
 
 	return func(g *check.Group, c *check.Check) bool {
+		var test = true
 		if len(groupIDs) > 0 {
 			_, ok := groupIDs[g.ID]
-			if !ok {
-				return false
-			}
+			test = test && ok
 		}
 
 		if len(checkIDs) > 0 {
 			_, ok := checkIDs[c.ID]
-			if !ok {
-				return false
-			}
+			test = test && ok
 		}
 
-		if opts.Scored && opts.Unscored {
-			return true
-		}
-		if opts.Scored {
-			return c.Scored
-		}
-		if opts.Unscored {
-			return !c.Scored
-		}
-		return true
-	}
+		test = test && (opts.Scored && c.Scored || opts.Unscored && !c.Scored)
+
+		return test
+	}, nil
 }
 
 func runChecks(nodetype check.NodeType) {
@@ -111,7 +97,10 @@ func runChecks(nodetype check.NodeType) {
 	}
 
 	runner := check.NewRunner()
-	filter := NewRunFilter(filterOpts)
+	filter, err := NewRunFilter(filterOpts)
+	if err != nil {
+		exitWithError(fmt.Errorf("error setting up run filter: %v", err))
+	}
 
 	summary = controls.RunChecks(runner, filter)
 
diff --git a/cmd/common_test.go b/cmd/common_test.go
index 40fc906..b014e12 100644
--- a/cmd/common_test.go
+++ b/cmd/common_test.go
@@ -64,7 +64,7 @@ func TestNewRunFilter(t *testing.T) {
 
 		{
 			Name:       "Should return true when group flag contains group's ID",
-			FilterOpts: FilterOpts{GroupList: "G1,G2,G3"},
+			FilterOpts: FilterOpts{Scored: true, Unscored: true, GroupList: "G1,G2,G3"},
 			Group:      &check.Group{ID: "G2"},
 			Check:      &check.Check{},
 			Expected:   true,
@@ -79,7 +79,7 @@ func TestNewRunFilter(t *testing.T) {
 
 		{
 			Name:       "Should return true when check flag contains check's ID",
-			FilterOpts: FilterOpts{CheckList: "C1,C2,C3"},
+			FilterOpts: FilterOpts{Scored: true, Unscored: true, CheckList: "C1,C2,C3"},
 			Group:      &check.Group{},
 			Check:      &check.Check{ID: "C2"},
 			Expected:   true,
@@ -95,9 +95,18 @@ func TestNewRunFilter(t *testing.T) {
 
 	for _, testCase := range testCases {
 		t.Run(testCase.Name, func(t *testing.T) {
-			filter := NewRunFilter(testCase.FilterOpts)
+			filter, _ := NewRunFilter(testCase.FilterOpts)
 			assert.Equal(t, testCase.Expected, filter(testCase.Group, testCase.Check))
 		})
 	}
 
+	t.Run("Should return error when both group and check flags are used", func(t *testing.T) {
+		// given
+		opts := FilterOpts{GroupList: "G1", CheckList: "C1"}
+		// when
+		_, err := NewRunFilter(opts)
+		// then
+		assert.EqualError(t, err, "group option and check option can't be used together")
+	})
+
 }
diff --git a/cmd/root.go b/cmd/root.go
index 1a3e844..2f481ba 100644
--- a/cmd/root.go
+++ b/cmd/root.go
@@ -85,8 +85,8 @@ func init() {
 	RootCmd.PersistentFlags().BoolVar(&noRemediations, "noremediations", false, "Disable printing of remediations section")
 	RootCmd.PersistentFlags().BoolVar(&jsonFmt, "json", false, "Prints the results as JSON")
 	RootCmd.PersistentFlags().BoolVar(&pgSQL, "pgsql", false, "Save the results to PostgreSQL")
-	RootCmd.PersistentFlags().BoolVar(&filterOpts.Scored, "scored", false, "Run only scored CIS checks")
-	RootCmd.PersistentFlags().BoolVar(&filterOpts.Unscored, "unscored", false, "Run only unscored CIS checks")
+	RootCmd.PersistentFlags().BoolVar(&filterOpts.Scored, "scored", true, "Run the scored CIS checks")
+	RootCmd.PersistentFlags().BoolVar(&filterOpts.Unscored, "unscored", true, "Run the unscored CIS checks")
 
 	RootCmd.PersistentFlags().StringVarP(
 		&filterOpts.CheckList,
-- 
GitLab