Skip to content

Commit

Permalink
Add basic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Javex committed May 22, 2024
1 parent 496b625 commit e1fa157
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 15 deletions.
30 changes: 15 additions & 15 deletions fail2ban.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ import (
)

func init() {
caddy.RegisterModule(Middleware{})
caddy.RegisterModule(Fail2Ban{})
// httpcaddyfile.RegisterHandlerDirective("visitor_ip", parseCaddyfile)
}

// Middleware implements an HTTP handler that writes the
// Fail2Ban implements an HTTP handler that writes the
// visitor's IP address to a file or stream.
type Middleware struct {
type Fail2Ban struct {
// The file or stream to write to. Can be "stdout"
// or "stderr".
Output string `json:"output,omitempty"`
Expand All @@ -33,20 +33,20 @@ type Middleware struct {
}

// CaddyModule returns the Caddy module information.
func (Middleware) CaddyModule() caddy.ModuleInfo {
func (Fail2Ban) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
ID: "http.matchers.fail2ban",
New: func() caddy.Module { return new(Middleware) },
New: func() caddy.Module { return new(Fail2Ban) },
}
}

// Provision implements caddy.Provisioner.
func (m *Middleware) Provision(ctx caddy.Context) error {
func (m *Fail2Ban) Provision(ctx caddy.Context) error {
m.logger = ctx.Logger()
return nil
}

func (m *Middleware) getBannedIps() ([]string, error) {
func (m *Fail2Ban) getBannedIps() ([]string, error) {

// Open banfile
// Try to open file
Expand Down Expand Up @@ -79,14 +79,14 @@ func (m *Middleware) getBannedIps() ([]string, error) {
}

// Validate implements caddy.Validator.
// func (m *Middleware) Validate() error {
// func (m *Fail2Ban) Validate() error {
// // if m.w == nil {
// // return fmt.Errorf("no writer")
// // }
// return nil
// }

func (m *Middleware) Match(req *http.Request) bool {
func (m *Fail2Ban) Match(req *http.Request) bool {
remote_ip, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
m.logger.Error("Error parsing remote addr into IP & port", zap.String("remote_addr", req.RemoteAddr), zap.Error(err))
Expand Down Expand Up @@ -121,7 +121,7 @@ func (m *Middleware) Match(req *http.Request) bool {
}

// UnmarshalCaddyfile implements caddyfile.Unmarshaler.
func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
func (m *Fail2Ban) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
for d.Next() {
switch v := d.Val(); v {
case "fail2ban":
Expand All @@ -139,15 +139,15 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {

// parseCaddyfile unmarshals tokens from h into a new Middleware.
// func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
// var m Middleware
// var m Fail2Ban
// err := m.UnmarshalCaddyfile(h.Dispenser)
// return m, err
// }

// Interface guards
var (
_ caddy.Provisioner = (*Middleware)(nil)
// _ caddy.Validator = (*Middleware)(nil)
_ caddyhttp.RequestMatcher = (*Middleware)(nil)
_ caddyfile.Unmarshaler = (*Middleware)(nil)
_ caddy.Provisioner = (*Fail2Ban)(nil)
// _ caddy.Validator = (*Fail2Ban)(nil)
_ caddyhttp.RequestMatcher = (*Fail2Ban)(nil)
_ caddyfile.Unmarshaler = (*Fail2Ban)(nil)
)
135 changes: 135 additions & 0 deletions fail2ban_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package caddy_fail2ban

import (
"context"
"fmt"
"net/http/httptest"
"os"
"path"
"strings"
"testing"

"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
)

func setupTest(t *testing.T) (string, string) {
t.Helper()
tempDir, err := os.MkdirTemp("", "caddy-fail2ban-test")
if err != nil {
t.Fatalf("failed to create temporary directory: %v", err)
}

fail2banFile := path.Join(tempDir, "banned-ips")
return tempDir, fail2banFile
}

func cleanupTest(t *testing.T, tempDir string) {
t.Helper()
err := os.RemoveAll(tempDir)
if err != nil {
t.Fatalf("error removing temp directory: %v", err)
}
}

func TestModule(t *testing.T) {
tempDir, fail2banFile := setupTest(t)
defer cleanupTest(t, tempDir)

d := caddyfile.NewTestDispenser(fmt.Sprintf(`fail2ban %s`, fail2banFile))

m := Fail2Ban{}
err := m.UnmarshalCaddyfile(d)
if err != nil {
t.Errorf("unmarshal error: %v", err)
}

ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
defer cancel()

err = m.Provision(ctx)
if err != nil {
t.Errorf("error provisioning: %v", err)
}

req := httptest.NewRequest("GET", "https://127.0.0.1", strings.NewReader(""))

if got, exp := m.Match(req), false; got != exp {
t.Errorf("unexpected match. got: %t, exp: %t", got, exp)
}

bannedIps, err := m.getBannedIps()
if err != nil {
t.Errorf("error loading banned ips: %v", err)
}

if got, exp := len(bannedIps), 0; got != exp {
t.Errorf("unexpected number of banned IPs. got: %d, exp: %d", got, exp)
}
}

func TestHeaderBan(t *testing.T) {
tempDir, fail2banFile := setupTest(t)
defer cleanupTest(t, tempDir)

d := caddyfile.NewTestDispenser(fmt.Sprintf(`fail2ban %s`, fail2banFile))

m := Fail2Ban{}
err := m.UnmarshalCaddyfile(d)
if err != nil {
t.Errorf("unmarshal error: %v", err)
}

ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
defer cancel()

err = m.Provision(ctx)
if err != nil {
t.Errorf("error provisioning: %v", err)
}

req := httptest.NewRequest("GET", "https://127.0.0.1", strings.NewReader(""))
req.Header.Add("X-Caddy-Ban", "1")

if got, exp := m.Match(req), true; got != exp {
t.Errorf("unexpected match. got: %t, exp: %t", got, exp)
}
}

func TestBanIp(t *testing.T) {
tempDir, fail2banFile := setupTest(t)
defer cleanupTest(t, tempDir)

d := caddyfile.NewTestDispenser(fmt.Sprintf(`fail2ban %s`, fail2banFile))

m := Fail2Ban{}
err := m.UnmarshalCaddyfile(d)
if err != nil {
t.Errorf("unmarshal error: %v", err)
}

ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
defer cancel()

err = m.Provision(ctx)
if err != nil {
t.Errorf("error provisioning: %v", err)
}

req := httptest.NewRequest("GET", "https://127.0.0.1", strings.NewReader(""))
req.RemoteAddr = "127.0.0.1:1337"

if m.Match(req) {
t.Errorf("IP banned unexpectedly")
}

// ban IP
os.WriteFile(fail2banFile, []byte("127.0.0.1"), 0644)

req = httptest.NewRequest("GET", "https://127.0.0.1", strings.NewReader(""))
req.RemoteAddr = "127.0.0.1:1337"

if !m.Match(req) {
t.Errorf("IP should have been banned")
}
}

0 comments on commit e1fa157

Please sign in to comment.