Skip to content

Commit

Permalink
fix: handle play store subscription notifications
Browse files Browse the repository at this point in the history
  • Loading branch information
clD11 committed Feb 2, 2024
1 parent a32650a commit ecbad0f
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 172 deletions.
2 changes: 1 addition & 1 deletion libs/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func WrapError(err error, msg string, passedCode int) *AppError {
}

// RenderContent based on the header
func RenderContent(ctx context.Context, v interface{}, w http.ResponseWriter, status int) *AppError {
func RenderContent(_ context.Context, v interface{}, w http.ResponseWriter, status int) *AppError {
switch w.Header().Get("content-type") {
case "application/json":
var b bytes.Buffer
Expand Down
6 changes: 4 additions & 2 deletions services/grant/cmd/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/getsentry/sentry-go"
"github.com/go-chi/chi"
chiware "github.com/go-chi/chi/middleware"
"github.com/go-playground/validator/v10"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -468,6 +469,7 @@ func setupRouter(ctx context.Context, logger *zerolog.Logger) (context.Context,

// initialize skus service keys for credentials to use
skus.InitEncryptionKeys()
valid := validator.New()

{
origins := strings.Split(os.Getenv("ALLOWED_ORIGINS"), ",")
Expand All @@ -478,7 +480,7 @@ func setupRouter(ctx context.Context, logger *zerolog.Logger) (context.Context,

r.Mount("/v1/credentials", skus.CredentialRouter(skusService, authMwr))
r.Mount("/v2/credentials", skus.CredentialV2Router(skusService, authMwr))
r.Mount("/v1/orders", skus.Router(skusService, authMwr, middleware.InstrumentHandler, corsOpts))
r.Mount("/v1/orders", skus.Router(skusService, authMwr, middleware.InstrumentHandler, corsOpts, valid))

subr := chi.NewRouter()
orderh := handler.NewOrder(skusService)
Expand Down Expand Up @@ -511,7 +513,7 @@ func setupRouter(ctx context.Context, logger *zerolog.Logger) (context.Context,
r.Mount("/v1/orders-new", subr)
}

r.Mount("/v1/webhooks", skus.WebhookRouter(skusService))
r.Mount("/v1/webhooks", skus.WebhookRouter(skusService, valid))
r.Mount("/v1/votes", skus.VoteRouter(skusService, middleware.InstrumentHandler))

// add profiling flag to enable profiling routes
Expand Down
121 changes: 71 additions & 50 deletions services/skus/controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func Router(
authMwr middlewareFn,
metricsMwr middleware.InstrumentHandlerDef,
copts cors.Options,
valid *validator.Validate,
) chi.Router {
r := chi.NewRouter()

Expand Down Expand Up @@ -100,14 +101,9 @@ func Router(
metricsMwr("CreateAnonCardTransaction", CreateAnonCardTransaction(svc)),
)

// Receipt validation.
{
valid := validator.New()

r.Method(http.MethodPost, "/{orderID}/submit-receipt", metricsMwr("SubmitReceipt", corsMwrPost(handleSubmitReceipt(svc, valid))))
r.Method(http.MethodPost, "/receipt", metricsMwr("createOrderFromReceipt", corsMwrPost(handleCreateOrderFromReceipt(svc, valid))))
r.Method(http.MethodPost, "/{orderID}/receipt", metricsMwr("checkOrderReceipt", authMwr(handleCheckOrderReceipt(svc, valid))))
}
r.Method(http.MethodPost, "/{orderID}/submit-receipt", metricsMwr("SubmitReceipt", corsMwrPost(handleSubmitReceipt(svc, valid))))
r.Method(http.MethodPost, "/receipt", metricsMwr("createOrderFromReceipt", corsMwrPost(handleCreateOrderFromReceipt(svc, valid))))
r.Method(http.MethodPost, "/{orderID}/receipt", metricsMwr("checkOrderReceipt", authMwr(handleCheckOrderReceipt(svc, valid))))

r.Route("/{orderID}/credentials", func(cr chi.Router) {
cr.Use(NewCORSMwr(copts, http.MethodGet, http.MethodPost))
Expand Down Expand Up @@ -993,76 +989,101 @@ func VerifyCredentialV1(service *Service) handlers.AppHandler {
}

// WebhookRouter - handles calls from various payment method webhooks informing payments of completion
func WebhookRouter(service *Service) chi.Router {
func WebhookRouter(service *Service, valid *validator.Validate) chi.Router {
r := chi.NewRouter()
r.Method("POST", "/stripe", middleware.InstrumentHandler("HandleStripeWebhook", HandleStripeWebhook(service)))
r.Method("POST", "/radom", middleware.InstrumentHandler("HandleRadomWebhook", HandleRadomWebhook(service)))
r.Method("POST", "/android", middleware.InstrumentHandler("HandleAndroidWebhook", HandleAndroidWebhook(service)))
r.Method("POST", "/android", middleware.InstrumentHandler("HandleAndroidWebhook", handleAndroidWebhook(service, valid)))
r.Method("POST", "/ios", middleware.InstrumentHandler("HandleIOSWebhook", HandleIOSWebhook(service)))
return r
}

// HandleAndroidWebhook is the handler for the Google Playstore webhooks
func HandleAndroidWebhook(service *Service) handlers.AppHandler {
const errInternalServer model.Error = "internal server error"

func handleAndroidWebhook(service *Service, valid *validator.Validate) handlers.AppHandler {
return func(w http.ResponseWriter, r *http.Request) *handlers.AppError {
ctx := r.Context()

l := logging.Logger(ctx, "payments").With().Str("func", "HandleAndroidWebhook").Logger()
l := logging.Logger(ctx, "skus").With().Str("func", "handleAndroidWebhook").Logger()

if err := service.gcpValidator.validate(ctx, r); err != nil {
l.Error().Err(err).Msg("invalid request")
return handlers.WrapError(err, "invalid request", http.StatusUnauthorized)
b, err := io.ReadAll(io.LimitReader(r.Body, reqBodyLimit10MB))
if err != nil {
l.Warn().Err(err).Msg("error reading request body")
return handlers.ValidationError("request body", err)
}

payload, err := requestutils.Read(r.Context(), r.Body)
dn, err := parseDeveloperNotification(b)
if err != nil {
l.Error().Err(err).Msg("failed to read payload")
return handlers.WrapValidationError(err)
l.Warn().Err(err).Msg("error parsing notification")
return handlers.ValidationError("parse notification", err)
}

l.Info().Str("payload", string(payload)).Msg("")
if err := valid.StructCtx(ctx, &dn); err != nil {
l.Warn().Err(err).Msg("validation errors")

var validationErrMap = map[string]interface{}{}
verrs, ok := collectValidationErrors(err)
if !ok {
return handlers.ValidationError("request", map[string]interface{}{"request-body": err.Error()})
}

var req AndroidNotification
if err := inputs.DecodeAndValidate(context.Background(), &req, payload); err != nil {
validationErrMap["request-body-decode"] = err.Error()
l.Error().Interface("validation_map", validationErrMap).Msg("validation_error")
return handlers.ValidationError("Error validating request", validationErrMap)
return handlers.ValidationError("request", verrs)
}

l.Info().Interface("req", req).Msg("")
if err := service.verifyDeveloperNotification(ctx, dn); err != nil {
switch {
case errors.Is(err, model.ErrOrderNotFound):
// This error can legitimately be returned when a user has not linked their mobile purchase.
// Currently, there is no way to distinguish between a missing order and a user that
// has not linked. We have no choice but to ack the message.
return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK)

dn, err := req.Message.GetDeveloperNotification()
if err != nil {
validationErrMap["invalid-developer-notification"] = err.Error()
l.Error().Interface("validation_map", validationErrMap).Msg("validation_error")
return handlers.ValidationError("Error validating request", validationErrMap)
default:
l.Error().Err(err).Msg("error verifying notification")
return handlers.WrapError(errInternalServer, "error verifying notification", http.StatusInternalServerError)
}
}

l.Info().Interface("developer_notification", dn).Msg("")
return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK)
}
}

type gcpMsgWrapper struct {
Message gcpMsg `json:"message"`
}

if dn == nil || dn.SubscriptionNotification.PurchaseToken == "" {
validationErrMap["invalid-developer-notification-token"] = "notification has no purchase token"
l.Error().Interface("validation_map", validationErrMap).Msg("validation_error")
return handlers.ValidationError("Error validating request", validationErrMap)
}
type gcpMsg struct {
Data string `json:"data"`
MessageID string `json:"messageId"`
}

l.Info().Msg("verify_developer_notification")
type developerNotification struct {
PackageName string `json:"packageName" validate:"required"`
SubscriptionNotification subscriptionNotification `json:"subscriptionNotification"`
}

err = service.verifyDeveloperNotification(ctx, dn)
if err != nil {
l.Error().Err(err).Msg("failed to verify subscription notification")
switch {
case errors.Is(err, errNotFound):
return handlers.WrapError(err, "failed to verify subscription notification", http.StatusNotFound)
default:
return handlers.WrapError(err, "failed to verify subscription notification", http.StatusInternalServerError)
}
}
type subscriptionNotification struct {
NotificationType int `json:"notificationType" validate:"required"`
PurchaseToken string `json:"purchaseToken" validate:"required"`
SubscriptionID string `json:"subscriptionId" validate:"required"`
}

return handlers.RenderContent(ctx, "event received", w, http.StatusOK)
func parseDeveloperNotification(raw []byte) (developerNotification, error) {
var m gcpMsgWrapper
if err := json.Unmarshal(raw, &m); err != nil {
return developerNotification{}, fmt.Errorf("error unmarshaling msg wrapper: %w", err)
}

b, err := base64.StdEncoding.DecodeString(m.Message.Data)
if err != nil {
return developerNotification{}, fmt.Errorf("error decoding msg data: %w", err)
}

var dn developerNotification
if err := json.Unmarshal(b, &dn); err != nil {
return developerNotification{}, fmt.Errorf("error unmarshaling developer notification: %w", err)
}

return dn, nil
}

const (
Expand Down
74 changes: 74 additions & 0 deletions services/skus/controllers_noint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,77 @@ func TestHandleReceiptErr(t *testing.T) {
})
}
}

func TestParseNotification(t *testing.T) {
type tcGiven struct {
raw []byte
}

type exp struct {
dn developerNotification
mustErr must.ErrorAssertionFunc
}

type testCase struct {
name string
given tcGiven
exp exp
}

tests := []testCase{
{
name: "error_msg_wrapper",
exp: exp{
mustErr: func(t must.TestingT, err error, i ...interface{}) {
must.ErrorContains(t, err, "error unmarshaling msg wrapper: ")
},
},
},
{
name: "error_msg_data",
given: tcGiven{raw: []byte(`{"message":{"data":"not-base64"}}`)},
exp: exp{
mustErr: func(t must.TestingT, err error, i ...interface{}) {
must.ErrorContains(t, err, "error decoding msg data: ")
},
},
},
{
name: "error_developer_notification",
given: tcGiven{raw: []byte(`{"message":{"data":"dGVzdA=="}}`)},
exp: exp{
mustErr: func(t must.TestingT, err error, i ...interface{}) {
must.ErrorContains(t, err, "error unmarshaling developer notification: ")
},
},
},
{
name: "success",
given: tcGiven{raw: []byte(`{"message":{"data":"eyJ2ZXJzaW9uIjoidmVyc2lvbiIsInBhY2thZ2VOYW1lIjoicGFja2FnZS1uYW1lIiwic3Vic2NyaXB0aW9uTm90aWZpY2F0aW9uIjp7InZlcnNpb24iOiJ2ZXJzaW9uIiwibm90aWZpY2F0aW9uVHlwZSI6MSwicHVyY2hhc2VUb2tlbiI6InB1cmNoYXNlLXRva2VuIiwic3Vic2NyaXB0aW9uSWQiOiJzdWJzY3JpcHRpb24taWQifX0="}}`)},
exp: exp{
dn: developerNotification{
PackageName: "package-name",
SubscriptionNotification: subscriptionNotification{
NotificationType: 1,
PurchaseToken: "purchase-token",
SubscriptionID: "subscription-id",
},
},
mustErr: func(t must.TestingT, err error, i ...interface{}) {
must.NoError(t, err)
},
},
},
}

for i := range tests {
tc := tests[i]

t.Run(tc.name, func(t *testing.T) {
actual, err := parseDeveloperNotification(tc.given.raw)
tc.exp.mustErr(t, err)

should.Equal(t, tc.exp.dn, actual)
})
}
}
Loading

0 comments on commit ecbad0f

Please sign in to comment.