diff --git a/services/skus/controllers.go b/services/skus/controllers.go index 54fc0d50d..7d07ccfac 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -362,7 +362,7 @@ func CancelOrder(service *Service) handlers.AppHandler { return handlers.WrapError(err, "Error retrieving the order", http.StatusInternalServerError) } - return handlers.RenderContent(ctx, nil, w, http.StatusOK) + return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK) }) } @@ -959,10 +959,10 @@ func VerifyCredentialV1(service *Service) handlers.AppHandler { // WebhookRouter - handles calls from various payment method webhooks informing payments of completion func WebhookRouter(service *Service) chi.Router { r := chi.NewRouter() - r.Method("POST", "/stripe", middleware.InstrumentHandler("HandleStripeWebhook", HandleStripeWebhook(service))) + r.Method("POST", "/stripe", middleware.InstrumentHandler("HandleStripeWebhook", handleStripeWebhook(service))) r.Method("POST", "/radom", middleware.InstrumentHandler("HandleRadomWebhook", HandleRadomWebhook(service))) r.Method("POST", "/android", middleware.InstrumentHandler("HandleAndroidWebhook", HandleAndroidWebhook(service))) - r.Method("POST", "/ios", middleware.InstrumentHandler("HandleIOSWebhook", HandleIOSWebhook(service))) + r.Method("POST", "/ios", middleware.InstrumentHandler("HandleIOSWebhook", handleIOSWebhook(service))) return r } @@ -978,7 +978,7 @@ func HandleAndroidWebhook(service *Service) handlers.AppHandler { return handlers.WrapError(err, "invalid request", http.StatusUnauthorized) } - payload, err := requestutils.Read(r.Context(), r.Body) + payload, err := io.ReadAll(io.LimitReader(r.Body, reqBodyLimit10MB)) if err != nil { l.Error().Err(err).Msg("failed to read payload") return handlers.WrapValidationError(err) @@ -1014,18 +1014,18 @@ func HandleAndroidWebhook(service *Service) handlers.AppHandler { l.Info().Msg("verify_developer_notification") - err = service.verifyDeveloperNotification(ctx, dn) - if err != nil { + if err := service.verifyDeveloperNotification(ctx, dn); err != nil { l.Error().Err(err).Msg("failed to verify subscription notification") + switch { - case errors.Is(err, errNotFound): - return handlers.WrapError(err, "failed to verify subscription notification", http.StatusNotFound) + case errors.Is(err, errNotFound), errors.Is(err, model.ErrOrderNotFound): + return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK) default: return handlers.WrapError(err, "failed to verify subscription notification", http.StatusInternalServerError) } } - return handlers.RenderContent(ctx, "event received", w, http.StatusOK) + return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK) } } @@ -1095,68 +1095,51 @@ func (g *gcpPushNotificationValidator) validate(ctx context.Context, r *http.Req return nil } -// HandleIOSWebhook is the handler for ios iap webhooks -func HandleIOSWebhook(service *Service) handlers.AppHandler { +func handleIOSWebhook(service *Service) handlers.AppHandler { return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { + ctx := r.Context() - var ( - ctx = r.Context() - req = new(IOSNotification) - validationErrMap = map[string]interface{}{} // for tracking our validation errors - ) - - // get logger - logger := logging.Logger(ctx, "payments").With(). - Str("func", "HandleIOSWebhook"). - Logger() + l := logging.Logger(ctx, "skus").With().Str("func", "handleIOSWebhook").Logger() - // read the payload - payload, err := requestutils.Read(r.Context(), r.Body) + data, err := io.ReadAll(io.LimitReader(r.Body, reqBodyLimit10MB)) if err != nil { - logger.Error().Err(err).Msg("failed to read the payload") - // no need to go further - return handlers.WrapValidationError(err) + l.Error().Err(err).Msg("error reading request body") + return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK) } - // validate the payload - if err := inputs.DecodeAndValidate(context.Background(), req, payload); err != nil { - logger.Debug().Str("payload", string(payload)).Msg("failed to decode and validate the payload") - logger.Warn().Err(err).Msg("failed to decode and validate the payload") - validationErrMap["request-body-decode"] = err.Error() + req := &IOSNotification{} + if err := inputs.DecodeAndValidate(ctx, req, data); err != nil { + l.Warn().Err(err).Msg("failed to decode and validate the payload") + + return handlers.ValidationError("request", map[string]interface{}{"request-body-decode": err.Error()}) } - // transaction info txInfo, err := req.GetTransactionInfo(ctx) if err != nil { - logger.Warn().Err(err).Msg("failed to get transaction info from message") - validationErrMap["invalid-transaction-info"] = err.Error() + l.Warn().Err(err).Msg("failed to get transaction info from message") + + return handlers.ValidationError("request", map[string]interface{}{"invalid-transaction-info": err.Error()}) } - // renewal info renewalInfo, err := req.GetRenewalInfo(ctx) if err != nil { - logger.Warn().Err(err).Msg("failed to get renewal info from message") - validationErrMap["invalid-renewal-info"] = err.Error() - } + l.Warn().Err(err).Msg("failed to get renewal info from message") - // if we had any validation errors, return the validation error map to the caller - if len(validationErrMap) != 0 { - return handlers.ValidationError("Error validating request url", validationErrMap) + return handlers.ValidationError("request", map[string]interface{}{"invalid-renewal-info": err.Error()}) } - err = service.verifyIOSNotification(ctx, txInfo, renewalInfo) - if err != nil { - logger.Error().Err(err).Msg("failed to verify ios subscription notification") + if err := service.verifyIOSNotification(ctx, txInfo, renewalInfo); err != nil { + l.Error().Err(err).Msg("failed to verify ios subscription notification") + switch { - case errors.Is(err, errNotFound): - return handlers.WrapError(err, "failed to verify ios subscription notification", - http.StatusNotFound) + case errors.Is(err, errNotFound), errors.Is(err, model.ErrOrderNotFound): + return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK) default: - return handlers.WrapError(err, "failed to verify ios subscription notification", - http.StatusInternalServerError) + return handlers.WrapError(err, "failed to verify ios subscription notification", http.StatusInternalServerError) } } - return handlers.RenderContent(ctx, "event received", w, http.StatusOK) + + return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK) } } @@ -1235,7 +1218,7 @@ func HandleRadomWebhook(service *Service) handlers.AppHandler { } // Set paymentProcessor to Radom. - if err := service.Datastore.AppendOrderMetadata(ctx, &orderID, paymentProcessor, model.RadomPaymentMethod); err != nil { + if err := service.Datastore.AppendOrderMetadata(ctx, &orderID, "paymentProcessor", model.RadomPaymentMethod); err != nil { lg.Error().Err(err).Msg("failed to update order to add the payment processor") return handlers.WrapError(err, "failed to update order to add the payment processor", http.StatusInternalServerError) } @@ -1245,124 +1228,115 @@ func HandleRadomWebhook(service *Service) handlers.AppHandler { } } -// HandleStripeWebhook handles webhook events from Stripe. -func HandleStripeWebhook(service *Service) handlers.AppHandler { +func handleStripeWebhook(svc *Service) handlers.AppHandler { return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { ctx := r.Context() - lg := logging.Logger(ctx, "payments").With().Str("func", "HandleStripeWebhook").Logger() + lg := logging.Logger(ctx, "skus").With().Str("func", "HandleStripeWebhook").Logger() - endpointSecret, err := appctx.GetStringFromContext(ctx, appctx.StripeWebhookSecretCTXKey) + secret, err := appctx.GetStringFromContext(ctx, appctx.StripeWebhookSecretCTXKey) if err != nil { lg.Error().Err(err).Msg("failed to get stripe_webhook_secret from context") return handlers.WrapError(err, "error getting stripe_webhook_secret from context", http.StatusInternalServerError) } - b, err := requestutils.Read(ctx, r.Body) + data, err := io.ReadAll(io.LimitReader(r.Body, reqBodyLimit10MB)) if err != nil { lg.Error().Err(err).Msg("failed to read request body") return handlers.WrapError(err, "error reading request body", http.StatusServiceUnavailable) } - event, err := webhook.ConstructEvent(b, r.Header.Get("Stripe-Signature"), endpointSecret) + event, err := webhook.ConstructEvent(data, r.Header.Get("Stripe-Signature"), secret) if err != nil { - lg.Error().Err(err).Msg("failed to verify stripe signature") + lg.Error().Err(err).Msg("failed to verify Stripe signature") return handlers.WrapError(err, "error verifying webhook signature", http.StatusBadRequest) } switch event.Type { - case StripeInvoiceUpdated, StripeInvoicePaid: - // Handle invoice events. - - var invoice stripe.Invoice - if err := json.Unmarshal(event.Data.Raw, &invoice); err != nil { - lg.Error().Err(err).Msg("error parsing webhook json") - return handlers.WrapError(err, "error parsing webhook JSON", http.StatusBadRequest) + case whStripeInvoiceUpdated, whStripeInvoicePaid: + invoice := &stripe.Invoice{} + if err := json.Unmarshal(event.Data.Raw, invoice); err != nil { + lg.Error().Err(err).Msg("failed to parse invoice") + return handlers.WrapError(err, "error parsing webhook invoice", http.StatusBadRequest) } - subscription, err := service.scClient.Subscriptions.Get(invoice.Subscription.ID, nil) + sub, err := svc.scClient.Subscriptions.Get(invoice.Subscription.ID, nil) if err != nil { - lg.Error().Err(err).Msg("error getting subscription") + lg.Error().Err(err).Msg("failed to get subscription") + + if isErrStripeNotFound(err) { + return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK) + } + return handlers.WrapError(err, "error retrieving subscription", http.StatusInternalServerError) } - orderID, err := uuid.FromString(subscription.Metadata["orderID"]) + orderID, err := uuid.FromString(sub.Metadata["orderID"]) if err != nil { - lg.Error().Err(err).Msg("error getting order id from subscription metadata") + lg.Error().Err(err).Msg("failed to parse orderID from Stripe metadata") return handlers.WrapError(err, "error retrieving orderID", http.StatusInternalServerError) } - // If the invoice is paid set order status to paid, otherwise - if invoice.Paid { - ok, subID, err := service.Datastore.IsStripeSub(orderID) - if err != nil { - lg.Error().Err(err).Msg("failed to tell if this is a stripe subscription") - return handlers.WrapError(err, "error looking up payment provider", http.StatusInternalServerError) + ord, err := svc.orderRepo.Get(ctx, svc.Datastore.RawDB(), orderID) + if err != nil { + lg.Error().Err(err).Msg("failed to get order") + + if errors.Is(err, model.ErrOrderNotFound) { + return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK) } - // Handle renewal. - if ok && subID != "" { - if err := service.RenewOrder(ctx, orderID); err != nil { - lg.Error().Err(err).Msg("failed to renew the order") - return handlers.WrapError(err, "error renewing order", http.StatusInternalServerError) - } + return handlers.WrapError(err, "failed to get order", http.StatusInternalServerError) + } - return handlers.RenderContent(ctx, "subscription renewed", w, http.StatusOK) + if subID, ok := ord.StripeSubID(); !ok || subID != sub.ID { + if err := svc.Datastore.AppendOrderMetadata(ctx, &orderID, "stripeSubscriptionId", sub.ID); err != nil { + lg.Error().Err(err).Msg("failed to update order metadata stripeSubscriptionId") + return handlers.WrapError(err, "failed to update order metadata stripeSubscriptionId", http.StatusInternalServerError) } + } - // New subscription. - // Update the order's expires at as it was just paid. - if err := service.Datastore.UpdateOrder(orderID, OrderStatusPaid); err != nil { - lg.Error().Err(err).Msg("failed to update order status") - return handlers.WrapError(err, "error updating order status", http.StatusInternalServerError) - } + switch event.Type { + case whStripeInvoiceUpdated: + return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK) - if err := service.Datastore.AppendOrderMetadata(ctx, &orderID, "stripeSubscriptionId", subscription.ID); err != nil { - lg.Error().Err(err).Msg("failed to update order metadata") - return handlers.WrapError(err, "error updating order metadata", http.StatusInternalServerError) + case whStripeInvoicePaid: + if err := svc.RenewOrder(ctx, orderID); err != nil { + lg.Error().Err(err).Msg("failed to renew order") + return handlers.WrapError(err, "error renewing order", http.StatusInternalServerError) } - if err := service.Datastore.AppendOrderMetadata(ctx, &orderID, paymentProcessor, StripePaymentMethod); err != nil { - lg.Error().Err(err).Msg("failed to update order to add the payment processor") - return handlers.WrapError(err, "failed to update order to add the payment processor", http.StatusInternalServerError) + if err := svc.Datastore.AppendOrderMetadata(ctx, &orderID, "paymentProcessor", model.StripePaymentMethod); err != nil { + lg.Error().Err(err).Msg("failed to update order metadata paymentProcessor") + return handlers.WrapError(err, "failed to update order metadata paymentProcessor", http.StatusInternalServerError) } - return handlers.RenderContent(ctx, "payment successful", w, http.StatusOK) - } - - if err := service.Datastore.UpdateOrder(orderID, "pending"); err != nil { - lg.Error().Err(err).Msg("failed to update order status") - return handlers.WrapError(err, "error updating order status", http.StatusInternalServerError) - } + return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK) - if err := service.Datastore.AppendOrderMetadata(ctx, &orderID, "stripeSubscriptionId", subscription.ID); err != nil { - lg.Error().Err(err).Msg("failed to update order metadata") - return handlers.WrapError(err, "error updating order metadata", http.StatusInternalServerError) + default: + handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK) } - return handlers.RenderContent(ctx, "payment failed", w, http.StatusOK) + case whStripeCustSubscriptionDeleted: + // TODO: Enable it and handle properly. - case StripeCustomerSubscriptionDeleted: - // Handle subscription cancellations - - var subscription stripe.Subscription - if err := json.Unmarshal(event.Data.Raw, &subscription); err != nil { - return handlers.WrapError(err, "error parsing webhook JSON", http.StatusBadRequest) + sub := &stripe.Subscription{} + if err := json.Unmarshal(event.Data.Raw, &sub); err != nil { + return handlers.WrapError(err, "failed to parse subscription", http.StatusBadRequest) } - orderID, err := uuid.FromString(subscription.Metadata["orderID"]) + orderID, err := uuid.FromString(sub.Metadata["orderID"]) if err != nil { - return handlers.WrapError(err, "error retrieving orderID", http.StatusInternalServerError) + return handlers.WrapError(err, "failed to parse orderID from Stripe metadata", http.StatusInternalServerError) } - if err := service.Datastore.UpdateOrder(orderID, OrderStatusCanceled); err != nil { - return handlers.WrapError(err, "error updating order status", http.StatusInternalServerError) + if err := svc.Datastore.UpdateOrder(orderID, OrderStatusCanceled); err != nil { + return handlers.WrapError(err, "failed to update order status canceled", http.StatusInternalServerError) } - return handlers.RenderContent(ctx, "subscription canceled", w, http.StatusOK) + return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK) } - return handlers.RenderContent(ctx, "event received", w, http.StatusOK) + return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK) } } @@ -1380,7 +1354,7 @@ func handleSubmitReceipt(svc *Service, valid *validator.Validate) handlers.AppHa return handlers.ValidationError("request", map[string]interface{}{"orderID": inputs.ErrIDDecodeNotUUID}) } - payload, err := requestutils.Read(ctx, r.Body) + payload, err := io.ReadAll(io.LimitReader(r.Body, reqBodyLimit10MB)) if err != nil { l.Warn().Err(err).Msg("failed to read body") diff --git a/services/skus/controllers_test.go b/services/skus/controllers_test.go index e20575f52..bde29419a 100644 --- a/services/skus/controllers_test.go +++ b/services/skus/controllers_test.go @@ -345,7 +345,7 @@ func (suite *ControllersTestSuite) TestIOSWebhookCertFail() { err := suite.service.Datastore.AppendOrderMetadata(context.Background(), &order.ID, "externalID", "my external id") suite.Require().NoError(err) - handler := HandleIOSWebhook(suite.service) + handler := handleIOSWebhook(suite.service) // create a jws message to send body := []byte{} diff --git a/services/skus/input.go b/services/skus/input.go index 0b593fe7e..1539131db 100644 --- a/services/skus/input.go +++ b/services/skus/input.go @@ -244,12 +244,8 @@ type IOSNotification struct { // Decode - implement Decodable interface func (iosn *IOSNotification) Decode(ctx context.Context, data []byte) error { - logger := logging.Logger(ctx, "IOSNotification.Decode") - logger.Debug().Msg("starting IOSNotification.Decode") - // json unmarshal the notification if err := json.Unmarshal(data, iosn); err != nil { - logger.Error().Msg("failed to json unmarshal body") return errorutils.Wrap(err, "error unmarshalling body") } @@ -266,8 +262,6 @@ func (iosn *IOSNotification) Decode(ctx context.Context, data []byte) error { // Validate - implement Validable interface func (iosn *IOSNotification) Validate(ctx context.Context) error { - logger := logging.Logger(ctx, "IOSNotification.Validate") - // extract the public key from the jws pk, err := extractPublicKey(iosn.SignedPayload) if err != nil { @@ -278,7 +272,6 @@ func (iosn *IOSNotification) Validate(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to verify jws payload in request: %w", err) } - logger.Debug().Msg("validated ios notification") iosn.payload = payload diff --git a/services/skus/model/model.go b/services/skus/model/model.go index 8065683b1..0d1d159c6 100644 --- a/services/skus/model/model.go +++ b/services/skus/model/model.go @@ -64,8 +64,9 @@ const ( ) const ( - VendorApple Vendor = "ios" - VendorGoogle Vendor = "android" + VendorUnknown Vendor = "unknown" + VendorApple Vendor = "ios" + VendorGoogle Vendor = "android" ) var ( @@ -145,6 +146,7 @@ func CreateStripeCheckoutSession( } params := &stripe.CheckoutSessionParams{ + // TODO: Get rid of this stripe.* nonsense, and use ptrTo instead. PaymentMethodTypes: stripe.StringSlice([]string{"card"}), Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), SuccessURL: stripe.String(successURI), @@ -318,6 +320,55 @@ func (o *Order) HasItem(id uuid.UUID) (*OrderItem, bool) { return nil, false } +func (o *Order) StripeSubID() (string, bool) { + sid, ok := o.Metadata["stripeSubscriptionId"].(string) + + return sid, ok +} + +func (o *Order) IsIOS() bool { + pp, ok := o.PaymentProc() + if !ok { + return false + } + + vn, ok := o.Vendor() + if !ok { + return false + } + + return pp == "ios" && vn == VendorApple +} + +func (o *Order) IsAndroid() bool { + pp, ok := o.PaymentProc() + if !ok { + return false + } + + vn, ok := o.Vendor() + if !ok { + return false + } + + return pp == "android" && vn == VendorGoogle +} + +func (o *Order) PaymentProc() (string, bool) { + pp, ok := o.Metadata["paymentProcessor"].(string) + + return pp, ok +} + +func (o *Order) Vendor() (Vendor, bool) { + vn, ok := o.Metadata["vendor"].(string) + if !ok { + return VendorUnknown, false + } + + return Vendor(vn), true +} + // OrderItem represents a particular order item. type OrderItem struct { ID uuid.UUID `json:"id" db:"id"` diff --git a/services/skus/model/model_test.go b/services/skus/model/model_test.go index 237119485..a06dbd10f 100644 --- a/services/skus/model/model_test.go +++ b/services/skus/model/model_test.go @@ -698,6 +698,388 @@ func TestOrder_HasItem(t *testing.T) { } } +func TestOrder_StripeSubID(t *testing.T) { + type tcExpected struct { + val string + ok bool + } + + type testCase struct { + name string + given model.Order + exp tcExpected + } + + tests := []testCase{ + { + name: "no_metadata", + }, + + { + name: "no_field", + given: model.Order{ + Metadata: datastore.Metadata{"key": "value"}, + }, + }, + + { + name: "not_string", + given: model.Order{ + Metadata: datastore.Metadata{ + "stripeSubscriptionId": 42, + }, + }, + }, + + { + name: "empty_string", + given: model.Order{ + Metadata: datastore.Metadata{ + "stripeSubscriptionId": "", + }, + }, + exp: tcExpected{ok: true}, + }, + + { + name: "sub_id", + given: model.Order{ + Metadata: datastore.Metadata{ + "stripeSubscriptionId": "sub_id", + }, + }, + exp: tcExpected{val: "sub_id", ok: true}, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual, ok := tc.given.StripeSubID() + should.Equal(t, tc.exp.ok, ok) + should.Equal(t, tc.exp.val, actual) + }) + } +} + +func TestOrder_IsIOS(t *testing.T) { + type testCase struct { + name string + given model.Order + exp bool + } + + tests := []testCase{ + { + name: "no_metadata", + }, + + { + name: "no_pp", + given: model.Order{ + Metadata: datastore.Metadata{"key": "value"}, + }, + }, + + { + name: "pp_stripe_no_vn", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "stripe", + }, + }, + }, + + { + name: "pp_stripe_vn_android", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "stripe", + "vendor": "android", + }, + }, + }, + + { + name: "pp_ios_vn_android", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "ios", + "vendor": "android", + }, + }, + }, + + { + name: "pp_stripe_vn_ios", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "stripe", + "vendor": "ios", + }, + }, + }, + + { + name: "ios", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "ios", + "vendor": "ios", + }, + }, + exp: true, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := tc.given.IsIOS() + should.Equal(t, tc.exp, actual) + }) + } +} + +func TestOrder_IsAndroid(t *testing.T) { + type testCase struct { + name string + given model.Order + exp bool + } + + tests := []testCase{ + { + name: "no_metadata", + }, + + { + name: "no_pp", + given: model.Order{ + Metadata: datastore.Metadata{"key": "value"}, + }, + }, + + { + name: "pp_stripe_no_vn", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "stripe", + }, + }, + }, + + { + name: "pp_stripe_vn_ios", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "stripe", + "vendor": "ios", + }, + }, + }, + + { + name: "pp_android_vn_ios", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "android", + "vendor": "ios", + }, + }, + }, + + { + name: "pp_stripe_vn_android", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "stripe", + "vendor": "android", + }, + }, + }, + + { + name: "android", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "android", + "vendor": "android", + }, + }, + exp: true, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := tc.given.IsAndroid() + should.Equal(t, tc.exp, actual) + }) + } +} + +func TestOrder_PaymentProc(t *testing.T) { + type tcExpected struct { + val string + ok bool + } + + type testCase struct { + name string + given model.Order + exp tcExpected + } + + tests := []testCase{ + { + name: "no_metadata", + }, + + { + name: "no_field", + given: model.Order{ + Metadata: datastore.Metadata{"key": "value"}, + }, + }, + + { + name: "not_string", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": 42, + }, + }, + }, + + { + name: "empty_string", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "", + }, + }, + exp: tcExpected{ok: true}, + }, + + { + name: "sub_id", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "stripe", + }, + }, + exp: tcExpected{val: "stripe", ok: true}, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual, ok := tc.given.PaymentProc() + should.Equal(t, tc.exp.ok, ok) + should.Equal(t, tc.exp.val, actual) + }) + } +} + +func TestOrder_Vendor(t *testing.T) { + type tcExpected struct { + val model.Vendor + ok bool + } + + type testCase struct { + name string + given model.Order + exp tcExpected + } + + tests := []testCase{ + { + name: "no_metadata", + exp: tcExpected{ + val: model.VendorUnknown, + }, + }, + + { + name: "no_field", + given: model.Order{ + Metadata: datastore.Metadata{"key": "value"}, + }, + exp: tcExpected{ + val: model.VendorUnknown, + }, + }, + + { + name: "not_string", + given: model.Order{ + Metadata: datastore.Metadata{ + "vendor": 42, + }, + }, + exp: tcExpected{ + val: model.VendorUnknown, + }, + }, + + { + name: "empty_string", + given: model.Order{ + Metadata: datastore.Metadata{ + "vendor": "", + }, + }, + exp: tcExpected{ + ok: true, + }, + }, + + { + name: "something_else", + given: model.Order{ + Metadata: datastore.Metadata{ + "vendor": "something_else", + }, + }, + exp: tcExpected{ + val: model.Vendor("something_else"), + ok: true, + }, + }, + + { + name: "apple", + given: model.Order{ + Metadata: datastore.Metadata{ + "vendor": "ios", + }, + }, + exp: tcExpected{ + val: model.VendorApple, + ok: true, + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual, ok := tc.given.Vendor() + should.Equal(t, tc.exp.ok, ok) + should.Equal(t, tc.exp.val, actual) + }) + } +} + func mustDecimalFromString(v string) decimal.Decimal { result, err := decimal.NewFromString(v) if err != nil { diff --git a/services/skus/order.go b/services/skus/order.go index 2b60ede03..50d642fe0 100644 --- a/services/skus/order.go +++ b/services/skus/order.go @@ -19,28 +19,16 @@ import ( ) const ( - paymentProcessor = "paymentProcessor" - // IOSPaymentMethod - indicating this used an ios payment method - IOSPaymentMethod = "ios" - // AndroidPaymentMethod - indicating this used an android payment method - AndroidPaymentMethod = "android" -) - -const ( - // TODO(pavelb): Gradually replace it everywhere. - StripePaymentMethod = model.StripePaymentMethod - - StripeInvoiceUpdated = "invoice.updated" - StripeInvoicePaid = "invoice.paid" - StripeCustomerSubscriptionDeleted = "customer.subscription.deleted" + whStripeInvoiceUpdated = "invoice.updated" + whStripeInvoicePaid = "invoice.paid" + whStripeCustSubscriptionDeleted = "customer.subscription.deleted" ) // TODO(pavelb): Gradually replace these everywhere. type ( - Order = model.Order - OrderItem = model.OrderItem - CreateCheckoutSessionResponse = model.CreateCheckoutSessionResponse - Issuer = model.Issuer + Order = model.Order + OrderItem = model.OrderItem + Issuer = model.Issuer ) func decodeAndUnmarshalSku(sku string) (*macaroon.Macaroon, error) { @@ -198,5 +186,5 @@ func (s *Service) RenewOrder(ctx context.Context, orderID uuid.UUID) error { return fmt.Errorf("failed to set order status to paid: %w", err) } - return s.DeleteOrderCreds(ctx, orderID, true) + return nil } diff --git a/services/skus/service.go b/services/skus/service.go index 0cc7a0a63..37f21736a 100644 --- a/services/skus/service.go +++ b/services/skus/service.go @@ -620,7 +620,7 @@ func (s *Service) TransformStripeOrder(order *Order) (*Order, error) { return nil, fmt.Errorf("failed to update order to add the subscription id") } - if err := s.Datastore.AppendOrderMetadata(ctx, &order.ID, paymentProcessor, StripePaymentMethod); err != nil { + if err := s.Datastore.AppendOrderMetadata(ctx, &order.ID, "paymentProcessor", model.StripePaymentMethod); err != nil { return nil, fmt.Errorf("failed to update order to add the payment processor") } } @@ -637,43 +637,55 @@ func (s *Service) TransformStripeOrder(order *Order) (*Order, error) { // CancelOrder cancels an order, propagates to stripe if needed. // -// TODO(pavelb): Refactor and make it precise. -// Currently, this method does something weird for the case when the order was not found in the DB. -// If we have an order id, but ended up without the order, that means either the id is wrong, -// or we somehow lost data. The latter is less likely. -// Yet we allow non-existing order ids to be searched for in Stripe, which is strange. +// TODO(pavelb): Refactor this. func (s *Service) CancelOrder(orderID uuid.UUID) error { - // Check the order, do we have a stripe subscription? - ok, subID, err := s.Datastore.IsStripeSub(orderID) - if err != nil && !errors.Is(err, model.ErrOrderNotFound) { - return fmt.Errorf("failed to check stripe subscription: %w", err) + // TODO: Refactor this later. Now here's a quick fix. + ord, err := s.Datastore.GetOrder(orderID) + if err != nil { + return err } + if ord == nil { + return model.ErrOrderNotFound + } + + subID, ok := ord.StripeSubID() if ok && subID != "" { // Cancel the stripe subscription. if _, err := sub.Cancel(subID, nil); err != nil { - return fmt.Errorf("failed to cancel stripe subscription: %w", err) + // Error out if it's not 404. + if !isErrStripeNotFound(err) { + return fmt.Errorf("failed to cancel stripe subscription: %w", err) + } } + // Cancel even for 404. return s.Datastore.UpdateOrder(orderID, OrderStatusCanceled) } - // Try to find order in Stripe. + if ord.IsIOS() || ord.IsAndroid() { + return s.Datastore.UpdateOrder(orderID, OrderStatusCanceled) + } + + // Try to find by order_id in Stripe. params := &stripe.SubscriptionSearchParams{} - params.Query = *stripe.String(fmt.Sprintf("status:'active' AND metadata['orderID']:'%s'", orderID.String())) + params.Query = fmt.Sprintf("status:'active' AND metadata['orderID']:'%s'", orderID.String()) ctx := context.TODO() iter := sub.Search(params) for iter.Next() { - // we have a result, fix the stripe sub on the db record, and then cancel sub - subscription := iter.Subscription() - // cancel the stripe subscription - if _, err := sub.Cancel(subscription.ID, nil); err != nil { + sb := iter.Subscription() + if _, err := sub.Cancel(sb.ID, nil); err != nil { + // It seems that already canceled subscriptions might return 404. + if isErrStripeNotFound(err) { + continue + } + return fmt.Errorf("failed to cancel stripe subscription: %w", err) } - if err := s.Datastore.AppendOrderMetadata(ctx, &orderID, "stripeSubscriptionId", subscription.ID); err != nil { + if err := s.Datastore.AppendOrderMetadata(ctx, &orderID, "stripeSubscriptionId", sb.ID); err != nil { return fmt.Errorf("failed to update order metadata with subscription id: %w", err) } } @@ -1588,33 +1600,37 @@ func (s *Service) verifyIOSNotification(ctx context.Context, txInfo *appstore.JW return errors.New("notification has no tx or renewal") } + // TODO: The documentation says nothing about these conditions. + // Shall this be gone? if !govalidator.IsAlphanumeric(txInfo.OriginalTransactionId) || len(txInfo.OriginalTransactionId) > 32 { return errors.New("original transaction id should be alphanumeric and less than 32 chars") } // lookup the order based on the token as externalID - o, err := s.Datastore.GetOrderByExternalID(txInfo.OriginalTransactionId) + ord, err := s.Datastore.GetOrderByExternalID(txInfo.OriginalTransactionId) if err != nil { return fmt.Errorf("failed to get order from db (%s): %w", txInfo.OriginalTransactionId, err) } - if o == nil { + if ord == nil { return fmt.Errorf("failed to get order from db (%s): %w", txInfo.OriginalTransactionId, errNotFound) } - // check if we are past the expiration date on transaction or the order was revoked + // Check if we are past the expiration date on transaction or the order was revoked. + now := time.Now() - if time.Now().After(time.Unix(0, txInfo.ExpiresDate*int64(time.Millisecond))) || - (txInfo.RevocationDate > 0 && time.Now().After(time.Unix(0, txInfo.RevocationDate*int64(time.Millisecond)))) { - // past our tx expires/renewal time - if err = s.CancelOrder(o.ID); err != nil { + if shouldCancelOrderIOS(txInfo, now) { + if err := s.CancelOrder(ord.ID); err != nil { return fmt.Errorf("failed to cancel subscription in skus: %w", err) } - } else { - if err = s.RenewOrder(ctx, o.ID); err != nil { - return fmt.Errorf("failed to renew subscription in skus: %w", err) - } + + return nil } + + if err := s.RenewOrder(ctx, ord.ID); err != nil { + return fmt.Errorf("failed to renew subscription in skus: %w", err) + } + return nil } @@ -2092,9 +2108,9 @@ func createOrderItem(req *model.OrderItemRequestNew) (*model.OrderItem, error) { func newMobileOrderMdata(req model.ReceiptRequest, extID string) datastore.Metadata { result := datastore.Metadata{ - "externalID": extID, - paymentProcessor: req.Type.String(), - "vendor": req.Type.String(), + "externalID": extID, + "paymentProcessor": req.Type.String(), + "vendor": req.Type.String(), } return result @@ -2128,3 +2144,36 @@ type tlv1CredPresentation struct { func ptrTo[T any](v T) *T { return &v } + +func shouldCancelOrderIOS(info *appstore.JWSTransactionDecodedPayload, now time.Time) bool { + tx := (*appStoreTransaction)(info) + + return tx.hasExpired(now) || tx.isRevoked(now) +} + +type appStoreTransaction appstore.JWSTransactionDecodedPayload + +func (x *appStoreTransaction) hasExpired(now time.Time) bool { + if x == nil { + return false + } + + return x.ExpiresDate > 0 && now.After(time.UnixMilli(x.ExpiresDate)) +} + +func (x *appStoreTransaction) isRevoked(now time.Time) bool { + if x == nil { + return false + } + + return x.RevocationDate > 0 && now.After(time.UnixMilli(x.RevocationDate)) +} + +func isErrStripeNotFound(err error) bool { + var serr *stripe.Error + if !errors.As(err, &serr) { + return false + } + + return serr.HTTPStatusCode == http.StatusNotFound && serr.Code == stripe.ErrorCodeResourceMissing +} diff --git a/services/skus/service_nonint_test.go b/services/skus/service_nonint_test.go index ceda3fa2d..86c4aa055 100644 --- a/services/skus/service_nonint_test.go +++ b/services/skus/service_nonint_test.go @@ -3,15 +3,18 @@ package skus import ( "context" "database/sql" + "net/http" "testing" "time" + "github.com/awa/go-iap/appstore" "github.com/jmoiron/sqlx" "github.com/lib/pq" uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" should "github.com/stretchr/testify/assert" must "github.com/stretchr/testify/require" + "github.com/stripe/stripe-go/v72" "github.com/brave-intl/bat-go/libs/datastore" @@ -942,6 +945,177 @@ func TestService_checkOrderReceipt(t *testing.T) { } } +func TestShouldCancelOrderIOS(t *testing.T) { + type tcGiven struct { + now time.Time + info *appstore.JWSTransactionDecodedPayload + } + + type testCase struct { + name string + given tcGiven + exp bool + } + + tests := []testCase{ + { + name: "nil", + given: tcGiven{ + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + }, + }, + + { + name: "empty_dates_not_expired", + given: tcGiven{ + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + info: &appstore.JWSTransactionDecodedPayload{}, + }, + }, + + { + name: "expires_date_before_no_revocation_date", + given: tcGiven{ + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + info: &appstore.JWSTransactionDecodedPayload{ + // 2023-12-31 23:59:59. + ExpiresDate: 1704067199000, + }, + }, + exp: true, + }, + + { + name: "expires_date_after_no_revocation_date", + given: tcGiven{ + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + info: &appstore.JWSTransactionDecodedPayload{ + // 2024-01-01 01:00:01. + ExpiresDate: 1704070801000, + }, + }, + }, + + { + name: "expires_date_after_revocation_date_after", + given: tcGiven{ + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + info: &appstore.JWSTransactionDecodedPayload{ + // 2024-01-01 01:00:01. + ExpiresDate: 1704070801000, + + // 2024-01-01 00:30:01. + RevocationDate: 1704069001000, + }, + }, + }, + + { + name: "expires_date_after_revocation_date_before", + given: tcGiven{ + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + info: &appstore.JWSTransactionDecodedPayload{ + // 2024-01-01 01:00:01. + ExpiresDate: 1704070801000, + + // 2023-12-31 23:30:01. + RevocationDate: 1704065401000, + }, + }, + exp: true, + }, + + { + name: "no_expires_date_revocation_date_before", + given: tcGiven{ + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + info: &appstore.JWSTransactionDecodedPayload{ + // 2023-12-31 23:59:59. + RevocationDate: 1704067199000, + }, + }, + exp: true, + }, + + { + name: "no_expires_date_revocation_date_after", + given: tcGiven{ + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + info: &appstore.JWSTransactionDecodedPayload{ + // 2024-01-01 01:00:01. + RevocationDate: 1704070801000, + }, + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := shouldCancelOrderIOS(tc.given.info, tc.given.now) + + should.Equal(t, tc.exp, actual) + }) + } +} + +func TestIsErrStripeNotFound(t *testing.T) { + tests := []struct { + name string + given error + exp bool + }{ + { + name: "something_else", + given: model.Error("something else"), + }, + + { + name: "429_rate_limit", + given: &stripe.Error{ + HTTPStatusCode: http.StatusTooManyRequests, + Code: stripe.ErrorCodeRateLimit, + }, + }, + + { + name: "429_resource_missing", + given: &stripe.Error{ + HTTPStatusCode: http.StatusTooManyRequests, + Code: stripe.ErrorCodeResourceMissing, + }, + }, + + { + name: "404_rate_limit", + given: &stripe.Error{ + HTTPStatusCode: http.StatusNotFound, + Code: stripe.ErrorCodeRateLimit, + }, + }, + + { + name: "404_resource_missing", + given: &stripe.Error{ + HTTPStatusCode: http.StatusNotFound, + Code: stripe.ErrorCodeResourceMissing, + }, + exp: true, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := isErrStripeNotFound(tc.given) + + should.Equal(t, tc.exp, actual) + }) + } +} + type mockPaidOrderCreator struct { fnCreateOrder func(ctx context.Context, req *model.CreateOrderRequestNew, ordNew *model.OrderNew, items []model.OrderItem) (*model.Order, error) fnUpdateOrderStatusPaidWithMetadata func(ctx context.Context, oid *uuid.UUID, mdata datastore.Metadata) error