diff --git a/internal/commands/scan.go b/internal/commands/scan.go index 464549c49..23ce2355a 100644 --- a/internal/commands/scan.go +++ b/internal/commands/scan.go @@ -1525,6 +1525,12 @@ func runCreateScanCommand( if timeoutMinutes < 0 { return errors.Errorf("--%s should be equal or higher than 0", commonParams.ScanTimeoutFlag) } + threshold, _ := cmd.Flags().GetString(commonParams.Threshold) + thresholdMap := parseThreshold(threshold) + err = validateThresholds(thresholdMap) + if err != nil { + return err + } scanModel, zipFilePath, err := createScanModel( cmd, uploadsWrapper, @@ -1588,7 +1594,7 @@ func runCreateScanCommand( return err } - err = applyThreshold(cmd, resultsWrapper, scanResponseModel) + err = applyThreshold(cmd, resultsWrapper, scanResponseModel, thresholdMap) if err != nil { return err } @@ -1830,14 +1836,12 @@ func applyThreshold( cmd *cobra.Command, resultsWrapper wrappers.ResultsWrapper, scanResponseModel *wrappers.ScanResponseModel, + thresholdMap map[string]int, ) error { - threshold, _ := cmd.Flags().GetString(commonParams.Threshold) - if strings.TrimSpace(threshold) == "" { + if len(thresholdMap) == 0 { return nil } - thresholdMap := parseThreshold(threshold) - summaryMap, err := getSummaryThresholdMap(resultsWrapper, scanResponseModel) if err != nil { return err @@ -1872,25 +1876,22 @@ func applyThreshold( } func parseThreshold(threshold string) map[string]int { + if strings.TrimSpace(threshold) == "" { + return nil + } thresholdMap := make(map[string]int) if threshold != "" { threshold = strings.ReplaceAll(strings.ReplaceAll(threshold, " ", ""), ",", ";") thresholdLimits := strings.Split(strings.ToLower(threshold), ";") for _, limits := range thresholdLimits { - limit := strings.Split(limits, "=") - engineName := limit[0] - engineName = strings.Replace(engineName, commonParams.KicsType, commonParams.IacType, 1) - if len(limit) > 1 { - intLimit, err := strconv.Atoi(limit[1]) - if err != nil { - log.Println("Error parsing threshold limit: ", err) - } else { - thresholdMap[engineName] = intLimit - } + engineName, intLimit, err := parseThresholdLimit(limits) + if err != nil { + log.Printf("%s", err) + } else { + thresholdMap[engineName] = intLimit } } } - return thresholdMap } @@ -2495,6 +2496,35 @@ func validateCreateScanFlags(cmd *cobra.Command) error { return nil } +func validateThresholds(thresholdMap map[string]int) error { + var errMsgBuilder strings.Builder + + for engineName, limit := range thresholdMap { + if limit < 1 { + errMsgBuilder.WriteString(errors.Errorf("Invalid value for threshold limit %s. Threshold should be greater or equal to 1.\n", engineName).Error()) + } + } + + errMsg := errMsgBuilder.String() + if errMsg != "" { + return errors.New(errMsg) + } + return nil +} + +func parseThresholdLimit(limit string) (engineName string, intLimit int, err error) { + parts := strings.Split(limit, "=") + engineName = strings.Replace(parts[0], commonParams.KicsType, commonParams.IacType, 1) + if len(parts) <= 1 { + return engineName, 0, errors.Errorf("Error parsing threshold limit: missing values\n") + } + intLimit, err = strconv.Atoi(parts[1]) + if err != nil { + err = errors.Errorf("%s: Error parsing threshold limit: %v\n", engineName, err) + } + return engineName, intLimit, err +} + func validateBooleanString(value string) error { if value == "" { return nil diff --git a/internal/commands/scan_test.go b/internal/commands/scan_test.go index da25cc89f..cf553b3fc 100644 --- a/internal/commands/scan_test.go +++ b/internal/commands/scan_test.go @@ -678,3 +678,88 @@ func TestCreateScanProjectTagsCheckResendToScan(t *testing.T) { err := executeTestCommand(cmd, baseArgs...) assert.NilError(t, err) } + +func Test_parseThresholdLimit(t *testing.T) { + type args struct { + limit string + } + tests := []struct { + name string + args args + wantEngineName string + wantIntLimit int + wantErr bool + }{ + { + name: "Test parseThresholdLimit with valid limit Success", + args: args{limit: "sast-low=1"}, + wantEngineName: "sast-low", + wantIntLimit: 1, + wantErr: false, + }, + { + name: "Test parseThresholdLimit with invalid limit Fail", + args: args{limit: "kics-medium=error"}, + wantEngineName: "iac-security-medium", + wantIntLimit: 0, + wantErr: true, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + gotEngineName, gotIntLimit, err := parseThresholdLimit(tt.args.limit) + if (err != nil) != tt.wantErr { + t.Errorf("parseThresholdLimit() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotEngineName != tt.wantEngineName { + t.Errorf("parseThresholdLimit() gotEngineName = %v, want %v", gotEngineName, tt.wantEngineName) + } + if gotIntLimit != tt.wantIntLimit { + t.Errorf("parseThresholdLimit() gotIntLimit = %v, want %v", gotIntLimit, tt.wantIntLimit) + } + }) + } +} + +func Test_validateThresholds(t *testing.T) { + tests := []struct { + name string + thresholdMap map[string]int + wantErr bool + }{ + { + name: "Valid Thresholds", + thresholdMap: map[string]int{ + "sast-medium": 5, + "sast-high": 10, + }, + wantErr: false, + }, + { + name: "Invalid Threshold - Negative Limit", + thresholdMap: map[string]int{ + "sca-medium": -3, + }, + wantErr: true, + }, + { + name: "Invalid Threshold - Zero Limit", + thresholdMap: map[string]int{ + "sca-high": 0, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + err := validateThresholds(tt.thresholdMap) + if (err != nil) != tt.wantErr { + t.Errorf("validateThresholds() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}