diff --git a/.vscode/launch.json b/.vscode/launch.json index 4e1eefb..6f337ee 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -5,22 +5,15 @@ "version": "0.2.0", "configurations": [ { - "name": "wingman how to list files in the current directory", + "name": "try", "type": "go", "request": "launch", "mode": "auto", "program": "${workspaceFolder}", "console": "integratedTerminal", "args": [ - "how", - "to", - "list", - "files", - "in", - "the", - "current", - "directory" + "install postgresql" ] } ] -} +} \ No newline at end of file diff --git a/Readme.md b/Readme.md index d0f7e75..464d124 100644 --- a/Readme.md +++ b/Readme.md @@ -26,8 +26,11 @@ Then, you need to create a file `~/.wingman.yaml` with the following content: ```yaml openai_token = +openai_model = gpt-4o ``` +`openai_model` can be omitted if you want to use the default model, which is `gpt-3.5-turbo`. + ### Usage ```bash diff --git a/cmd/root.go b/cmd/root.go index c4e3720..3d01670 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,6 +1,5 @@ /* Copyright © 2023 NAME HERE - */ package cmd @@ -28,10 +27,14 @@ var rootCmd = &cobra.Command{ if openAIToken == "" { return fmt.Errorf("openai_token token is not set. Please set it in config file or environment variable") } + openaiModel := viper.GetString("openai_model") + if openaiModel == "" { + openaiModel = openai.GPT3Dot5Turbo + } client := openai.NewClient(openAIToken) - app, err := wingman.NewApp(client) + app, err := wingman.NewApp(client, openaiModel) if err != nil { return err } diff --git a/go.mod b/go.mod index 45cb12a..961574d 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/harnyk/wingman go 1.20 require ( - github.com/sashabaranov/go-openai v1.5.7 + github.com/sashabaranov/go-openai v1.24.1 github.com/spf13/cobra v1.6.1 github.com/spf13/viper v1.15.0 ) diff --git a/go.sum b/go.sum index a98bce5..6ec3d43 100644 --- a/go.sum +++ b/go.sum @@ -201,6 +201,8 @@ github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBO github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sashabaranov/go-openai v1.5.7 h1:8DGgRG+P7yWixte5j720y6yiXgY3Hlgcd0gcpHdltfo= github.com/sashabaranov/go-openai v1.5.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/sashabaranov/go-openai v1.24.1 h1:DWK95XViNb+agQtuzsn+FyHhn3HQJ7Va8z04DQDJ1MI= +github.com/sashabaranov/go-openai v1.24.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/spf13/afero v1.9.3 h1:41FoI0fD7OR7mGcKE/aOiLkGreyf8ifIOQmJANWogMk= github.com/spf13/afero v1.9.3/go.mod h1:iUV7ddyEEZPO5gA3zD4fJt6iStLlL+Lg4m2cihcDf8Y= diff --git a/internal/wingman/app.go b/internal/wingman/app.go index 46ac56e..885f3ce 100644 --- a/internal/wingman/app.go +++ b/internal/wingman/app.go @@ -6,6 +6,7 @@ import ( "os" "os/exec" "strings" + "text/template" "github.com/sashabaranov/go-openai" ) @@ -20,19 +21,21 @@ import ( // 7. If the user wants to exit, exit type App struct { - OpenAIClient *openai.Client - envContext EnvironmentContext + client *openai.Client + envContext EnvironmentContext + model string } -func NewApp(openAIClient *openai.Client) (*App, error) { +func NewApp(openAIClient *openai.Client, openaiModel string) (*App, error) { context, err := NewContext() if err != nil { return nil, err } return &App{ - OpenAIClient: openAIClient, - envContext: context, + client: openAIClient, + envContext: context, + model: openaiModel, }, nil } @@ -48,11 +51,10 @@ func (a *App) Loop(query string) error { stopSpinner := StartSpinner() resp, err := a.getResponse(query) + stopSpinner() if err != nil { - stopSpinner() return err } - stopSpinner() DisplayResponse(query, resp) @@ -81,14 +83,14 @@ func (a *App) getResponse(query string) (Response, error) { prompt := a.createPrompt(query) - resp, err := a.OpenAIClient.CreateChatCompletion( + resp, err := a.client.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ - Model: openai.GPT3Dot5Turbo, - N: 1, - Stop: []string{ - "[END]", + Model: a.model, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONObject, }, + N: 1, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -103,7 +105,7 @@ func (a *App) getResponse(query string) (Response, error) { } raw := resp.Choices[0] - response, err := ParseResponse(raw.Message.Content) + response, err := ParseResponseJSON(raw.Message.Content) if err != nil { return Response{}, err @@ -142,34 +144,59 @@ func (a *App) runCommand(command string) error { return nil } -func (a *App) createPrompt(userPrompt string) string { - sbPrompt := strings.Builder{} +const promptTemplateTextJsonFormat = ` +Provide a terminal command which would do the following:\n +{{ .UserPrompt }} + +Environment Context: +OS: {{ .Context.OS }} +Shell: {{ .Context.Shell }} +User: {{ .Context.User }} +Instruction: it is very important that your response is in the JSON format, corresponding to the following JSON schema: + +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "command": { + "description": "the command to run directly in the shell", + "type": "string" + }, + "explanation": { + "description": "a brief explanation how the command works", + "type": "string" + } + }, + "required": ["command", "explanation"], + "additionalProperties": false +} + +Example: - sbPrompt.WriteString("Provide a terminal command which would do the following:\n") - sbPrompt.WriteString(userPrompt) - sbPrompt.WriteString("\n\n") - sbPrompt.WriteString("Environment Context:\n") +User prompt: - sbPrompt.WriteString("OS: ") - sbPrompt.WriteString(a.envContext.OS) - sbPrompt.WriteString("\n") +list all files in the current directory - sbPrompt.WriteString("Shell: ") - sbPrompt.WriteString(a.envContext.Shell) - sbPrompt.WriteString("\n") +Response: +{ + "command": "ls -la", + "explanation": "ls lists all files in the current directory" +} - sbPrompt.WriteString("User: ") - sbPrompt.WriteString(a.envContext.User) - sbPrompt.WriteString("\n") +In the explanation field use Markdown formatting if needed. +It is also important to pay attention on producing secure syntax, e.g. using proper quotes in the command. +Do not use Markdown formatting in the command field. +` - sbPrompt.WriteString("Instruction: it is very important to reply in the following format (the response must terminate with [END]):\n\n") - sbPrompt.WriteString("[COMMAND]:\n") - sbPrompt.WriteString("some command, e.g. ls -la\n") - sbPrompt.WriteString("[EXPLANATION]:\n") - sbPrompt.WriteString("a brief explanation how the command works\n") - sbPrompt.WriteString("[END]\n") - sbPrompt.WriteString("\n") - sbPrompt.WriteString("In the EXPLANATION field use Markdown formatting if needed\n") +var promptTemplate = template.Must(template.New("prompt").Parse(promptTemplateTextJsonFormat)) +func (a *App) createPrompt(userPrompt string) string { + sbPrompt := strings.Builder{} + if err := promptTemplate.Execute(&sbPrompt, map[string]interface{}{ + "Context": a.envContext, + "UserPrompt": userPrompt, + }); err != nil { + log.Fatal(err) + } return sbPrompt.String() } diff --git a/internal/wingman/response.go b/internal/wingman/response.go index 49fd4a6..9733929 100644 --- a/internal/wingman/response.go +++ b/internal/wingman/response.go @@ -1,57 +1,23 @@ package wingman import ( + "encoding/json" "strings" ) type Response struct { - Command string - Explanation string + Command string `json:"command"` + Explanation string `json:"explanation"` } -// example: -// ai noise -// [COMMAND]: -// ls -la -// the command is multiline -// [EXPLANATION]: -// list all files in the current directory -// this is a multiline explanation -// [END] -// ai noise +func ParseResponseJSON(response string) (Response, error) { -func ParseResponse(response string) (Response, error) { - lines := strings.Split(response, "\n") + response = strings.Trim(response, "\n\r ") + res := Response{} - sbCommand := strings.Builder{} - sbExplanation := strings.Builder{} - - var sbCurrent *strings.Builder - - for _, line := range lines { - if strings.HasPrefix(line, "[COMMAND]:") { - sbCurrent = &sbCommand - continue - } - - if strings.HasPrefix(line, "[EXPLANATION]:") { - sbCurrent = &sbExplanation - continue - } - - if strings.HasPrefix(line, "[END]") { - break - } - - if sbCurrent == nil { - continue - } - sbCurrent.WriteString(line) - sbCurrent.WriteString("\n") + if err := json.Unmarshal([]byte(response), &res); err != nil { + return Response{}, err } - return Response{ - Command: strings.Trim(sbCommand.String(), "\n"), - Explanation: strings.Trim(sbExplanation.String(), "\n"), - }, nil + return res, nil } diff --git a/internal/wingman/ui.go b/internal/wingman/ui.go index af9752b..c4a3138 100644 --- a/internal/wingman/ui.go +++ b/internal/wingman/ui.go @@ -12,11 +12,6 @@ import ( "github.com/manifoldco/promptui" ) -// use promptui to create a menu: -// 1. Run This Command -// 2. Revise Query -// 3. Exit - type MenuAction int const ( @@ -33,7 +28,7 @@ func Menu() (MenuAction, error) { index, _, err := prompt.Run() if err != nil { - return 0, fmt.Errorf("Prompt failed %v", err) + return 0, fmt.Errorf("prompt failed %v", err) } return MenuAction(index + 1), nil @@ -76,7 +71,7 @@ func ReviseQuery(initialQuery string) (string, error) { Validate: func(input string) error { input = strings.TrimSpace(input) if len(input) < 1 { - return fmt.Errorf("Query must be at least 1 character") + return fmt.Errorf("query must be at least 1 character") } return nil },