Skip to content

Commit

Permalink
work in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
Koeng101 committed Mar 4, 2024
1 parent 38dafa4 commit 5445040
Show file tree
Hide file tree
Showing 11 changed files with 499 additions and 42 deletions.
203 changes: 172 additions & 31 deletions api/api/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net/http"
"os"
"regexp"
"strings"

"github.com/gorilla/websocket"
"github.com/sashabaranov/go-openai"
Expand All @@ -19,10 +20,14 @@ import (
// To test multiple times:
// for i in {1..20}; do echo "Run #$i"; go test; done

// OPENAI_API_KEY =""
// API_KEY=""
// MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1"
// BASE_URL="https://api.deepinfra.com/v1/openai"

// API_KEY=""
// MODEL="gpt-4-0125-preview"
// BASE_URL="https://api.openai.com/v1"

/*
*****************************************************************************
Expand Down Expand Up @@ -116,25 +121,98 @@ func chatHandler(w http.ResponseWriter, r *http.Request) {
/*
*****************************************************************************
# Question asker
*****************************************************************************
*/

// MagicJSONIncantation contains text that somehow consistently gets AI models to return valid JSON.
var MagicJSONIncantation = "Please respond ONLY with valid json that conforms to this json_schema: %s\n. Do not include additional text other than the object json as we will load this object with json.loads()."

func AskQuestions(ctx context.Context, client *openai.Client, model string, jsonSchema string, messages []openai.ChatCompletionMessage) (map[string]bool, error) {
var resultMap map[string]bool
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: fmt.Sprintf(MagicJSONIncantation, jsonSchema),
})
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: model,
Messages: messages,
Stop: []string{"}"}, // this stop token makes sure nothing else is generated.
},
)
if err != nil {
return resultMap, err
}
if len(resp.Choices) == 0 {
return resultMap, fmt.Errorf("Got zero responses")
}
response := resp.Choices[0].Message.Content + "}" // add back stop token
err = json.Unmarshal([]byte(response), &resultMap)
if err != nil {
return resultMap, err
}

return resultMap, nil
}

func toolJSONSchema() string {
type tool struct {
Name string
Question string
}

tools := []tool{
tool{Name: "code", Question: "The user is asking for code to be written"},
}

properties := make(map[string]interface{})
required := []string{}

for _, ex := range tools {
// Generating each property's schema
properties[ex.Name] = map[string]interface{}{
"type": "boolean",
"description": ex.Question,
}
// Accumulating required properties
required = append(required, ex.Name)
}

// Putting together the final schema
schema := map[string]interface{}{
"$schema": "http://json-schema.org/draft-07/schema#", // Assuming draft-07; adjust if necessary
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": false,
}

jsonBytes, _ := json.Marshal(schema)
return string(jsonBytes)
}

/*
*****************************************************************************
# Examples
*****************************************************************************
*/

//go:embed examples.lua
var examples string
var Examples string

type Example struct {
Name string
Question string
Text string
}

// MagicJSONIncantation contains text that somehow consistently gets AI models to return valid JSON.
var MagicJSONIncantation = "Please respond ONLY with valid json that conforms to this json_schema: %s\n. Do not include additional text other than the object json as we will load this object with json.loads()."

// FunctionExamples is a complete list of DnaDesign lua examples.
var FunctionExamples = parseExamples(examples)
var FunctionExamples = parseExamples(Examples)

