diff --git a/app/bot/spam.go b/app/bot/spam_cas.go similarity index 77% rename from app/bot/spam.go rename to app/bot/spam_cas.go index 75fba6b..0442e00 100644 --- a/app/bot/spam.go +++ b/app/bot/spam_cas.go @@ -8,8 +8,8 @@ import ( "time" ) -// SpamFilter bot, checks if user is a spammer using CAS API -type SpamFilter struct { +// SpamCasFilter bot, checks if user is a spammer using CAS API +type SpamCasFilter struct { casAPI string dry bool client HTTPClient @@ -22,14 +22,14 @@ type SpamFilter struct { // they are considered to be restricted forever. var permanentBanDuration = time.Hour * 24 * 400 -// NewSpamFilter makes a spam detecting bot -func NewSpamFilter(api string, client HTTPClient, superUser SuperUser, dry bool) *SpamFilter { - log.Printf("[INFO] Spam bot with %s", api) - return &SpamFilter{casAPI: api, client: client, dry: dry, approvedUsers: map[int64]bool{}, superUser: superUser} +// NewSpamCasFilter makes a spam detecting bot +func NewSpamCasFilter(api string, client HTTPClient, superUser SuperUser, dry bool) *SpamCasFilter { + log.Printf("[INFO] Spam bot (cas) with %s", api) + return &SpamCasFilter{casAPI: api, client: client, dry: dry, approvedUsers: map[int64]bool{}, superUser: superUser} } // OnMessage checks if user already approved and if not checks if user is a spammer -func (s *SpamFilter) OnMessage(msg Message) (response Response) { +func (s *SpamCasFilter) OnMessage(msg Message) (response Response) { if s.approvedUsers[msg.From.ID] { return Response{} } @@ -80,7 +80,7 @@ func (s *SpamFilter) OnMessage(msg Message) (response Response) { } // Help returns help message -func (s *SpamFilter) Help() string { return "" } +func (s *SpamCasFilter) Help() string { return "" } // ReactOn keys -func (s *SpamFilter) ReactOn() []string { return []string{} } +func (s *SpamCasFilter) ReactOn() []string { return []string{} } diff --git a/app/bot/spam_test.go b/app/bot/spam_cas_test.go similarity index 94% rename from app/bot/spam_test.go rename to app/bot/spam_cas_test.go index 41aba43..36d81d0 100644 --- a/app/bot/spam_test.go +++ b/app/bot/spam_cas_test.go @@ -19,7 +19,7 @@ func TestNewSpamFilter(t *testing.T) { return false }} client := &mocks.HTTPClient{} - sf := NewSpamFilter("http://localhost", client, su, false) + sf := NewSpamCasFilter("http://localhost", client, su, false) assert.NotNil(t, sf) assert.Equal(t, "http://localhost", sf.casAPI) assert.Equal(t, client, sf.client) @@ -92,7 +92,7 @@ func TestSpamFilter_OnMessage(t *testing.T) { }, } - s := NewSpamFilter("http://localhost", mockedHTTPClient, su, tt.dryMode) + s := NewSpamCasFilter("http://localhost", mockedHTTPClient, su, tt.dryMode) msg := Message{ From: User{ @@ -127,7 +127,7 @@ func TestSpamFilter_OnMessageCheckOnce(t *testing.T) { return false }} - s := NewSpamFilter("http://localhost", mockedHTTPClient, su, false) + s := NewSpamCasFilter("http://localhost", mockedHTTPClient, su, false) res := s.OnMessage(Message{From: User{ID: 1, Username: "testuser"}, ID: 1, Text: "Hello"}) assert.Equal(t, Response{}, res) assert.Len(t, mockedHTTPClient.DoCalls(), 1, "Do should be called once") diff --git a/app/bot/spam_local.go b/app/bot/spam_local.go new file mode 100644 index 0000000..2825437 --- /dev/null +++ b/app/bot/spam_local.go @@ -0,0 +1,127 @@ +package bot + +import ( + "bufio" + "fmt" + "io" + "log" + "math" + "strings" +) + +// SpamLocalFilter bot, checks if user is a spammer using internal matching +type SpamLocalFilter struct { + dry bool + superUser SuperUser + threshold float64 + + enabled bool + spamMessages []string + approvedUsers map[int64]bool +} + +// NewSpamLocalFilter makes a spam detecting bot +func NewSpamLocalFilter(spamSamples io.Reader, threshold float64, superUser SuperUser, dry bool) *SpamLocalFilter { + log.Printf("[INFO] Spam bot (local), threshold=%0.2f", threshold) + res := &SpamLocalFilter{dry: dry, approvedUsers: map[int64]bool{}, superUser: superUser, threshold: threshold} + scanner := bufio.NewScanner(spamSamples) + for scanner.Scan() { + res.spamMessages = append(res.spamMessages, scanner.Text()) + } + if err := scanner.Err(); err != nil { + log.Printf("[WARN] failed to read spam samples, error=%v", err) + res.enabled = false + } else { + res.enabled = true + } + return res +} + +// OnMessage checks if user already approved and if not checks if user is a spammer +func (s *SpamLocalFilter) OnMessage(msg Message) (response Response) { + if !s.enabled { + return Response{} + } + + if s.approvedUsers[msg.From.ID] { + return Response{} + } + + if s.superUser.IsSuper(msg.From.Username) { + return Response{} // don't check super users for spam + } + + if !s.isSpam(msg.Text) { + log.Printf("[INFO] user %s is not a spammer, added to aproved", msg.From.Username) + s.approvedUsers[msg.From.ID] = true + return Response{} // not a spam + } + + log.Printf("[INFO] user %s detected as spammer, msg: %q", msg.From.Username, msg.Text) + if s.dry { + return Response{ + Text: fmt.Sprintf("this is spam from %q, but I'm in dry mode, so I'll do nothing yet", msg.From.Username), + Send: true, ReplyTo: msg.ID, + } + } + return Response{Text: "this is spam! go to ban, " + msg.From.DisplayName, Send: true, ReplyTo: msg.ID, BanInterval: permanentBanDuration, DeleteReplyTo: true} +} + +// Help returns help message +func (s *SpamLocalFilter) Help() string { return "" } + +// ReactOn keys +func (s *SpamLocalFilter) ReactOn() []string { return []string{} } + +// isSpam checks if a given message is similar to any of the known bad messages. +func (s *SpamLocalFilter) isSpam(message string) bool { + tokenizedMessage := s.tokenize(message) + maxSimilarity := 0.0 + for _, spam := range s.spamMessages { + similarity := s.cosineSimilarity(tokenizedMessage, s.tokenize(spam)) + if similarity > maxSimilarity { + maxSimilarity = similarity + } + if similarity >= s.threshold { + return true + } + } + log.Printf("[DEBUG] spam similarity: %0.2f", maxSimilarity) + return false +} + +// tokenize takes a string and returns a map where the keys are unique words (tokens) +// and the values are the frequencies of those words in the string. +func (s *SpamLocalFilter) tokenize(inp string) map[string]int { + tokenFrequency := make(map[string]int) + tokens := strings.Fields(inp) + for _, token := range tokens { + tokenFrequency[strings.ToLower(token)]++ + } + return tokenFrequency +} + +// cosineSimilarity calculates the cosine similarity between two token frequency maps. +func (s *SpamLocalFilter) cosineSimilarity(a, b map[string]int) float64 { + if len(a) == 0 || len(b) == 0 { + return 0.0 + } + + dotProduct := 0 // sum of product of corresponding frequencies + normA, normB := 0, 0 // square root of sum of squares of frequencies + + for key, val := range a { + dotProduct += val * b[key] + normA += val * val + } + for _, val := range b { + normB += val * val + } + + if normA == 0 || normB == 0 { + return 0.0 + } + + // cosine similarity formula + return float64(dotProduct) / (math.Sqrt(float64(normA)) * math.Sqrt(float64(normB))) +} diff --git a/app/bot/spam_local_test.go b/app/bot/spam_local_test.go new file mode 100644 index 0000000..34fa133 --- /dev/null +++ b/app/bot/spam_local_test.go @@ -0,0 +1,70 @@ +package bot + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/radio-t/super-bot/app/bot/mocks" +) + +func TestSpamLocalFilter_OnMessage(t *testing.T) { + superUser := &mocks.SuperUser{IsSuperFunc: func(userName string) bool { + if userName == "super" || userName == "admin" { + return true + } + return false + }} + spamSamples := strings.NewReader("win free iPhone\nlottery prize") + + filter := NewSpamLocalFilter(spamSamples, 0.5, superUser, false) + + tests := []struct { + msg Message + expected Response + }{ + { + Message{From: User{ID: 1, Username: "john", DisplayName: "John"}, Text: "Hello, how are you?", ID: 1}, + Response{}, + }, + { + Message{From: User{ID: 2, Username: "spammer", DisplayName: "Spammer"}, Text: "Win a free iPhone now!", ID: 2}, + Response{Text: "this is spam! go to ban, Spammer", Send: true, ReplyTo: 2, BanInterval: permanentBanDuration, DeleteReplyTo: true}, + }, + { + Message{From: User{ID: 3, Username: "super", DisplayName: "SuperUser"}, Text: "Win a free iPhone now!", ID: 3}, + Response{}, + }, + } + + for _, test := range tests { + assert.Equal(t, test.expected, filter.OnMessage(test.msg)) + } +} + +func TestIsSpam(t *testing.T) { + spamSamples := strings.NewReader("win free iPhone\nlottery prize") + filter := NewSpamLocalFilter(spamSamples, 0.5, nil, false) // SuperUser set to nil for this test + + tests := []struct { + name string + message string + threshold float64 + expected bool + }{ + {"Not Spam", "Hello, how are you?", 0.5, false}, + {"Exact Match", "Win a free iPhone now!", 0.5, true}, + {"Similar Match", "You won a lottery prize!", 0.3, true}, + {"High Threshold", "You won a lottery prize!", 0.9, false}, + {"Partial Match", "win free", 0.9, false}, + {"Low Threshold", "win free", 0.8, true}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + filter.threshold = test.threshold // Update threshold for each test case + assert.Equal(t, test.expected, filter.isSpam(test.message)) + }) + } +} diff --git a/app/main.go b/app/main.go index c09460e..8ae7b75 100644 --- a/app/main.go +++ b/app/main.go @@ -47,10 +47,12 @@ var opts struct { ExportBroadcastUsers events.SuperUser `long:"broadcast" description:"broadcast-users"` SpamFilter struct { - Enabled bool `long:"enabled" env:"ENABLED" description:"enable spam filter"` - API string `long:"api" env:"CAS_API" default:"https://api.cas.chat" description:"CAS API"` - TimeOut time.Duration `long:"timeout" env:"TIMEOUT" default:"5s" description:"CAS timeout"` - Dry bool `long:"dry" env:"DRY" description:"dry mode, no bans"` + Enabled bool `long:"enabled" env:"ENABLED" description:"enable spam filter"` + API string `long:"api" env:"CAS_API" default:"https://api.cas.chat" description:"CAS API"` + TimeOut time.Duration `long:"timeout" env:"TIMEOUT" default:"5s" description:"CAS timeout"` + SpamSamples string `long:"spam-samples" env:"SPAM_SAMPLES" default:"" description:"path to spam samples"` + SpamThreshold float64 `long:"spam-threshold" env:"SPAM_THRESHOLD" default:"0.5" description:"spam threshold"` + Dry bool `long:"dry" env:"DRY" description:"dry mode, no bans"` } `group:"spam-filter" namespace:"spam-filter" env-namespace:"SPAM_FILTER"` OpenAI struct { @@ -144,7 +146,15 @@ func main() { log.Printf("[INFO] spam filter enabled, api=%s, timeout=%s, dry=%v", opts.SpamFilter.API, opts.SpamFilter.TimeOut, opts.SpamFilter.Dry) httpCasClient := &http.Client{Timeout: opts.SpamFilter.TimeOut} - multiBot = append(multiBot, bot.NewSpamFilter(opts.SpamFilter.API, httpCasClient, opts.SuperUsers, opts.SpamFilter.Dry)) + multiBot = append(multiBot, bot.NewSpamCasFilter(opts.SpamFilter.API, httpCasClient, opts.SuperUsers, opts.SpamFilter.Dry)) + if opts.SpamFilter.SpamSamples != "" { + spamFh, err := os.Open(opts.SpamFilter.SpamSamples) + if err != nil { + log.Fatalf("[ERROR] failed to open spam samples file %s, %v", opts.SpamFilter.SpamSamples, err) + } + multiBot = append(multiBot, bot.NewSpamLocalFilter(spamFh, opts.SpamFilter.SpamThreshold, + opts.SuperUsers, opts.SpamFilter.Dry)) + } } else { log.Print("[INFO] spam filter disabled") }