diff --git a/.golangci.yml b/.golangci.yml index f09da4e3c..e6c4cf24b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -9,6 +9,7 @@ linters-settings: - github.com/checkmarx/ast-cli/internal - github.com/gookit/color - github.com/CheckmarxDev/containers-resolver/pkg/containerResolver + - github.com/Checkmarx/gen-ai-prompts/prompts/sast_result_remediation - github.com/spf13/viper - github.com/checkmarxDev/gpt-wrapper - github.com/spf13/cobra diff --git a/go.mod b/go.mod index cf83f60bf..30f70c3f6 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/CheckmarxDev/containers-resolver v1.0.6 github.com/MakeNowJust/heredoc v1.0.0 github.com/checkmarxDev/gpt-wrapper v0.0.0-20230721160222-85da2fd1cc4c + github.com/Checkmarx/gen-ai-prompts v0.0.0-20240807143411-708ceec12b63 github.com/golang-jwt/jwt v3.2.2+incompatible github.com/gomarkdown/markdown v0.0.0-20230922112808-5421fefb8386 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 diff --git a/go.sum b/go.sum index 130b6239c..71103cfba 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,8 @@ github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbi github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/Checkmarx/gen-ai-prompts v0.0.0-20240807143411-708ceec12b63 h1:SCuTcE+CFvgjbIxUNL8rsdB2sAhfuNx85HvxImKta3g= +github.com/Checkmarx/gen-ai-prompts v0.0.0-20240807143411-708ceec12b63/go.mod h1:MI6lfLerXU+5eTV/EPTDavgnV3owz3GPT4g/msZBWPo= github.com/CheckmarxDev/containers-resolver v1.0.6 h1:Y0CKTR5tlw0YV+nQpz44kF0sZxWwCyvgYtjOukfYm0E= github.com/CheckmarxDev/containers-resolver v1.0.6/go.mod h1:S3m6qscOWqaJJw56hR/hZxBVdcZRn8AnRGU/6jtONI4= github.com/CycloneDX/cyclonedx-go v0.8.0 h1:FyWVj6x6hoJrui5uRQdYZcSievw3Z32Z88uYzG/0D6M= diff --git a/internal/commands/chat-sast.go b/internal/commands/chat-sast.go index 88ceb623f..d39a81d7c 100644 --- a/internal/commands/chat-sast.go +++ b/internal/commands/chat-sast.go @@ -4,6 +4,7 @@ import ( "fmt" "strconv" + sastchat "github.com/Checkmarx/gen-ai-prompts/prompts/sast_result_remediation" "github.com/checkmarx/ast-cli/internal/commands/util/printer" "github.com/checkmarx/ast-cli/internal/logger" "github.com/checkmarx/ast-cli/internal/params" @@ -17,8 +18,6 @@ import ( "github.com/spf13/cobra" ) -const ScanResultsFileErrorFormat = "Error reading and parsing scan results %s" -const CreatePromptErrorFormat = "Error creating prompt for result ID %s" const UserInputRequiredErrorFormat = "%s is required when %s is provided" const AiGuidedRemediationDisabledError = "The AI Guided Remediation is disabled in your tenant account" @@ -83,7 +82,7 @@ func runChatSast(chatWrapper wrappers.ChatWrapper, tenantWrapper wrappers.Tenant var newMessages []message.Message if newConversation { - systemPrompt, userPrompt, e := buildPrompt(scanResultsFile, sastResultID, sourceDir) + systemPrompt, userPrompt, e := sastchat.BuildPrompt(scanResultsFile, sastResultID, sourceDir) if e != nil { logger.PrintIfVerbose(e.Error()) return outputError(cmd, id, e) @@ -109,7 +108,7 @@ func runChatSast(chatWrapper wrappers.ChatWrapper, tenantWrapper wrappers.Tenant responseContent := getMessageContents(response) - responseContent = addDescriptionForIdentifier(responseContent) + responseContent = sastchat.AddDescriptionForIdentifier(responseContent) return printer.Print(cmd.OutOrStdout(), &OutputModel{ ConversationID: id.String(), @@ -137,34 +136,6 @@ func isAiGuidedRemediationEnabled(tenantWrapper wrappers.TenantConfigurationWrap return false } -func buildPrompt(scanResultsFile, sastResultID, sourceDir string) (systemPrompt, userPrompt string, err error) { - scanResults, err := ReadResultsSAST(scanResultsFile) - if err != nil { - return "", "", fmt.Errorf("error in build-prompt: %s: %w", fmt.Sprintf(ScanResultsFileErrorFormat, scanResultsFile), err) - } - - if sastResultID == "" { - return "", "", errors.Errorf(fmt.Sprintf("error in build-prompt: currently only --%s is supported", params.ChatSastResultID)) - } - - sastResult, err := GetResultByID(scanResults, sastResultID) - if err != nil { - return "", "", fmt.Errorf("error in build-prompt: %w", err) - } - - sources, err := GetSourcesForResult(sastResult, sourceDir) - if err != nil { - return "", "", fmt.Errorf("error in build-prompt: %w", err) - } - - prompt, err := CreateUserPrompt(sastResult, sources) - if err != nil { - return "", "", fmt.Errorf("error in build-prompt: %s: %w", fmt.Sprintf(CreatePromptErrorFormat, sastResultID), err) - } - - return GetSystemPrompt(), prompt, nil -} - func getMessageContents(response []message.Message) []string { var responseContent []string for _, r := range response { diff --git a/internal/commands/chat-sast_test.go b/internal/commands/chat-sast_test.go index 6673a8e34..6161d0114 100644 --- a/internal/commands/chat-sast_test.go +++ b/internal/commands/chat-sast_test.go @@ -46,6 +46,8 @@ func TestChatSastNoUserInput(t *testing.T) { } func TestChatSastInvalidScanResultsFile(t *testing.T) { + const ScanResultsFileErrorFormat = "Error reading and parsing SAST results file '%s'" + buffer, err := executeRedirectedTestCommand("chat", "sast", "--chat-apikey", "apiKey", "--scan-results-file", "invalidFile", diff --git a/internal/commands/sast-prompt.go b/internal/commands/sast-prompt.go deleted file mode 100644 index f6f85dd55..000000000 --- a/internal/commands/sast-prompt.go +++ /dev/null @@ -1,168 +0,0 @@ -package commands - -import ( - "fmt" - "strings" -) - -const systemPrompt = `You are the Checkmarx AI Guided Remediation bot who can answer technical questions related to the results of Checkmarx Static Application -Security Testing (SAST). You should be able to analyze and understand both the technical aspects of the security results and the common queries users may have -about the results. You should also be capable of delivering clear, concise, and informative answers to help take appropriate action based on the findings. -If a question irrelevant to the mentioned source code or SAST result is asked, answer 'I am the AI Guided Remediation assistant and can answer only on questions -related to source code or SAST results or SAST Queries'.` - -const ( - confidence = "**CONFIDENCE:**" - explanation = "**EXPLANATION:**" - fix = "**PROPOSED REMEDIATION:**" - code = "```" -) - -const ( - confidenceDescription = " A score between 0 (low) and 100 (high) indicating the degree of confidence in the exploitability of this vulnerability in the context of your code.
" - explanationDescription = " An OpenAI generated description of the vulnerability.
" - fixDescription = " A customized snippet, generated by OpenAI, that can be used to remediate the vulnerability in your code.
" -) - -// This constant is used to format the identifiers (confidence, explanation, fix) and their descriptions with HTML tags -const identifierTitleForamt = "%s%s" - -const userPromptTemplate = `Checkmarx Static Application Security Testing (SAST) detected the %s vulnerability within the provided %s code snippet. -The attack vector is presented by code snippets annotated by comments in the form ` + "`//SAST Node #X: element (element-type)`" + ` where X is -the node index in the result, ` + "`element`" + ` is the name of the element through which the data flows, and the ` + "`element-type`" + ` is it's type. -The first and last nodes are indicated by ` + "`(input ...)` and `(output ...)`" + ` respectively: -` + code + ` -%s -` + code + ` -Please review the code above and provide a confidence score ranging from 0 to 100. -A score of 0 means you believe the result is completely incorrect, unexploitable, and a false positive. -A score of 100 means you believe the result is completely correct, exploitable, and a true positive. - -Instructions for confidence score computation: - -1. The confidence score of a vulnerability which can be done from the Internet is much higher than from the local console. -2. The confidence score of a vulnerability which can be done by anonymous user is much higher than of an authenticated user. -3. The confidence score of a vulnerability with a vector starting with a stored input (like from files/db etc) cannot be more than 50. -This is also known as a second-order vulnerability -4. Pay your special attention to the first and last code snippet - whether a specific vulnerability found by Checkmarx SAST can start/occur here, -or it's a false positive. -5. If you don't find enough evidence about a vulnerability, just lower the score. -6. If you are not sure, just lower the confidence - we don't want to have false positive results with a high confidence score. - -Please provide a brief explanation for your confidence score, don't mention all the instruction above. - -Next, please provide code that remediates the vulnerability so that a developer can copy paste instead of the snippet above. - -Your analysis MUST be presented in the following format: -` + confidence + - `number -` + "\n" + explanation + - `short_text -` + "\n" + fix + ":" + - `fixed_snippet` - -func GetSystemPrompt() string { - return systemPrompt -} - -func CreateUserPrompt(result *Result, sources map[string][]string) (string, error) { - promptSource, err := createSourceForPrompt(result, sources) - if err != nil { - return "", err - } - return fmt.Sprintf(userPromptTemplate, result.Data.QueryName, result.Data.LanguageName, promptSource), nil -} - -func createSourceForPrompt(result *Result, sources map[string][]string) (string, error) { - var sourcePrompt []string - methodsInPrompt := make(map[string][]string) - for i := range result.Data.Nodes { - node := result.Data.Nodes[i] - sourceFilename := strings.ReplaceAll(node.FileName, "\\", "/") - methodLines, exists := methodsInPrompt[sourceFilename+":"+node.Method] - if !exists { - m, err := GetMethodByMethodLine(sourceFilename, sources[sourceFilename], node.MethodLine, node.Line, false) - methodLines = m - if err != nil { - return "", fmt.Errorf("error getting method %s: %v", node.Method, err) - } - } else if len(methodLines) < node.Line-node.MethodLine+1 { - m, err := GetMethodByMethodLine(sourceFilename, sources[sourceFilename], node.MethodLine, node.Line, true) - methodLines = m - if err != nil { - return "", fmt.Errorf("error getting method %s: %v", node.Method, err) - } - } - lineInMethod := node.Line - node.MethodLine - var edge string - if i == 0 { - edge = " (input)" - } else if i == len(result.Data.Nodes)-1 { - edge = " (output)" - } else { - edge = "" - } - - // change UnknownReference to something more informational like VariableReference or TypeNameReference - nodeType := node.DomType - if node.DomType == "UnknownReference" { - if node.TypeName == "" { - nodeType = "VariableReference" - } else { - nodeType = node.TypeName + "Reference" - } - } - methodLines[lineInMethod] += fmt.Sprintf("//SAST Node #%d%s: %s (%s)", i, edge, node.Name, nodeType) - methodsInPrompt[sourceFilename+":"+node.Method] = methodLines - } - - for _, methodLines := range methodsInPrompt { - methodLines = append(methodLines, "// method continues ...") - sourcePrompt = append(sourcePrompt, methodLines...) - } - - return strings.Join(sourcePrompt, "\n"), nil -} - -func GetMethodByMethodLine(filename string, lines []string, methodLineNumber, nodeLineNumber int, tagged bool) ([]string, error) { - if methodLineNumber < 1 || methodLineNumber > len(lines) { - return nil, fmt.Errorf("method line number %d is out of range", methodLineNumber) - } - - if nodeLineNumber < 1 || nodeLineNumber > len(lines) { - return nil, fmt.Errorf("node line number %d is out of range", nodeLineNumber) - } - - if nodeLineNumber < methodLineNumber { - return nil, fmt.Errorf("node line number %d is less than method line number %d", nodeLineNumber, methodLineNumber) - } - - // Adjust line number to 0-based index for slice access - startIndex := methodLineNumber - 1 - numberOfLines := nodeLineNumber - methodLineNumber + 1 - methodLines := lines[startIndex : startIndex+numberOfLines] - if !tagged { - methodLines[0] += fmt.Sprintf("// %s:%d", filename, methodLineNumber) - } - return methodLines, nil -} - -func addDescriptionForIdentifier(responseContent []string) []string { - identifiersDescription := map[string]string{ - confidence: confidenceDescription, - explanation: explanationDescription, - fix: fixDescription, - } - if len(responseContent) > 0 { - for i := 0; i < len(responseContent); i++ { - for identifier, description := range identifiersDescription { - responseContent[i] = replaceIdentifierTitleIfNeeded(responseContent[i], identifier, description) - } - } - } - return responseContent -} - -func replaceIdentifierTitleIfNeeded(input, identifier, identifierDescription string) string { - return strings.Replace(input, identifier, fmt.Sprintf(identifierTitleForamt, identifier, identifierDescription), 1) -} diff --git a/internal/commands/sast-prompt_test.go b/internal/commands/sast-prompt_test.go deleted file mode 100644 index 9af903994..000000000 --- a/internal/commands/sast-prompt_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package commands - -import ( - "fmt" - "testing" -) - -const expectedOutputFormat = "**CONFIDENCE:** " + - "A score between 0 (low) and 100 (high) indicating the degree of confidence in the exploitability of this vulnerability in the context of your code. " + - "
%s**EXPLANATION:** " + - "An OpenAI generated description of the vulnerability.
%s**PROPOSED REMEDIATION:** " + - "A customized snippet, generated by OpenAI, that can be used to remediate the vulnerability in your code.
%s" - -func getExpectedOutput(confidenceNumber, explanationText, fixText string) string { - return fmt.Sprintf(expectedOutputFormat, confidenceNumber, explanationText, fixText) -} - -func TestAddDescriptionForIdentifiers(t *testing.T) { - input := confidence + " 35 " + explanation + " this is a short explanation." + fix + " a fixed snippet" - expected := getExpectedOutput(" 35 ", " this is a short explanation.", " a fixed snippet") - output := getActual(input, t) - - if output[len(output)-1] != expected { - t.Errorf("Expected %q, but got %q", expected, output) - } -} - -func TestAddNewlinesIfNecessarySomeNewlines(t *testing.T) { - input := confidence + " 35 " + explanation + " this is a short explanation.\n" + fix + " a fixed snippet" - expected := getExpectedOutput(" 35 ", " this is a short explanation.\n", " a fixed snippet") - - output := getActual(input, t) - - if output[len(output)-1] != expected { - t.Errorf("Expected %q, but got %q", expected, output) - } -} - -func TestAddNewlinesIfNecessaryAllNewlines(t *testing.T) { - input := confidence + " 35\n " + explanation + " this is a short explanation.\n" + fix + " a fixed snippet" - expected := getExpectedOutput(" 35\n ", " this is a short explanation.\n", " a fixed snippet") - - output := getActual(input, t) - - if output[len(output)-1] != expected { - t.Errorf("Expected %q, but got %q", expected, output) - } -} - -func getActual(input string, t *testing.T) []string { - someText := "some text" - response := []string{someText, someText, input} - output := addDescriptionForIdentifier(response) - for i := 0; i < len(output)-1; i++ { - if output[i] != response[i] { - t.Errorf("All strings except last expected to stay the same") - } - } - return output -} diff --git a/internal/commands/sast-results-json.go b/internal/commands/sast-results-json.go deleted file mode 100644 index 8c8e64963..000000000 --- a/internal/commands/sast-results-json.go +++ /dev/null @@ -1,150 +0,0 @@ -package commands - -import ( - "encoding/json" - "fmt" - "os" -) - -// Define the Go structs that match the JSON structure -type ResultPartial struct { - Type string `json:"type"` - Label string `json:"label"` - ID string `json:"id"` - SimilarityID string `json:"similarityId"` - Status string `json:"status"` - State string `json:"state"` - Severity string `json:"severity"` - Created string `json:"created"` - FirstFoundAt string `json:"firstFoundAt"` - FoundAt string `json:"foundAt"` - FirstScanID string `json:"firstScanId"` - Description string `json:"description"` - DescriptionHTML string `json:"descriptionHTML"` - Data json.RawMessage `json:"data"` - Comments map[string]interface{} `json:"comments"` - VulnerabilityDetails json.RawMessage `json:"vulnerabilityDetails"` -} -type Result struct { - Type string `json:"type"` - Label string `json:"label"` - ID string `json:"id"` - SimilarityID string `json:"similarityId"` - Status string `json:"status"` - State string `json:"state"` - Severity string `json:"severity"` - Created string `json:"created"` - FirstFoundAt string `json:"firstFoundAt"` - FoundAt string `json:"foundAt"` - FirstScanID string `json:"firstScanId"` - Description string `json:"description"` - DescriptionHTML string `json:"descriptionHTML"` - Data Data `json:"data"` - Comments map[string]interface{} `json:"comments"` - VulnerabilityDetails VulnerabilityDetails `json:"vulnerabilityDetails"` -} - -type Data struct { - QueryID uint64 `json:"queryId"` - QueryName string `json:"queryName"` - Group string `json:"group"` - ResultHash string `json:"resultHash"` - LanguageName string `json:"languageName"` - Nodes []Node `json:"nodes"` -} - -type Node struct { - ID string `json:"id"` - Line int `json:"line"` - Name string `json:"name"` - Column int `json:"column"` - Length int `json:"length"` - Method string `json:"method"` - NodeID int `json:"nodeID"` - DomType string `json:"domType"` - FileName string `json:"fileName"` - FullName string `json:"fullName"` - TypeName string `json:"typeName"` - MethodLine int `json:"methodLine"` - Definitions string `json:"definitions"` -} - -type VulnerabilityDetails struct { - CweID int `json:"cweId"` - Cvss map[string]interface{} `json:"cvss"` - Compliances []string `json:"compliances"` -} - -type ScanResultsPartial struct { - Results json.RawMessage `json:"results"` - TotalCount int `json:"totalCount"` - ScanID string `json:"scanID"` -} - -type ScanResults struct { - Results []*Result `json:"results"` - TotalCount int `json:"totalCount"` - ScanID string `json:"scanID"` -} - -func ReadResultsSAST(filename string) (*ScanResults, error) { - bytes, err := os.ReadFile(filename) - if err != nil { - return nil, err - } - - // Unmarshal the JSON data into the ScanResults struct - var scanResultsPartial ScanResultsPartial - if err := json.Unmarshal(bytes, &scanResultsPartial); err != nil { - return nil, err - } - - var results []*Result - var resultsPartial []*ResultPartial - if err := json.Unmarshal(scanResultsPartial.Results, &resultsPartial); err != nil { - return nil, err - } - - for _, resultPartial := range resultsPartial { - if resultPartial.Type != "sast" { - continue - } - var data Data - if err := json.Unmarshal(resultPartial.Data, &data); err != nil { - return nil, err - } - var vulnerabilityDetails VulnerabilityDetails - if err := json.Unmarshal(resultPartial.VulnerabilityDetails, &vulnerabilityDetails); err != nil { - return nil, err - } - - result := &Result{resultPartial.Type, - resultPartial.Label, - resultPartial.ID, - resultPartial.SimilarityID, - resultPartial.Status, - resultPartial.State, - resultPartial.Severity, - resultPartial.Created, - resultPartial.FirstFoundAt, - resultPartial.FoundAt, - resultPartial.FirstScanID, - resultPartial.Description, - resultPartial.DescriptionHTML, - data, - resultPartial.Comments, - vulnerabilityDetails} - results = append(results, result) - } - scanResults := ScanResults{results, scanResultsPartial.TotalCount, scanResultsPartial.ScanID} - return &scanResults, nil -} - -func GetResultByID(results *ScanResults, resultID string) (*Result, error) { - for _, result := range results.Results { - if result.ID == resultID { - return result, nil - } - } - return &Result{}, fmt.Errorf("result ID %s not found", resultID) -} diff --git a/internal/commands/sast-sources.go b/internal/commands/sast-sources.go deleted file mode 100644 index ece464391..000000000 --- a/internal/commands/sast-sources.go +++ /dev/null @@ -1,54 +0,0 @@ -package commands - -import ( - "bufio" - "os" - "path/filepath" - "strings" -) - -func GetSourcesForResult(scanResult *Result, sourceDir string) (map[string][]string, error) { - sourceFilenames := make(map[string]bool) - for i := range scanResult.Data.Nodes { - sourceFilename := strings.ReplaceAll(scanResult.Data.Nodes[i].FileName, "\\", "/") - sourceFilenames[sourceFilename] = true - } - - fileContents, err := GetFileContents(sourceFilenames, sourceDir) - if err != nil { - return nil, err - } - - return fileContents, nil -} - -func GetFileContents(filenames map[string]bool, sourceDir string) (map[string][]string, error) { - fileContents := make(map[string][]string) - - for filename := range filenames { - sourceFilename := filepath.Join(sourceDir, filename) - file, err := os.Open(sourceFilename) - if err != nil { - return nil, err - } - - scanner := bufio.NewScanner(file) - var lines []string - for scanner.Scan() { - lines = append(lines, scanner.Text()) - } - - err = file.Close() - if err != nil { - return nil, err - } - - if err := scanner.Err(); err != nil { - return nil, err - } - - fileContents[filename] = lines - } - - return fileContents, nil -}