// FunctionExamplesJSONSchema is a JSON schema containing questions of whether
// or not a given user request requires a
Expand Down Expand Up @@ -205,38 +283,101 @@ func generateJSONSchemaFromExamples(examples []Example) string {

// RequiredFunctions takes in a userRequest and returns a map of the examples
// that should be inserted along with that request to generate lua code.
func RequiredFunctions(ctx context.Context, client *openai.Client, model string, userRequest string) (map[string]bool, error) {
var resultMap map[string]bool
resp, err := client.CreateChatCompletion(
func RequiredFunctions(ctx context.Context, client *openai.Client, model string, messages []openai.ChatCompletionMessage) (map[string]bool, error) {
return AskQuestions(ctx, client, model, generateJSONSchemaFromExamples(parseExamples(Examples)), messages)
}

/*
*****************************************************************************
# Code writing
*****************************************************************************
*/

// MagicLuaIncantation contains text that gets the AI models to return valid lua.
var MagicLuaIncantation = "Please respond ONLY with valid lua. The lua will be run inside of a sandbox, so do not write or read files: only print data out. Be as concise as possible. The following functions are preloaded in the sandbox, but you must apply them to the user's problem: \n```lua\n%s\n```"

// WriteCode takes in user messages and writes code to accomplish the specific tasks.
func WriteCode(ctx context.Context, client *openai.Client, model string, messages []openai.ChatCompletionMessage) (*openai.ChatCompletionStream, error) {
// If we need to write code, first get the required functions
examplesToInject, err := RequiredFunctions(ctx, client, model, messages)
if err != nil {
return nil, err
}

// Now that we have the required functions, get their content
exampleText := ``
examples := parseExamples(Examples)
for _, example := range examples {
_, ok := examplesToInject[example.Name]
if ok {
exampleText = exampleText + example.Text + "\n"
}
}

// Now, we create the stream
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: fmt.Sprintf(MagicLuaIncantation, exampleText),
})
return client.CreateChatCompletionStream(
context.Background(),
openai.ChatCompletionRequest{
Model: model,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
// You will be answering questions about a user request, but not directly answering the user request
Content: fmt.Sprintf(MagicJSONIncantation, FunctionExamplesJSONSchema),
},
{
Role: openai.ChatMessageRoleUser,
Content: fmt.Sprintf(`USER REQUEST: %s`, userRequest),
},
},
Stop: []string{"}"}, // this stop token makes sure nothing else is generated.
Model: model,
Messages: messages,
Stream: true,
},
)
}

func WriteCodeString(ctx context.Context, client *openai.Client, model string, messages []openai.ChatCompletionMessage) (string, error) {
stream, err := WriteCode(ctx, client, model, messages)
if err != nil {
return resultMap, err
return "", err
}
if len(resp.Choices) == 0 {
return resultMap, fmt.Errorf("Got zero responses")
defer stream.Close()
var buffer strings.Builder
for {
var response openai.ChatCompletionStreamResponse
response, err = stream.Recv()
if errors.Is(err, io.EOF) {
err = nil
break
}

if err != nil {
break
}
buffer.WriteString(response.Choices[0].Delta.Content)
}
response := resp.Choices[0].Message.Content + "}" // add back stop token
err = json.Unmarshal([]byte(response), &resultMap)
if err != nil {
fmt.Println(response)
return resultMap, err

return ParseLuaFromLLM(buffer.String()), err
}

func ParseLuaFromLLM(input string) string {
luaPrefix := "```lua"
codePrefix := "```"
codeSuffix := "```"

// Check for ```lua ... ```
luaStartIndex := strings.Index(input, luaPrefix)
if luaStartIndex != -1 {
luaEndIndex := strings.Index(input[luaStartIndex+len(luaPrefix):], codeSuffix)
if luaEndIndex != -1 {
return input[luaStartIndex+len(luaPrefix) : luaStartIndex+len(luaPrefix)+luaEndIndex]
}
}

return resultMap, nil
// Check for ``` ... ```
codeStartIndex := strings.Index(input, codePrefix)
if codeStartIndex != -1 {
codeEndIndex := strings.Index(input[codeStartIndex+len(codePrefix):], codeSuffix)
if codeEndIndex != -1 {
return input[codeStartIndex+len(codePrefix) : codeStartIndex+len(codePrefix)+codeEndIndex]
}
}

// Return original if no markers found
return input
}
41 changes: 39 additions & 2 deletions api/api/ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"context"
"fmt"
"os"
"strings"
"testing"

"github.com/koeng101/dnadesign/api/api"
"github.com/koeng101/dnadesign/api/gen"
"github.com/sashabaranov/go-openai"
)

Expand All @@ -25,9 +27,10 @@ func TestAiFastaParse(t *testing.T) {
ctx := context.Background()
userRequest := `I would like you to parse the following FASTA and return to me the headers in a csv file.
The fasta:
>test\nATGC\ntest2\nGATC
>test\nATGC\n>test2\nGATC
`
examples, err := api.RequiredFunctions(ctx, client, model, userRequest)
message := openai.ChatCompletionMessage{Role: "user", Content: userRequest}
examples, err := api.RequiredFunctions(ctx, client, model, []openai.ChatCompletionMessage{message})
if err != nil {
t.Errorf("Got err: %s", err)
}
Expand All @@ -38,3 +41,37 @@ The fasta:
}
}
}

