Skip to content

Commit

Permalink
feat: add jwt validation to android webhook (#2318)
Browse files Browse the repository at this point in the history
* feat: add jwt validation to android webhook

* test: add mock to android test and rename field
  • Loading branch information
clD11 authored Jan 29, 2024
1 parent 972a597 commit 243166c
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 30 deletions.
122 changes: 95 additions & 27 deletions services/skus/controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ import (
"net/http"
"os"
"strconv"
"strings"

"github.com/asaskevich/govalidator"
"github.com/go-chi/chi"
"github.com/go-chi/cors"
uuid "github.com/satori/go.uuid"
"github.com/stripe/stripe-go/v72"
"github.com/stripe/stripe-go/v72/webhook"
"google.golang.org/api/idtoken"

"github.com/brave-intl/bat-go/libs/clients/radom"
appctx "github.com/brave-intl/bat-go/libs/context"
Expand Down Expand Up @@ -948,66 +950,132 @@ func WebhookRouter(service *Service) chi.Router {
// HandleAndroidWebhook is the handler for the Google Playstore webhooks
func HandleAndroidWebhook(service *Service) handlers.AppHandler {
return func(w http.ResponseWriter, r *http.Request) *handlers.AppError {
ctx := r.Context()

var (
ctx = r.Context()
req = new(AndroidNotification)
validationErrMap = map[string]interface{}{} // for tracking our validation errors
)
l := logging.Logger(ctx, "payments").With().Str("func", "HandleAndroidWebhook").Logger()

// get logger
logger := logging.Logger(ctx, "payments").With().
Str("func", "HandleAndroidWebhook").
Logger()
if err := service.gcpValidator.validate(ctx, r); err != nil {
l.Error().Err(err).Msg("invalid request")
return handlers.WrapError(err, "invalid request", http.StatusUnauthorized)
}

// read the payload
payload, err := requestutils.Read(r.Context(), r.Body)
if err != nil {
logger.Error().Err(err).Msg("failed to read the payload")
l.Error().Err(err).Msg("failed to read payload")
return handlers.WrapValidationError(err)
}

// validate the payload
if err := inputs.DecodeAndValidate(context.Background(), req, payload); err != nil {
logger.Debug().Str("payload", string(payload)).
Msg("failed to decode and validate the payload")
l.Info().Str("payload", string(payload)).Msg("")

var validationErrMap = map[string]interface{}{}

var req AndroidNotification
if err := inputs.DecodeAndValidate(context.Background(), &req, payload); err != nil {
validationErrMap["request-body-decode"] = err.Error()
l.Error().Interface("validation_map", validationErrMap).Msg("validation_error")
return handlers.ValidationError("Error validating request", validationErrMap)
}

// extract out the Developer notification
l.Info().Interface("req", req).Msg("")

dn, err := req.Message.GetDeveloperNotification()
if err != nil {
validationErrMap["invalid-developer-notification"] = err.Error()
l.Error().Interface("validation_map", validationErrMap).Msg("validation_error")
return handlers.ValidationError("Error validating request", validationErrMap)
}

l.Info().Interface("developer_notification", dn).Msg("")

if dn == nil || dn.SubscriptionNotification.PurchaseToken == "" {
logger.Error().Interface("validation-errors", validationErrMap).
Msg("failed to get developer notification from message")
validationErrMap["invalid-developer-notification-token"] = "notification has no purchase token"
l.Error().Interface("validation_map", validationErrMap).Msg("validation_error")
return handlers.ValidationError("Error validating request", validationErrMap)
}

// if we had any validation errors, return the validation error map to the caller
if len(validationErrMap) != 0 {
return handlers.ValidationError("Error validating request url", validationErrMap)
}
l.Info().Msg("verify_developer_notification")

err = service.verifyDeveloperNotification(ctx, dn)
if err != nil {
logger.Error().Err(err).Msg("failed to verify subscription notification")
l.Error().Err(err).Msg("failed to verify subscription notification")
switch {
case errors.Is(err, errNotFound):
return handlers.WrapError(err, "failed to verify subscription notification",
http.StatusNotFound)
return handlers.WrapError(err, "failed to verify subscription notification", http.StatusNotFound)
default:
return handlers.WrapError(err, "failed to verify subscription notification",
http.StatusInternalServerError)
return handlers.WrapError(err, "failed to verify subscription notification", http.StatusInternalServerError)
}
}

return handlers.RenderContent(ctx, "event received", w, http.StatusOK)
}
}

const (
errAuthHeaderEmpty model.Error = "skus: gcp authorization header is empty"
errAuthHeaderFormat model.Error = "skus: gcp authorization header invalid format"
errInvalidIssuer model.Error = "skus: gcp invalid issuer"
errInvalidEmail model.Error = "skus: gcp invalid email"
errEmailNotVerified model.Error = "skus: gcp email not verified"
)

type gcpTokenValidator interface {
Validate(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error)
}

type gcpValidatorConfig struct {
audience string
issuer string
serviceAccount string
disabled bool
}

type gcpPushNotificationValidator struct {
validator gcpTokenValidator
cfg gcpValidatorConfig
}

func newGcpPushNotificationValidator(gcpTokenValidator gcpTokenValidator, cfg gcpValidatorConfig) *gcpPushNotificationValidator {
return &gcpPushNotificationValidator{
validator: gcpTokenValidator,
cfg: cfg,
}
}

func (g *gcpPushNotificationValidator) validate(ctx context.Context, r *http.Request) error {
if g.cfg.disabled {
return nil
}

ah := r.Header.Get("Authorization")
if ah == "" {
return errAuthHeaderEmpty
}

token := strings.Split(ah, " ")
if len(token) != 2 {
return errAuthHeaderFormat
}

p, err := g.validator.Validate(ctx, token[1], g.cfg.audience)
if err != nil {
return fmt.Errorf("invalid authentication token: %w", err)
}

if p.Issuer == "" || p.Issuer != g.cfg.issuer {
return errInvalidIssuer
}

if p.Claims["email"] != g.cfg.serviceAccount {
return errInvalidEmail
}

if p.Claims["email_verified"] != true {
return errEmailNotVerified
}

return nil
}

// HandleIOSWebhook is the handler for ios iap webhooks
func HandleIOSWebhook(service *Service) handlers.AppHandler {
return func(w http.ResponseWriter, r *http.Request) *handlers.AppError {
Expand Down
176 changes: 176 additions & 0 deletions services/skus/controllers_pvt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package skus

import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"google.golang.org/api/idtoken"
)

func TestNewGogglePushNotificationValidator_IsValid(t *testing.T) {
type tcGiven struct {
req *http.Request
cfg gcpValidatorConfig
tokenValidator gcpTokenValidator
}

type testCase struct {
name string
given tcGiven
assertErr assert.ErrorAssertionFunc
}

testCases := []testCase{
{
name: "disabled",
given: tcGiven{
cfg: gcpValidatorConfig{disabled: true},
},
assertErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.NoError(t, err)
},
},
{
name: "invalid_no_authorization_header",
given: tcGiven{
req: newRequest(""),
},
assertErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, errAuthHeaderEmpty)
},
},
{
name: "invalid_authorization_header_format",
given: tcGiven{
req: newRequest("some-random-header-value"),
},
assertErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, errAuthHeaderFormat)
},
},
{
name: "invalid_authentication_token",
given: tcGiven{
req: newRequest("Bearer: some-token"),
tokenValidator: mockGcpTokenValidator{fnValidate: func(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) {
return nil, errors.New("error")
}},
},
assertErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorContains(t, err, "invalid authentication token: ")
},
},
{
name: "invalid_issuer_empty",
given: tcGiven{
req: newRequest("Bearer: some-token"),
tokenValidator: mockGcpTokenValidator{fnValidate: func(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) {
return &idtoken.Payload{}, nil
}},
},
assertErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, errInvalidIssuer)
},
},
{
name: "invalid_issuer_not_equal",
given: tcGiven{
req: newRequest("Bearer: some-token"),
cfg: gcpValidatorConfig{
issuer: "issuer-1",
},
tokenValidator: mockGcpTokenValidator{fnValidate: func(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) {
return &idtoken.Payload{Issuer: "issuer-2"}, nil
}},
},
assertErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, errInvalidIssuer)
},
},
{
name: "invalid_email",
given: tcGiven{
req: newRequest("Bearer: some-token"),
cfg: gcpValidatorConfig{
issuer: "issuer-1",
serviceAccount: "service-account-1",
},
tokenValidator: mockGcpTokenValidator{fnValidate: func(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) {
issuer := "issuer-1"
claims := map[string]interface{}{"email": "service-account-2"}
return &idtoken.Payload{Issuer: issuer, Claims: claims}, nil
}},
},
assertErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, errInvalidEmail)
},
},
{
name: "invalid_email_not_verified",
given: tcGiven{
req: newRequest("Bearer: some-token"),
cfg: gcpValidatorConfig{
issuer: "issuer-1",
serviceAccount: "service-account-1",
},
tokenValidator: mockGcpTokenValidator{fnValidate: func(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) {
issuer := "issuer-1"
claims := map[string]interface{}{"email": "service-account-1"}
return &idtoken.Payload{Issuer: issuer, Claims: claims}, nil
}},
},
assertErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, errEmailNotVerified)
},
},
{
name: "valid_request",
given: tcGiven{
req: newRequest("Bearer: some-token"),
cfg: gcpValidatorConfig{
issuer: "issuer-1",
serviceAccount: "service-account-1",
},
tokenValidator: mockGcpTokenValidator{fnValidate: func(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) {
issuer := "issuer-1"
claims := map[string]interface{}{"email": "service-account-1", "email_verified": true}
return &idtoken.Payload{Issuer: issuer, Claims: claims}, nil
}},
},
assertErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.NoError(t, err)
},
},
}

