diff --git a/app/bot/mocks/openai_client.go b/app/bot/mocks/openai_client.go new file mode 100644 index 00000000..7e0a1aa3 --- /dev/null +++ b/app/bot/mocks/openai_client.go @@ -0,0 +1,78 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mocks + +import ( + "context" + "github.com/sashabaranov/go-openai" + "sync" +) + +// OpenAIClientMock is a mock implementation of bot.OpenAIClient. +// +// func TestSomethingThatUsesOpenAIClient(t *testing.T) { +// +// // make and configure a mocked bot.OpenAIClient +// mockedOpenAIClient := &OpenAIClientMock{ +// CreateChatCompletionFunc: func(contextMoqParam context.Context, chatCompletionRequest openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { +// panic("mock out the CreateChatCompletion method") +// }, +// } +// +// // use mockedOpenAIClient in code that requires bot.OpenAIClient +// // and then make assertions. +// +// } +type OpenAIClientMock struct { + // CreateChatCompletionFunc mocks the CreateChatCompletion method. + CreateChatCompletionFunc func(contextMoqParam context.Context, chatCompletionRequest openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) + + // calls tracks calls to the methods. + calls struct { + // CreateChatCompletion holds details about calls to the CreateChatCompletion method. + CreateChatCompletion []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // ChatCompletionRequest is the chatCompletionRequest argument value. + ChatCompletionRequest openai.ChatCompletionRequest + } + } + lockCreateChatCompletion sync.RWMutex +} + +// CreateChatCompletion calls CreateChatCompletionFunc. +func (mock *OpenAIClientMock) CreateChatCompletion(contextMoqParam context.Context, chatCompletionRequest openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + if mock.CreateChatCompletionFunc == nil { + panic("OpenAIClientMock.CreateChatCompletionFunc: method is nil but OpenAIClient.CreateChatCompletion was just called") + } + callInfo := struct { + ContextMoqParam context.Context + ChatCompletionRequest openai.ChatCompletionRequest + }{ + ContextMoqParam: contextMoqParam, + ChatCompletionRequest: chatCompletionRequest, + } + mock.lockCreateChatCompletion.Lock() + mock.calls.CreateChatCompletion = append(mock.calls.CreateChatCompletion, callInfo) + mock.lockCreateChatCompletion.Unlock() + return mock.CreateChatCompletionFunc(contextMoqParam, chatCompletionRequest) +} + +// CreateChatCompletionCalls gets all the calls that were made to CreateChatCompletion. +// Check the length with: +// +// len(mockedOpenAIClient.CreateChatCompletionCalls()) +func (mock *OpenAIClientMock) CreateChatCompletionCalls() []struct { + ContextMoqParam context.Context + ChatCompletionRequest openai.ChatCompletionRequest +} { + var calls []struct { + ContextMoqParam context.Context + ChatCompletionRequest openai.ChatCompletionRequest + } + mock.lockCreateChatCompletion.RLock() + calls = mock.calls.CreateChatCompletion + mock.lockCreateChatCompletion.RUnlock() + return calls +} diff --git a/app/bot/openai/openai.go b/app/bot/openai/openai.go index 06bb7688..7cec4f2f 100644 --- a/app/bot/openai/openai.go +++ b/app/bot/openai/openai.go @@ -304,3 +304,8 @@ func (o *OpenAI) Summary(text string) (response string, err error) { func (o *OpenAI) ReactOn() []string { return []string{"chat!", "gpt!", "ai!", "чат!"} } + +// CreateChatCompletion exposes the underlying openai.CreateChatCompletion method +func (o *OpenAI) CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + return o.client.CreateChatCompletion(ctx, req) +} diff --git a/app/bot/spam_openai.go b/app/bot/spam_openai.go new file mode 100644 index 00000000..6063c52b --- /dev/null +++ b/app/bot/spam_openai.go @@ -0,0 +1,117 @@ +package bot + +import ( + "bufio" + "context" + "fmt" + "io" + "log" + "strings" + + "github.com/sashabaranov/go-openai" +) + +//go:generate moq --out mocks/openai_client.go --pkg mocks --skip-ensure . OpenAIClient + +// SpamOpenAIFilter bot, checks if user is a spammer using openai api call +type SpamOpenAIFilter struct { + dry bool + superUser SuperUser + openai OpenAIClient + maxLen int + + spamPrompt string + enabled bool + approvedUsers map[int64]bool +} + +// OpenAIClient is interface for OpenAI client with the possibility to mock it +type OpenAIClient interface { + CreateChatCompletion(context.Context, openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) +} + +// NewSpamOpenAIFilter makes a spam detecting bot +func NewSpamOpenAIFilter(spamSamples io.Reader, openai OpenAIClient, maxLen int, superUser SuperUser, dry bool) *SpamOpenAIFilter { + log.Printf("[INFO] Spam bot (openai)") + res := &SpamOpenAIFilter{dry: dry, approvedUsers: map[int64]bool{}, superUser: superUser, openai: openai} + + scanner := bufio.NewScanner(spamSamples) + for scanner.Scan() { + res.spamPrompt += scanner.Text() + "\n" + } + if err := scanner.Err(); err != nil { + log.Printf("[WARN] failed to read spam samples, error=%v", err) + res.enabled = false + } else { + res.enabled = true + } + if len(res.spamPrompt) > maxLen { + res.spamPrompt = res.spamPrompt[:maxLen] + } + return res +} + +// OnMessage checks if user already approved and if not checks if user is a spammer +func (s *SpamOpenAIFilter) OnMessage(msg Message) (response Response) { + if s.approvedUsers[msg.From.ID] { + return Response{} + } + + if s.superUser.IsSuper(msg.From.Username) { + return Response{} // don't check super users for spam + } + + if !s.isSpam(msg.Text) { + log.Printf("[INFO] user %s (%d) is not a spammer, added to aproved", msg.From.Username, msg.From.ID) + s.approvedUsers[msg.From.ID] = true + return Response{} // not a spam + } + + log.Printf("[INFO] user %s detected as spammer, msg: %q", msg.From.Username, msg.Text) + if s.dry { + return Response{ + Text: fmt.Sprintf("this is spam from %q, but I'm in dry mode, so I'll do nothing yet", msg.From.Username), + Send: true, ReplyTo: msg.ID, + } + } + return Response{Text: "this is spam! go to ban, " + msg.From.DisplayName, Send: true, ReplyTo: msg.ID, BanInterval: permanentBanDuration, DeleteReplyTo: true} +} + +// Help returns help message +func (s *SpamOpenAIFilter) Help() string { return "" } + +// ReactOn keys +func (s *SpamOpenAIFilter) ReactOn() []string { return []string{} } + +// isSpam checks if a given message is similar to any of the known bad messages. +func (s *SpamOpenAIFilter) isSpam(message string) bool { + + messages := []openai.ChatCompletionMessage{} + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleSystem, + Content: "this is the list of spam messages. I will give you a messages to detect if this is spam or not and you will answer with a single world \"OK\" or \"SPAM\"\n\n" + s.spamPrompt, + }) + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: message, + }) + + resp, err := s.openai.CreateChatCompletion( + context.Background(), + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + MaxTokens: 1024, + Messages: messages, + }, + ) + if err != nil { + log.Printf("[WARN] failed to check spam, error=%v", err) + return false + } + if len(resp.Choices) == 0 { + log.Printf("[WARN] empty response from openai") + return false + } + + return strings.Contains(resp.Choices[0].Message.Content, "SPAM") +} diff --git a/app/bot/spam_openai_test.go b/app/bot/spam_openai_test.go new file mode 100644 index 00000000..e6eb17fa --- /dev/null +++ b/app/bot/spam_openai_test.go @@ -0,0 +1,83 @@ +package bot + +import ( + "context" + "strings" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/radio-t/super-bot/app/bot/mocks" +) + +func TestSpamOpenAIFilter_isSpam(t *testing.T) { + spamSamples := strings.NewReader("spam1\nspam2\nspam3") + mockOpenAI := &mocks.OpenAIClientMock{} + + filter := NewSpamOpenAIFilter(spamSamples, mockOpenAI, 4096, nil, false) + require.True(t, filter.enabled) + assert.True(t, len(filter.spamPrompt) <= 4096) + + mockOpenAI.CreateChatCompletionFunc = func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + return openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{ + Content: "SPAM", + }, + }, + }, + }, nil + } + + assert.True(t, filter.isSpam("this is a spam message")) + assert.Equal(t, 1, len(mockOpenAI.CreateChatCompletionCalls())) + + mockOpenAI.CreateChatCompletionFunc = func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + return openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{ + Content: "OK", + }, + }, + }, + }, nil + } + + assert.False(t, filter.isSpam("this is not a spam message")) + assert.Equal(t, 2, len(mockOpenAI.CreateChatCompletionCalls())) +} + +func TestSpamOpenAIFilter_OnMessage(t *testing.T) { + superUser := &mocks.SuperUser{IsSuperFunc: func(userName string) bool { + if userName == "super" || userName == "admin" { + return true + } + return false + }} + + spamSamples := strings.NewReader("spam1\nspam2\nspam3") + mockOpenAI := &mocks.OpenAIClientMock{ + CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + return openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{ + Content: "OK", + }, + }, + }, + }, nil + }, + } + + filter := NewSpamOpenAIFilter(spamSamples, mockOpenAI, 4096, superUser, false) + + msg := Message{From: User{ID: 1, Username: "user1"}, Text: "hello"} + resp := filter.OnMessage(msg) + + assert.Empty(t, resp.Text) +} diff --git a/app/main.go b/app/main.go index 39b151fa..4c4f3540 100644 --- a/app/main.go +++ b/app/main.go @@ -154,6 +154,8 @@ func main() { } multiBot = append(multiBot, bot.NewSpamLocalFilter(spamFh, opts.SpamFilter.Threshold, opts.SuperUsers, opts.SpamFilter.Dry)) + multiBot = append(multiBot, bot.NewSpamOpenAIFilter(spamFh, openAIBot, opts.OpenAI.MaxSymbolsRequest, + opts.SuperUsers, opts.SpamFilter.Dry)) } } else { log.Print("[INFO] spam filter disabled")