diff --git a/go.mod b/go.mod index af50c78..61b9601 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,10 @@ module github.com/hibare/GoGeoIP go 1.20 require ( - github.com/go-chi/chi/v5 v5.0.8 - github.com/go-chi/render v1.0.2 + github.com/go-chi/chi/v5 v5.0.10 github.com/go-playground/validator/v10 v10.14.1 github.com/google/uuid v1.3.0 - github.com/joho/godotenv v1.5.1 + github.com/hibare/GoCommon v0.0.4 github.com/oschwald/geoip2-golang v1.9.0 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e @@ -16,12 +15,12 @@ require ( ) require ( - github.com/ajg/form v1.5.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/joho/godotenv v1.5.1 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/oschwald/maxminddb-golang v1.11.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 5d3f8d4..b60d682 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,11 @@ -github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= -github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= -github.com/go-chi/chi/v5 v5.0.8 h1:lD+NLqFcAi1ovnVZpsnObHGW4xb4J8lNmoYVfECH1Y0= -github.com/go-chi/chi/v5 v5.0.8/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= -github.com/go-chi/render v1.0.2 h1:4ER/udB0+fMWB2Jlf15RV3F4A2FDuYi/9f+lFttR/Lg= -github.com/go-chi/render v1.0.2/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= +github.com/go-chi/chi/v5 v5.0.10 h1:rLz5avzKpjqxrYwXNfmjkrYYXOyLJd37pz53UFHC6vk= +github.com/go-chi/chi/v5 v5.0.10/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= @@ -19,6 +15,8 @@ github.com/go-playground/validator/v10 v10.14.1 h1:9c50NUPC30zyuKprjL3vNZ0m5oG+j github.com/go-playground/validator/v10 v10.14.1/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hibare/GoCommon v0.0.4 h1:NUv8o+clZQxGoShzUCY1LRdEgpPdKP0h21injTVo7AM= +github.com/hibare/GoCommon v0.0.4/go.mod h1:1YOvHY7UqXqAzF1SeczZkMyujqpzeiigAubv3XJguAw= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= diff --git a/internal/api/handler/geoip.go b/internal/api/handler/geoip.go index bbc2df2..b576fd3 100644 --- a/internal/api/handler/geoip.go +++ b/internal/api/handler/geoip.go @@ -7,7 +7,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/go-chi/chi/v5" - "github.com/go-chi/render" + commonErrors "github.com/hibare/GoCommon/pkg/errors" + commonHttp "github.com/hibare/GoCommon/pkg/http" "github.com/hibare/GoGeoIP/internal/constants" "github.com/hibare/GoGeoIP/internal/maxmind" ) @@ -20,15 +21,11 @@ func GeoIP(w http.ResponseWriter, r *http.Request) { if err != nil { log.Errorf("Error fetching record for ip %s, %s", ip, err) if errors.Is(err, constants.ErrInvalidIP) { - render.Status(r, http.StatusBadRequest) + commonHttp.WriteErrorResponse(w, http.StatusBadRequest, err) } else { - render.Status(r, http.StatusInternalServerError) + commonHttp.WriteErrorResponse(w, http.StatusInternalServerError, commonErrors.ErrInternalServerError) } - render.JSON(w, r, ErrorStruct{ - Message: err.Error(), - }) return } - - render.JSON(w, r, ipGeo) + commonHttp.WriteJsonResponse(w, http.StatusOK, ipGeo) } diff --git a/internal/api/handler/healthcheck.go b/internal/api/handler/healthcheck.go deleted file mode 100644 index cadb43e..0000000 --- a/internal/api/handler/healthcheck.go +++ /dev/null @@ -1,10 +0,0 @@ -package handler - -import ( - "encoding/json" - "net/http" -) - -func HealthCheck(response http.ResponseWriter, request *http.Request) { - json.NewEncoder(response).Encode(map[string]bool{"ok": true}) -} diff --git a/internal/api/handler/utils.go b/internal/api/handler/utils.go deleted file mode 100644 index eabc4f5..0000000 --- a/internal/api/handler/utils.go +++ /dev/null @@ -1,5 +0,0 @@ -package handler - -type ErrorStruct struct { - Message string `json:"message"` -} diff --git a/internal/api/middlewares/token_auth.go b/internal/api/middlewares/token_auth.go deleted file mode 100644 index 0fe3fb6..0000000 --- a/internal/api/middlewares/token_auth.go +++ /dev/null @@ -1,41 +0,0 @@ -package middlewares - -import ( - "net/http" - - "github.com/go-chi/render" - log "github.com/sirupsen/logrus" - - "github.com/hibare/GoGeoIP/internal/api/handler" - "github.com/hibare/GoGeoIP/internal/config" - "github.com/hibare/GoGeoIP/internal/constants" - "github.com/hibare/GoGeoIP/internal/utils" -) - -const AuthHeaderName = "Authorization" - -func TokenAuth(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Infof("Client: [%s] %s", r.RemoteAddr, r.RequestURI) - - apiKey := r.Header.Get(AuthHeaderName) - - if apiKey == "" { - render.Status(r, http.StatusUnauthorized) - render.JSON(w, r, handler.ErrorStruct{ - Message: constants.ErrUnauthorized.Error(), - }) - return - } - - if utils.SliceContains(apiKey, config.Current.API.APIKeys) { - next.ServeHTTP(w, r) - return - } - - render.Status(r, http.StatusUnauthorized) - render.JSON(w, r, handler.ErrorStruct{ - Message: constants.ErrUnauthorized.Error(), - }) - }) -} diff --git a/internal/api/server.go b/internal/api/server.go index 99638fd..cf098c9 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -13,8 +13,9 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-playground/validator/v10" + commonHandler "github.com/hibare/GoCommon/pkg/http/handler" + commonMiddleware "github.com/hibare/GoCommon/pkg/http/middleware" "github.com/hibare/GoGeoIP/internal/api/handler" - "github.com/hibare/GoGeoIP/internal/api/middlewares" "github.com/hibare/GoGeoIP/internal/config" ) @@ -32,7 +33,9 @@ func (a *App) setRouters() { func (a *App) Get(path string, f func(w http.ResponseWriter, r *http.Request), protected bool) { if protected { pr := chi.NewRouter() - pr.Use(middlewares.TokenAuth) + pr.Use(func(h http.Handler) http.Handler { + return commonMiddleware.TokenAuth(h, config.Current.API.APIKeys) + }) pr.Get(path, f) a.Router.Mount("/", pr) } else { @@ -56,7 +59,7 @@ func (a *App) Delete(path string, f func(w http.ResponseWriter, r *http.Request) } func (a *App) HealthCheck(w http.ResponseWriter, r *http.Request) { - handler.HealthCheck(w, r) + commonHandler.HealthCheck(w, r) } func (a *App) GeoIP(w http.ResponseWriter, r *http.Request) { diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 7935396..3aa1df8 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -7,8 +7,8 @@ import ( "os" "testing" - "github.com/hibare/GoGeoIP/internal/api/handler" - "github.com/hibare/GoGeoIP/internal/api/middlewares" + "github.com/hibare/GoCommon/pkg/errors" + commonMiddleware "github.com/hibare/GoCommon/pkg/http/middleware" "github.com/hibare/GoGeoIP/internal/config" "github.com/hibare/GoGeoIP/internal/constants" "github.com/hibare/GoGeoIP/internal/testhelper" @@ -28,62 +28,24 @@ func TestMain(m *testing.M) { os.Exit(code) } -func TestHealthCheckHandler(t *testing.T) { - testCases := []struct { - Name string - URL string - }{ - { - Name: "URL without trailing slash", - URL: "/api/v1/health", - }, { - Name: "URL with trailing slash", - URL: "/api/v1/health/", - }, - } - - for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - r, err := http.NewRequest("GET", tc.URL, nil) - - assert.NoError(t, err) - - w := httptest.NewRecorder() - - app.Router.ServeHTTP(w, r) - - assert.Equal(t, http.StatusOK, w.Code) - - expectedBody := map[string]bool{"ok": true} - responseBody := map[string]bool{} - - err = json.NewDecoder(w.Body).Decode(&responseBody) - - assert.NoError(t, err) - - if assert.NotNil(t, responseBody) { - assert.Equal(t, responseBody, expectedBody) - } - }) - } -} - func TestGeoIP400(t *testing.T) { testCases := []struct { Name string URL string - expectedBody handler.ErrorStruct + expectedBody errors.Error }{ { Name: "URL without trailing slash", URL: "/api/v1/ip/8.8.8", - expectedBody: handler.ErrorStruct{ + expectedBody: errors.Error{ + Code: http.StatusBadRequest, Message: constants.ErrInvalidIP.Error(), }, }, { Name: "URL with trailing slash", URL: "/api/v1/ip/8.8.8/", - expectedBody: handler.ErrorStruct{ + expectedBody: errors.Error{ + Code: http.StatusBadRequest, Message: constants.ErrInvalidIP.Error(), }, }, @@ -94,7 +56,7 @@ func TestGeoIP400(t *testing.T) { r, err := http.NewRequest("GET", tc.URL, nil) assert.NoError(t, err) - r.Header.Add(middlewares.AuthHeaderName, config.Current.API.APIKeys[0]) + r.Header.Add(commonMiddleware.AuthHeaderName, config.Current.API.APIKeys[0]) w := httptest.NewRecorder() @@ -102,7 +64,7 @@ func TestGeoIP400(t *testing.T) { assert.Equal(t, http.StatusBadRequest, w.Code) - responseBody := handler.ErrorStruct{} + responseBody := errors.Error{} err = json.NewDecoder(w.Body).Decode(&responseBody) assert.NoError(t, err) @@ -118,19 +80,21 @@ func TestGeoIP401(t *testing.T) { testCases := []struct { Name string URL string - expectedBody handler.ErrorStruct + expectedBody errors.Error }{ { Name: "URL without trailing slash", URL: "/api/v1/ip", - expectedBody: handler.ErrorStruct{ - Message: constants.ErrUnauthorized.Error(), + expectedBody: errors.Error{ + Code: http.StatusUnauthorized, + Message: errors.ErrUnauthorized.Error(), }, }, { Name: "URL with trailing slash", URL: "/api/v1/ip/", - expectedBody: handler.ErrorStruct{ - Message: constants.ErrUnauthorized.Error(), + expectedBody: errors.Error{ + Code: http.StatusUnauthorized, + Message: errors.ErrUnauthorized.Error(), }, }, } @@ -146,7 +110,7 @@ func TestGeoIP401(t *testing.T) { assert.Equal(t, http.StatusUnauthorized, w.Code) - responseBody := handler.ErrorStruct{} + responseBody := errors.Error{} err = json.NewDecoder(w.Body).Decode(&responseBody) assert.NoError(t, err) @@ -177,7 +141,7 @@ func TestGeoIP404(t *testing.T) { r, err := http.NewRequest("GET", tc.URL, nil) assert.NoError(t, err) - r.Header.Add(middlewares.AuthHeaderName, config.Current.API.APIKeys[0]) + r.Header.Add(commonMiddleware.AuthHeaderName, config.Current.API.APIKeys[0]) w := httptest.NewRecorder() @@ -209,7 +173,7 @@ func TestGeoIP500(t *testing.T) { r, err := http.NewRequest("GET", tc.URL, nil) assert.NoError(t, err) - r.Header.Add(middlewares.AuthHeaderName, config.Current.API.APIKeys[0]) + r.Header.Add(commonMiddleware.AuthHeaderName, config.Current.API.APIKeys[0]) w := httptest.NewRecorder() @@ -246,7 +210,7 @@ func TestGeoIP200(t *testing.T) { r, err := http.NewRequest("GET", tc.URL, nil) assert.NoError(t, err) - r.Header.Add(middlewares.AuthHeaderName, config.Current.API.APIKeys[0]) + r.Header.Add(commonMiddleware.AuthHeaderName, config.Current.API.APIKeys[0]) w := httptest.NewRecorder() diff --git a/internal/config/config.go b/internal/config/config.go index 6227b44..3f4f103 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,8 +8,8 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" + "github.com/hibare/GoCommon/pkg/env" "github.com/hibare/GoGeoIP/internal/constants" - "github.com/hibare/GoGeoIP/internal/env" ) type UtilConfig struct { diff --git a/internal/constants/constants.go b/internal/constants/constants.go index 7f99f1d..6c2dd92 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -17,7 +17,6 @@ const ( var ( ErrChecksumMismatch = errors.New("checksum Mismatch") ErrInvalidIP = errors.New("invalid IP") - ErrUnauthorized = errors.New("unauthorized") ) const ( diff --git a/internal/env/env.go b/internal/env/env.go deleted file mode 100644 index f008615..0000000 --- a/internal/env/env.go +++ /dev/null @@ -1,51 +0,0 @@ -package env - -import ( - "os" - "strconv" - "strings" - "time" - - "github.com/joho/godotenv" -) - -// Load loads an optional .env file -func Load() { - godotenv.Load() -} - -// MustString returns the content of the environment variable with the given key or the given fallback -func MustString(key, fallback string) string { - value, found := os.LookupEnv(key) - if !found { - return fallback - } - return value -} - -// MustBool uses MustString and parses it into a boolean -func MustBool(key string, fallback bool) bool { - parsed, _ := strconv.ParseBool(MustString(key, strconv.FormatBool(fallback))) - return parsed -} - -// MustInt uses MustString and parses it into an integer -func MustInt(key string, fallback int) int { - parsed, _ := strconv.Atoi(MustString(key, strconv.Itoa(fallback))) - return parsed -} - -// MustDuration uses MustString and parses it into a duration -func MustDuration(key string, fallback time.Duration) time.Duration { - parsed, _ := time.ParseDuration(MustString(key, fallback.String())) - return parsed -} - -// MustStringSlice uses MustString and parses it into a slice of strings -func MustStringSlice(key string, fallback []string) []string { - value := MustString(key, "") - if value == "" { - return fallback - } - return strings.Split(value, ",") -} diff --git a/internal/env/env_test.go b/internal/env/env_test.go deleted file mode 100644 index 32d03da..0000000 --- a/internal/env/env_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package env - -import ( - "os" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestEnv(t *testing.T) { - // Set Sample environment variables - os.Setenv("STRING_ENV", "test_string") - os.Setenv("BOOL_ENV", "true") - os.Setenv("INT_ENV", "100") - os.Setenv("DURATION_ENV", "24h") - os.Setenv("SLICE_ENV", "1,2,3") - - Load() - - assert.Equal(t, "test_string", MustString("STRING_ENV", "")) - assert.True(t, true, MustBool("BOOL_ENV", false)) - assert.Equal(t, 100, MustInt("INT_ENV", 0)) - assert.Equal(t, 24*time.Hour, MustDuration("DURATION_ENV", time.Duration(0))) - assert.Equal(t, []string{"1", "2", "3"}, MustStringSlice("SLICE_ENV", []string{})) - - assert.Equal(t, "default_string", MustString("DEFAULT_STRING_ENV", "default_string")) - assert.Equal(t, false, MustBool("DEFAULT_BOOL_ENV", false)) - assert.Equal(t, 1, MustInt("DEFAULT_INT_ENV", 1)) - assert.Equal(t, time.Duration(1), MustDuration("DEFAULT_DURATION_ENV", time.Duration(1))) - assert.Equal(t, []string{"1"}, MustStringSlice("DEFAULT_SLICE_ENV", []string{"1"})) - - // Unset Sample environment variables - os.Unsetenv("STRING_ENV") - os.Unsetenv("BOOL_ENV") - os.Unsetenv("INT_ENV") - os.Unsetenv("DURATION_ENV") - os.Unsetenv("SLICE_ENV") -} diff --git a/internal/utils/utils.go b/internal/utils/utils.go deleted file mode 100644 index 2dca3a0..0000000 --- a/internal/utils/utils.go +++ /dev/null @@ -1,11 +0,0 @@ -package utils - -// StringInSlice checks if a string is present in slice -func SliceContains[T comparable](a T, list []T) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go deleted file mode 100644 index bc2c25a..0000000 --- a/internal/utils/utils_test.go +++ /dev/null @@ -1,16 +0,0 @@ -package utils - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestStringInSlice(t *testing.T) { - assert.True(t, SliceContains("1", []string{"1", "2", "3"})) - assert.False(t, SliceContains("11", []string{"1", "2", "3"})) - assert.True(t, SliceContains(2, []int{1, 2, 3})) - assert.False(t, SliceContains(22, []int{1, 2, 3})) - assert.True(t, SliceContains(3.3, []float64{1.1, 2.2, 3.3})) - assert.False(t, SliceContains(33.3, []float64{1.1, 2.2, 3.3})) -}