for i := range testCases {
tc := testCases[i]

t.Run(tc.name, func(t *testing.T) {
v := newGcpPushNotificationValidator(tc.given.tokenValidator, tc.given.cfg)
actual := v.validate(context.TODO(), tc.given.req)
tc.assertErr(t, actual)
})
}
}

func newRequest(headerValue string) *http.Request {
r := httptest.NewRequest(http.MethodPost, "https://some-url.com", nil)
r.Header.Add("authorization", headerValue)
return r
}

type mockGcpTokenValidator struct {
fnValidate func(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error)
}

func (m mockGcpTokenValidator) Validate(ctx context.Context, idToken string, audience string) (*idtoken.Payload, error) {
if m.fnValidate == nil {
return nil, nil
}
return m.fnValidate(ctx, idToken, audience)
}
13 changes: 13 additions & 0 deletions services/skus/controllers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ func (suite *ControllersTestSuite) TestAndroidWebhook() {
},
}

suite.service.gcpValidator = &mockGcpRequestValidator{}

handler := HandleAndroidWebhook(suite.service)

// notification message
Expand Down Expand Up @@ -1909,3 +1911,14 @@ func (v *mockVendorReceiptValidator) validateGoogle(ctx context.Context, receipt

return v.fnValidateGoogle(ctx, receipt)
}

type mockGcpRequestValidator struct {
fnValidate func(ctx context.Context, r *http.Request) error
}

func (m *mockGcpRequestValidator) validate(ctx context.Context, r *http.Request) error {
if m.fnValidate == nil {
return nil
}
return m.fnValidate(ctx, r)
}
5 changes: 2 additions & 3 deletions services/skus/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,19 +261,18 @@ func (anm *AndroidNotificationMessage) Validate(ctx context.Context) error {

// GetDeveloperNotification - Extract the developer notification from the android notification message
func (anm *AndroidNotificationMessage) GetDeveloperNotification() (*DeveloperNotification, error) {

var devNotification = new(DeveloperNotification)
buf := make([]byte, base64.StdEncoding.DecodedLen(len([]byte(anm.Data))))

// base64 decode the bytes
n, err := base64.StdEncoding.Decode(buf, []byte(anm.Data))
if err != nil {
return nil, fmt.Errorf("failed to decode input base64: %w", err)
}
// read the json values

if err := json.Unmarshal(buf[:n], devNotification); err != nil {
return nil, fmt.Errorf("failed to decode input json: %w", err)
}

return devNotification, nil
}

Expand Down
Loading

0 comments on commit 243166c

Please sign in to comment.