diff --git a/libs/middleware/upgrade_required.go b/libs/middleware/upgrade_required.go new file mode 100644 index 000000000..bba7a0e60 --- /dev/null +++ b/libs/middleware/upgrade_required.go @@ -0,0 +1,31 @@ +package middleware + +import ( + "errors" + "net/http" + "time" + + "github.com/brave-intl/bat-go/libs/handlers" +) + +var ( + errUpgradeRequired = errors.New("upgrade required, cutoff exceeded") +) + +// NewUpgradeRequiredByMiddleware passes a service into the context +func NewUpgradeRequiredByMiddleware(cutoff time.Time) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if time.Now().Before(cutoff) { + next.ServeHTTP(w, r) + return + } + ae := handlers.AppError{ + Cause: errUpgradeRequired, + Message: "upgrade required, cutoff exceeded", + Code: http.StatusUpgradeRequired, + } + ae.ServeHTTP(w, r) + }) + } +} diff --git a/libs/middleware/upgrade_required_test.go b/libs/middleware/upgrade_required_test.go new file mode 100644 index 000000000..609b7baa8 --- /dev/null +++ b/libs/middleware/upgrade_required_test.go @@ -0,0 +1,26 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestUpgradeRequiredByMiddleware(t *testing.T) { + // after cutoff + wrappedHandler := NewUpgradeRequiredByMiddleware(time.Now().Add(-1 * time.Second))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + server := httptest.NewServer(wrappedHandler) + defer server.Close() + resp, _ := http.Get(server.URL) + assert.Equal(t, resp.StatusCode, http.StatusUpgradeRequired, "status code should be upgrade required") + + // not yet cutoff + wrappedHandler = NewUpgradeRequiredByMiddleware(time.Now().Add(1 * time.Second))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + server = httptest.NewServer(wrappedHandler) + defer server.Close() + resp, _ = http.Get(server.URL) + assert.Equal(t, resp.StatusCode, http.StatusOK, "status code should be OK") +} diff --git a/services/grant/cmd/grant.go b/services/grant/cmd/grant.go index 897cf8f78..a175a4b9a 100644 --- a/services/grant/cmd/grant.go +++ b/services/grant/cmd/grant.go @@ -393,17 +393,29 @@ func setupRouter(ctx context.Context, logger *zerolog.Logger) (context.Context, // add runnable jobs: jobs = append(jobs, promotionService.Jobs()...) - r.Mount("/v1/promotions", promotion.Router(promotionService)) - r.Mount("/v2/promotions", promotion.RouterV2(promotionService)) + // vbat expired var from env + vbatExpires, err := time.Parse(time.RFC3339, "2023-11-02T00:00:00Z") // default 11/2/23 + if err != nil { + logger.Panic().Err(err).Msg("failed to parse vbatExpires time") + } + if os.Getenv("VBAT_EXPIRES") != "" { // use what is in the environment if exists + vbatExpires, err = time.Parse(time.RFC3339, os.Getenv("VBAT_EXPIRES")) + if err != nil { + logger.Panic().Err(err).Msg("failed to parse vbatExpires time") + } + } + + r.Mount("/v1/promotions", promotion.Router(promotionService, vbatExpires)) + r.Mount("/v2/promotions", promotion.RouterV2(promotionService, vbatExpires)) - sRouter, err := promotion.SuggestionsRouter(promotionService) + sRouter, err := promotion.SuggestionsRouter(promotionService, vbatExpires) if err != nil { logger.Panic().Err(err).Msg("failed to initialize the suggestions router") } r.Mount("/v1/suggestions", sRouter) - sV2Router, err := promotion.SuggestionsV2Router(promotionService) + sV2Router, err := promotion.SuggestionsV2Router(promotionService, vbatExpires) if err != nil { logger.Panic().Err(err).Msg("failed to initialize the suggestions router") } @@ -411,7 +423,7 @@ func setupRouter(ctx context.Context, logger *zerolog.Logger) (context.Context, r.Mount("/v2/suggestions", sV2Router) // temporarily house batloss events in promotion to avoid widespread conflicts later - r.Mount("/v1/wallets", promotion.WalletEventRouter(promotionService)) + r.Mount("/v1/wallets", promotion.WalletEventRouter(promotionService, vbatExpires)) skuOrderRepo := repository.NewOrder() skuOrderItemRepo := repository.NewOrderItem() diff --git a/services/promotion/controllers.go b/services/promotion/controllers.go index 6c7a1f040..8caf43362 100644 --- a/services/promotion/controllers.go +++ b/services/promotion/controllers.go @@ -9,6 +9,7 @@ import ( "net/http" "os" "strconv" + "time" "github.com/asaskevich/govalidator" "github.com/brave-intl/bat-go/libs/clients" @@ -30,8 +31,9 @@ import ( ) // RouterV2 for promotion endpoints -func RouterV2(service *Service) chi.Router { +func RouterV2(service *Service, vbatExpires time.Time) chi.Router { r := chi.NewRouter() + r.Use(middleware.NewUpgradeRequiredByMiddleware(vbatExpires)) if os.Getenv("ENV") != "local" { r.Method("POST", "/", middleware.SimpleTokenAuthorizedOnly(CreatePromotion(service))) } else { @@ -45,8 +47,9 @@ func RouterV2(service *Service) chi.Router { } // Router for promotion endpoints -func Router(service *Service) chi.Router { +func Router(service *Service, vbatExpires time.Time) chi.Router { r := chi.NewRouter() + r.Use(middleware.NewUpgradeRequiredByMiddleware(vbatExpires)) if os.Getenv("ENV") != "local" { r.Method("POST", "/", middleware.SimpleTokenAuthorizedOnly(CreatePromotion(service))) } else { @@ -67,8 +70,9 @@ func Router(service *Service) chi.Router { } // SuggestionsV2Router for suggestions endpoints -func SuggestionsV2Router(service *Service) (chi.Router, error) { +func SuggestionsV2Router(service *Service, vbatExpires time.Time) (chi.Router, error) { r := chi.NewRouter() + r.Use(middleware.NewUpgradeRequiredByMiddleware(vbatExpires)) var ( enableLinkingDraining bool err error @@ -88,8 +92,9 @@ func SuggestionsV2Router(service *Service) (chi.Router, error) { } // SuggestionsRouter for suggestions endpoints -func SuggestionsRouter(service *Service) (chi.Router, error) { +func SuggestionsRouter(service *Service, vbatExpires time.Time) (chi.Router, error) { r := chi.NewRouter() + r.Use(middleware.NewUpgradeRequiredByMiddleware(vbatExpires)) r.Method("POST", "/", middleware.InstrumentHandler("MakeSuggestion", MakeSuggestion(service))) var ( @@ -111,8 +116,9 @@ func SuggestionsRouter(service *Service) (chi.Router, error) { } // WalletEventRouter for reporting bat loss events -func WalletEventRouter(service *Service) chi.Router { +func WalletEventRouter(service *Service, vbatExpires time.Time) chi.Router { r := chi.NewRouter() + r.Use(middleware.NewUpgradeRequiredByMiddleware(vbatExpires)) r.Method("POST", "/{walletId}/events/batloss/{reportId}", middleware.HTTPSignedOnly(service)(middleware.InstrumentHandler("PostReportWalletEvent", PostReportWalletEvent(service)))) return r }