From f1282f356d64f6e450aa292ca0c24bd599a6036d Mon Sep 17 00:00:00 2001 From: eV <8796196+evq@users.noreply.github.com> Date: Wed, 11 Oct 2023 10:50:05 +0000 Subject: [PATCH 1/4] remove call to check reputation during wallet linking (#2069) * remove call to check reputation during wallet linking * feat: allow non-reputable users to link (#2101) * remove unused parameters --------- Co-authored-by: clD11 <23483715+clD11@users.noreply.github.com> --- services/wallet/controllers_v3_test.go | 54 +---------------------- services/wallet/datastore.go | 41 +---------------- services/wallet/datastore_test.go | 6 +-- services/wallet/instrumented_datastore.go | 8 ++-- services/wallet/service.go | 10 ++--- 5 files changed, 16 insertions(+), 103 deletions(-) diff --git a/services/wallet/controllers_v3_test.go b/services/wallet/controllers_v3_test.go index 890d583fc..d1ea2e629 100644 --- a/services/wallet/controllers_v3_test.go +++ b/services/wallet/controllers_v3_test.go @@ -16,7 +16,7 @@ import ( "testing" "time" - "github.com/DATA-DOG/go-sqlmock" + sqlmock "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -34,7 +34,7 @@ import ( "github.com/golang/mock/gomock" "github.com/jmoiron/sqlx" uuid "github.com/satori/go.uuid" - "gopkg.in/square/go-jose.v2" + jose "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" ) @@ -259,16 +259,6 @@ func TestLinkBitFlyerWalletV3(t *testing.T) { ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputation) ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") - mockReputation.EXPECT().IsLinkingReputable( - gomock.Any(), // ctx - gomock.Any(), // wallet id - gomock.Any(), // country - ).Return( - true, - []int{}, - nil, - ) - r = r.WithContext(ctx) router := chi.NewRouter() @@ -327,16 +317,6 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { rw = httptest.NewRecorder() ) - mockReputationClient.EXPECT().IsLinkingReputable( - gomock.Any(), // ctx - gomock.Any(), // wallet id - gomock.Any(), // country - ).Return( - true, - []int{}, - nil, - ) - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputationClient) ctx = context.WithValue(ctx, appctx.GeminiClientCTXKey, mockGeminiClient) @@ -557,16 +537,6 @@ func TestLinkGeminiWalletV3FirstLinking(t *testing.T) { rw = httptest.NewRecorder() ) - mockReputationClient.EXPECT().IsLinkingReputable( - gomock.Any(), // ctx - gomock.Any(), // wallet id - gomock.Any(), // country - ).Return( - true, - []int{}, - nil, - ) - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputationClient) ctx = context.WithValue(ctx, appctx.GeminiClientCTXKey, mockGeminiClient) @@ -770,16 +740,6 @@ func TestLinkZebPayWalletV3(t *testing.T) { )), ) - mockReputationClient.EXPECT().IsLinkingReputable( - gomock.Any(), // ctx - gomock.Any(), // wallet id - gomock.Any(), // country - ).Return( - true, - []int{}, - nil, - ) - mockSQLCustodianLink(mock, "zebpay") // begin linking tx @@ -887,16 +847,6 @@ func TestLinkGeminiWalletV3(t *testing.T) { nil, ) - mockReputationClient.EXPECT().IsLinkingReputable( - gomock.Any(), // ctx - gomock.Any(), // wallet id - gomock.Any(), // country - ).Return( - true, - []int{}, - nil, - ) - mockSQLCustodianLink(mock, "gemini") // begin linking tx diff --git a/services/wallet/datastore.go b/services/wallet/datastore.go index 7e6d68ba6..4b545c233 100644 --- a/services/wallet/datastore.go +++ b/services/wallet/datastore.go @@ -71,7 +71,7 @@ func init() { // Datastore holds the interface for the wallet datastore type Datastore interface { datastore.Datastore - LinkWallet(ctx context.Context, ID string, providerID string, providerLinkingID uuid.UUID, anonymousAddress *uuid.UUID, depositProvider, country string) error + LinkWallet(ctx context.Context, ID string, providerID string, providerLinkingID uuid.UUID, depositProvider string) error GetLinkingLimitInfo(ctx context.Context, providerLinkingID string) (map[string]LinkingInfo, error) HasPriorLinking(ctx context.Context, walletID uuid.UUID, providerLinkingID uuid.UUID) (bool, error) // GetLinkingsByProviderLinkingID gets the wallet linking info by provider linking id @@ -561,47 +561,10 @@ var ( ) // LinkWallet links a wallet together -func (pg *Postgres) LinkWallet(ctx context.Context, ID string, userDepositDestination string, providerLinkingID uuid.UUID, anonymousAddress *uuid.UUID, depositProvider, country string) error { +func (pg *Postgres) LinkWallet(ctx context.Context, ID string, userDepositDestination string, providerLinkingID uuid.UUID, depositProvider string) error { sublogger := logger(ctx).With().Str("wallet_id", ID).Logger() sublogger.Debug().Msg("linking wallet") - // rep check - if repClient, ok := ctx.Value(appctx.ReputationClientCTXKey).(reputation.Client); ok { - walletID, err := uuid.FromString(ID) - if err != nil { - sublogger.Warn().Err(err).Msg("invalid wallet id") - return fmt.Errorf("invalid wallet id, not uuid: %w", err) - } - // we have a client, check the value for ID - reputable, cohorts, err := repClient.IsLinkingReputable(ctx, walletID, country) - if err != nil { - sublogger.Warn().Err(err).Msg("failed to check reputation") - return fmt.Errorf("failed to check wallet rep: %w", err) - } - - var ( - isTooYoung = false - geoResetDifferent = false - ) - for _, v := range cohorts { - if isTooYoung = (v == reputation.CohortTooYoung); isTooYoung { - break - } - if geoResetDifferent = (v == reputation.CohortGeoResetDifferent); geoResetDifferent { - break - } - } - - if !reputable && !isTooYoung && !geoResetDifferent { - sublogger.Info().Msg("wallet linking attempt failed - unusual activity") - countLinkingFlaggedUnusual.Inc() - return ErrUnusualActivity - } else if geoResetDifferent { - sublogger.Info().Msg("wallet linking attempt failed - geo reset is different") - return ErrGeoResetDifferent - } - } - ctx, tx, rollback, commit, err := getTx(ctx, pg) if err != nil { sublogger.Error().Err(err).Msg("error getting tx") diff --git a/services/wallet/datastore_test.go b/services/wallet/datastore_test.go index 6370a8560..54b126369 100644 --- a/services/wallet/datastore_test.go +++ b/services/wallet/datastore_test.go @@ -226,7 +226,7 @@ func (suite *WalletPostgresTestSuite) TestLinkWallet_Concurrent_InsertUpdate() { go func() { defer wg.Done() err = pg.LinkWallet(context.WithValue(context.Background(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D"), - walletInfo.ID, userDepositDestination, providerLinkingID, walletInfo.AnonymousAddress, walletInfo.Provider, "") + walletInfo.ID, userDepositDestination, providerLinkingID, walletInfo.Provider) }() } @@ -260,7 +260,7 @@ func (suite *WalletPostgresTestSuite) seedWallet(pg Datastore) (string, uuid.UUI suite.Require().NoError(err, "save wallet should succeed") err = pg.LinkWallet(context.WithValue(context.Background(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D"), - walletInfo.ID, userDepositDestination, providerLinkingID, walletInfo.AnonymousAddress, "uphold", "") + walletInfo.ID, userDepositDestination, providerLinkingID, "uphold") suite.Require().NoError(err, "link wallet should succeed") } @@ -303,7 +303,7 @@ func (suite *WalletPostgresTestSuite) TestLinkWallet_Concurrent_MaxLinkCount() { go func(index int) { defer wg.Done() err = pg.LinkWallet(context.WithValue(context.Background(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D"), - wallets[index].ID, userDepositDestination, providerLinkingID, wallets[index].AnonymousAddress, wallets[index].Provider, "") + wallets[index].ID, userDepositDestination, providerLinkingID, wallets[index].Provider) }(i) } diff --git a/services/wallet/instrumented_datastore.go b/services/wallet/instrumented_datastore.go index 837783d57..a5c4c7b9c 100644 --- a/services/wallet/instrumented_datastore.go +++ b/services/wallet/instrumented_datastore.go @@ -1,9 +1,9 @@ +package wallet + // Code generated by gowrap. DO NOT EDIT. // template: ../../.prom-gowrap.tmpl // gowrap: http://github.com/hexdigest/gowrap -package wallet - //go:generate gowrap gen -p github.com/brave-intl/bat-go/services/wallet -i Datastore -t ../../.prom-gowrap.tmpl -o instrumented_datastore.go -l "" import ( @@ -255,7 +255,7 @@ func (_d DatastoreWithPrometheus) InsertWalletTx(ctx context.Context, tx *sqlx.T } // LinkWallet implements Datastore -func (_d DatastoreWithPrometheus) LinkWallet(ctx context.Context, ID string, providerID string, providerLinkingID uuid.UUID, anonymousAddress *uuid.UUID, depositProvider string, country string) (err error) { +func (_d DatastoreWithPrometheus) LinkWallet(ctx context.Context, ID string, providerID string, providerLinkingID uuid.UUID, depositProvider string) (err error) { _since := time.Now() defer func() { result := "ok" @@ -265,7 +265,7 @@ func (_d DatastoreWithPrometheus) LinkWallet(ctx context.Context, ID string, pro datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "LinkWallet", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.LinkWallet(ctx, ID, providerID, providerLinkingID, anonymousAddress, depositProvider, country) + return _d.base.LinkWallet(ctx, ID, providerID, providerLinkingID, depositProvider) } // Migrate implements Datastore diff --git a/services/wallet/service.go b/services/wallet/service.go index 90819c180..900b295dc 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -415,7 +415,7 @@ func (service *Service) LinkBitFlyerWallet(ctx context.Context, walletID uuid.UU // we also validated that this "info" signed the request to perform the linking with http signature // we assume that since we got linkingInfo signed from BF that they are KYC providerLinkingID := uuid.NewV5(ClaimNamespace, accountHash) - err = service.Datastore.LinkWallet(ctx, walletID.String(), depositID, providerLinkingID, nil, depositProvider, country) + err = service.Datastore.LinkWallet(ctx, walletID.String(), depositID, providerLinkingID, depositProvider) if err != nil { if errors.Is(err, ErrUnusualActivity) { return "", handlers.WrapError(err, "unable to link - unusual activity", http.StatusBadRequest) @@ -497,7 +497,7 @@ func (service *Service) LinkZebPayWallet(ctx context.Context, walletID uuid.UUID } providerLinkingID := uuid.NewV5(ClaimNamespace, claims.AccountID) - if err := service.Datastore.LinkWallet(ctx, walletID.String(), claims.DepositID, providerLinkingID, nil, depositProvider, country); err != nil { + if err := service.Datastore.LinkWallet(ctx, walletID.String(), claims.DepositID, providerLinkingID, depositProvider); err != nil { if errors.Is(err, ErrUnusualActivity) { return "", handlers.WrapError(err, "unable to link - unusual activity", http.StatusBadRequest) } @@ -567,7 +567,7 @@ func (service *Service) LinkGeminiWallet(ctx context.Context, walletID uuid.UUID // we assume that since we got linking_info(VerificationToken) signed from Gemini that they are KYC providerLinkingID := uuid.NewV5(ClaimNamespace, accountID) - err = service.Datastore.LinkWallet(ctx, walletID.String(), depositID, providerLinkingID, nil, depositProvider, country) + err = service.Datastore.LinkWallet(ctx, walletID.String(), depositID, providerLinkingID, depositProvider) if err != nil { if errors.Is(err, ErrUnusualActivity) { return "", handlers.WrapError(err, "unable to link - unusual activity", http.StatusBadRequest) @@ -589,7 +589,7 @@ func (service *Service) LinkGeminiWallet(ctx context.Context, walletID uuid.UUID } // LinkUpholdWallet links an uphold.Wallet and transfers funds. -func (service *Service) LinkUpholdWallet(ctx context.Context, wallet uphold.Wallet, transaction string, anonymousAddress *uuid.UUID) (string, error) { +func (service *Service) LinkUpholdWallet(ctx context.Context, wallet uphold.Wallet, transaction string, _ *uuid.UUID) (string, error) { const depositProvider = "uphold" // do not confirm this transaction yet info := wallet.GetWalletInfo() @@ -669,7 +669,7 @@ func (service *Service) LinkUpholdWallet(ctx context.Context, wallet uphold.Wall providerLinkingID := uuid.NewV5(ClaimNamespace, userID) // tx.Destination will be stored as UserDepositDestination in the wallet info upon linking - err = service.Datastore.LinkWallet(ctx, info.ID, transactionInfo.Destination, providerLinkingID, anonymousAddress, depositProvider, country) + err = service.Datastore.LinkWallet(ctx, info.ID, transactionInfo.Destination, providerLinkingID, depositProvider) if err != nil { if errors.Is(err, ErrUnusualActivity) { return "", handlers.WrapError(err, "unable to link - unusual activity", http.StatusBadRequest) From b64db37ea4d5397ca5ac7e77f698484387cd2986 Mon Sep 17 00:00:00 2001 From: Pavel Brm <5097196+pavelbrm@users.noreply.github.com> Date: Fri, 13 Oct 2023 00:13:49 +1300 Subject: [PATCH 2/4] =?UTF-8?q?Bundles=20=E2=80=93=20Handle=20Location=20f?= =?UTF-8?q?or=20Multi-Item=20Orders=20(#1982)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Handle Location for Multi-Item Orders * Backport changes from 1998 * Rename req to oreq in datastore and repository --- .github/PULL_REQUEST_TEMPLATE.md | 12 +- services/skus/datastore.go | 41 ++-- services/skus/datastore_test.go | 73 +++++--- services/skus/instrumented_datastore.go | 15 +- services/skus/mockdatastore.go | 11 +- services/skus/model/model.go | 74 +++++--- services/skus/service.go | 175 ++++++++++-------- .../storage/repository/order_item_test.go | 37 ++-- .../skus/storage/repository/repository.go | 27 +-- .../storage/repository/repository_test.go | 113 +++++------ 10 files changed, 302 insertions(+), 276 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 744d1f5c9..d8c974e1c 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,7 +2,8 @@ -### Type of change ( select one ) + +### Type of Change - [ ] Product feature - [ ] Bug fix @@ -12,13 +13,15 @@ + ### Tested Environments - [ ] Development - [ ] Staging - [ ] Production -### Before submitting this PR: + +### Before Requesting Review - [ ] Does your code build cleanly without any errors or warnings? - [ ] Have you used auto closing keywords? @@ -28,9 +31,10 @@ - [ ] Have you squashed all intermediate commits? - [ ] Is there a clear title that explains what the PR does? - [ ] Have you used intuitive function, variable and other naming? -- [ ] Have you requested security / privacy review if needed +- [ ] Have you requested security and/or privacy review if needed - [ ] Have you performed a self review of this PR? -### Manual Test Plan: + +### Manual Test Plan diff --git a/services/skus/datastore.go b/services/skus/datastore.go index 9114105be..e1a80c7bc 100644 --- a/services/skus/datastore.go +++ b/services/skus/datastore.go @@ -32,11 +32,11 @@ const ( errNotFound = model.Error("not found") ) -// Datastore abstracts over the underlying datastore +// Datastore abstracts over the underlying datastore. type Datastore interface { datastore.Datastore - // CreateOrder is used to create an order for payments - CreateOrder(totalPrice decimal.Decimal, merchantID string, status string, currency string, location string, validFor *time.Duration, orderItems []OrderItem, allowedPaymentMethods []string) (*Order, error) + + CreateOrder(ctx context.Context, dbi sqlx.ExtContext, oreq *model.OrderNew, items []model.OrderItem) (*model.Order, error) // SetOrderTrialDays - set the number of days of free trial for this order SetOrderTrialDays(ctx context.Context, orderID *uuid.UUID, days int64) (*Order, error) // GetOrder by ID @@ -101,14 +101,7 @@ type Datastore interface { type orderStore interface { Get(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) GetByExternalID(ctx context.Context, dbi sqlx.QueryerContext, extID string) (*model.Order, error) - Create( - ctx context.Context, - dbi sqlx.QueryerContext, - totalPrice decimal.Decimal, - merchantID, status, currency, location string, - paymentMethods []string, - validFor *time.Duration, - ) (*model.Order, error) + Create(ctx context.Context, dbi sqlx.QueryerContext, oreq *model.OrderNew) (*model.Order, error) SetLastPaidAt(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error SetTrialDays(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID, ndays int64) (*model.Order, error) SetStatus(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, status string) error @@ -301,38 +294,26 @@ func (pg *Postgres) SetOrderTrialDays(ctx context.Context, orderID *uuid.UUID, d return result, nil } -// CreateOrder creates an order with the given total price, merchant ID, status and orderItems. -func (pg *Postgres) CreateOrder(totalPrice decimal.Decimal, merchantID, status, currency, location string, validFor *time.Duration, orderItems []OrderItem, allowedPaymentMethods []string) (*Order, error) { - tx, err := pg.RawDB().Beginx() +// CreateOrder creates an order from the given prototype, and inserts items. +func (pg *Postgres) CreateOrder(ctx context.Context, dbi sqlx.ExtContext, oreq *model.OrderNew, items []model.OrderItem) (*model.Order, error) { + result, err := pg.orderRepo.Create(ctx, dbi, oreq) if err != nil { return nil, err } - defer pg.RollbackTx(tx) - - ctx := context.TODO() - result, err := pg.orderRepo.Create(ctx, tx, totalPrice, merchantID, status, currency, location, allowedPaymentMethods, validFor) - if err != nil { - return nil, err - } - - if status == OrderStatusPaid { - if err := pg.recordOrderPayment(ctx, tx, result.ID, time.Now()); err != nil { + if oreq.Status == OrderStatusPaid { + if err := pg.recordOrderPayment(ctx, dbi, result.ID, time.Now()); err != nil { return nil, fmt.Errorf("failed to record order payment: %w", err) } } - model.OrderItemList(orderItems).SetOrderID(result.ID) + model.OrderItemList(items).SetOrderID(result.ID) - result.Items, err = pg.orderItemRepo.InsertMany(ctx, tx, orderItems...) + result.Items, err = pg.orderItemRepo.InsertMany(ctx, dbi, items...) if err != nil { return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - return result, nil } diff --git a/services/skus/datastore_test.go b/services/skus/datastore_test.go index f20fc1f17..6ff26089c 100644 --- a/services/skus/datastore_test.go +++ b/services/skus/datastore_test.go @@ -4,6 +4,7 @@ package skus import ( "context" + "database/sql" "encoding/json" "os" "strings" @@ -13,6 +14,7 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/golang/mock/gomock" "github.com/jmoiron/sqlx" + "github.com/lib/pq" uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" must "github.com/stretchr/testify/require" @@ -471,16 +473,21 @@ func createOrderAndIssuer(t *testing.T, ctx context.Context, storage Datastore, } validFor := 3600 * time.Second * 24 - order, err := storage.CreateOrder( - decimal.NewFromInt32(int32(test.RandomInt())), - test.RandomString(), - OrderStatusPaid, - test.RandomString(), - test.RandomString(), - &validFor, - orderItems, - methods, - ) + + oreq := &model.OrderNew{ + MerchantID: test.RandomString(), + Currency: test.RandomString(), + Status: model.OrderStatusPaid, + TotalPrice: decimal.NewFromInt(int64(test.RandomInt())), + Location: sql.NullString{ + Valid: true, + String: test.RandomString(), + }, + AllowedPaymentMethods: pq.StringArray(methods), + ValidFor: &validFor, + } + + order, err := storage.CreateOrder(ctx, storage.RawDB(), oreq, orderItems) must.NoError(t, err) { @@ -514,16 +521,19 @@ func (suite *PostgresTestSuite) createTimeLimitedV2OrderCreds(t *testing.T, ctx methods = append(methods, method...) } - order, err := suite.storage.CreateOrder( - decimal.NewFromInt32(int32(test.RandomInt())), - test.RandomString(), - OrderStatusPaid, - test.RandomString(), - test.RandomString(), - nil, - orderItems, - methods, - ) + oreq := &model.OrderNew{ + MerchantID: test.RandomString(), + Currency: test.RandomString(), + Status: model.OrderStatusPaid, + TotalPrice: decimal.NewFromInt(int64(test.RandomInt())), + Location: sql.NullString{ + Valid: true, + String: test.RandomString(), + }, + AllowedPaymentMethods: pq.StringArray(methods), + } + + order, err := suite.storage.CreateOrder(ctx, suite.storage.RawDB(), oreq, orderItems) must.NoError(t, err) repo := repository.NewIssuer() @@ -595,16 +605,19 @@ func (suite *PostgresTestSuite) createOrderCreds(t *testing.T, ctx context.Conte methods = append(methods, method...) } - order, err := suite.storage.CreateOrder( - decimal.NewFromInt32(int32(test.RandomInt())), - test.RandomString(), - OrderStatusPaid, - test.RandomString(), - test.RandomString(), - nil, - orderItems, - methods, - ) + oreq := &model.OrderNew{ + MerchantID: test.RandomString(), + Currency: test.RandomString(), + Status: model.OrderStatusPaid, + TotalPrice: decimal.NewFromInt(int64(test.RandomInt())), + Location: sql.NullString{ + Valid: true, + String: test.RandomString(), + }, + AllowedPaymentMethods: pq.StringArray(methods), + } + + order, err := suite.storage.CreateOrder(ctx, suite.storage.RawDB(), oreq, orderItems) must.NoError(t, err) pk := test.RandomString() diff --git a/services/skus/instrumented_datastore.go b/services/skus/instrumented_datastore.go index 5235797b4..bc6a28657 100644 --- a/services/skus/instrumented_datastore.go +++ b/services/skus/instrumented_datastore.go @@ -11,6 +11,7 @@ import ( "time" "github.com/brave-intl/bat-go/libs/inputs" + "github.com/brave-intl/bat-go/services/skus/model" migrate "github.com/golang-migrate/migrate/v4" "github.com/jmoiron/sqlx" "github.com/prometheus/client_golang/prometheus" @@ -155,18 +156,8 @@ func (_d DatastoreWithPrometheus) CreateKey(merchant string, name string, encryp return _d.base.CreateKey(merchant, name, encryptedSecretKey, nonce) } -// CreateOrder implements Datastore -func (_d DatastoreWithPrometheus) CreateOrder(totalPrice decimal.Decimal, merchantID string, status string, currency string, location string, validFor *time.Duration, orderItems []OrderItem, allowedPaymentMethods []string) (op1 *Order, err error) { - _since := time.Now() - defer func() { - result := "ok" - if err != nil { - result = "error" - } - - datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "CreateOrder", result).Observe(time.Since(_since).Seconds()) - }() - return _d.base.CreateOrder(totalPrice, merchantID, status, currency, location, validFor, orderItems, allowedPaymentMethods) +func (_d DatastoreWithPrometheus) CreateOrder(ctx context.Context, dbi sqlx.ExtContext, req *model.OrderNew, items []model.OrderItem) (op1 *model.Order, err error) { + return _d.base.CreateOrder(ctx, dbi, req, items) } // CreateTransaction implements Datastore diff --git a/services/skus/mockdatastore.go b/services/skus/mockdatastore.go index 8a4bb49b3..88cd749f3 100644 --- a/services/skus/mockdatastore.go +++ b/services/skus/mockdatastore.go @@ -10,6 +10,7 @@ import ( time "time" inputs "github.com/brave-intl/bat-go/libs/inputs" + model "github.com/brave-intl/bat-go/services/skus/model" v4 "github.com/golang-migrate/migrate/v4" gomock "github.com/golang/mock/gomock" sqlx "github.com/jmoiron/sqlx" @@ -163,18 +164,18 @@ func (mr *MockDatastoreMockRecorder) CreateKey(merchant, name, encryptedSecretKe } // CreateOrder mocks base method. -func (m *MockDatastore) CreateOrder(totalPrice decimal.Decimal, merchantID, status, currency, location string, validFor *time.Duration, orderItems []OrderItem, allowedPaymentMethods []string) (*Order, error) { +func (m *MockDatastore) CreateOrder(ctx context.Context, dbi sqlx.ExtContext, req *model.OrderNew, items []model.OrderItem) (*model.Order, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateOrder", totalPrice, merchantID, status, currency, location, validFor, orderItems, allowedPaymentMethods) - ret0, _ := ret[0].(*Order) + ret := m.ctrl.Call(m, "CreateOrder", ctx, dbi, req, items) + ret0, _ := ret[0].(*model.Order) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateOrder indicates an expected call of CreateOrder. -func (mr *MockDatastoreMockRecorder) CreateOrder(totalPrice, merchantID, status, currency, location, validFor, orderItems, allowedPaymentMethods interface{}) *gomock.Call { +func (mr *MockDatastoreMockRecorder) CreateOrder(ctx, dbi, req, items interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrder", reflect.TypeOf((*MockDatastore)(nil).CreateOrder), totalPrice, merchantID, status, currency, location, validFor, orderItems, allowedPaymentMethods) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrder", reflect.TypeOf((*MockDatastore)(nil).CreateOrder), ctx, dbi, req, items) } // CreateTransaction mocks base method. diff --git a/services/skus/model/model.go b/services/skus/model/model.go index ebcbaf9cb..d13859935 100644 --- a/services/skus/model/model.go +++ b/services/skus/model/model.go @@ -93,52 +93,59 @@ func (o *Order) IsRadomPayable() bool { } // CreateStripeCheckoutSession creates a Stripe checkout session for the order. +// +// Deprecated: Use CreateStripeCheckoutSession function instead of this method. func (o *Order) CreateStripeCheckoutSession( email, successURI, cancelURI string, freeTrialDays int64, +) (CreateCheckoutSessionResponse, error) { + return CreateStripeCheckoutSession(o.ID.String(), email, successURI, cancelURI, freeTrialDays, o.Items) +} + +// CreateStripeCheckoutSession creates a Stripe checkout session for the order. +func CreateStripeCheckoutSession( + oid, email, successURI, cancelURI string, + trialDays int64, + items []OrderItem, ) (CreateCheckoutSessionResponse, error) { var custID string if email != "" { - // find the existing customer by email - // so we can use the customer id instead of a customer email - i := customer.List(&stripe.CustomerListParams{ + // Find the existing customer by email to use the customer id instead email. + l := customer.List(&stripe.CustomerListParams{ Email: stripe.String(email), }) - for i.Next() { - custID = i.Customer().ID + for l.Next() { + custID = l.Customer().ID } } - sd := &stripe.CheckoutSessionSubscriptionDataParams{} - // If a free trial is set, apply it. - if freeTrialDays > 0 { - sd.TrialPeriodDays = &freeTrialDays + params := &stripe.CheckoutSessionParams{ + PaymentMethodTypes: stripe.StringSlice([]string{"card"}), + Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), + SuccessURL: stripe.String(successURI), + CancelURL: stripe.String(cancelURI), + ClientReferenceID: stripe.String(oid), + SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{}, + LineItems: OrderItemList(items).stripeLineItems(), } - params := &stripe.CheckoutSessionParams{ - PaymentMethodTypes: stripe.StringSlice([]string{ - "card", - }), - Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), - SuccessURL: stripe.String(successURI), - CancelURL: stripe.String(cancelURI), - ClientReferenceID: stripe.String(o.ID.String()), - SubscriptionData: sd, - LineItems: OrderItemList(o.Items).stripeLineItems(), + // If a free trial is set, apply it. + if trialDays > 0 { + params.SubscriptionData.TrialPeriodDays = &trialDays } if custID != "" { - // try to use existing customer we found by email + // Use existing customer if found. params.Customer = stripe.String(custID) } else if email != "" { - // if we dont have an existing customer, this CustomerEmail param will create a new one + // Otherwise, create a new using email. params.CustomerEmail = stripe.String(email) } - // else we have no record of this email for this checkout session - // the user will be asked for the email, we cannot send an empty customer email as a param + // Otherwise, we have no record of this email for this checkout session. + // ? The user will be asked for the email, we cannot send an empty customer email as a param. - params.SubscriptionData.AddMetadata("orderID", o.ID.String()) + params.SubscriptionData.AddMetadata("orderID", oid) params.AddExtra("allow_promotion_codes", "true") session, err := session.New(params) @@ -264,6 +271,25 @@ type OrderItem struct { IssuerConfig *IssuerConfig `json:"-" db:"-"` } +func (x *OrderItem) IsLeo() bool { + if x == nil { + return false + } + + return x.SKU == "brave-leo-premium" +} + +// OrderNew represents a request to create an order in the database. +type OrderNew struct { + MerchantID string `db:"merchant_id"` + Currency string `db:"currency"` + Status string `db:"status"` + Location sql.NullString `db:"location"` + TotalPrice decimal.Decimal `db:"total_price"` + AllowedPaymentMethods pq.StringArray `db:"allowed_payment_methods"` + ValidFor *time.Duration `db:"valid_for"` +} + // CreateCheckoutSessionResponse represents a checkout session response. type CreateCheckoutSessionResponse struct { SessionID string `json:"checkoutSessionId"` diff --git a/services/skus/service.go b/services/skus/service.go index 0504b49b0..4ca82a44d 100644 --- a/services/skus/service.go +++ b/services/skus/service.go @@ -18,6 +18,7 @@ import ( "github.com/awa/go-iap/appstore" "github.com/getsentry/sentry-go" "github.com/jmoiron/sqlx" + "github.com/lib/pq" "github.com/linkedin/goavro" uuid "github.com/satori/go.uuid" "github.com/segmentio/kafka-go" @@ -279,17 +280,17 @@ func (s *Service) ExternalIDExists(ctx context.Context, externalID string) (bool // CreateOrderFromRequest creates an order from the request func (s *Service) CreateOrderFromRequest(ctx context.Context, req model.CreateOrderRequest) (*Order, error) { - totalPrice := decimal.New(0, 0) + const merchantID = "brave.com" + var ( + totalPrice = decimal.New(0, 0) currency string orderItems []OrderItem location string validFor *time.Duration stripeSuccessURI string stripeCancelURI string - status string allowedPaymentMethods []string - merchantID = "brave.com" numIntervals int numPerInterval = 2 // two per interval credentials to be submitted for signing ) @@ -302,8 +303,9 @@ func (s *Service) CreateOrderFromRequest(ctx context.Context, req model.CreateOr // TODO: we ultimately need to figure out how to provision numPerInterval and numIntervals // on the order item instead of the order itself to support multiple orders with - // different time limited v2 issuers. For now leo sku needs 192 as num per interval - if orderItem.SKU == "brave-leo-premium" { + // different time limited v2 issuers. + // For now leo sku needs 192 as num per interval. + if orderItem.IsLeo() { numPerInterval = 192 // 192 credentials per day for leo } @@ -378,27 +380,40 @@ func (s *Service) CreateOrderFromRequest(ctx context.Context, req model.CreateOr orderItems = append(orderItems, *orderItem) } - // If order consists entirely of zero cost items ( e.g. trials ), we can consider it paid - if totalPrice.IsZero() { - status = OrderStatusPaid - } else { - status = OrderStatusPending - } - - order, err := s.Datastore.CreateOrder( - totalPrice, - merchantID, - status, - currency, - location, - validFor, - orderItems, - allowedPaymentMethods, - ) + oreq := &model.OrderNew{ + MerchantID: merchantID, + Currency: currency, + Status: OrderStatusPending, + TotalPrice: totalPrice, + AllowedPaymentMethods: pq.StringArray(allowedPaymentMethods), + ValidFor: validFor, + } + + // Consider the order paid if it consists entirely of zero cost items (e.g. trials). + if oreq.TotalPrice.IsZero() { + oreq.Status = OrderStatusPaid + } + + if location != "" { + oreq.Location.Valid = true + oreq.Location.String = location + } + + tx, err := s.Datastore.BeginTx() + if err != nil { + return nil, err + } + defer func() { _ = tx.Rollback() }() + + order, err := s.Datastore.CreateOrder(ctx, tx, oreq, orderItems) if err != nil { return nil, fmt.Errorf("failed to create order: %w", err) } + if err := tx.Commit(); err != nil { + return nil, err + } + if !order.IsPaid() { switch { case order.IsStripePayable(): @@ -1633,12 +1648,12 @@ func (s *Service) CreateOrder(ctx context.Context, req *model.CreateOrderRequest return nil, err } - // Check for number of items to be at least one. + // Check for number of items to be above 0. // // Validation should already have taken care of this. - // However, this method does not know about it. - // Therefore, an explicit check is necessary. - if len(items) == 0 { + // This method does not know about it, hence the explicit check. + nitems := len(items) + if nitems == 0 { return nil, model.ErrInvalidOrderRequest } @@ -1655,57 +1670,60 @@ func (s *Service) CreateOrder(ctx context.Context, req *model.CreateOrderRequest return nil, err } - // TODO: Gradually use this tx for other database operations. - // Eventually, move this call to the end of the method. - if err := tx.Commit(); err != nil { - return nil, err + oreq := &model.OrderNew{ + MerchantID: merchID, + Currency: req.Currency, + Status: model.OrderStatusPending, + TotalPrice: model.OrderItemList(items).TotalCost(), + AllowedPaymentMethods: pq.StringArray(req.PaymentMethods), } - totalCost := model.OrderItemList(items).TotalCost() - - status := model.OrderStatusPending - if totalCost.IsZero() { - status = model.OrderStatusPaid + if oreq.TotalPrice.IsZero() { + oreq.Status = model.OrderStatusPaid } - // Use validFor from the first item. + // Location on the order is only defined when there is only one item. // - // TODO: Deprecate the use of valid_for: - // valid_for_iso is now used instead of valid_for for calculating order's expiration time. - // - // The old code in CreateOrderFromRequest does a contradictory thing – it takes validFor from last item. - // It does not make any sense, but it's working because there is only one item normally. - var validFor time.Duration - if items[0].ValidFor != nil { - validFor = *items[0].ValidFor - } - - order, err := s.Datastore.CreateOrder( - totalCost, - merchID, - status, - req.Currency, - // FIXME: Location. + // Multi-item orders have NULL location. + if nitems == 1 && items[0].Location.Valid { + oreq.Location.Valid = true + oreq.Location.String = items[0].Location.String + } + + { + // Use validFor from the first item. // - // The old code in CreateOrderFromRequest contradictory things: - // - it looks as though it supports multiple items (mind the loop) - // - it requires all items to have the same location, at the same time. - // For this to work with bundles, this has to change. - // At this stage (i.e. just adding this new endpoint and switching over to it) - // using the location of the first (and the only) item accomplishes the same result. - items[0].Location.String, - &validFor, - items, - req.PaymentMethods, - ) + // TODO: Deprecate the use of valid_for: + // valid_for_iso is now used instead of valid_for for calculating order's expiration time. + // + // The old code in CreateOrderFromRequest does a contradictory thing – it takes validFor from last item. + // It does not make any sense, but it's working because there is only one item normally. + var vf time.Duration + if items[0].ValidFor != nil { + vf = *items[0].ValidFor + } + + oreq.ValidFor = &vf + } + + order, err := s.Datastore.CreateOrder(ctx, tx, oreq, items) if err != nil { return nil, fmt.Errorf("failed to create order: %w", err) } + if err := tx.Commit(); err != nil { + return nil, err + } + if !order.IsPaid() && order.IsStripePayable() { - if err := s.createStripeSessID(ctx, req, order); err != nil { + ssid, err := s.createStripeSessID(ctx, req, order) + if err != nil { return nil, err } + + if err := s.Datastore.AppendOrderMetadata(ctx, &order.ID, "stripeCheckoutSessionId", ssid); err != nil { + return nil, fmt.Errorf("failed to update order metadata: %w", err) + } } if numIntervals > 0 { @@ -1714,8 +1732,16 @@ func (s *Service) CreateOrder(ctx context.Context, req *model.CreateOrderRequest } } - if err := s.Datastore.AppendOrderMetadataInt(ctx, &order.ID, "numPerInterval", 2); err != nil { - return nil, fmt.Errorf("failed to update order metadata: %w", err) + // Backporting changes from https://github.com/brave-intl/bat-go/pull/1998. + { + numPerInterval := 2 + if nitems == 1 && items[0].IsLeo() { + numPerInterval = 192 + } + + if err := s.Datastore.AppendOrderMetadataInt(ctx, &order.ID, "numPerInterval", numPerInterval); err != nil { + return nil, fmt.Errorf("failed to update order metadata: %w", err) + } } return order, nil @@ -1746,30 +1772,25 @@ func (s *Service) createOrderIssuers(ctx context.Context, dbi sqlx.QueryerContex return numIntervals, nil } -func (s *Service) createStripeSessID(ctx context.Context, req *model.CreateOrderRequestNew, order *model.Order) error { +func (s *Service) createStripeSessID(ctx context.Context, req *model.CreateOrderRequestNew, order *model.Order) (string, error) { oid := order.ID.String() - // This should not happen, but enforce the check anyway. surl, err := req.StripeMetadata.SuccessURL(oid) if err != nil { - return err + return "", err } curl, err := req.StripeMetadata.CancelURL(oid) if err != nil { - return err + return "", err } - sess, err := order.CreateStripeCheckoutSession(req.Email, surl, curl, order.GetTrialDays()) + sess, err := model.CreateStripeCheckoutSession(oid, req.Email, surl, curl, order.GetTrialDays(), order.Items) if err != nil { - return fmt.Errorf("failed to create checkout session: %w", err) + return "", fmt.Errorf("failed to create checkout session: %w", err) } - if err := s.Datastore.AppendOrderMetadata(ctx, &order.ID, "stripeCheckoutSessionId", sess.SessionID); err != nil { - return fmt.Errorf("failed to update order metadata: %w", err) - } - - return nil + return sess.SessionID, nil } func (s *Service) redeemBlindedCred(ctx context.Context, w http.ResponseWriter, kind string, cred *cbr.CredentialRedemption) *handlers.AppError { diff --git a/services/skus/storage/repository/order_item_test.go b/services/skus/storage/repository/order_item_test.go index 250fc1c58..4b1427aff 100644 --- a/services/skus/storage/repository/order_item_test.go +++ b/services/skus/storage/repository/order_item_test.go @@ -6,9 +6,9 @@ import ( "context" "database/sql" "testing" - "time" "github.com/jmoiron/sqlx" + "github.com/lib/pq" uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" should "github.com/stretchr/testify/assert" @@ -184,14 +184,7 @@ func setupDBI() (*sqlx.DB, error) { } type orderCreator interface { - Create( - ctx context.Context, - dbi sqlx.QueryerContext, - totalPrice decimal.Decimal, - merchantID, status, currency, location string, - paymentMethods []string, - validFor *time.Duration, - ) (*model.Order, error) + Create(ctx context.Context, dbi sqlx.QueryerContext, req *model.OrderNew) (*model.Order, error) } func createOrderForTest(ctx context.Context, dbi sqlx.QueryerContext, repo orderCreator) (*model.Order, error) { @@ -200,19 +193,19 @@ func createOrderForTest(ctx context.Context, dbi sqlx.QueryerContext, repo order return nil, err } - methods := []string{"stripe"} - - result, err := repo.Create( - ctx, - dbi, - price, - "brave.com", - "pending", - "USD", - "somelocation", - methods, - nil, - ) + req := &model.OrderNew{ + MerchantID: "brave.com", + Currency: "USD", + Status: "pending", + Location: sql.NullString{ + Valid: true, + String: "somelocation", + }, + TotalPrice: price, + AllowedPaymentMethods: pq.StringArray{"stripe"}, + } + + result, err := repo.Create(ctx, dbi, req) if err != nil { return nil, err } diff --git a/services/skus/storage/repository/repository.go b/services/skus/storage/repository/repository.go index 618373d4e..f4132d12b 100644 --- a/services/skus/storage/repository/repository.go +++ b/services/skus/storage/repository/repository.go @@ -8,9 +8,7 @@ import ( "time" "github.com/jmoiron/sqlx" - "github.com/lib/pq" uuid "github.com/satori/go.uuid" - "github.com/shopspring/decimal" "github.com/brave-intl/bat-go/libs/datastore" @@ -61,15 +59,8 @@ func (r *Order) GetByExternalID(ctx context.Context, dbi sqlx.QueryerContext, ex return result, nil } -// Create creates an order with the given inputs. -func (r *Order) Create( - ctx context.Context, - dbi sqlx.QueryerContext, - totalPrice decimal.Decimal, - merchantID, status, currency, location string, - paymentMethods []string, - validFor *time.Duration, -) (*model.Order, error) { +// Create creates an order with the data in req. +func (r *Order) Create(ctx context.Context, dbi sqlx.QueryerContext, oreq *model.OrderNew) (*model.Order, error) { const q = `INSERT INTO orders (total_price, merchant_id, status, currency, location, allowed_payment_methods, valid_for) VALUES ($1, $2, $3, $4, $5, $6, $7) @@ -79,13 +70,13 @@ func (r *Order) Create( if err := dbi.QueryRowxContext( ctx, q, - totalPrice, - merchantID, - status, - currency, - location, - pq.StringArray(paymentMethods), - validFor, + oreq.TotalPrice, + oreq.MerchantID, + oreq.Status, + oreq.Currency, + oreq.Location, + oreq.AllowedPaymentMethods, + oreq.ValidFor, ).StructScan(result); err != nil { return nil, err } diff --git a/services/skus/storage/repository/repository_test.go b/services/skus/storage/repository/repository_test.go index ea883422b..86c318576 100644 --- a/services/skus/storage/repository/repository_test.go +++ b/services/skus/storage/repository/repository_test.go @@ -689,7 +689,7 @@ func TestOrder_CreateGet(t *testing.T) { }() type tcGiven struct { - order *model.Order + req *model.OrderNew } type tcExpected struct { @@ -707,15 +707,13 @@ func TestOrder_CreateGet(t *testing.T) { { name: "nil_allowed_payment_methods", given: tcGiven{ - order: &model.Order{ + req: &model.OrderNew{ MerchantID: "brave.com", Currency: "USD", Status: "pending", - Location: datastore.NullString{ - NullString: sql.NullString{ - Valid: true, - String: "https://somewhere.brave.software", - }, + Location: sql.NullString{ + Valid: true, + String: "https://somewhere.brave.software", }, TotalPrice: mustDecimalFromString("5"), }, @@ -739,15 +737,13 @@ func TestOrder_CreateGet(t *testing.T) { { name: "empty_allowed_payment_methods", given: tcGiven{ - order: &model.Order{ + req: &model.OrderNew{ MerchantID: "brave.com", Currency: "USD", Status: "pending", - Location: datastore.NullString{ - NullString: sql.NullString{ - Valid: true, - String: "https://somewhere.brave.software", - }, + Location: sql.NullString{ + Valid: true, + String: "https://somewhere.brave.software", }, TotalPrice: mustDecimalFromString("5"), AllowedPaymentMethods: pq.StringArray{}, @@ -773,15 +769,13 @@ func TestOrder_CreateGet(t *testing.T) { { name: "single_allowed_payment_methods", given: tcGiven{ - order: &model.Order{ + req: &model.OrderNew{ MerchantID: "brave.com", Currency: "USD", Status: "pending", - Location: datastore.NullString{ - NullString: sql.NullString{ - Valid: true, - String: "https://somewhere.brave.software", - }, + Location: sql.NullString{ + Valid: true, + String: "https://somewhere.brave.software", }, TotalPrice: mustDecimalFromString("5"), AllowedPaymentMethods: pq.StringArray{"stripe"}, @@ -807,15 +801,13 @@ func TestOrder_CreateGet(t *testing.T) { { name: "many_allowed_payment_methods", given: tcGiven{ - order: &model.Order{ + req: &model.OrderNew{ MerchantID: "brave.com", Currency: "USD", Status: "pending", - Location: datastore.NullString{ - NullString: sql.NullString{ - Valid: true, - String: "https://somewhere.brave.software", - }, + Location: sql.NullString{ + Valid: true, + String: "https://somewhere.brave.software", }, TotalPrice: mustDecimalFromString("5"), AllowedPaymentMethods: pq.StringArray{"stripe", "cash"}, @@ -837,6 +829,28 @@ func TestOrder_CreateGet(t *testing.T) { }, }, }, + + { + name: "empty_location", + given: tcGiven{ + req: &model.OrderNew{ + MerchantID: "brave.com", + Currency: "USD", + Status: "pending", + TotalPrice: mustDecimalFromString("5"), + AllowedPaymentMethods: pq.StringArray{"stripe"}, + }, + }, + exp: tcExpected{ + result: &model.Order{ + MerchantID: "brave.com", + Currency: "USD", + Status: "pending", + TotalPrice: mustDecimalFromString("5"), + AllowedPaymentMethods: pq.StringArray{"stripe"}, + }, + }, + }, } repo := repository.NewOrder() @@ -852,44 +866,35 @@ func TestOrder_CreateGet(t *testing.T) { t.Cleanup(func() { _ = tx.Rollback() }) - act1, err := repo.Create( - ctx, - tx, - tc.given.order.TotalPrice, - tc.given.order.MerchantID, - tc.given.order.Status, - tc.given.order.Currency, - tc.given.order.Location.String, - tc.given.order.AllowedPaymentMethods, - tc.given.order.ValidFor, - ) + actual1, err := repo.Create(ctx, tx, tc.given.req) must.Equal(t, true, errors.Is(err, tc.exp.err)) if tc.exp.err != nil { return } - should.Equal(t, tc.exp.result.MerchantID, act1.MerchantID) - should.Equal(t, tc.exp.result.Currency, act1.Currency) - should.Equal(t, tc.exp.result.Status, act1.Status) - should.Equal(t, tc.exp.result.Location, act1.Location) - should.Equal(t, true, tc.exp.result.TotalPrice.Equal(act1.TotalPrice)) - should.Equal(t, tc.exp.result.AllowedPaymentMethods, act1.AllowedPaymentMethods) - should.Equal(t, tc.exp.result.ValidFor, act1.ValidFor) + should.Equal(t, tc.exp.result.MerchantID, actual1.MerchantID) + should.Equal(t, tc.exp.result.Currency, actual1.Currency) + should.Equal(t, tc.exp.result.Status, actual1.Status) + should.Equal(t, tc.exp.result.Location.Valid, actual1.Location.Valid) + should.Equal(t, tc.exp.result.Location.String, actual1.Location.String) + should.Equal(t, true, tc.exp.result.TotalPrice.Equal(actual1.TotalPrice)) + should.Equal(t, tc.exp.result.AllowedPaymentMethods, actual1.AllowedPaymentMethods) + should.Equal(t, tc.exp.result.ValidFor, actual1.ValidFor) - act2, err := repo.Get(ctx, tx, act1.ID) + actual2, err := repo.Get(ctx, tx, actual1.ID) must.Equal(t, nil, err) - should.Equal(t, act1.ID, act2.ID) - should.Equal(t, act1.MerchantID, act2.MerchantID) - should.Equal(t, act1.Currency, act2.Currency) - should.Equal(t, act1.Status, act2.Status) - should.Equal(t, act1.Location, act2.Location) - should.Equal(t, true, act1.TotalPrice.Equal(act2.TotalPrice)) - should.Equal(t, act1.AllowedPaymentMethods, act2.AllowedPaymentMethods) - should.Equal(t, act1.ValidFor, act2.ValidFor) - should.Equal(t, act1.CreatedAt, act2.CreatedAt) - should.Equal(t, act1.UpdatedAt, act2.UpdatedAt) + should.Equal(t, actual1.ID, actual2.ID) + should.Equal(t, actual1.MerchantID, actual2.MerchantID) + should.Equal(t, actual1.Currency, actual2.Currency) + should.Equal(t, actual1.Status, actual2.Status) + should.Equal(t, actual1.Location, actual2.Location) + should.Equal(t, true, actual1.TotalPrice.Equal(actual2.TotalPrice)) + should.Equal(t, actual1.AllowedPaymentMethods, actual2.AllowedPaymentMethods) + should.Equal(t, actual1.ValidFor, actual2.ValidFor) + should.Equal(t, actual1.CreatedAt, actual2.CreatedAt) + should.Equal(t, actual1.UpdatedAt, actual2.UpdatedAt) }) } } From 3de8c63b05ca64a7e06fc718c1bbad2716310eb7 Mon Sep 17 00:00:00 2001 From: Pavel Brm <5097196+pavelbrm@users.noreply.github.com> Date: Fri, 13 Oct 2023 00:50:00 +1300 Subject: [PATCH 3/4] =?UTF-8?q?Bundles=20=E2=80=93=20Refactor=20validateOr?= =?UTF-8?q?derMerchantAndCaveats=20(#1988)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Handle Location for Multi-Item Orders * Refactor validateOrderMerchantAndCaveats * Keep old placement of some funcs for ease of reviewing * Keep old placement of some funcs for ease of reviewing * Keep old placement of some funcs for ease of reviewing * Fix grammar * Trigger CI * Unexport and rename setTrialDaysRequest * Fix panic with missing datastore * Add comment * Backport changes from 1998 --- services/skus/controllers.go | 78 +++--- services/skus/key.go | 81 +++--- services/skus/key_test.go | 300 ++++++++++++++++------- services/skus/service.go | 12 +- services/skus/storage/repository/mock.go | 16 ++ 5 files changed, 319 insertions(+), 168 deletions(-) diff --git a/services/skus/controllers.go b/services/skus/controllers.go index 93acffd41..9c7901e49 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -275,81 +275,70 @@ func VoteRouter(service *Service, instrumentHandler middleware.InstrumentHandler return r } -// SetOrderTrialDaysInput - SetOrderTrialDays handler input -type SetOrderTrialDaysInput struct { +type setTrialDaysRequest struct { TrialDays int64 `json:"trialDays" valid:"int"` } -// SetOrderTrialDays is the handler for cancelling an order +// SetOrderTrialDays handles requests for setting trial days on orders. func SetOrderTrialDays(service *Service) handlers.AppHandler { return handlers.AppHandler(func(w http.ResponseWriter, r *http.Request) *handlers.AppError { - var ( - ctx = r.Context() - orderID = new(inputs.ID) - ) - if err := inputs.DecodeAndValidateString(context.Background(), orderID, chi.URLParam(r, "orderID")); err != nil { + ctx := r.Context() + orderID := &inputs.ID{} + + if err := inputs.DecodeAndValidateString(ctx, orderID, chi.URLParam(r, "orderID")); err != nil { return handlers.ValidationError( "Error validating request url parameter", - map[string]interface{}{ - "orderID": err.Error(), - }, + map[string]interface{}{"orderID": err.Error()}, ) } // validate order merchant and caveats (to make sure this is the right merch) - if err := service.ValidateOrderMerchantAndCaveats(r, *orderID.UUID()); err != nil { + if err := service.validateOrderMerchantAndCaveats(ctx, *orderID.UUID()); err != nil { return handlers.ValidationError( "Error validating request merchant and caveats", - map[string]interface{}{ - "orderMerchantAndCaveats": err.Error(), - }, + map[string]interface{}{"orderMerchantAndCaveats": err.Error()}, ) } - var input SetOrderTrialDaysInput - err := requestutils.ReadJSON(r.Context(), r.Body, &input) - if err != nil { + req := &setTrialDaysRequest{} + if err := requestutils.ReadJSON(ctx, r.Body, req); err != nil { return handlers.WrapError(err, "Error in request body", http.StatusBadRequest) } - _, err = govalidator.ValidateStruct(input) - if err != nil { + if _, err := govalidator.ValidateStruct(req); err != nil { return handlers.WrapValidationError(err) } - err = service.SetOrderTrialDays(ctx, orderID.UUID(), input.TrialDays) - if err != nil { + if err := service.SetOrderTrialDays(ctx, orderID.UUID(), req.TrialDays); err != nil { return handlers.WrapError(err, "Error setting the trial days on the order", http.StatusInternalServerError) } - return handlers.RenderContent(r.Context(), nil, w, http.StatusOK) + return handlers.RenderContent(ctx, nil, w, http.StatusOK) }) } -// CancelOrder is the handler for cancelling an order +// CancelOrder handles requests for cancelling orders. func CancelOrder(service *Service) handlers.AppHandler { return handlers.AppHandler(func(w http.ResponseWriter, r *http.Request) *handlers.AppError { - var orderID = new(inputs.ID) - if err := inputs.DecodeAndValidateString(context.Background(), orderID, chi.URLParam(r, "orderID")); err != nil { + ctx := r.Context() + orderID := &inputs.ID{} + + if err := inputs.DecodeAndValidateString(ctx, orderID, chi.URLParam(r, "orderID")); err != nil { return handlers.ValidationError( "Error validating request url parameter", - map[string]interface{}{ - "orderID": err.Error(), - }, + map[string]interface{}{"orderID": err.Error()}, ) } - err := service.ValidateOrderMerchantAndCaveats(r, *orderID.UUID()) - if err != nil { + if err := service.validateOrderMerchantAndCaveats(ctx, *orderID.UUID()); err != nil { return handlers.WrapError(err, "Error validating auth merchant and caveats", http.StatusForbidden) } - err = service.CancelOrder(*orderID.UUID()) - if err != nil { + if err := service.CancelOrder(*orderID.UUID()); err != nil { return handlers.WrapError(err, "Error retrieving the order", http.StatusInternalServerError) } - return handlers.RenderContent(r.Context(), nil, w, http.StatusOK) + return handlers.RenderContent(ctx, nil, w, http.StatusOK) }) } @@ -656,33 +645,28 @@ func GetOrderCreds(service *Service) handlers.AppHandler { } } -// DeleteOrderCreds is the handler for deleting order credentials +// DeleteOrderCreds handles requests for deleting order credentials. func DeleteOrderCreds(service *Service) handlers.AppHandler { return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { - var orderID = new(inputs.ID) - if err := inputs.DecodeAndValidateString(context.Background(), orderID, chi.URLParam(r, "orderID")); err != nil { + ctx := r.Context() + orderID := &inputs.ID{} + if err := inputs.DecodeAndValidateString(ctx, orderID, chi.URLParam(r, "orderID")); err != nil { return handlers.ValidationError( "Error validating request url parameter", - map[string]interface{}{ - "orderID": err.Error(), - }, + map[string]interface{}{"orderID": err.Error()}, ) } - err := service.ValidateOrderMerchantAndCaveats(r, *orderID.UUID()) - if err != nil { + if err := service.validateOrderMerchantAndCaveats(ctx, *orderID.UUID()); err != nil { return handlers.WrapError(err, "Error validating auth merchant and caveats", http.StatusForbidden) } - // is signed param isSigned := r.URL.Query().Get("isSigned") == "true" - - err = service.DeleteOrderCreds(r.Context(), *orderID.UUID(), isSigned) - if err != nil { + if err := service.DeleteOrderCreds(ctx, *orderID.UUID(), isSigned); err != nil { return handlers.WrapError(err, "Error deleting credentials", http.StatusBadRequest) } - return handlers.RenderContent(r.Context(), "Order credentials successfully deleted", w, http.StatusOK) + return handlers.RenderContent(ctx, "Order credentials successfully deleted", w, http.StatusOK) } } diff --git a/services/skus/key.go b/services/skus/key.go index e32819476..9b53c57c5 100644 --- a/services/skus/key.go +++ b/services/skus/key.go @@ -12,18 +12,31 @@ import ( "os" "time" + uuid "github.com/satori/go.uuid" + "github.com/brave-intl/bat-go/libs/cryptography" "github.com/brave-intl/bat-go/libs/httpsignature" "github.com/brave-intl/bat-go/libs/middleware" - uuid "github.com/satori/go.uuid" + + "github.com/brave-intl/bat-go/services/skus/model" +) + +const ( + // What the merchant key length should be. + keyLength = 24 + + errInvalidMerchant model.Error = "merchant was missing from context" + errMerchantMismatch model.Error = "Order merchant does not match authentication" + errLocationMismatch model.Error = "Order location does not match authentication" + errUnexpectedSKUCvt model.Error = "SKU caveat is not supported on order endpoints" ) -// EncryptionKey for encrypting secrets -var EncryptionKey = os.Getenv("ENCRYPTION_KEY") -var byteEncryptionKey [32]byte +var ( + // EncryptionKey for encrypting secrets. + EncryptionKey = os.Getenv("ENCRYPTION_KEY") -// What the merchant key length should be -var keyLength = 24 + byteEncryptionKey [32]byte +) type caveatsCtxKey struct{} type merchantCtxKey struct{} @@ -66,9 +79,9 @@ func (key *Key) GetSecretKey() (*string, error) { func randomString(n int) (string, error) { b := make([]byte, n) - _, err := rand.Read(b) + // Note that err == nil only if we read len(b) bytes. - if err != nil { + if _, err := rand.Read(b); err != nil { return "", err } @@ -130,54 +143,44 @@ func (s *Service) LookupVerifier(ctx context.Context, keyID string) (context.Con return ctx, &verifier, nil } -// GetCaveats returns any authorized caveats that have been stored in the context -func GetCaveats(ctx context.Context) map[string]string { +// caveatsFromCtx returns authorized caveats from ctx. +func caveatsFromCtx(ctx context.Context) map[string]string { caveats, ok := ctx.Value(caveatsCtxKey{}).(map[string]string) if !ok { return nil } + return caveats } -// GetMerchant returns any authorized merchant that has been stored in the context -func GetMerchant(ctx context.Context) (string, error) { +// merchantFromCtx returns an authorized merchant from ctx. +func merchantFromCtx(ctx context.Context) (string, error) { merchant, ok := ctx.Value(merchantCtxKey{}).(string) if !ok { - return "", errors.New("merchant was missing from context") + return "", errInvalidMerchant } + return merchant, nil } -// ValidateOrderMerchantAndCaveats checks that the current authentication of the request has -// permissions to this order by cross-checking the merchant and caveats in context -func (s *Service) ValidateOrderMerchantAndCaveats(r *http.Request, orderID uuid.UUID) error { - merchant, err := GetMerchant(r.Context()) +// validateOrderMerchantAndCaveats checks that the current authentication of the request has +// permissions to this order by cross-checking the merchant and caveats in context. +func (s *Service) validateOrderMerchantAndCaveats(ctx context.Context, oid uuid.UUID) error { + merchant, err := merchantFromCtx(ctx) if err != nil { return err } - caveats := GetCaveats(r.Context()) - order, err := s.Datastore.GetOrder(orderID) + order, err := s.orderRepo.Get(ctx, s.Datastore.RawDB(), oid) if err != nil { return err } if order.MerchantID != merchant { - return errors.New("Order merchant does not match authentication") + return errMerchantMismatch } - if caveats != nil { - if location, ok := caveats["location"]; ok { - if order.Location.Valid && order.Location.String != location { - return errors.New("Order location does not match authentication") - } - } - - if _, ok := caveats["sku"]; ok { - return errors.New("SKU caveat is not supported on order endpoints") - } - } - return nil + return validateOrderCvt(order, caveatsFromCtx(ctx)) } // NewAuthMwr returns a handler that authorises requests via http signature or simple tokens. @@ -212,3 +215,17 @@ func NewAuthMwr(ks httpsignature.Keystore) func(http.Handler) http.Handler { }) } } + +func validateOrderCvt(ord *model.Order, cvt map[string]string) error { + if loc, ok := cvt["location"]; ok && ord.Location.Valid { + if ord.Location.String != loc { + return errLocationMismatch + } + } + + if _, ok := cvt["sku"]; ok { + return errUnexpectedSKUCvt + } + + return nil +} diff --git a/services/skus/key_test.go b/services/skus/key_test.go index 48321364f..19a30623b 100644 --- a/services/skus/key_test.go +++ b/services/skus/key_test.go @@ -3,6 +3,7 @@ package skus import ( "context" "crypto" + "database/sql" "encoding/base64" "encoding/hex" "fmt" @@ -16,11 +17,14 @@ import ( "github.com/jmoiron/sqlx" uuid "github.com/satori/go.uuid" "github.com/stretchr/testify/assert" + must "github.com/stretchr/testify/require" "github.com/brave-intl/bat-go/libs/cryptography" "github.com/brave-intl/bat-go/libs/datastore" "github.com/brave-intl/bat-go/libs/httpsignature" "github.com/brave-intl/bat-go/libs/middleware" + + "github.com/brave-intl/bat-go/services/skus/model" "github.com/brave-intl/bat-go/services/skus/storage/repository" ) @@ -159,11 +163,11 @@ func TestMerchantSignedMiddleware(t *testing.T) { fn2 := func(w http.ResponseWriter, r *http.Request) { // with simple auth legacy mode there are no caveats - caveats := GetCaveats(r.Context()) + caveats := caveatsFromCtx(r.Context()) assert.Nil(t, caveats) // and the merchant is always brave.com - merchant, err := GetMerchant(r.Context()) + merchant, err := merchantFromCtx(r.Context()) assert.NoError(t, err) assert.Equal(t, merchant, "brave.com") } @@ -184,10 +188,10 @@ func TestMerchantSignedMiddleware(t *testing.T) { // Test that merchant signed works and sets caveats / merchant correctly fn3 := func(w http.ResponseWriter, r *http.Request) { - caveats := GetCaveats(r.Context()) + caveats := caveatsFromCtx(r.Context()) assert.Equal(t, caveats, expectedCaveats) - merchant, err := GetMerchant(r.Context()) + merchant, err := merchantFromCtx(r.Context()) assert.NoError(t, err) assert.Equal(t, merchant, expectedMerchant) } @@ -255,96 +259,222 @@ func TestMerchantSignedMiddleware(t *testing.T) { } func TestValidateOrderMerchantAndCaveats(t *testing.T) { - db, mock, _ := sqlmock.New() - service := &Service{} - service.Datastore = Datastore( - &Postgres{ - Postgres: datastore.Postgres{ - DB: sqlx.NewDb(db, "postgres"), + type tcGiven struct { + orderID uuid.UUID + merch string + cvt map[string]string + repo *repository.MockOrder + } + + type testCase struct { + name string + given tcGiven + exp error + } + + tests := []testCase{ + { + name: "invalid_order", + given: tcGiven{ + orderID: uuid.Must(uuid.FromString("0fb1d6ba-5d39-4f69-830b-c92c4640c86e")), + merch: "brave.com", + + repo: &repository.MockOrder{ + FnGet: func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { + return nil, model.ErrOrderNotFound + }, + }, + }, + exp: model.ErrOrderNotFound, + }, + + { + name: "merchant_no_caveats", + given: tcGiven{ + orderID: uuid.Must(uuid.FromString("056bf179-1c07-4787-bd36-db51a83ad139")), + merch: "brave.com", + + repo: &repository.MockOrder{ + FnGet: func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { + result := &model.Order{ + ID: uuid.Must(uuid.FromString("056bf179-1c07-4787-bd36-db51a83ad139")), + Currency: "BAT", + MerchantID: "brave.com", + Location: datastore.NullString{ + NullString: sql.NullString{ + Valid: true, + String: "test.brave.com", + }, + }, + Status: "paid", + } + + return result, nil + }, + }, + }, + }, + + { + name: "incorrect_merchant", + given: tcGiven{ + orderID: uuid.Must(uuid.FromString("056bf179-1c07-4787-bd36-db51a83ad139")), + merch: "brave.software", + + repo: &repository.MockOrder{ + FnGet: func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { + result := &model.Order{ + ID: uuid.Must(uuid.FromString("056bf179-1c07-4787-bd36-db51a83ad139")), + Currency: "BAT", + MerchantID: "brave.com", + Location: datastore.NullString{ + NullString: sql.NullString{ + Valid: true, + String: "test.brave.com", + }, + }, + Status: "paid", + } + + return result, nil + }, + }, + }, + exp: errMerchantMismatch, + }, + + { + name: "merchant_location", + given: tcGiven{ + orderID: uuid.Must(uuid.FromString("056bf179-1c07-4787-bd36-db51a83ad139")), + merch: "brave.com", + cvt: map[string]string{"location": "test.brave.com"}, + + repo: &repository.MockOrder{ + FnGet: func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { + result := &model.Order{ + ID: uuid.Must(uuid.FromString("056bf179-1c07-4787-bd36-db51a83ad139")), + Currency: "BAT", + MerchantID: "brave.com", + Location: datastore.NullString{ + NullString: sql.NullString{ + Valid: true, + String: "test.brave.com", + }, + }, + Status: "paid", + } + + return result, nil + }, + }, + }, + }, + + { + name: "incorrect_location", + given: tcGiven{ + orderID: uuid.Must(uuid.FromString("056bf179-1c07-4787-bd36-db51a83ad139")), + merch: "brave.com", + cvt: map[string]string{"location": "test.brave.software"}, + + repo: &repository.MockOrder{ + FnGet: func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { + result := &model.Order{ + ID: uuid.Must(uuid.FromString("056bf179-1c07-4787-bd36-db51a83ad139")), + Currency: "BAT", + MerchantID: "brave.com", + Location: datastore.NullString{ + NullString: sql.NullString{ + Valid: true, + String: "test.brave.com", + }, + }, + Status: "paid", + } + + return result, nil + }, + }, + }, + exp: errLocationMismatch, + }, + + { + name: "unexpected_sku", + given: tcGiven{ + orderID: uuid.Must(uuid.FromString("056bf179-1c07-4787-bd36-db51a83ad139")), + merch: "brave.com", + cvt: map[string]string{"location": "test.brave.com", "sku": "some_sku"}, + + repo: &repository.MockOrder{ + FnGet: func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { + result := &model.Order{ + ID: uuid.Must(uuid.FromString("056bf179-1c07-4787-bd36-db51a83ad139")), + Currency: "BAT", + MerchantID: "brave.com", + Location: datastore.NullString{ + NullString: sql.NullString{ + Valid: true, + String: "test.brave.com", + }, + }, + Status: "paid", + } + + return result, nil + }, + }, + }, + exp: errUnexpectedSKUCvt, + }, + + { + name: "empty_order_location", + given: tcGiven{ + orderID: uuid.Must(uuid.FromString("056bf179-1c07-4787-bd36-db51a83ad139")), + merch: "brave.com", + cvt: map[string]string{"location": "test.brave.com"}, + + repo: &repository.MockOrder{ + FnGet: func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { + result := &model.Order{ + ID: uuid.Must(uuid.FromString("056bf179-1c07-4787-bd36-db51a83ad139")), + Currency: "BAT", + MerchantID: "brave.com", + Status: "paid", + } + + return result, nil + }, + }, }, - orderRepo: repository.NewOrder(), - orderItemRepo: repository.NewOrderItem(), - orderPayHistory: repository.NewOrderPayHistory(), }, - ) - expectedOrderID := uuid.NewV4() - - cases := []validateOrderMerchantAndCaveatsTestCase{ - {"brave.com", nil, uuid.NewV4(), false, "invalid order should fail"}, - {"brave.com", nil, expectedOrderID, true, "correct merchant and no caveats should succeed"}, - {"brave.software", nil, expectedOrderID, false, "incorrect merchant should fail"}, - {"brave.com", map[string]string{"location": "test.brave.com"}, expectedOrderID, true, "correct merchant and location caveat should succeed"}, - {"brave.com", map[string]string{"location": "test.brave.software"}, expectedOrderID, false, "incorrect location caveat should fail"}, - {"brave.com", map[string]string{"sku": "example-sku"}, expectedOrderID, false, "sku caveat is not supported"}, } - for _, testCase := range cases { - itemRows := sqlmock.NewRows([]string{}) - orderRows := sqlmock.NewRows([]string{ - "id", - "created_at", - "currency", - "updated_at", - "total_price", - "merchant_id", - "location", - "status", - "allowed_payment_methods", - "metadata", - "valid_for", - "last_paid_at", - "expires_at", - }) - orderRows.AddRow( - expectedOrderID.String(), - time.Now(), - "BAT", - time.Now(), - "0", - "brave.com", - "test.brave.com", - "paid", - nil, - nil, - nil, - nil, - nil, - ) - - mock.ExpectQuery(` -^SELECT (.+) FROM orders* -`). - WithArgs(expectedOrderID). - WillReturnRows(orderRows) - mock.ExpectQuery(` -^SELECT (.+) FROM order_items* -`). - WithArgs(expectedOrderID). - WillReturnRows(itemRows) + // Need a database instance in Datastore. + // Not using mocks (as the suppressed return value suggests). + dbi, _, err := sqlmock.New() + must.Equal(t, nil, err) - ValidateOrderMerchantAndCaveats(t, service, testCase) + ds := &Postgres{ + Postgres: datastore.Postgres{DB: sqlx.NewDb(dbi, "postgres")}, } -} -type validateOrderMerchantAndCaveatsTestCase struct { - merchant string - caveats map[string]string - orderID uuid.UUID - expectedSuccess bool - explanation string -} + for i := range tests { + tc := tests[i] -func ValidateOrderMerchantAndCaveats(t *testing.T, service *Service, testCase validateOrderMerchantAndCaveatsTestCase) { - ctx := context.WithValue(context.Background(), merchantCtxKey{}, testCase.merchant) - ctx = context.WithValue(ctx, caveatsCtxKey{}, testCase.caveats) + t.Run(tc.name, func(t *testing.T) { + ctx := context.WithValue(context.Background(), merchantCtxKey{}, tc.given.merch) + ctx = context.WithValue(ctx, caveatsCtxKey{}, tc.given.cvt) - req, err := http.NewRequestWithContext(ctx, "GET", "/hello-world", nil) - assert.NoError(t, err) + svc := &Service{ + Datastore: ds, + orderRepo: tc.given.repo, + } - err = service.ValidateOrderMerchantAndCaveats(req, testCase.orderID) - if testCase.expectedSuccess { - assert.NoError(t, err, testCase.explanation) - } else { - assert.Error(t, err, testCase.explanation) + err := svc.validateOrderMerchantAndCaveats(ctx, tc.given.orderID) + assert.Equal(t, tc.exp, err) + }) } } diff --git a/services/skus/service.go b/services/skus/service.go index 4ca82a44d..3222ee2ee 100644 --- a/services/skus/service.go +++ b/services/skus/service.go @@ -84,9 +84,13 @@ const ( defaultOverlap = 5 ) +type orderStoreSvc interface { + Get(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) +} + // Service contains datastore type Service struct { - orderRepo orderStore + orderRepo orderStoreSvc issuerRepo issuerStore // TODO: Eventually remove it. @@ -150,7 +154,7 @@ func (s *Service) InitKafka(ctx context.Context) error { } // InitService creates a service using the passed datastore and clients configured from the environment. -func InitService(ctx context.Context, datastore Datastore, walletService *wallet.Service, orderRepo orderStore, issuerRepo issuerStore) (*Service, error) { +func InitService(ctx context.Context, datastore Datastore, walletService *wallet.Service, orderRepo orderStoreSvc, issuerRepo issuerStore) (*Service, error) { sublogger := logging.Logger(ctx, "payments").With().Str("func", "InitService").Logger() // setup the in app purchase clients initClients(ctx) @@ -1316,7 +1320,7 @@ type credential interface { func (s *Service) verifyCredential(ctx context.Context, cred credential, w http.ResponseWriter) *handlers.AppError { logger := logging.Logger(ctx, "verifyCredential") - merchant, err := GetMerchant(ctx) + merchant, err := merchantFromCtx(ctx) if err != nil { logger.Error().Err(err).Msg("failed to get the merchant from the context") return handlers.WrapError(err, "Error getting auth merchant", http.StatusInternalServerError) @@ -1324,7 +1328,7 @@ func (s *Service) verifyCredential(ctx context.Context, cred credential, w http. logger.Debug().Str("merchant", merchant).Msg("got merchant from the context") - caveats := GetCaveats(ctx) + caveats := caveatsFromCtx(ctx) if cred.GetMerchantID(ctx) != merchant { logger.Warn(). diff --git a/services/skus/storage/repository/mock.go b/services/skus/storage/repository/mock.go index 71fe9e8af..d12fdca58 100644 --- a/services/skus/storage/repository/mock.go +++ b/services/skus/storage/repository/mock.go @@ -10,6 +10,22 @@ import ( "github.com/brave-intl/bat-go/services/skus/model" ) +type MockOrder struct { + FnGet func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) +} + +func (r *MockOrder) Get(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { + if r.FnGet == nil { + result := &model.Order{ + ID: uuid.NewV4(), + } + + return result, nil + } + + return r.FnGet(ctx, dbi, id) +} + type MockIssuer struct { FnGetByMerchID func(ctx context.Context, dbi sqlx.QueryerContext, merchID string) (*model.Issuer, error) FnGetByPubKey func(ctx context.Context, dbi sqlx.QueryerContext, pubKey string) (*model.Issuer, error) From 7cdb520d6926008bbfe33f1dcb619c270ccc8a5f Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Fri, 13 Oct 2023 12:12:43 +0100 Subject: [PATCH 4/4] fix: multi redeem tlv2 success response (#2129) change tlv2 response --- services/skus/service.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/services/skus/service.go b/services/skus/service.go index 3222ee2ee..f4844ec85 100644 --- a/services/skus/service.go +++ b/services/skus/service.go @@ -1829,7 +1829,11 @@ func (s *Service) redeemBlindedCred(ctx context.Context, w http.ResponseWriter, return handlers.WrapError(err, "Error verifying credentials", http.StatusInternalServerError) } - return handlers.RenderContent(ctx, &blindedCredVrfResult{ID: cred.TokenPreimage}, w, http.StatusOK) + // TODO(clD11): cleanup after quick fix + if kind == timeLimitedV2 { + return handlers.RenderContent(ctx, &blindedCredVrfResult{ID: cred.TokenPreimage}, w, http.StatusOK) + } + return handlers.RenderContent(ctx, "Credentials successfully verified", w, http.StatusOK) } func createOrderItems(req *model.CreateOrderRequestNew) ([]model.OrderItem, error) {