diff --git a/.golangci.yml b/.golangci.yml
index f09da4e3c..e6c4cf24b 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -9,6 +9,7 @@ linters-settings:
- github.com/checkmarx/ast-cli/internal
- github.com/gookit/color
- github.com/CheckmarxDev/containers-resolver/pkg/containerResolver
+ - github.com/Checkmarx/gen-ai-prompts/prompts/sast_result_remediation
- github.com/spf13/viper
- github.com/checkmarxDev/gpt-wrapper
- github.com/spf13/cobra
diff --git a/go.mod b/go.mod
index cf83f60bf..30f70c3f6 100644
--- a/go.mod
+++ b/go.mod
@@ -6,6 +6,7 @@ require (
github.com/CheckmarxDev/containers-resolver v1.0.6
github.com/MakeNowJust/heredoc v1.0.0
github.com/checkmarxDev/gpt-wrapper v0.0.0-20230721160222-85da2fd1cc4c
+ github.com/Checkmarx/gen-ai-prompts v0.0.0-20240807143411-708ceec12b63
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/gomarkdown/markdown v0.0.0-20230922112808-5421fefb8386
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
diff --git a/go.sum b/go.sum
index 130b6239c..71103cfba 100644
--- a/go.sum
+++ b/go.sum
@@ -60,6 +60,8 @@ github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbi
github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8=
github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
+github.com/Checkmarx/gen-ai-prompts v0.0.0-20240807143411-708ceec12b63 h1:SCuTcE+CFvgjbIxUNL8rsdB2sAhfuNx85HvxImKta3g=
+github.com/Checkmarx/gen-ai-prompts v0.0.0-20240807143411-708ceec12b63/go.mod h1:MI6lfLerXU+5eTV/EPTDavgnV3owz3GPT4g/msZBWPo=
github.com/CheckmarxDev/containers-resolver v1.0.6 h1:Y0CKTR5tlw0YV+nQpz44kF0sZxWwCyvgYtjOukfYm0E=
github.com/CheckmarxDev/containers-resolver v1.0.6/go.mod h1:S3m6qscOWqaJJw56hR/hZxBVdcZRn8AnRGU/6jtONI4=
github.com/CycloneDX/cyclonedx-go v0.8.0 h1:FyWVj6x6hoJrui5uRQdYZcSievw3Z32Z88uYzG/0D6M=
diff --git a/internal/commands/chat-sast.go b/internal/commands/chat-sast.go
index 88ceb623f..d39a81d7c 100644
--- a/internal/commands/chat-sast.go
+++ b/internal/commands/chat-sast.go
@@ -4,6 +4,7 @@ import (
"fmt"
"strconv"
+ sastchat "github.com/Checkmarx/gen-ai-prompts/prompts/sast_result_remediation"
"github.com/checkmarx/ast-cli/internal/commands/util/printer"
"github.com/checkmarx/ast-cli/internal/logger"
"github.com/checkmarx/ast-cli/internal/params"
@@ -17,8 +18,6 @@ import (
"github.com/spf13/cobra"
)
-const ScanResultsFileErrorFormat = "Error reading and parsing scan results %s"
-const CreatePromptErrorFormat = "Error creating prompt for result ID %s"
const UserInputRequiredErrorFormat = "%s is required when %s is provided"
const AiGuidedRemediationDisabledError = "The AI Guided Remediation is disabled in your tenant account"
@@ -83,7 +82,7 @@ func runChatSast(chatWrapper wrappers.ChatWrapper, tenantWrapper wrappers.Tenant
var newMessages []message.Message
if newConversation {
- systemPrompt, userPrompt, e := buildPrompt(scanResultsFile, sastResultID, sourceDir)
+ systemPrompt, userPrompt, e := sastchat.BuildPrompt(scanResultsFile, sastResultID, sourceDir)
if e != nil {
logger.PrintIfVerbose(e.Error())
return outputError(cmd, id, e)
@@ -109,7 +108,7 @@ func runChatSast(chatWrapper wrappers.ChatWrapper, tenantWrapper wrappers.Tenant
responseContent := getMessageContents(response)
- responseContent = addDescriptionForIdentifier(responseContent)
+ responseContent = sastchat.AddDescriptionForIdentifier(responseContent)
return printer.Print(cmd.OutOrStdout(), &OutputModel{
ConversationID: id.String(),
@@ -137,34 +136,6 @@ func isAiGuidedRemediationEnabled(tenantWrapper wrappers.TenantConfigurationWrap
return false
}
-func buildPrompt(scanResultsFile, sastResultID, sourceDir string) (systemPrompt, userPrompt string, err error) {
- scanResults, err := ReadResultsSAST(scanResultsFile)
- if err != nil {
- return "", "", fmt.Errorf("error in build-prompt: %s: %w", fmt.Sprintf(ScanResultsFileErrorFormat, scanResultsFile), err)
- }
-
- if sastResultID == "" {
- return "", "", errors.Errorf(fmt.Sprintf("error in build-prompt: currently only --%s is supported", params.ChatSastResultID))
- }
-
- sastResult, err := GetResultByID(scanResults, sastResultID)
- if err != nil {
- return "", "", fmt.Errorf("error in build-prompt: %w", err)
- }
-
- sources, err := GetSourcesForResult(sastResult, sourceDir)
- if err != nil {
- return "", "", fmt.Errorf("error in build-prompt: %w", err)
- }
-
- prompt, err := CreateUserPrompt(sastResult, sources)
- if err != nil {
- return "", "", fmt.Errorf("error in build-prompt: %s: %w", fmt.Sprintf(CreatePromptErrorFormat, sastResultID), err)
- }
-
- return GetSystemPrompt(), prompt, nil
-}
-
func getMessageContents(response []message.Message) []string {
var responseContent []string
for _, r := range response {
diff --git a/internal/commands/chat-sast_test.go b/internal/commands/chat-sast_test.go
index 6673a8e34..6161d0114 100644
--- a/internal/commands/chat-sast_test.go
+++ b/internal/commands/chat-sast_test.go
@@ -46,6 +46,8 @@ func TestChatSastNoUserInput(t *testing.T) {
}
func TestChatSastInvalidScanResultsFile(t *testing.T) {
+ const ScanResultsFileErrorFormat = "Error reading and parsing SAST results file '%s'"
+
buffer, err := executeRedirectedTestCommand("chat", "sast",
"--chat-apikey", "apiKey",
"--scan-results-file", "invalidFile",
diff --git a/internal/commands/sast-prompt.go b/internal/commands/sast-prompt.go
deleted file mode 100644
index f6f85dd55..000000000
--- a/internal/commands/sast-prompt.go
+++ /dev/null
@@ -1,168 +0,0 @@
-package commands
-
-import (
- "fmt"
- "strings"
-)
-
-const systemPrompt = `You are the Checkmarx AI Guided Remediation bot who can answer technical questions related to the results of Checkmarx Static Application
-Security Testing (SAST). You should be able to analyze and understand both the technical aspects of the security results and the common queries users may have
-about the results. You should also be capable of delivering clear, concise, and informative answers to help take appropriate action based on the findings.
-If a question irrelevant to the mentioned source code or SAST result is asked, answer 'I am the AI Guided Remediation assistant and can answer only on questions
-related to source code or SAST results or SAST Queries'.`
-
-const (
- confidence = "**CONFIDENCE:**"
- explanation = "**EXPLANATION:**"
- fix = "**PROPOSED REMEDIATION:**"
- code = "```"
-)
-
-const (
- confidenceDescription = " A score between 0 (low) and 100 (high) indicating the degree of confidence in the exploitability of this vulnerability in the context of your code.
"
- explanationDescription = " An OpenAI generated description of the vulnerability.
"
- fixDescription = " A customized snippet, generated by OpenAI, that can be used to remediate the vulnerability in your code.
"
-)
-
-// This constant is used to format the identifiers (confidence, explanation, fix) and their descriptions with HTML tags
-const identifierTitleForamt = "%s%s"
-
-const userPromptTemplate = `Checkmarx Static Application Security Testing (SAST) detected the %s vulnerability within the provided %s code snippet.
-The attack vector is presented by code snippets annotated by comments in the form ` + "`//SAST Node #X: element (element-type)`" + ` where X is
-the node index in the result, ` + "`element`" + ` is the name of the element through which the data flows, and the ` + "`element-type`" + ` is it's type.
-The first and last nodes are indicated by ` + "`(input ...)` and `(output ...)`" + ` respectively:
-` + code + `
-%s
-` + code + `
-Please review the code above and provide a confidence score ranging from 0 to 100.
-A score of 0 means you believe the result is completely incorrect, unexploitable, and a false positive.
-A score of 100 means you believe the result is completely correct, exploitable, and a true positive.
-
-Instructions for confidence score computation:
-
-1. The confidence score of a vulnerability which can be done from the Internet is much higher than from the local console.
-2. The confidence score of a vulnerability which can be done by anonymous user is much higher than of an authenticated user.
-3. The confidence score of a vulnerability with a vector starting with a stored input (like from files/db etc) cannot be more than 50.
-This is also known as a second-order vulnerability
-4. Pay your special attention to the first and last code snippet - whether a specific vulnerability found by Checkmarx SAST can start/occur here,
-or it's a false positive.
-5. If you don't find enough evidence about a vulnerability, just lower the score.
-6. If you are not sure, just lower the confidence - we don't want to have false positive results with a high confidence score.
-
-Please provide a brief explanation for your confidence score, don't mention all the instruction above.
-
-Next, please provide code that remediates the vulnerability so that a developer can copy paste instead of the snippet above.
-
-Your analysis MUST be presented in the following format:
-` + confidence +
- `number
-` + "\n" + explanation +
- `short_text
-` + "\n" + fix + ":" +
- `fixed_snippet`
-
-func GetSystemPrompt() string {
- return systemPrompt
-}
-
-func CreateUserPrompt(result *Result, sources map[string][]string) (string, error) {
- promptSource, err := createSourceForPrompt(result, sources)
- if err != nil {
- return "", err
- }
- return fmt.Sprintf(userPromptTemplate, result.Data.QueryName, result.Data.LanguageName, promptSource), nil
-}
-
-func createSourceForPrompt(result *Result, sources map[string][]string) (string, error) {
- var sourcePrompt []string
- methodsInPrompt := make(map[string][]string)
- for i := range result.Data.Nodes {
- node := result.Data.Nodes[i]
- sourceFilename := strings.ReplaceAll(node.FileName, "\\", "/")
- methodLines, exists := methodsInPrompt[sourceFilename+":"+node.Method]
- if !exists {
- m, err := GetMethodByMethodLine(sourceFilename, sources[sourceFilename], node.MethodLine, node.Line, false)
- methodLines = m
- if err != nil {
- return "", fmt.Errorf("error getting method %s: %v", node.Method, err)
- }
- } else if len(methodLines) < node.Line-node.MethodLine+1 {
- m, err := GetMethodByMethodLine(sourceFilename, sources[sourceFilename], node.MethodLine, node.Line, true)
- methodLines = m
- if err != nil {
- return "", fmt.Errorf("error getting method %s: %v", node.Method, err)
- }
- }
- lineInMethod := node.Line - node.MethodLine
- var edge string
- if i == 0 {
- edge = " (input)"
- } else if i == len(result.Data.Nodes)-1 {
- edge = " (output)"
- } else {
- edge = ""
- }
-
- // change UnknownReference to something more informational like VariableReference or TypeNameReference
- nodeType := node.DomType
- if node.DomType == "UnknownReference" {
- if node.TypeName == "" {
- nodeType = "VariableReference"
- } else {
- nodeType = node.TypeName + "Reference"
- }
- }
- methodLines[lineInMethod] += fmt.Sprintf("//SAST Node #%d%s: %s (%s)", i, edge, node.Name, nodeType)
- methodsInPrompt[sourceFilename+":"+node.Method] = methodLines
- }
-
- for _, methodLines := range methodsInPrompt {
- methodLines = append(methodLines, "// method continues ...")
- sourcePrompt = append(sourcePrompt, methodLines...)
- }
-
- return strings.Join(sourcePrompt, "\n"), nil
-}
-
-func GetMethodByMethodLine(filename string, lines []string, methodLineNumber, nodeLineNumber int, tagged bool) ([]string, error) {
- if methodLineNumber < 1 || methodLineNumber > len(lines) {
- return nil, fmt.Errorf("method line number %d is out of range", methodLineNumber)
- }
-
- if nodeLineNumber < 1 || nodeLineNumber > len(lines) {
- return nil, fmt.Errorf("node line number %d is out of range", nodeLineNumber)
- }
-
- if nodeLineNumber < methodLineNumber {
- return nil, fmt.Errorf("node line number %d is less than method line number %d", nodeLineNumber, methodLineNumber)
- }
-
- // Adjust line number to 0-based index for slice access
- startIndex := methodLineNumber - 1
- numberOfLines := nodeLineNumber - methodLineNumber + 1
- methodLines := lines[startIndex : startIndex+numberOfLines]
- if !tagged {
- methodLines[0] += fmt.Sprintf("// %s:%d", filename, methodLineNumber)
- }
- return methodLines, nil
-}
-
-func addDescriptionForIdentifier(responseContent []string) []string {
- identifiersDescription := map[string]string{
- confidence: confidenceDescription,
- explanation: explanationDescription,
- fix: fixDescription,
- }
- if len(responseContent) > 0 {
- for i := 0; i < len(responseContent); i++ {
- for identifier, description := range identifiersDescription {
- responseContent[i] = replaceIdentifierTitleIfNeeded(responseContent[i], identifier, description)
- }
- }
- }
- return responseContent
-}
-
-func replaceIdentifierTitleIfNeeded(input, identifier, identifierDescription string) string {
- return strings.Replace(input, identifier, fmt.Sprintf(identifierTitleForamt, identifier, identifierDescription), 1)
-}
diff --git a/internal/commands/sast-prompt_test.go b/internal/commands/sast-prompt_test.go
deleted file mode 100644
index 9af903994..000000000
--- a/internal/commands/sast-prompt_test.go
+++ /dev/null
@@ -1,61 +0,0 @@
-package commands
-
-import (
- "fmt"
- "testing"
-)
-
-const expectedOutputFormat = "**CONFIDENCE:** " +
- "A score between 0 (low) and 100 (high) indicating the degree of confidence in the exploitability of this vulnerability in the context of your code. " +
- "
%s**EXPLANATION:** " +
- "An OpenAI generated description of the vulnerability.
%s**PROPOSED REMEDIATION:** " +
- "A customized snippet, generated by OpenAI, that can be used to remediate the vulnerability in your code.
%s"
-
-func getExpectedOutput(confidenceNumber, explanationText, fixText string) string {
- return fmt.Sprintf(expectedOutputFormat, confidenceNumber, explanationText, fixText)
-}
-
-func TestAddDescriptionForIdentifiers(t *testing.T) {
- input := confidence + " 35 " + explanation + " this is a short explanation." + fix + " a fixed snippet"
- expected := getExpectedOutput(" 35 ", " this is a short explanation.", " a fixed snippet")
- output := getActual(input, t)
-
- if output[len(output)-1] != expected {
- t.Errorf("Expected %q, but got %q", expected, output)
- }
-}
-
-func TestAddNewlinesIfNecessarySomeNewlines(t *testing.T) {
- input := confidence + " 35 " + explanation + " this is a short explanation.\n" + fix + " a fixed snippet"
- expected := getExpectedOutput(" 35 ", " this is a short explanation.\n", " a fixed snippet")
-
- output := getActual(input, t)
-
- if output[len(output)-1] != expected {
- t.Errorf("Expected %q, but got %q", expected, output)
- }
-}
-
-func TestAddNewlinesIfNecessaryAllNewlines(t *testing.T) {
- input := confidence + " 35\n " + explanation + " this is a short explanation.\n" + fix + " a fixed snippet"
- expected := getExpectedOutput(" 35\n ", " this is a short explanation.\n", " a fixed snippet")
-
- output := getActual(input, t)
-
- if output[len(output)-1] != expected {
- t.Errorf("Expected %q, but got %q", expected, output)
- }
-}
-
-func getActual(input string, t *testing.T) []string {
- someText := "some text"
- response := []string{someText, someText, input}
- output := addDescriptionForIdentifier(response)
- for i := 0; i < len(output)-1; i++ {
- if output[i] != response[i] {
- t.Errorf("All strings except last expected to stay the same")
- }
- }
- return output
-}
diff --git a/internal/commands/sast-results-json.go b/internal/commands/sast-results-json.go
deleted file mode 100644
index 8c8e64963..000000000
--- a/internal/commands/sast-results-json.go
+++ /dev/null
@@ -1,150 +0,0 @@
-package commands
-
-import (
- "encoding/json"
- "fmt"
- "os"
-)
-
-// Define the Go structs that match the JSON structure
-type ResultPartial struct {
- Type string `json:"type"`
- Label string `json:"label"`
- ID string `json:"id"`
- SimilarityID string `json:"similarityId"`
- Status string `json:"status"`
- State string `json:"state"`
- Severity string `json:"severity"`
- Created string `json:"created"`
- FirstFoundAt string `json:"firstFoundAt"`
- FoundAt string `json:"foundAt"`
- FirstScanID string `json:"firstScanId"`
- Description string `json:"description"`
- DescriptionHTML string `json:"descriptionHTML"`
- Data json.RawMessage `json:"data"`
- Comments map[string]interface{} `json:"comments"`
- VulnerabilityDetails json.RawMessage `json:"vulnerabilityDetails"`
-}
-type Result struct {
- Type string `json:"type"`
- Label string `json:"label"`
- ID string `json:"id"`
- SimilarityID string `json:"similarityId"`
- Status string `json:"status"`
- State string `json:"state"`
- Severity string `json:"severity"`
- Created string `json:"created"`
- FirstFoundAt string `json:"firstFoundAt"`
- FoundAt string `json:"foundAt"`
- FirstScanID string `json:"firstScanId"`
- Description string `json:"description"`
- DescriptionHTML string `json:"descriptionHTML"`
- Data Data `json:"data"`
- Comments map[string]interface{} `json:"comments"`
- VulnerabilityDetails VulnerabilityDetails `json:"vulnerabilityDetails"`
-}
-
-type Data struct {
- QueryID uint64 `json:"queryId"`
- QueryName string `json:"queryName"`
- Group string `json:"group"`
- ResultHash string `json:"resultHash"`
- LanguageName string `json:"languageName"`
- Nodes []Node `json:"nodes"`
-}
-
-type Node struct {
- ID string `json:"id"`
- Line int `json:"line"`
- Name string `json:"name"`
- Column int `json:"column"`
- Length int `json:"length"`
- Method string `json:"method"`
- NodeID int `json:"nodeID"`
- DomType string `json:"domType"`
- FileName string `json:"fileName"`
- FullName string `json:"fullName"`
- TypeName string `json:"typeName"`
- MethodLine int `json:"methodLine"`
- Definitions string `json:"definitions"`
-}
-
-type VulnerabilityDetails struct {
- CweID int `json:"cweId"`
- Cvss map[string]interface{} `json:"cvss"`
- Compliances []string `json:"compliances"`
-}
-
-type ScanResultsPartial struct {
- Results json.RawMessage `json:"results"`
- TotalCount int `json:"totalCount"`
- ScanID string `json:"scanID"`
-}
-
-type ScanResults struct {
- Results []*Result `json:"results"`
- TotalCount int `json:"totalCount"`
- ScanID string `json:"scanID"`
-}
-
-func ReadResultsSAST(filename string) (*ScanResults, error) {
- bytes, err := os.ReadFile(filename)
- if err != nil {
- return nil, err
- }
-
- // Unmarshal the JSON data into the ScanResults struct
- var scanResultsPartial ScanResultsPartial
- if err := json.Unmarshal(bytes, &scanResultsPartial); err != nil {
- return nil, err
- }
-
- var results []*Result
- var resultsPartial []*ResultPartial
- if err := json.Unmarshal(scanResultsPartial.Results, &resultsPartial); err != nil {
- return nil, err
- }
-
- for _, resultPartial := range resultsPartial {
- if resultPartial.Type != "sast" {
- continue
- }
- var data Data
- if err := json.Unmarshal(resultPartial.Data, &data); err != nil {
- return nil, err
- }
- var vulnerabilityDetails VulnerabilityDetails
- if err := json.Unmarshal(resultPartial.VulnerabilityDetails, &vulnerabilityDetails); err != nil {
- return nil, err
- }
-
- result := &Result{resultPartial.Type,
- resultPartial.Label,
- resultPartial.ID,
- resultPartial.SimilarityID,
- resultPartial.Status,
- resultPartial.State,
- resultPartial.Severity,
- resultPartial.Created,
- resultPartial.FirstFoundAt,
- resultPartial.FoundAt,
- resultPartial.FirstScanID,
- resultPartial.Description,
- resultPartial.DescriptionHTML,
- data,
- resultPartial.Comments,
- vulnerabilityDetails}
- results = append(results, result)
- }
- scanResults := ScanResults{results, scanResultsPartial.TotalCount, scanResultsPartial.ScanID}
- return &scanResults, nil
-}
-
-func GetResultByID(results *ScanResults, resultID string) (*Result, error) {
- for _, result := range results.Results {
- if result.ID == resultID {
- return result, nil
- }
- }
- return &Result{}, fmt.Errorf("result ID %s not found", resultID)
-}
diff --git a/internal/commands/sast-sources.go b/internal/commands/sast-sources.go
deleted file mode 100644
index ece464391..000000000
--- a/internal/commands/sast-sources.go
+++ /dev/null
@@ -1,54 +0,0 @@
-package commands
-
-import (
- "bufio"
- "os"
- "path/filepath"
- "strings"
-)
-
-func GetSourcesForResult(scanResult *Result, sourceDir string) (map[string][]string, error) {
- sourceFilenames := make(map[string]bool)
- for i := range scanResult.Data.Nodes {
- sourceFilename := strings.ReplaceAll(scanResult.Data.Nodes[i].FileName, "\\", "/")
- sourceFilenames[sourceFilename] = true
- }
-
- fileContents, err := GetFileContents(sourceFilenames, sourceDir)
- if err != nil {
- return nil, err
- }
-
- return fileContents, nil
-}
-
-func GetFileContents(filenames map[string]bool, sourceDir string) (map[string][]string, error) {
- fileContents := make(map[string][]string)
-
- for filename := range filenames {
- sourceFilename := filepath.Join(sourceDir, filename)
- file, err := os.Open(sourceFilename)
- if err != nil {
- return nil, err
- }
-
- scanner := bufio.NewScanner(file)
- var lines []string
- for scanner.Scan() {
- lines = append(lines, scanner.Text())
- }
-
- err = file.Close()
- if err != nil {
- return nil, err
- }
-
- if err := scanner.Err(); err != nil {
- return nil, err
- }
-
- fileContents[filename] = lines
- }
-
- return fileContents, nil
-}