From 243166c53cd7959a121f158ad93411b9cbda3d9d Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Mon, 29 Jan 2024 11:14:25 +0000 Subject: [PATCH] feat: add jwt validation to android webhook (#2318) * feat: add jwt validation to android webhook * test: add mock to android test and rename field --- services/skus/controllers.go | 122 ++++++++++++++---- services/skus/controllers_pvt_test.go | 176 ++++++++++++++++++++++++++ services/skus/controllers_test.go | 13 ++ services/skus/input.go | 5 +- services/skus/service.go | 42 ++++++ 5 files changed, 328 insertions(+), 30 deletions(-) create mode 100644 services/skus/controllers_pvt_test.go diff --git a/services/skus/controllers.go b/services/skus/controllers.go index a661b7aff..8c4f44db6 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -9,6 +9,7 @@ import ( "net/http" "os" "strconv" + "strings" "github.com/asaskevich/govalidator" "github.com/go-chi/chi" @@ -16,6 +17,7 @@ import ( uuid "github.com/satori/go.uuid" "github.com/stripe/stripe-go/v72" "github.com/stripe/stripe-go/v72/webhook" + "google.golang.org/api/idtoken" "github.com/brave-intl/bat-go/libs/clients/radom" appctx "github.com/brave-intl/bat-go/libs/context" @@ -948,59 +950,59 @@ func WebhookRouter(service *Service) chi.Router { // HandleAndroidWebhook is the handler for the Google Playstore webhooks func HandleAndroidWebhook(service *Service) handlers.AppHandler { return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { + ctx := r.Context() - var ( - ctx = r.Context() - req = new(AndroidNotification) - validationErrMap = map[string]interface{}{} // for tracking our validation errors - ) + l := logging.Logger(ctx, "payments").With().Str("func", "HandleAndroidWebhook").Logger() - // get logger - logger := logging.Logger(ctx, "payments").With(). - Str("func", "HandleAndroidWebhook"). - Logger() + if err := service.gcpValidator.validate(ctx, r); err != nil { + l.Error().Err(err).Msg("invalid request") + return handlers.WrapError(err, "invalid request", http.StatusUnauthorized) + } - // read the payload payload, err := requestutils.Read(r.Context(), r.Body) if err != nil { - logger.Error().Err(err).Msg("failed to read the payload") + l.Error().Err(err).Msg("failed to read payload") return handlers.WrapValidationError(err) } - // validate the payload - if err := inputs.DecodeAndValidate(context.Background(), req, payload); err != nil { - logger.Debug().Str("payload", string(payload)). - Msg("failed to decode and validate the payload") + l.Info().Str("payload", string(payload)).Msg("") + + var validationErrMap = map[string]interface{}{} + + var req AndroidNotification + if err := inputs.DecodeAndValidate(context.Background(), &req, payload); err != nil { validationErrMap["request-body-decode"] = err.Error() + l.Error().Interface("validation_map", validationErrMap).Msg("validation_error") + return handlers.ValidationError("Error validating request", validationErrMap) } - // extract out the Developer notification + l.Info().Interface("req", req).Msg("") + dn, err := req.Message.GetDeveloperNotification() if err != nil { validationErrMap["invalid-developer-notification"] = err.Error() + l.Error().Interface("validation_map", validationErrMap).Msg("validation_error") + return handlers.ValidationError("Error validating request", validationErrMap) } + l.Info().Interface("developer_notification", dn).Msg("") + if dn == nil || dn.SubscriptionNotification.PurchaseToken == "" { - logger.Error().Interface("validation-errors", validationErrMap). - Msg("failed to get developer notification from message") validationErrMap["invalid-developer-notification-token"] = "notification has no purchase token" + l.Error().Interface("validation_map", validationErrMap).Msg("validation_error") + return handlers.ValidationError("Error validating request", validationErrMap) } - // if we had any validation errors, return the validation error map to the caller - if len(validationErrMap) != 0 { - return handlers.ValidationError("Error validating request url", validationErrMap) - } + l.Info().Msg("verify_developer_notification") err = service.verifyDeveloperNotification(ctx, dn) if err != nil { - logger.Error().Err(err).Msg("failed to verify subscription notification") + l.Error().Err(err).Msg("failed to verify subscription notification") switch { case errors.Is(err, errNotFound): - return handlers.WrapError(err, "failed to verify subscription notification", - http.StatusNotFound) + return handlers.WrapError(err, "failed to verify subscription notification", http.StatusNotFound) default: - return handlers.WrapError(err, "failed to verify subscription notification", - http.StatusInternalServerError) + return handlers.WrapError(err, "failed to verify subscription notification", http.StatusInternalServerError) } } @@ -1008,6 +1010,72 @@ func HandleAndroidWebhook(service *Service) handlers.AppHandler { } } +const ( + errAuthHeaderEmpty model.Error = "skus: gcp authorization header is empty" + errAuthHeaderFormat model.Error = "skus: gcp authorization header invalid format" + errInvalidIssuer model.Error = "skus: gcp invalid issuer" + errInvalidEmail model.Error = "skus: gcp invalid email" + errEmailNotVerified model.Error = "skus: gcp email not verified" +) + +type gcpTokenValidator interface { + Validate(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) +} + +type gcpValidatorConfig struct { + audience string + issuer string + serviceAccount string + disabled bool +} + +type gcpPushNotificationValidator struct { + validator gcpTokenValidator + cfg gcpValidatorConfig +} + +func newGcpPushNotificationValidator(gcpTokenValidator gcpTokenValidator, cfg gcpValidatorConfig) *gcpPushNotificationValidator { + return &gcpPushNotificationValidator{ + validator: gcpTokenValidator, + cfg: cfg, + } +} + +func (g *gcpPushNotificationValidator) validate(ctx context.Context, r *http.Request) error { + if g.cfg.disabled { + return nil + } + + ah := r.Header.Get("Authorization") + if ah == "" { + return errAuthHeaderEmpty + } + + token := strings.Split(ah, " ") + if len(token) != 2 { + return errAuthHeaderFormat + } + + p, err := g.validator.Validate(ctx, token[1], g.cfg.audience) + if err != nil { + return fmt.Errorf("invalid authentication token: %w", err) + } + + if p.Issuer == "" || p.Issuer != g.cfg.issuer { + return errInvalidIssuer + } + + if p.Claims["email"] != g.cfg.serviceAccount { + return errInvalidEmail + } + + if p.Claims["email_verified"] != true { + return errEmailNotVerified + } + + return nil +} + // HandleIOSWebhook is the handler for ios iap webhooks func HandleIOSWebhook(service *Service) handlers.AppHandler { return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { diff --git a/services/skus/controllers_pvt_test.go b/services/skus/controllers_pvt_test.go new file mode 100644 index 000000000..0628726dc --- /dev/null +++ b/services/skus/controllers_pvt_test.go @@ -0,0 +1,176 @@ +package skus + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/api/idtoken" +) + +func TestNewGogglePushNotificationValidator_IsValid(t *testing.T) { + type tcGiven struct { + req *http.Request + cfg gcpValidatorConfig + tokenValidator gcpTokenValidator + } + + type testCase struct { + name string + given tcGiven + assertErr assert.ErrorAssertionFunc + } + + testCases := []testCase{ + { + name: "disabled", + given: tcGiven{ + cfg: gcpValidatorConfig{disabled: true}, + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) + }, + }, + { + name: "invalid_no_authorization_header", + given: tcGiven{ + req: newRequest(""), + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, errAuthHeaderEmpty) + }, + }, + { + name: "invalid_authorization_header_format", + given: tcGiven{ + req: newRequest("some-random-header-value"), + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, errAuthHeaderFormat) + }, + }, + { + name: "invalid_authentication_token", + given: tcGiven{ + req: newRequest("Bearer: some-token"), + tokenValidator: mockGcpTokenValidator{fnValidate: func(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) { + return nil, errors.New("error") + }}, + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorContains(t, err, "invalid authentication token: ") + }, + }, + { + name: "invalid_issuer_empty", + given: tcGiven{ + req: newRequest("Bearer: some-token"), + tokenValidator: mockGcpTokenValidator{fnValidate: func(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) { + return &idtoken.Payload{}, nil + }}, + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, errInvalidIssuer) + }, + }, + { + name: "invalid_issuer_not_equal", + given: tcGiven{ + req: newRequest("Bearer: some-token"), + cfg: gcpValidatorConfig{ + issuer: "issuer-1", + }, + tokenValidator: mockGcpTokenValidator{fnValidate: func(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) { + return &idtoken.Payload{Issuer: "issuer-2"}, nil + }}, + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, errInvalidIssuer) + }, + }, + { + name: "invalid_email", + given: tcGiven{ + req: newRequest("Bearer: some-token"), + cfg: gcpValidatorConfig{ + issuer: "issuer-1", + serviceAccount: "service-account-1", + }, + tokenValidator: mockGcpTokenValidator{fnValidate: func(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) { + issuer := "issuer-1" + claims := map[string]interface{}{"email": "service-account-2"} + return &idtoken.Payload{Issuer: issuer, Claims: claims}, nil + }}, + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, errInvalidEmail) + }, + }, + { + name: "invalid_email_not_verified", + given: tcGiven{ + req: newRequest("Bearer: some-token"), + cfg: gcpValidatorConfig{ + issuer: "issuer-1", + serviceAccount: "service-account-1", + }, + tokenValidator: mockGcpTokenValidator{fnValidate: func(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) { + issuer := "issuer-1" + claims := map[string]interface{}{"email": "service-account-1"} + return &idtoken.Payload{Issuer: issuer, Claims: claims}, nil + }}, + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, errEmailNotVerified) + }, + }, + { + name: "valid_request", + given: tcGiven{ + req: newRequest("Bearer: some-token"), + cfg: gcpValidatorConfig{ + issuer: "issuer-1", + serviceAccount: "service-account-1", + }, + tokenValidator: mockGcpTokenValidator{fnValidate: func(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) { + issuer := "issuer-1" + claims := map[string]interface{}{"email": "service-account-1", "email_verified": true} + return &idtoken.Payload{Issuer: issuer, Claims: claims}, nil + }}, + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) + }, + }, + } + + for i := range testCases { + tc := testCases[i] + + t.Run(tc.name, func(t *testing.T) { + v := newGcpPushNotificationValidator(tc.given.tokenValidator, tc.given.cfg) + actual := v.validate(context.TODO(), tc.given.req) + tc.assertErr(t, actual) + }) + } +} + +func newRequest(headerValue string) *http.Request { + r := httptest.NewRequest(http.MethodPost, "https://some-url.com", nil) + r.Header.Add("authorization", headerValue) + return r +} + +type mockGcpTokenValidator struct { + fnValidate func(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) +} + +func (m mockGcpTokenValidator) Validate(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) { + if m.fnValidate == nil { + return nil, nil + } + return m.fnValidate(ctx, idToken, audience) +} diff --git a/services/skus/controllers_test.go b/services/skus/controllers_test.go index eb37ebfaa..cee0f3b8b 100644 --- a/services/skus/controllers_test.go +++ b/services/skus/controllers_test.go @@ -379,6 +379,8 @@ func (suite *ControllersTestSuite) TestAndroidWebhook() { }, } + suite.service.gcpValidator = &mockGcpRequestValidator{} + handler := HandleAndroidWebhook(suite.service) // notification message @@ -1909,3 +1911,14 @@ func (v *mockVendorReceiptValidator) validateGoogle(ctx context.Context, receipt return v.fnValidateGoogle(ctx, receipt) } + +type mockGcpRequestValidator struct { + fnValidate func(ctx context.Context, r *http.Request) error +} + +func (m *mockGcpRequestValidator) validate(ctx context.Context, r *http.Request) error { + if m.fnValidate == nil { + return nil + } + return m.fnValidate(ctx, r) +} diff --git a/services/skus/input.go b/services/skus/input.go index de6615d6c..8f8a4902c 100644 --- a/services/skus/input.go +++ b/services/skus/input.go @@ -261,19 +261,18 @@ func (anm *AndroidNotificationMessage) Validate(ctx context.Context) error { // GetDeveloperNotification - Extract the developer notification from the android notification message func (anm *AndroidNotificationMessage) GetDeveloperNotification() (*DeveloperNotification, error) { - var devNotification = new(DeveloperNotification) buf := make([]byte, base64.StdEncoding.DecodedLen(len([]byte(anm.Data)))) - // base64 decode the bytes n, err := base64.StdEncoding.Decode(buf, []byte(anm.Data)) if err != nil { return nil, fmt.Errorf("failed to decode input base64: %w", err) } - // read the json values + if err := json.Unmarshal(buf[:n], devNotification); err != nil { return nil, fmt.Errorf("failed to decode input json: %w", err) } + return devNotification, nil } diff --git a/services/skus/service.go b/services/skus/service.go index a93607bbc..ee8089674 100644 --- a/services/skus/service.go +++ b/services/skus/service.go @@ -10,6 +10,7 @@ import ( "net/http" "net/url" "os" + "strconv" "strings" "sync" "time" @@ -27,6 +28,7 @@ import ( "github.com/stripe/stripe-go/v72/checkout/session" "github.com/stripe/stripe-go/v72/client" "github.com/stripe/stripe-go/v72/sub" + "google.golang.org/api/idtoken" appctx "github.com/brave-intl/bat-go/libs/context" errorutils "github.com/brave-intl/bat-go/libs/errors" @@ -96,6 +98,10 @@ type vendorReceiptValidator interface { validateGoogle(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) } +type gcpRequestValidator interface { + validate(ctx context.Context, r *http.Request) error +} + // Service contains datastore type Service struct { orderRepo orderStoreSvc @@ -120,6 +126,7 @@ type Service struct { radomSellerAddress string vendorReceiptValid vendorReceiptValidator + gcpValidator gcpRequestValidator } // PauseWorker - pause worker until time specified @@ -249,6 +256,40 @@ func InitService(ctx context.Context, datastore Datastore, walletService *wallet return nil, err } + idv, err := idtoken.NewValidator(ctx) + if err != nil { + return nil, err + } + + disabled, _ := strconv.ParseBool(os.Getenv("GCP_PUSH_NOTIFICATION")) + if disabled { + sublogger.Warn().Msg("gcp push notification is disabled") + } + + aud := os.Getenv("GCP_PUSH_SUBSCRIPTION_AUDIENCE") + if aud == "" { + sublogger.Warn().Msg("gcp push subscription audience is empty") + } + + iss := os.Getenv("GCP_CERT_ISSUER") + if iss == "" { + sublogger.Warn().Msg("gcp cert issuer is empty") + } + + sa := os.Getenv("GCP_PUSH_SUBSCRIPTION_SERVICE_ACCOUNT") + if sa == "" { + sublogger.Warn().Msg("gcp push subscription service account is empty") + } + + conf := gcpValidatorConfig{ + audience: aud, + issuer: iss, + serviceAccount: sa, + disabled: disabled, + } + + gcpValidator := newGcpPushNotificationValidator(idv, conf) + service := &Service{ orderRepo: orderRepo, issuerRepo: issuerRepo, @@ -264,6 +305,7 @@ func InitService(ctx context.Context, datastore Datastore, walletService *wallet radomClient: radomClient, radomSellerAddress: radomSellerAddress, vendorReceiptValid: rcptValidator, + gcpValidator: gcpValidator, } // setup runnable jobs