func TestWriteCodeString(t *testing.T) {
apiKey := os.Getenv("API_KEY")
if apiKey == "" {
return
}
baseUrl := os.Getenv("BASE_URL")
model := os.Getenv("MODEL")
config := openai.DefaultConfig(apiKey)
if baseUrl != "" {
config.BaseURL = baseUrl
}
client := openai.NewClientWithConfig(config)
ctx := context.Background()
userRequest := "Please parse the following FASTA and return me a csv. Add headers identifier and sequence to the top of the csv. Data:\n```>test\nATGC\n>test2\nGATC\n"

message := openai.ChatCompletionMessage{Role: "user", Content: userRequest}
code, err := api.WriteCodeString(ctx, client, model, []openai.ChatCompletionMessage{message})
if err != nil {
t.Errorf("Got error: %s", err)
}
// run the code
output, err := app.ExecuteLua(code, []gen.Attachment{})
if err != nil {
t.Errorf("No error should be found. Got err: %s", err)
}
expectedOutput := `identifier,sequence
test,ATGC
test2,GATC`
if strings.TrimSpace(output) != strings.TrimSpace(expectedOutput) {
t.Errorf("Unexpected response. Expected: " + expectedOutput + "\nGot: " + output)
fmt.Println(code)
}
}
2 changes: 1 addition & 1 deletion api/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (app *App) ExecuteLua(data string, attachments []gen.Attachment) (string, e
L.SetGlobal("print", L.NewFunction(customPrint(&buffer)))

// Add IO functions
L.SetGlobal("fasta_parse", L.NewFunction(app.LuaIoFastaParse))
L.SetGlobal("fastaParse", L.NewFunction(app.LuaIoFastaParse))

// Execute the Lua script
if err := L.DoString(data); err != nil {
Expand Down
8 changes: 5 additions & 3 deletions api/api/api_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
package api
package api_test

import (
"net/http/httptest"
"os"
"strings"
"testing"

"github.com/koeng101/dnadesign/api/api"
)

var app App
var app api.App

func TestMain(m *testing.M) {
app = InitializeApp()
app = api.InitializeApp()
code := m.Run()
os.Exit(code)
}
Expand Down
2 changes: 1 addition & 1 deletion api/api/examples.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ IO examples.
-- QUESTION: Does the request require the parsing of FASTA formatted data? Return a boolean.
-- fastaParse parses a fasta file into a list of tables of "identifier" and
-- "sequence"
parsedFasta = fastaParse(">test\nATGC\ntest2\nGATC")
parsedFasta = fastaParse(">test\nATGC\n>test2\nGATC")
print(parsedFasta[1]["identifier"]) -- returns "test"
print(parsedFasta[2]["sequence"]) -- returns "GATC"
-- END
Expand Down
20 changes: 17 additions & 3 deletions api/api/lua_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
package api
package api_test

import (
"testing"

"github.com/koeng101/dnadesign/api/api"
"github.com/koeng101/dnadesign/api/gen"
)

func TestApp_Examples(t *testing.T) {
output, err := app.ExecuteLua(api.Examples, []gen.Attachment{})
if err != nil {
t.Errorf("No error should be found. Got err: %s", err)
}
expectedOutput := `test
GATC
`
if output != expectedOutput {
t.Errorf("Unexpected response. Expected: " + expectedOutput + "\nGot: " + output)
}
}

func TestApp_LuaIoFastaParse(t *testing.T) {
luaScript := `
parsed_fasta = fasta_parse(attachments["input.fasta"])
parsedFasta = fastaParse(attachments["input.fasta"])
print(parsed_fasta[1].identifier)
print(parsedFasta[1].identifier)
`
inputFasta := `>AAD44166.1
LCLYTHIGRNIYYGSYLYSETWNTGIMLLLITMATAFMGYVLPWGQMSFWGATVITNLFSAIPYIGTNLV
Expand Down
3 changes: 3 additions & 0 deletions api/api/output.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
identifier,sequence
test,ATGC
test2,GATC
Loading

0 comments on commit 5445040

Please sign in to comment.