-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
285 additions
and
0 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
package bot | ||
|
||
import ( | ||
"bufio" | ||
"context" | ||
"fmt" | ||
"io" | ||
"log" | ||
"strings" | ||
|
||
"github.com/sashabaranov/go-openai" | ||
) | ||
|
||
//go:generate moq --out mocks/openai_client.go --pkg mocks --skip-ensure . OpenAIClient | ||
|
||
// SpamOpenAIFilter bot, checks if user is a spammer using openai api call | ||
type SpamOpenAIFilter struct { | ||
dry bool | ||
superUser SuperUser | ||
openai OpenAIClient | ||
maxLen int | ||
|
||
spamPrompt string | ||
enabled bool | ||
approvedUsers map[int64]bool | ||
} | ||
|
||
// OpenAIClient is interface for OpenAI client with the possibility to mock it | ||
type OpenAIClient interface { | ||
CreateChatCompletion(context.Context, openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) | ||
} | ||
|
||
// NewSpamOpenAIFilter makes a spam detecting bot | ||
func NewSpamOpenAIFilter(spamSamples io.Reader, openai OpenAIClient, maxLen int, superUser SuperUser, dry bool) *SpamOpenAIFilter { | ||
log.Printf("[INFO] Spam bot (openai)") | ||
res := &SpamOpenAIFilter{dry: dry, approvedUsers: map[int64]bool{}, superUser: superUser, openai: openai} | ||
|
||
scanner := bufio.NewScanner(spamSamples) | ||
for scanner.Scan() { | ||
res.spamPrompt += scanner.Text() + "\n" | ||
} | ||
if err := scanner.Err(); err != nil { | ||
log.Printf("[WARN] failed to read spam samples, error=%v", err) | ||
res.enabled = false | ||
} else { | ||
res.enabled = true | ||
} | ||
if len(res.spamPrompt) > maxLen { | ||
res.spamPrompt = res.spamPrompt[:maxLen] | ||
} | ||
return res | ||
} | ||
|
||
// OnMessage checks if user already approved and if not checks if user is a spammer | ||
func (s *SpamOpenAIFilter) OnMessage(msg Message) (response Response) { | ||
if s.approvedUsers[msg.From.ID] { | ||
return Response{} | ||
} | ||
|
||
if s.superUser.IsSuper(msg.From.Username) { | ||
return Response{} // don't check super users for spam | ||
} | ||
|
||
if !s.isSpam(msg.Text) { | ||
log.Printf("[INFO] user %s (%d) is not a spammer, added to aproved", msg.From.Username, msg.From.ID) | ||
s.approvedUsers[msg.From.ID] = true | ||
return Response{} // not a spam | ||
} | ||
|
||
log.Printf("[INFO] user %s detected as spammer, msg: %q", msg.From.Username, msg.Text) | ||
if s.dry { | ||
return Response{ | ||
Text: fmt.Sprintf("this is spam from %q, but I'm in dry mode, so I'll do nothing yet", msg.From.Username), | ||
Send: true, ReplyTo: msg.ID, | ||
} | ||
} | ||
return Response{Text: "this is spam! go to ban, " + msg.From.DisplayName, Send: true, ReplyTo: msg.ID, BanInterval: permanentBanDuration, DeleteReplyTo: true} | ||
} | ||
|
||
// Help returns help message | ||
func (s *SpamOpenAIFilter) Help() string { return "" } | ||
|
||
// ReactOn keys | ||
func (s *SpamOpenAIFilter) ReactOn() []string { return []string{} } | ||
|
||
// isSpam checks if a given message is similar to any of the known bad messages. | ||
func (s *SpamOpenAIFilter) isSpam(message string) bool { | ||
|
||
messages := []openai.ChatCompletionMessage{} | ||
messages = append(messages, openai.ChatCompletionMessage{ | ||
Role: openai.ChatMessageRoleSystem, | ||
Content: "this is the list of spam messages. I will give you a messages to detect if this is spam or not and you will answer with a single world \"OK\" or \"SPAM\"\n\n" + s.spamPrompt, | ||
}) | ||
messages = append(messages, openai.ChatCompletionMessage{ | ||
Role: openai.ChatMessageRoleUser, | ||
Content: message, | ||
}) | ||
|
||
resp, err := s.openai.CreateChatCompletion( | ||
context.Background(), | ||
openai.ChatCompletionRequest{ | ||
Model: openai.GPT3Dot5Turbo, | ||
MaxTokens: 1024, | ||
Messages: messages, | ||
}, | ||
) | ||
if err != nil { | ||
log.Printf("[WARN] failed to check spam, error=%v", err) | ||
return false | ||
} | ||
if len(resp.Choices) == 0 { | ||
log.Printf("[WARN] empty response from openai") | ||
return false | ||
} | ||
|
||
return strings.Contains(resp.Choices[0].Message.Content, "SPAM") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package bot | ||
|
||
import ( | ||
"context" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/sashabaranov/go-openai" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/radio-t/super-bot/app/bot/mocks" | ||
) | ||
|
||
func TestSpamOpenAIFilter_isSpam(t *testing.T) { | ||
spamSamples := strings.NewReader("spam1\nspam2\nspam3") | ||
mockOpenAI := &mocks.OpenAIClientMock{} | ||
|
||
filter := NewSpamOpenAIFilter(spamSamples, mockOpenAI, 4096, nil, false) | ||
require.True(t, filter.enabled) | ||
assert.True(t, len(filter.spamPrompt) <= 4096) | ||
|
||
mockOpenAI.CreateChatCompletionFunc = func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { | ||
return openai.ChatCompletionResponse{ | ||
Choices: []openai.ChatCompletionChoice{ | ||
{ | ||
Message: openai.ChatCompletionMessage{ | ||
Content: "SPAM", | ||
}, | ||
}, | ||
}, | ||
}, nil | ||
} | ||
|
||
assert.True(t, filter.isSpam("this is a spam message")) | ||
assert.Equal(t, 1, len(mockOpenAI.CreateChatCompletionCalls())) | ||
|
||
mockOpenAI.CreateChatCompletionFunc = func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { | ||
return openai.ChatCompletionResponse{ | ||
Choices: []openai.ChatCompletionChoice{ | ||
{ | ||
Message: openai.ChatCompletionMessage{ | ||
Content: "OK", | ||
}, | ||
}, | ||
}, | ||
}, nil | ||
} | ||
|
||
assert.False(t, filter.isSpam("this is not a spam message")) | ||
assert.Equal(t, 2, len(mockOpenAI.CreateChatCompletionCalls())) | ||
} | ||
|
||
func TestSpamOpenAIFilter_OnMessage(t *testing.T) { | ||
superUser := &mocks.SuperUser{IsSuperFunc: func(userName string) bool { | ||
if userName == "super" || userName == "admin" { | ||
return true | ||
} | ||
return false | ||
}} | ||
|
||
spamSamples := strings.NewReader("spam1\nspam2\nspam3") | ||
mockOpenAI := &mocks.OpenAIClientMock{ | ||
CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { | ||
return openai.ChatCompletionResponse{ | ||
Choices: []openai.ChatCompletionChoice{ | ||
{ | ||
Message: openai.ChatCompletionMessage{ | ||
Content: "OK", | ||
}, | ||
}, | ||
}, | ||
}, nil | ||
}, | ||
} | ||
|
||
filter := NewSpamOpenAIFilter(spamSamples, mockOpenAI, 4096, superUser, false) | ||
|
||
msg := Message{From: User{ID: 1, Username: "user1"}, Text: "hello"} | ||
resp := filter.OnMessage(msg) | ||
|
||
assert.Empty(t, resp.Text) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters