Skip to content

Commit

Permalink
feat: configurable model; ref: json reponse format
Browse files Browse the repository at this point in the history
  • Loading branch information
harnyk committed Jun 1, 2024
1 parent 2c55dbe commit d61f58e
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 99 deletions.
13 changes: 3 additions & 10 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,15 @@
"version": "0.2.0",
"configurations": [
{
"name": "wingman how to list files in the current directory",
"name": "try",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}",
"console": "integratedTerminal",
"args": [
"how",
"to",
"list",
"files",
"in",
"the",
"current",
"directory"
"install postgresql"
]
}
]
}
}
3 changes: 3 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ Then, you need to create a file `~/.wingman.yaml` with the following content:

```yaml
openai_token = <your key>
openai_model = gpt-4o
```

`openai_model` can be omitted if you want to use the default model, which is `gpt-3.5-turbo`.

### Usage

```bash
Expand Down
7 changes: 5 additions & 2 deletions cmd/root.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
/*
Copyright © 2023 NAME HERE <EMAIL ADDRESS>
*/
package cmd

Expand Down Expand Up @@ -28,10 +27,14 @@ var rootCmd = &cobra.Command{
if openAIToken == "" {
return fmt.Errorf("openai_token token is not set. Please set it in config file or environment variable")
}
openaiModel := viper.GetString("openai_model")
if openaiModel == "" {
openaiModel = openai.GPT3Dot5Turbo
}

client := openai.NewClient(openAIToken)

app, err := wingman.NewApp(client)
app, err := wingman.NewApp(client, openaiModel)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/harnyk/wingman
go 1.20

require (
github.com/sashabaranov/go-openai v1.5.7
github.com/sashabaranov/go-openai v1.24.1
github.com/spf13/cobra v1.6.1
github.com/spf13/viper v1.15.0
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBO
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sashabaranov/go-openai v1.5.7 h1:8DGgRG+P7yWixte5j720y6yiXgY3Hlgcd0gcpHdltfo=
github.com/sashabaranov/go-openai v1.5.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.24.1 h1:DWK95XViNb+agQtuzsn+FyHhn3HQJ7Va8z04DQDJ1MI=
github.com/sashabaranov/go-openai v1.24.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/spf13/afero v1.9.3 h1:41FoI0fD7OR7mGcKE/aOiLkGreyf8ifIOQmJANWogMk=
github.com/spf13/afero v1.9.3/go.mod h1:iUV7ddyEEZPO5gA3zD4fJt6iStLlL+Lg4m2cihcDf8Y=
Expand Down
99 changes: 63 additions & 36 deletions internal/wingman/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"os/exec"
"strings"
"text/template"

"github.com/sashabaranov/go-openai"
)
Expand All @@ -20,19 +21,21 @@ import (
// 7. If the user wants to exit, exit

type App struct {
OpenAIClient *openai.Client
envContext EnvironmentContext
client *openai.Client
envContext EnvironmentContext
model string
}

func NewApp(openAIClient *openai.Client) (*App, error) {
func NewApp(openAIClient *openai.Client, openaiModel string) (*App, error) {
context, err := NewContext()
if err != nil {
return nil, err
}

return &App{
OpenAIClient: openAIClient,
envContext: context,
client: openAIClient,
envContext: context,
model: openaiModel,
}, nil
}

Expand All @@ -48,11 +51,10 @@ func (a *App) Loop(query string) error {

stopSpinner := StartSpinner()
resp, err := a.getResponse(query)
stopSpinner()
if err != nil {
stopSpinner()
return err
}
stopSpinner()

DisplayResponse(query, resp)

Expand Down Expand Up @@ -81,14 +83,14 @@ func (a *App) getResponse(query string) (Response, error) {

prompt := a.createPrompt(query)

resp, err := a.OpenAIClient.CreateChatCompletion(
resp, err := a.client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
N: 1,
Stop: []string{
"[END]",
Model: a.model,
ResponseFormat: &openai.ChatCompletionResponseFormat{
Type: openai.ChatCompletionResponseFormatTypeJSONObject,
},
N: 1,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Expand All @@ -103,7 +105,7 @@ func (a *App) getResponse(query string) (Response, error) {
}

raw := resp.Choices[0]
response, err := ParseResponse(raw.Message.Content)
response, err := ParseResponseJSON(raw.Message.Content)

if err != nil {
return Response{}, err
Expand Down Expand Up @@ -142,34 +144,59 @@ func (a *App) runCommand(command string) error {
return nil
}

func (a *App) createPrompt(userPrompt string) string {
sbPrompt := strings.Builder{}
const promptTemplateTextJsonFormat = `
Provide a terminal command which would do the following:\n
{{ .UserPrompt }}
Environment Context:
OS: {{ .Context.OS }}
Shell: {{ .Context.Shell }}
User: {{ .Context.User }}
Instruction: it is very important that your response is in the JSON format, corresponding to the following JSON schema:
{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"command": {
"description": "the command to run directly in the shell",
"type": "string"
},
"explanation": {
"description": "a brief explanation how the command works",
"type": "string"
}
},
"required": ["command", "explanation"],
"additionalProperties": false
}
Example:
sbPrompt.WriteString("Provide a terminal command which would do the following:\n")
sbPrompt.WriteString(userPrompt)
sbPrompt.WriteString("\n\n")
sbPrompt.WriteString("Environment Context:\n")
User prompt:
sbPrompt.WriteString("OS: ")
sbPrompt.WriteString(a.envContext.OS)
sbPrompt.WriteString("\n")
list all files in the current directory
sbPrompt.WriteString("Shell: ")
sbPrompt.WriteString(a.envContext.Shell)
sbPrompt.WriteString("\n")
Response:
{
"command": "ls -la",
"explanation": "ls lists all files in the current directory"
}
sbPrompt.WriteString("User: ")
sbPrompt.WriteString(a.envContext.User)
sbPrompt.WriteString("\n")
In the explanation field use Markdown formatting if needed.
It is also important to pay attention on producing secure syntax, e.g. using proper quotes in the command.
Do not use Markdown formatting in the command field.
`

sbPrompt.WriteString("Instruction: it is very important to reply in the following format (the response must terminate with [END]):\n\n")
sbPrompt.WriteString("[COMMAND]:\n")
sbPrompt.WriteString("some command, e.g. ls -la\n")
sbPrompt.WriteString("[EXPLANATION]:\n")
sbPrompt.WriteString("a brief explanation how the command works\n")
sbPrompt.WriteString("[END]\n")
sbPrompt.WriteString("\n")
sbPrompt.WriteString("In the EXPLANATION field use Markdown formatting if needed\n")
var promptTemplate = template.Must(template.New("prompt").Parse(promptTemplateTextJsonFormat))

func (a *App) createPrompt(userPrompt string) string {
sbPrompt := strings.Builder{}
if err := promptTemplate.Execute(&sbPrompt, map[string]interface{}{
"Context": a.envContext,
"UserPrompt": userPrompt,
}); err != nil {
log.Fatal(err)
}
return sbPrompt.String()
}
52 changes: 9 additions & 43 deletions internal/wingman/response.go
Original file line number Diff line number Diff line change
@@ -1,57 +1,23 @@
package wingman

import (
"encoding/json"
"strings"
)

type Response struct {
Command string
Explanation string
Command string `json:"command"`
Explanation string `json:"explanation"`
}

// example:
// ai noise
// [COMMAND]:
// ls -la
// the command is multiline
// [EXPLANATION]:
// list all files in the current directory
// this is a multiline explanation
// [END]
// ai noise
func ParseResponseJSON(response string) (Response, error) {

func ParseResponse(response string) (Response, error) {
lines := strings.Split(response, "\n")
response = strings.Trim(response, "\n\r ")
res := Response{}

sbCommand := strings.Builder{}
sbExplanation := strings.Builder{}

var sbCurrent *strings.Builder

for _, line := range lines {
if strings.HasPrefix(line, "[COMMAND]:") {
sbCurrent = &sbCommand
continue
}

if strings.HasPrefix(line, "[EXPLANATION]:") {
sbCurrent = &sbExplanation
continue
}

if strings.HasPrefix(line, "[END]") {
break
}

if sbCurrent == nil {
continue
}
sbCurrent.WriteString(line)
sbCurrent.WriteString("\n")
if err := json.Unmarshal([]byte(response), &res); err != nil {
return Response{}, err
}

return Response{
Command: strings.Trim(sbCommand.String(), "\n"),
Explanation: strings.Trim(sbExplanation.String(), "\n"),
}, nil
return res, nil
}
9 changes: 2 additions & 7 deletions internal/wingman/ui.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@ import (
"github.com/manifoldco/promptui"
)

// use promptui to create a menu:
// 1. Run This Command
// 2. Revise Query
// 3. Exit

type MenuAction int

const (
Expand All @@ -33,7 +28,7 @@ func Menu() (MenuAction, error) {
index, _, err := prompt.Run()

if err != nil {
return 0, fmt.Errorf("Prompt failed %v", err)
return 0, fmt.Errorf("prompt failed %v", err)
}

return MenuAction(index + 1), nil
Expand Down Expand Up @@ -76,7 +71,7 @@ func ReviseQuery(initialQuery string) (string, error) {
Validate: func(input string) error {
input = strings.TrimSpace(input)
if len(input) < 1 {
return fmt.Errorf("Query must be at least 1 character")
return fmt.Errorf("query must be at least 1 character")
}
return nil
},
Expand Down

0 comments on commit d61f58e

Please sign in to comment.