Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLI | Add threshold user input validation (AST-40169) #715

Merged
merged 11 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 46 additions & 16 deletions internal/commands/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,12 @@ func runCreateScanCommand(
if timeoutMinutes < 0 {
return errors.Errorf("--%s should be equal or higher than 0", commonParams.ScanTimeoutFlag)
}
threshold, _ := cmd.Flags().GetString(commonParams.Threshold)
thresholdMap := parseThreshold(threshold)
err = validateThresholds(thresholdMap)
if err != nil {
return err
}
scanModel, zipFilePath, err := createScanModel(
cmd,
uploadsWrapper,
Expand Down Expand Up @@ -1588,7 +1594,7 @@ func runCreateScanCommand(
return err
}

err = applyThreshold(cmd, resultsWrapper, scanResponseModel)
err = applyThreshold(cmd, resultsWrapper, scanResponseModel, thresholdMap)
if err != nil {
return err
}
Expand Down Expand Up @@ -1830,14 +1836,12 @@ func applyThreshold(
cmd *cobra.Command,
resultsWrapper wrappers.ResultsWrapper,
scanResponseModel *wrappers.ScanResponseModel,
thresholdMap map[string]int,
) error {
threshold, _ := cmd.Flags().GetString(commonParams.Threshold)
if strings.TrimSpace(threshold) == "" {
if len(thresholdMap) == 0 {
return nil
}

thresholdMap := parseThreshold(threshold)

summaryMap, err := getSummaryThresholdMap(resultsWrapper, scanResponseModel)
if err != nil {
return err
Expand Down Expand Up @@ -1872,25 +1876,22 @@ func applyThreshold(
}

func parseThreshold(threshold string) map[string]int {
if strings.TrimSpace(threshold) == "" {
return nil
}
thresholdMap := make(map[string]int)
if threshold != "" {
threshold = strings.ReplaceAll(strings.ReplaceAll(threshold, " ", ""), ",", ";")
thresholdLimits := strings.Split(strings.ToLower(threshold), ";")
for _, limits := range thresholdLimits {
limit := strings.Split(limits, "=")
engineName := limit[0]
engineName = strings.Replace(engineName, commonParams.KicsType, commonParams.IacType, 1)
if len(limit) > 1 {
intLimit, err := strconv.Atoi(limit[1])
if err != nil {
log.Println("Error parsing threshold limit: ", err)
} else {
thresholdMap[engineName] = intLimit
}
engineName, intLimit, err := parseThresholdLimit(limits)
if err != nil {
log.Printf("%s", err)
} else {
thresholdMap[engineName] = intLimit
}
}
}

return thresholdMap
}

Expand Down Expand Up @@ -2495,6 +2496,35 @@ func validateCreateScanFlags(cmd *cobra.Command) error {
return nil
}

func validateThresholds(thresholdMap map[string]int) error {
var errMsgBuilder strings.Builder

for engineName, limit := range thresholdMap {
if limit < 1 {
errMsgBuilder.WriteString(errors.Errorf("Invalid value for threshold limit %s. Threshold should be greater or equal to 1.\n", engineName).Error())
}
}

errMsg := errMsgBuilder.String()
if errMsg != "" {
return errors.New(errMsg)
}
return nil
}

func parseThresholdLimit(limit string) (engineName string, intLimit int, err error) {
parts := strings.Split(limit, "=")
engineName = strings.Replace(parts[0], commonParams.KicsType, commonParams.IacType, 1)
OrShamirCM marked this conversation as resolved.
Show resolved Hide resolved
if len(parts) <= 1 {
return engineName, 0, errors.Errorf("Error parsing threshold limit: missing values\n")
}
intLimit, err = strconv.Atoi(parts[1])
if err != nil {
err = errors.Errorf("%s: Error parsing threshold limit: %v\n", engineName, err)
}
return engineName, intLimit, err
}

func validateBooleanString(value string) error {
if value == "" {
return nil
Expand Down
85 changes: 85 additions & 0 deletions internal/commands/scan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -678,3 +678,88 @@ func TestCreateScanProjectTagsCheckResendToScan(t *testing.T) {
err := executeTestCommand(cmd, baseArgs...)
assert.NilError(t, err)
}

func Test_parseThresholdLimit(t *testing.T) {
type args struct {
limit string
}
tests := []struct {
name string
args args
wantEngineName string
wantIntLimit int
wantErr bool
}{
{
name: "Test parseThresholdLimit with valid limit Success",
args: args{limit: "sast-low=1"},
wantEngineName: "sast-low",
wantIntLimit: 1,
wantErr: false,
},
{
name: "Test parseThresholdLimit with invalid limit Fail",
args: args{limit: "kics-medium=error"},
wantEngineName: "iac-security-medium",
wantIntLimit: 0,
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
gotEngineName, gotIntLimit, err := parseThresholdLimit(tt.args.limit)
if (err != nil) != tt.wantErr {
t.Errorf("parseThresholdLimit() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotEngineName != tt.wantEngineName {
t.Errorf("parseThresholdLimit() gotEngineName = %v, want %v", gotEngineName, tt.wantEngineName)
}
if gotIntLimit != tt.wantIntLimit {
t.Errorf("parseThresholdLimit() gotIntLimit = %v, want %v", gotIntLimit, tt.wantIntLimit)
}
})
}
}

func Test_validateThresholds(t *testing.T) {
tests := []struct {
name string
thresholdMap map[string]int
wantErr bool
}{
{
name: "Valid Thresholds",
thresholdMap: map[string]int{
"sast-medium": 5,
"sast-high": 10,
},
wantErr: false,
},
{
name: "Invalid Threshold - Negative Limit",
thresholdMap: map[string]int{
"sca-medium": -3,
},
wantErr: true,
},
{
name: "Invalid Threshold - Zero Limit",
thresholdMap: map[string]int{
"sca-high": 0,
},
wantErr: true,
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
err := validateThresholds(tt.thresholdMap)
if (err != nil) != tt.wantErr {
t.Errorf("validateThresholds() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
Loading