diff --git a/fail2ban.go b/fail2ban.go index e95c70b..c88b82f 100644 --- a/fail2ban.go +++ b/fail2ban.go @@ -14,13 +14,13 @@ import ( ) func init() { - caddy.RegisterModule(Middleware{}) + caddy.RegisterModule(Fail2Ban{}) // httpcaddyfile.RegisterHandlerDirective("visitor_ip", parseCaddyfile) } -// Middleware implements an HTTP handler that writes the +// Fail2Ban implements an HTTP handler that writes the // visitor's IP address to a file or stream. -type Middleware struct { +type Fail2Ban struct { // The file or stream to write to. Can be "stdout" // or "stderr". Output string `json:"output,omitempty"` @@ -33,20 +33,20 @@ type Middleware struct { } // CaddyModule returns the Caddy module information. -func (Middleware) CaddyModule() caddy.ModuleInfo { +func (Fail2Ban) CaddyModule() caddy.ModuleInfo { return caddy.ModuleInfo{ ID: "http.matchers.fail2ban", - New: func() caddy.Module { return new(Middleware) }, + New: func() caddy.Module { return new(Fail2Ban) }, } } // Provision implements caddy.Provisioner. -func (m *Middleware) Provision(ctx caddy.Context) error { +func (m *Fail2Ban) Provision(ctx caddy.Context) error { m.logger = ctx.Logger() return nil } -func (m *Middleware) getBannedIps() ([]string, error) { +func (m *Fail2Ban) getBannedIps() ([]string, error) { // Open banfile // Try to open file @@ -79,14 +79,14 @@ func (m *Middleware) getBannedIps() ([]string, error) { } // Validate implements caddy.Validator. -// func (m *Middleware) Validate() error { +// func (m *Fail2Ban) Validate() error { // // if m.w == nil { // // return fmt.Errorf("no writer") // // } // return nil // } -func (m *Middleware) Match(req *http.Request) bool { +func (m *Fail2Ban) Match(req *http.Request) bool { remote_ip, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { m.logger.Error("Error parsing remote addr into IP & port", zap.String("remote_addr", req.RemoteAddr), zap.Error(err)) @@ -121,7 +121,7 @@ func (m *Middleware) Match(req *http.Request) bool { } // UnmarshalCaddyfile implements caddyfile.Unmarshaler. -func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { +func (m *Fail2Ban) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { for d.Next() { switch v := d.Val(); v { case "fail2ban": @@ -139,15 +139,15 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { // parseCaddyfile unmarshals tokens from h into a new Middleware. // func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) { -// var m Middleware +// var m Fail2Ban // err := m.UnmarshalCaddyfile(h.Dispenser) // return m, err // } // Interface guards var ( - _ caddy.Provisioner = (*Middleware)(nil) - // _ caddy.Validator = (*Middleware)(nil) - _ caddyhttp.RequestMatcher = (*Middleware)(nil) - _ caddyfile.Unmarshaler = (*Middleware)(nil) + _ caddy.Provisioner = (*Fail2Ban)(nil) + // _ caddy.Validator = (*Fail2Ban)(nil) + _ caddyhttp.RequestMatcher = (*Fail2Ban)(nil) + _ caddyfile.Unmarshaler = (*Fail2Ban)(nil) ) diff --git a/fail2ban_test.go b/fail2ban_test.go new file mode 100644 index 0000000..aca80e3 --- /dev/null +++ b/fail2ban_test.go @@ -0,0 +1,135 @@ +package caddy_fail2ban + +import ( + "context" + "fmt" + "net/http/httptest" + "os" + "path" + "strings" + "testing" + + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" +) + +func setupTest(t *testing.T) (string, string) { + t.Helper() + tempDir, err := os.MkdirTemp("", "caddy-fail2ban-test") + if err != nil { + t.Fatalf("failed to create temporary directory: %v", err) + } + + fail2banFile := path.Join(tempDir, "banned-ips") + return tempDir, fail2banFile +} + +func cleanupTest(t *testing.T, tempDir string) { + t.Helper() + err := os.RemoveAll(tempDir) + if err != nil { + t.Fatalf("error removing temp directory: %v", err) + } +} + +func TestModule(t *testing.T) { + tempDir, fail2banFile := setupTest(t) + defer cleanupTest(t, tempDir) + + d := caddyfile.NewTestDispenser(fmt.Sprintf(`fail2ban %s`, fail2banFile)) + + m := Fail2Ban{} + err := m.UnmarshalCaddyfile(d) + if err != nil { + t.Errorf("unmarshal error: %v", err) + } + + ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) + defer cancel() + + err = m.Provision(ctx) + if err != nil { + t.Errorf("error provisioning: %v", err) + } + + req := httptest.NewRequest("GET", "https://127.0.0.1", strings.NewReader("")) + + if got, exp := m.Match(req), false; got != exp { + t.Errorf("unexpected match. got: %t, exp: %t", got, exp) + } + + bannedIps, err := m.getBannedIps() + if err != nil { + t.Errorf("error loading banned ips: %v", err) + } + + if got, exp := len(bannedIps), 0; got != exp { + t.Errorf("unexpected number of banned IPs. got: %d, exp: %d", got, exp) + } +} + +func TestHeaderBan(t *testing.T) { + tempDir, fail2banFile := setupTest(t) + defer cleanupTest(t, tempDir) + + d := caddyfile.NewTestDispenser(fmt.Sprintf(`fail2ban %s`, fail2banFile)) + + m := Fail2Ban{} + err := m.UnmarshalCaddyfile(d) + if err != nil { + t.Errorf("unmarshal error: %v", err) + } + + ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) + defer cancel() + + err = m.Provision(ctx) + if err != nil { + t.Errorf("error provisioning: %v", err) + } + + req := httptest.NewRequest("GET", "https://127.0.0.1", strings.NewReader("")) + req.Header.Add("X-Caddy-Ban", "1") + + if got, exp := m.Match(req), true; got != exp { + t.Errorf("unexpected match. got: %t, exp: %t", got, exp) + } +} + +func TestBanIp(t *testing.T) { + tempDir, fail2banFile := setupTest(t) + defer cleanupTest(t, tempDir) + + d := caddyfile.NewTestDispenser(fmt.Sprintf(`fail2ban %s`, fail2banFile)) + + m := Fail2Ban{} + err := m.UnmarshalCaddyfile(d) + if err != nil { + t.Errorf("unmarshal error: %v", err) + } + + ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) + defer cancel() + + err = m.Provision(ctx) + if err != nil { + t.Errorf("error provisioning: %v", err) + } + + req := httptest.NewRequest("GET", "https://127.0.0.1", strings.NewReader("")) + req.RemoteAddr = "127.0.0.1:1337" + + if m.Match(req) { + t.Errorf("IP banned unexpectedly") + } + + // ban IP + os.WriteFile(fail2banFile, []byte("127.0.0.1"), 0644) + + req = httptest.NewRequest("GET", "https://127.0.0.1", strings.NewReader("")) + req.RemoteAddr = "127.0.0.1:1337" + + if !m.Match(req) { + t.Errorf("IP should have been banned") + } +}