From ac49405aefedfda3e79dac21071874d953517195 Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Tue, 9 Jan 2024 19:36:01 +0000 Subject: [PATCH 01/18] feat: implement solana address linking (#2280) * feat: implement solana address linking * feat: implement solana address linking go mod tidy * feat: implement solana address linking go mod tidy * feat: implement solana address linking add repo tests * feat: implement solana address linking address pr comments * feat: implement solana address linking address pr comments * feat: resolve pr comments implement solana address linking --- .gitignore | 3 + docker-compose.dev-refresh.yml | 1 + docker-compose.yml | 2 + libs/context/keys.go | 2 + libs/datastore/postgres.go | 6 +- libs/ptr/ptr.go | 7 - libs/wallet/wallet.go | 204 +++++++++++ libs/wallet/wallet_test.go | 335 ++++++++++++++++++ migrations/0064_challenges.down.sql | 1 + migrations/0064_challenges.up.sql | 6 + migrations/0065_allow_list.down.sql | 1 + migrations/0065_allow_list.up.sql | 4 + ...todian_check_custodian_add_solana.down.sql | 2 + ...ustodian_check_custodian_add_solana.up.sql | 2 + services/go.mod | 2 +- services/grant/cmd/grant.go | 14 +- services/grant/cmd/grant_test.go | 2 + services/wallet/cmd/rest_run.go | 18 +- ...keystore_test.go => controller_v3_test.go} | 208 +++++++++-- services/wallet/controllers_v3.go | 160 +++++++-- ..._v3_test.go => controllers_v3_pvt_test.go} | 177 +++------ services/wallet/controllers_v4_test.go | 51 ++- services/wallet/datastore.go | 110 ++++-- services/wallet/datastore_test.go | 8 +- services/wallet/instrumented_datastore.go | 5 +- services/wallet/model/model.go | 49 ++- services/wallet/model/model_test.go | 54 +++ services/wallet/service.go | 276 ++++++++++----- services/wallet/storage/storage.go | 91 +++++ services/wallet/storage/storage_test.go | 325 +++++++++++++++++ tools/go.mod | 1 + tools/go.sum | 2 + 32 files changed, 1790 insertions(+), 339 deletions(-) create mode 100644 libs/wallet/wallet_test.go create mode 100644 migrations/0064_challenges.down.sql create mode 100644 migrations/0064_challenges.up.sql create mode 100644 migrations/0065_allow_list.down.sql create mode 100644 migrations/0065_allow_list.up.sql create mode 100644 migrations/0066_wallet_custodian_check_custodian_add_solana.down.sql create mode 100644 migrations/0066_wallet_custodian_check_custodian_add_solana.up.sql rename services/wallet/{keystore_test.go => controller_v3_test.go} (75%) rename services/wallet/{controllers_v3_test.go => controllers_v3_pvt_test.go} (91%) create mode 100644 services/wallet/model/model_test.go create mode 100644 services/wallet/storage/storage.go create mode 100644 services/wallet/storage/storage_test.go diff --git a/.gitignore b/.gitignore index fbf196191..19725210b 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,6 @@ share-0.gpg fetch-quote.json main/main + +go.work +go.work.sum diff --git a/docker-compose.dev-refresh.yml b/docker-compose.dev-refresh.yml index ca5c07ea3..6f18e8cf6 100644 --- a/docker-compose.dev-refresh.yml +++ b/docker-compose.dev-refresh.yml @@ -56,3 +56,4 @@ services: - UPHOLD_ACCESS_TOKEN - "RATIOS_SERVICE=https://ratios.rewards.bravesoftware.com" - RATIOS_TOKEN + - "DAPP_ALLOWED_CORS_ORIGINS=https://my-dapp.com" diff --git a/docker-compose.yml b/docker-compose.yml index e93b4b75a..2cfc5c41f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -80,6 +80,7 @@ services: - TEST_RUN - TOKEN_LIST - UPHOLD_ACCESS_TOKEN + - "DAPP_ALLOWED_CORS_ORIGINS=https://my-dapp.com" volumes: - ./test/secrets:/etc/kafka/secrets - ./migrations:/src/migrations @@ -167,6 +168,7 @@ services: - TEST_RUN - TOKEN_LIST - UPHOLD_ACCESS_TOKEN + - "DAPP_ALLOWED_CORS_ORIGINS=https://my-dapp.com" volumes: - ./test/secrets:/etc/kafka/secrets - ./migrations:/src/migrations diff --git a/libs/context/keys.go b/libs/context/keys.go index 9418f611b..0104850a8 100644 --- a/libs/context/keys.go +++ b/libs/context/keys.go @@ -128,6 +128,8 @@ const ( DisableGeminiLinkingCTXKey CTXKey = "disable_gemini_linking" // DisableBitflyerLinkingCTXKey - this informs if bitflyer linking is enabled DisableBitflyerLinkingCTXKey CTXKey = "disable_bitflyer_linking" + // DisableSolanaLinkingCTXKey - this informs if solana linking is enabled + DisableSolanaLinkingCTXKey CTXKey = "disable_solana_linking" // RadomWebhookSecretCTXKey - the webhook secret key for radom integration RadomWebhookSecretCTXKey CTXKey = "radom_webhook_secret" diff --git a/libs/datastore/postgres.go b/libs/datastore/postgres.go index 1fde6af11..d9d8ac9b0 100644 --- a/libs/datastore/postgres.go +++ b/libs/datastore/postgres.go @@ -13,8 +13,8 @@ import ( appctx "github.com/brave-intl/bat-go/libs/context" "github.com/brave-intl/bat-go/libs/logging" "github.com/brave-intl/bat-go/libs/metrics" - sentry "github.com/getsentry/sentry-go" - migrate "github.com/golang-migrate/migrate/v4" + "github.com/getsentry/sentry-go" + "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/postgres" "github.com/jmoiron/sqlx" "github.com/prometheus/client_golang/prometheus" @@ -41,7 +41,7 @@ var ( } dbs = map[string]*sqlx.DB{} // CurrentMigrationVersion holds the default migration version - CurrentMigrationVersion = uint(63) + CurrentMigrationVersion = uint(66) // MigrationTracks holds the migration version for a given track (eyeshade, promotion, wallet) MigrationTracks = map[string]uint{ "eyeshade": 20, diff --git a/libs/ptr/ptr.go b/libs/ptr/ptr.go index 6c1667b17..46027145f 100644 --- a/libs/ptr/ptr.go +++ b/libs/ptr/ptr.go @@ -2,15 +2,8 @@ package ptr import ( "time" - - uuid "github.com/satori/go.uuid" ) -// FromUUID returns pointer to uuid -func FromUUID(u uuid.UUID) *uuid.UUID { - return &u -} - // FromString returns pointer to string func FromString(s string) *string { return &s diff --git a/libs/wallet/wallet.go b/libs/wallet/wallet.go index 1c3d89aac..40575bdbd 100644 --- a/libs/wallet/wallet.go +++ b/libs/wallet/wallet.go @@ -3,10 +3,15 @@ package wallet import ( "context" + "crypto/ed25519" + "encoding/base64" + "encoding/hex" "fmt" + "strings" "time" "github.com/brave-intl/bat-go/libs/altcurrency" + "github.com/btcsuite/btcutil/base58" uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" ) @@ -89,3 +94,202 @@ type Wallet interface { // ListTransactions for this wallet, limit number of transactions returned ListTransactions(ctx context.Context, limit int, startDate time.Time) ([]TransactionInfo, error) } + +type SolanaLinkReq struct { + Pub string + Sig string + Msg string + Nonce string +} + +func (w *Info) LinkSolanaAddress(_ context.Context, s SolanaLinkReq) error { + if err := w.isLinkReqValid(s); err != nil { + return newLinkSolanaAddressError(err) + } + + w.UserDepositDestination = s.Pub + + return nil +} + +func (w *Info) isLinkReqValid(s SolanaLinkReq) error { + if err := verifySolanaSignature(s.Pub, s.Msg, s.Sig); err != nil { + return err + } + + p := newSolMsgParser(w.ID, s.Pub, s.Nonce) + rm, err := p.parse(s.Msg) + if err != nil { + return fmt.Errorf("parsing error: %w", err) + } + + if err := verifyRewardsSignature(w.PublicKey, rm.msg, rm.sig); err != nil { + return err + } + + return nil +} + +const publicKeyLength = 32 + +const ( + errBadPublicKeyLength Error = "wallet: bad public key length" + errInvalidSolanaSignature Error = "wallet: invalid solana signature for message and public key" +) + +func verifySolanaSignature(pub, msg, sig string) error { + b := base58.Decode(pub) + if len(b) != publicKeyLength { + return fmt.Errorf("error verifying solana signature: %w", errBadPublicKeyLength) + } + pubKey := ed25519.PublicKey(b) + + decSig, err := base64.URLEncoding.DecodeString(sig) + if err != nil { + return fmt.Errorf("error decoding solana signature: %w", err) + } + + if !ed25519.Verify(pubKey, []byte(msg), decSig) { + return errInvalidSolanaSignature + } + + return nil +} + +type rewMsg struct { + msg string + sig string +} + +const ( + errInvalidPartsLineBreak Error = "wallet: invalid number of lines" + errInvalidPartsColon Error = "wallet: invalid parts colon" + errInvalidParts Error = "wallet: invalid number of parts" + errInvalidPaymentID Error = "wallet: payment id does not match" + errInvalidSolanaPubKey Error = "wallet: solana public key does not match" + errInvalidRewardsMessage Error = "wallet: invalid rewards message" +) + +type solMsgParser struct { + paymentID string + solPub string + nonce string +} + +func newSolMsgParser(paymentID, solPub, nonce string) solMsgParser { + return solMsgParser{ + paymentID: paymentID, + solPub: solPub, + nonce: nonce, + } +} + +// parse parses a linking message and returns the rewards part. +// +// For a message to parse successfully it must be a valid format and successfully match the +// configured parser parameters paymentID, solPub and nonce. +// +// The message format should be three lines with a colon delimiter. For example, +// : +// : +// :.. +// +// The final part is referred to as the rewards message and has the +// format .. we only need to compare +// the .. when parsing. +// +// The returned rewMsg will contain both the message and signature. For example, +// +// rewMsg { +// msg: ., +// sig: , +// } +func (s *solMsgParser) parse(msg string) (rewMsg, error) { + lines := strings.Split(msg, "\n") + if len(lines) != 3 { + return rewMsg{}, errInvalidPartsLineBreak + } + + parts := make([]string, 0, 3) + for i := range lines { + p := strings.Split(lines[i], ":") + if len(p) != 2 { + return rewMsg{}, errInvalidPartsColon + } + parts = append(parts, p[1]) + } + + if len(parts) != 3 { + return rewMsg{}, errInvalidParts + } + + if parts[0] != s.paymentID { + return rewMsg{}, errInvalidPaymentID + } + + if parts[1] != s.solPub { + return rewMsg{}, errInvalidSolanaPubKey + } + + // Compare the final part minus the signature. + exp := s.paymentID + "." + s.nonce + "." + for i := range exp { + if parts[2][i] != exp[i] { + return rewMsg{}, errInvalidRewardsMessage + } + } + + rm := rewMsg{ + msg: parts[2][:len(exp)-1], // -1 removes the trailing . + sig: parts[2][len(exp):], + } + + return rm, nil +} + +const errInvalidRewardsSignature Error = "wallet: invalid rewards signature for message and public key" + +func verifyRewardsSignature(pub, msg, sig string) error { + b, err := hex.DecodeString(pub) + if err != nil { + return fmt.Errorf("error decoding rewards public key: %w", err) + } + + if len(b) != publicKeyLength { + return fmt.Errorf("error verifying rewards signature: %w", errBadPublicKeyLength) + } + pubKey := ed25519.PublicKey(b) + + decSig, err := base64.URLEncoding.DecodeString(sig) + if err != nil { + return fmt.Errorf("error decoding rewards signature: %w", err) + } + + if !ed25519.Verify(pubKey, []byte(msg), decSig) { + return errInvalidRewardsSignature + } + + return nil +} + +type LinkSolanaAddressError struct { + err error +} + +func newLinkSolanaAddressError(err error) error { + return &LinkSolanaAddressError{err: err} +} + +func (e *LinkSolanaAddressError) Error() string { + return e.err.Error() +} + +func (e *LinkSolanaAddressError) Unwrap() error { + return e.err +} + +type Error string + +func (e Error) Error() string { + return string(e) +} diff --git a/libs/wallet/wallet_test.go b/libs/wallet/wallet_test.go new file mode 100644 index 000000000..cf0f33daa --- /dev/null +++ b/libs/wallet/wallet_test.go @@ -0,0 +1,335 @@ +package wallet + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestInfo_LinkSolanaAddress(t *testing.T) { + type tcGiven struct { + w Info + s SolanaLinkReq + } + + type testCase struct { + name string + given tcGiven + assertErr assert.ErrorAssertionFunc + } + + tests := []testCase{ + { + name: "success", + given: tcGiven{ + w: Info{ + ID: "5e99b1ae-6e91-481f-9021-7fb3c97327d4", + PublicKey: "f4205e18ea138efc59f18dfbb5c7c2a24b5724395a36e86e84c84b9ec9ccb1ec", + }, + s: SolanaLinkReq{ + Pub: "GysxUmKWkazQSUm3DCG7o5tPmsuNqFgQ2iG9imkzxzyG", + Sig: "X2nhxq-95ZR5QWk9R1m-Rqh8QVndDy2yL2NY5PSx0G-EyzX3xm7JKpPhILxZfc_cWwLtaPk6xRQBManPCKE6BQ==", + Msg: newMsg(msgParts{ + paymentID: "5e99b1ae-6e91-481f-9021-7fb3c97327d4", + solPub: "GysxUmKWkazQSUm3DCG7o5tPmsuNqFgQ2iG9imkzxzyG", + nonce: "86d6f240-df9b-4167-a66e-5df6da80ac24", + rewSig: "szPeTsDRUOFLS1y-k85yfvcI40OOWwTEn2mQ3cGcnZLPkDCXD1qJKYJNkNgdY5j5BA7pvj8AzEy8riKtdeRaAQ==", + }), + Nonce: "86d6f240-df9b-4167-a66e-5df6da80ac24", + }, + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) + }, + }, + { + name: "invalid_linking", + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + var expected *LinkSolanaAddressError + return assert.ErrorAs(t, err, &expected) + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + wallet := tc.given.w + err := wallet.LinkSolanaAddress(context.TODO(), tc.given.s) + tc.assertErr(t, err) + }) + } +} + +func TestVerifySolanaSignature(t *testing.T) { + type tcGiven struct { + solPub string + msg string + solSig string + } + + type testCase struct { + name string + given tcGiven + assertErr assert.ErrorAssertionFunc + } + + tests := []testCase{ + { + name: "invalid_public_key_length", + given: tcGiven{solPub: "123456789"}, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, errBadPublicKeyLength) + }, + }, + { + name: "signature_has_illegal_character", + given: tcGiven{solPub: "32rbMEtgTphzVnHuSsuHEv3hKpm92UsgMerjDjZr72T1", solSig: "+"}, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorContains(t, err, "error decoding solana signature") + }, + }, + { + name: "invalid_signature", + given: tcGiven{ + solPub: "5zjTAqk1xbeYFzkMCnY3H52SLQpuM1GUFUDAwfhJs1wg", + msg: "invalid_message", + solSig: "zc2boTImAAhzraUplAlUy2L6hNF6l-DYGfOqq_4UfrDsJEBg26jaHIAXJF2i3tifCZxrvmu3ahqIdnm2kOwyBQ==", + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, errInvalidSolanaSignature) + }, + }, + { + name: "valid_signature", + given: tcGiven{ + solPub: "5zjTAqk1xbeYFzkMCnY3H52SLQpuM1GUFUDAwfhJs1wg", + msg: "test", + solSig: "zc2boTImAAhzraUplAlUy2L6hNF6l-DYGfOqq_4UfrDsJEBg26jaHIAXJF2i3tifCZxrvmu3ahqIdnm2kOwyBQ==", + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) + }, + }, + } + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + err := verifySolanaSignature(tc.given.solPub, tc.given.msg, tc.given.solSig) + tc.assertErr(t, err) + }) + } +} + +func TestSolMsgParser_parse(t *testing.T) { + type tcGiven struct { + msgParser solMsgParser + msg string + } + + type exp struct { + rewMsg rewMsg + err error + } + + type testCase struct { + name string + given tcGiven + exp exp + } + + tests := []testCase{ + { + name: "invalid_number_of_line_breaks", + given: tcGiven{ + msgParser: solMsgParser{paymentID: "payment-id", solPub: "solana-address", nonce: "nonce"}, + msg: "some-text:payment-id-some-text:solana-address\nsome-text:payment-id.nonce.abcd", + }, + exp: exp{ + err: errInvalidPartsLineBreak, + }, + }, + { + name: "invalid_parts_colon", + given: tcGiven{ + msgParser: solMsgParser{paymentID: "payment-id", solPub: "solana-address", nonce: "nonce"}, + msg: "payment-id\nsome-text:solana-address\nsome-text:payment-id.nonce.abcd", + }, + exp: exp{ + err: errInvalidPartsColon, + }, + }, + { + name: "invalid_payment_id", + given: tcGiven{ + msgParser: solMsgParser{paymentID: "payment-id", solPub: "solana-address", nonce: "nonce"}, + msg: "some-text:another-id\nsome-text:solana-address\nsome-text:payment-id.nonce.abcd", + }, + exp: exp{ + err: errInvalidPaymentID, + }, + }, + { + name: "invalid_solana_public_key", + given: tcGiven{ + msgParser: solMsgParser{paymentID: "payment-id", solPub: "solana-address", nonce: "nonce"}, + msg: "some-text:payment-id\nsome-text:another-solana-address\nsome-text:payment-id.nonce.abcd", + }, + exp: exp{ + err: errInvalidSolanaPubKey, + }, + }, + { + name: "invalid_rewards_message", + given: tcGiven{ + msgParser: solMsgParser{paymentID: "payment-id", solPub: "solana-address", nonce: "nonce"}, + msg: "some-text:payment-id\nsome-text:solana-address\nsome-text:another-payment-id.nonce.abcd", + }, + exp: exp{ + err: errInvalidRewardsMessage, + }, + }, + { + name: "no_rewards_message_signature", + given: tcGiven{ + msgParser: solMsgParser{paymentID: "payment-id", solPub: "solana-address", nonce: "nonce"}, + msg: "some-text:payment-id\nsome-text:solana-address\nsome-text:payment-id.nonce.", + }, + exp: exp{ + rewMsg: rewMsg{ + msg: "payment-id.nonce", + }, + err: nil, + }, + }, + { + name: "success", + given: tcGiven{ + msgParser: solMsgParser{paymentID: "payment-id", solPub: "solana-address", nonce: "nonce"}, + msg: "some-text:payment-id\nsome-text:solana-address\nsome-text:payment-id.nonce.abcd", + }, + exp: exp{ + rewMsg: rewMsg{ + msg: "payment-id.nonce", + sig: "abcd", + }, + err: nil, + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + msgParser := tc.given.msgParser + actual, err := msgParser.parse(tc.given.msg) + assert.Equal(t, tc.exp.err, err) + assert.Equal(t, tc.exp.rewMsg, actual) + }) + } +} + +func TestVerifyRewardsSignature(t *testing.T) { + type tcGiven struct { + pub string + sig string + msg string + } + + type testCase struct { + name string + given tcGiven + assertErr assert.ErrorAssertionFunc + } + + tests := []testCase{ + { + name: "error_decoding_public_key", + given: tcGiven{pub: "invalid_key"}, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorContains(t, err, "error decoding rewards public key") + }, + }, + { + name: "invalid_public_key_length", + given: tcGiven{pub: hex.EncodeToString([]byte("key"))}, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, errBadPublicKeyLength) + }, + }, + { + name: "signature_has_illegal_character", + given: tcGiven{ + pub: "ac1e69da621a99cf29de8ac1b0ffc8ece154b98e99a0ebec2bfdf2af04b8ac53", + sig: "!", + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorContains(t, err, "error decoding rewards signature") + }, + }, + { + name: "invalid_signature", + given: tcGiven{ + pub: "e0e9196cfb3c98f8912c011ff46193167b7df72a166c595408c6ca6c690bb707", + msg: "invalid_message", + sig: "gJJptSk0lGBjpJOx7Mq_AwVtNkW5tg4esgbtYesQXLfabDZP4K_bFxpEn40TIBRISQho9oLzGfOnzWH88ntdAg=="}, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, errInvalidRewardsSignature) + }, + }, + { + name: "valid_signature", + given: tcGiven{ + pub: "e0e9196cfb3c98f8912c011ff46193167b7df72a166c595408c6ca6c690bb707", + msg: "test", + sig: "gJJptSk0lGBjpJOx7Mq_AwVtNkW5tg4esgbtYesQXLfabDZP4K_bFxpEn40TIBRISQho9oLzGfOnzWH88ntdAg==", + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + err := verifyRewardsSignature(tc.given.pub, tc.given.msg, tc.given.sig) + tc.assertErr(t, err) + }) + } +} + +func TestNewLinkSolanaAddressError(t *testing.T) { + err1 := errors.New("error text") + err2 := newLinkSolanaAddressError(err1) + err3 := fmt.Errorf("%w", err2) + + t.Run("error_text", func(t *testing.T) { + assert.ErrorContains(t, err2, err1.Error()) + }) + + t.Run("error_unwrap", func(t *testing.T) { + var target *LinkSolanaAddressError + assert.ErrorAs(t, err3, &target) + }) +} + +type msgParts struct { + paymentID string + solPub string + nonce string + rewSig string +} + +func newMsg(parts msgParts) string { + const msgTmpl = ":%s\n:%s\n:%s.%s.%s" + return fmt.Sprintf(msgTmpl, parts.paymentID, parts.solPub, parts.paymentID, parts.nonce, parts.rewSig) +} diff --git a/migrations/0064_challenges.down.sql b/migrations/0064_challenges.down.sql new file mode 100644 index 000000000..3efc2210d --- /dev/null +++ b/migrations/0064_challenges.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS challenge; diff --git a/migrations/0064_challenges.up.sql b/migrations/0064_challenges.up.sql new file mode 100644 index 000000000..d23b958cf --- /dev/null +++ b/migrations/0064_challenges.up.sql @@ -0,0 +1,6 @@ +CREATE TABLE challenge ( + payment_id uuid PRIMARY KEY, + created_at timestamp WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + nonce text NOT NULL, + CONSTRAINT challenge_nonce UNIQUE (nonce) +); diff --git a/migrations/0065_allow_list.down.sql b/migrations/0065_allow_list.down.sql new file mode 100644 index 000000000..7c8c09e99 --- /dev/null +++ b/migrations/0065_allow_list.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS allow_list; diff --git a/migrations/0065_allow_list.up.sql b/migrations/0065_allow_list.up.sql new file mode 100644 index 000000000..8b83ddf70 --- /dev/null +++ b/migrations/0065_allow_list.up.sql @@ -0,0 +1,4 @@ +CREATE TABLE allow_list ( + payment_id uuid PRIMARY KEY, + created_at timestamp WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP +); diff --git a/migrations/0066_wallet_custodian_check_custodian_add_solana.down.sql b/migrations/0066_wallet_custodian_check_custodian_add_solana.down.sql new file mode 100644 index 000000000..7f3e0126c --- /dev/null +++ b/migrations/0066_wallet_custodian_check_custodian_add_solana.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE wallet_custodian DROP CONSTRAINT IF EXISTS check_custodian, + ADD CONSTRAINT check_custodian CHECK (custodian IN ('brave', 'uphold', 'bitflyer', 'gemini', 'zebpay')); diff --git a/migrations/0066_wallet_custodian_check_custodian_add_solana.up.sql b/migrations/0066_wallet_custodian_check_custodian_add_solana.up.sql new file mode 100644 index 000000000..06c40dd1d --- /dev/null +++ b/migrations/0066_wallet_custodian_check_custodian_add_solana.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE wallet_custodian DROP CONSTRAINT IF EXISTS check_custodian, + ADD CONSTRAINT check_custodian CHECK (custodian IN ('brave', 'uphold', 'bitflyer', 'gemini', 'zebpay', 'solana')); diff --git a/services/go.mod b/services/go.mod index 7a8cec66a..7eabc1f32 100644 --- a/services/go.mod +++ b/services/go.mod @@ -20,6 +20,7 @@ require ( github.com/brave-intl/bat-go v1.0.2 github.com/brave-intl/bat-go/libs v1.0.2 github.com/brave-intl/bat-go/tools v1.0.2 + github.com/btcsuite/btcutil v1.0.2 github.com/getsentry/sentry-go v0.14.0 github.com/go-chi/chi v4.1.2+incompatible github.com/go-chi/cors v1.2.1 @@ -72,7 +73,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.6 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.18.7 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/btcsuite/btcutil v1.0.2 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect diff --git a/services/grant/cmd/grant.go b/services/grant/cmd/grant.go index 56cc51470..f34dfa2f9 100644 --- a/services/grant/cmd/grant.go +++ b/services/grant/cmd/grant.go @@ -88,6 +88,11 @@ func init() { Bind("disable-zebpay-linking"). Env("DISABLE_ZEBPAY_LINKING") + flagBuilder.Flag().Bool("disable-solana-linking", false, + "disable address linking for solana"). + Bind("disable-solana-linking"). + Env("DISABLE_SOLANA_LINKING") + flagBuilder.Flag().StringSlice("brave-transfer-promotion-ids", []string{""}, "brave vg deposit destination promotion id"). Bind("brave-transfer-promotion-ids"). @@ -350,7 +355,11 @@ func setupRouter(ctx context.Context, logger *zerolog.Logger) (context.Context, // this way we can have the wallet service completely separated from // grants service and easily deployable. ctx, walletService = wallet.SetupService(ctx) - r = wallet.RegisterRoutes(ctx, walletService, r) + origin := os.Getenv("DAPP_ALLOWED_CORS_ORIGINS") + if origin == "" { + logger.Panic().Msg("dapp origin env missing") + } + r = wallet.RegisterRoutes(ctx, walletService, r, middleware.InstrumentHandler, wallet.NewDAppCorsMw(origin)) promotionDB, promotionRODB, err := promotion.NewPostgres() if err != nil { @@ -615,6 +624,7 @@ func GrantServer( ctx = context.WithValue(ctx, appctx.DisableUpholdLinkingCTXKey, viper.GetBool("disable-uphold-linking")) ctx = context.WithValue(ctx, appctx.DisableGeminiLinkingCTXKey, viper.GetBool("disable-gemini-linking")) ctx = context.WithValue(ctx, appctx.DisableBitflyerLinkingCTXKey, viper.GetBool("disable-bitflyer-linking")) + ctx = context.WithValue(ctx, appctx.DisableSolanaLinkingCTXKey, viper.GetBool("disable-solana-linking")) // stripe variables ctx = context.WithValue(ctx, appctx.StripeEnabledCTXKey, viper.GetBool("stripe-enabled")) @@ -710,6 +720,7 @@ func newSrvStatusFromCtx(ctx context.Context) map[string]any { g, _ := ctx.Value(appctx.DisableGeminiLinkingCTXKey).(bool) bf, _ := ctx.Value(appctx.DisableBitflyerLinkingCTXKey).(bool) zp, _ := ctx.Value(appctx.DisableZebPayLinkingCTXKey).(bool) + s, _ := ctx.Value(appctx.DisableSolanaLinkingCTXKey).(bool) result := map[string]interface{}{ "wallet": map[string]bool{ @@ -717,6 +728,7 @@ func newSrvStatusFromCtx(ctx context.Context) map[string]any { "gemini": !g, "bitflyer": !bf, "zebpay": !zp, + "solana": !s, }, } diff --git a/services/grant/cmd/grant_test.go b/services/grant/cmd/grant_test.go index 481840159..714ce8388 100644 --- a/services/grant/cmd/grant_test.go +++ b/services/grant/cmd/grant_test.go @@ -16,6 +16,7 @@ func TestNewSrvStatusFromCtx(t *testing.T) { ctx = context.WithValue(ctx, appctx.DisableGeminiLinkingCTXKey, true) ctx = context.WithValue(ctx, appctx.DisableBitflyerLinkingCTXKey, true) ctx = context.WithValue(ctx, appctx.DisableZebPayLinkingCTXKey, true) + ctx = context.WithValue(ctx, appctx.DisableSolanaLinkingCTXKey, true) act := newSrvStatusFromCtx(ctx) exp := map[string]interface{}{ @@ -24,6 +25,7 @@ func TestNewSrvStatusFromCtx(t *testing.T) { "gemini": false, "bitflyer": false, "zebpay": false, + "solana": false, }, } diff --git a/services/wallet/cmd/rest_run.go b/services/wallet/cmd/rest_run.go index 940fc85bd..279233909 100644 --- a/services/wallet/cmd/rest_run.go +++ b/services/wallet/cmd/rest_run.go @@ -2,16 +2,17 @@ package cmd import ( "net/http" - "time" - // pprof imports _ "net/http/pprof" + "os" + "time" cmdutils "github.com/brave-intl/bat-go/cmd" appctx "github.com/brave-intl/bat-go/libs/context" + "github.com/brave-intl/bat-go/libs/middleware" "github.com/brave-intl/bat-go/services/cmd" "github.com/brave-intl/bat-go/services/wallet" - sentry "github.com/getsentry/sentry-go" + "github.com/getsentry/sentry-go" "github.com/go-chi/chi" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -22,12 +23,19 @@ import ( // wallets rest microservice. func WalletRestRun(command *cobra.Command, args []string) { ctx, service := wallet.SetupService(command.Context()) - router := cmd.SetupRouter(ctx) - wallet.RegisterRoutes(ctx, service, router) logger, err := appctx.GetLogger(ctx) cmdutils.Must(err) + router := cmd.SetupRouter(ctx) + + origin := os.Getenv("DAPP_ALLOWED_CORS_ORIGINS") + if origin == "" { + logger.Panic().Msg("dapp origin env missing") + } + + wallet.RegisterRoutes(ctx, service, router, middleware.InstrumentHandler, wallet.NewDAppCorsMw(origin)) + // add profiling flag to enable profiling routes if viper.GetString("pprof-enabled") != "" { // pprof attaches routes to default serve mux diff --git a/services/wallet/keystore_test.go b/services/wallet/controller_v3_test.go similarity index 75% rename from services/wallet/keystore_test.go rename to services/wallet/controller_v3_test.go index eef23498c..e3ada3402 100644 --- a/services/wallet/keystore_test.go +++ b/services/wallet/controller_v3_test.go @@ -7,12 +7,15 @@ import ( "context" "crypto" "crypto/ed25519" + "crypto/rand" + "encoding/base64" "encoding/hex" "encoding/json" "fmt" "net/http" "net/http/httptest" "testing" + "time" "github.com/brave-intl/bat-go/libs/altcurrency" mock_reputation "github.com/brave-intl/bat-go/libs/clients/reputation/mock" @@ -22,10 +25,15 @@ import ( walletutils "github.com/brave-intl/bat-go/libs/wallet" "github.com/brave-intl/bat-go/libs/wallet/provider/uphold" "github.com/brave-intl/bat-go/services/wallet" + "github.com/brave-intl/bat-go/services/wallet/model" + "github.com/brave-intl/bat-go/services/wallet/storage" + "github.com/btcsuite/btcutil/base58" "github.com/go-chi/chi" "github.com/golang/mock/gomock" uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" + "github.com/spf13/viper" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -75,10 +83,6 @@ func (suite *WalletControllersTestSuite) CleanDB() { } } -func noUUID() *uuid.UUID { - return nil -} - func (suite *WalletControllersTestSuite) FundWallet(w *uphold.Wallet, probi decimal.Decimal) decimal.Decimal { ctx := context.Background() balanceBefore, err := w.GetBalance(ctx, true) @@ -103,7 +107,7 @@ func (suite *WalletControllersTestSuite) TestBalanceV3() { mockCtrl := gomock.NewController(suite.T()) defer mockCtrl.Finish() - service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil) + service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil, nil, nil, wallet.DAppConfig{}) w1 := suite.NewWallet(service, "uphold") @@ -163,7 +167,7 @@ func (suite *WalletControllersTestSuite) TestLinkWalletV3() { mockCtrl := gomock.NewController(suite.T()) defer mockCtrl.Finish() - service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil) + service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil, nil, nil, wallet.DAppConfig{}) w1 := suite.NewWallet(service, "uphold") w2 := suite.NewWallet(service, "uphold") @@ -269,9 +273,7 @@ func (suite *WalletControllersTestSuite) claimCardV3( return linked, rr.Body.String() } -func (suite *WalletControllersTestSuite) createBody( - tx string, -) string { +func (suite *WalletControllersTestSuite) createBody(tx string) string { reqBody, _ := json.Marshal(wallet.UpholdCreationRequest{ SignedCreationRequest: tx, }) @@ -319,7 +321,7 @@ func (suite *WalletControllersTestSuite) TestCreateBraveWalletV3() { pg, _, err := wallet.NewPostgres() suite.Require().NoError(err, "Failed to get postgres connection") - service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil) + service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil, nil, nil, wallet.DAppConfig{}) publicKey, privKey, err := httpsignature.GenerateEd25519Key(nil) @@ -362,7 +364,7 @@ func (suite *WalletControllersTestSuite) TestCreateUpholdWalletV3() { pg, _, err := wallet.NewPostgres() suite.Require().NoError(err, "Failed to get postgres connection") - service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil) + service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil, nil, nil, wallet.DAppConfig{}) publicKey, privKey, err := httpsignature.GenerateEd25519Key(nil) @@ -414,11 +416,165 @@ func (suite *WalletControllersTestSuite) TestCreateUpholdWalletV3() { }`, notSignedResponse, "field is not valid") } -func (suite *WalletControllersTestSuite) getWallet( - service *wallet.Service, - paymentId uuid.UUID, - code int, -) string { +func (suite *WalletControllersTestSuite) TestChallenges_Success() { + paymentID := uuid.NewV4() + + body := struct { + PaymentID uuid.UUID `json:"paymentId"` + }{ + PaymentID: paymentID, + } + + b, err := json.Marshal(body) + suite.Require().NoError(err) + + r := httptest.NewRequest(http.MethodPost, "/v3/wallet/challenges", bytes.NewBuffer(b)) + r.Header.Set("origin", "https://my-dapp.com") + + rw := httptest.NewRecorder() + + pg, _, err := wallet.NewPostgres() + suite.Require().NoError(err) + + chlRep := storage.NewChallenge() + + dac := wallet.DAppConfig{ + AllowedOrigin: "https://my-dapp.com", + } + + s, err := wallet.InitService(pg, nil, chlRep, nil, nil, nil, nil, nil, nil, nil, dac) + suite.Require().NoError(err) + + svr := &http.Server{Addr: ":8080", Handler: setupRouter(s)} + svr.Handler.ServeHTTP(rw, r) + suite.Require().Equal(http.StatusCreated, rw.Code) + suite.Require().Equal("https://my-dapp.com", rw.Header().Get("Access-Control-Allow-Origin")) + + type chlResp struct { + Nonce string `json:"challengeId"` + } + + var resp chlResp + err = json.Unmarshal(rw.Body.Bytes(), &resp) + suite.Require().NoError(err) + + chlRepo := storage.NewChallenge() + chl, err := chlRepo.Get(context.TODO(), pg.RawDB(), paymentID) + suite.Require().NoError(err) + + suite.Assert().Equal(chl.Nonce, resp.Nonce) +} + +func (suite *WalletControllersTestSuite) TestLinkSolanaAddress_Success() { + viper.Set("enable-link-drain-flag", "true") + + pg, _, err := wallet.NewPostgres() + suite.Require().NoError(err) + + chlRep := storage.NewChallenge() + allowList := storage.NewAllowList() + + // create the wallet + pub, priv, err := ed25519.GenerateKey(nil) + suite.Require().NoError(err) + + paymentID := uuid.NewV4() + w := &walletutils.Info{ + ID: paymentID.String(), + Provider: "brave", + PublicKey: hex.EncodeToString(pub), + AltCurrency: ptrTo(altcurrency.BAT), + } + err = pg.InsertWallet(context.TODO(), w) + suite.Require().NoError(err) + + whitelistWallet(suite.T(), pg, w.ID) + + // create nonce + chl := model.NewChallenge(paymentID) + + err = chlRep.Upsert(context.TODO(), pg.RawDB(), chl) + suite.Require().NoError(err) + + dac := wallet.DAppConfig{ + AllowedOrigin: "https://my-dapp.com", + } + + s, err := wallet.InitService(pg, nil, chlRep, allowList, nil, nil, nil, nil, nil, nil, dac) + suite.Require().NoError(err) + + // create linking message + solPub, msg, solSig := createAndSignMessage(suite.T(), w.ID, priv, chl.Nonce) + + // make request + body := struct { + SolanaPublicKey string `json:"solanaPublicKey"` + Message string `json:"message"` + SolanaSignature string `json:"solanaSignature"` + }{ + SolanaPublicKey: solPub, + Message: msg, + SolanaSignature: solSig, + } + + b, err := json.Marshal(body) + suite.Require().NoError(err) + + r := httptest.NewRequest(http.MethodPost, "/v3/wallet/solana/"+w.ID+"/connect", bytes.NewBuffer(b)) + r.Header.Set("origin", "https://my-dapp.com") + + rw := httptest.NewRecorder() + + svr := &http.Server{Addr: ":8080", Handler: setupRouter(s)} + svr.Handler.ServeHTTP(rw, r) + suite.Require().Equal(http.StatusOK, rw.Code) + suite.Require().Equal("https://my-dapp.com", rw.Header().Get("Access-Control-Allow-Origin")) + + // assert + actual, err := pg.GetWallet(context.TODO(), paymentID) + suite.Require().NoError(err) + + suite.Require().Equal(solPub, actual.UserDepositDestination) + + // after a successful linking the challenge should be removed from the database. + _, actualErr := chlRep.Get(context.TODO(), pg.RawDB(), paymentID) + suite.Assert().ErrorIs(actualErr, model.ErrNotFound) +} + +func whitelistWallet(t *testing.T, pg wallet.Datastore, Id string) { + const q = `insert into allow_list (payment_id, created_at) values($1, $2)` + _, err := pg.RawDB().Exec(q, Id, time.Now()) + require.NoError(t, err) +} + +func createAndSignMessage(t *testing.T, paymentID string, rewardsPrivKey ed25519.PrivateKey, nonce string) (solPub, msg, solSig string) { + // Create and sign the rewards message. + // The message has the format = . + rewardsMsg := paymentID + "." + nonce + sig, err := rewardsPrivKey.Sign(rand.Reader, []byte(rewardsMsg), crypto.Hash(0)) + require.NoError(t, err) + + rewardsSig := base64.URLEncoding.EncodeToString(sig) + rewardsPart := rewardsMsg + "." + rewardsSig + + // Create the linking message and sign with the Solana key. + pub, priv, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + + solPub = base58.Encode(pub) + + const msgTmpl = ":%s\n:%s\n:%s" + msg = fmt.Sprintf(msgTmpl, paymentID, solPub, rewardsPart) + + sig, err = priv.Sign(rand.Reader, []byte(msg), crypto.Hash(0)) + require.NoError(t, err) + + solSig = base64.URLEncoding.EncodeToString(sig) + + return +} + +func (suite *WalletControllersTestSuite) getWallet(service *wallet.Service, paymentId uuid.UUID, code int) string { handler := handlers.AppHandler(wallet.GetWalletV3) req, err := http.NewRequest("GET", "/v3/wallet/"+paymentId.String(), nil) @@ -515,11 +671,7 @@ func (suite *WalletControllersTestSuite) createUpholdWalletV3( return rr.Body.String() } -func (suite *WalletControllersTestSuite) SignRequest( - req *http.Request, - publicKey httpsignature.Ed25519PubKey, - privateKey ed25519.PrivateKey, -) { +func (suite *WalletControllersTestSuite) SignRequest(req *http.Request, publicKey httpsignature.Ed25519PubKey, privateKey ed25519.PrivateKey) { var s httpsignature.SignatureParams s.Algorithm = httpsignature.ED25519 s.KeyID = hex.EncodeToString(publicKey) @@ -528,3 +680,17 @@ func (suite *WalletControllersTestSuite) SignRequest( err := s.Sign(privateKey, crypto.Hash(0), req) suite.Require().NoError(err) } + +func setupRouter(service *wallet.Service) *chi.Mux { + mw := func(name string, h http.Handler) http.Handler { + return h + } + s := "https://my-dapp.com" + r := chi.NewRouter() + r.Mount("/v3", wallet.RegisterRoutes(context.TODO(), service, r, mw, wallet.NewDAppCorsMw(s))) + return r +} + +func ptrTo[T any](v T) *T { + return &v +} diff --git a/services/wallet/controllers_v3.go b/services/wallet/controllers_v3.go index a62da46d8..ad0697015 100644 --- a/services/wallet/controllers_v3.go +++ b/services/wallet/controllers_v3.go @@ -5,10 +5,13 @@ import ( "crypto/ed25519" "database/sql" "encoding/hex" + "encoding/json" "errors" + "io" "net/http" "strings" + "github.com/asaskevich/govalidator" "github.com/brave-intl/bat-go/libs/altcurrency" appctx "github.com/brave-intl/bat-go/libs/context" errorutils "github.com/brave-intl/bat-go/libs/errors" @@ -19,10 +22,15 @@ import ( "github.com/brave-intl/bat-go/libs/middleware" walletutils "github.com/brave-intl/bat-go/libs/wallet" "github.com/brave-intl/bat-go/libs/wallet/provider/uphold" + "github.com/brave-intl/bat-go/services/wallet/model" "github.com/go-chi/chi" uuid "github.com/satori/go.uuid" ) +const ( + reqBodyLimit10MB = 10 << 20 +) + // LinkDepositAccountResponse is the response returned by the linking endpoints. type LinkDepositAccountResponse struct { GeoCountry string `json:"geoCountry"` @@ -441,49 +449,142 @@ func LinkUpholdDepositAccountV3(s *Service) func(w http.ResponseWriter, r *http. } } -// GetWalletV3 - produces an http handler for the service s which handles getting of brave wallets +const errOriginForbidden model.Error = "request origin forbidden" + +type linkSolanaAddrRequest struct { + SolanaPublicKey string `json:"solanaPublicKey" valid:"length(32|44)"` + Message string `json:"message" valid:"required"` + SolanaSignature string `json:"solanaSignature" valid:"required"` +} + +func LinkSolanaAddress(s *Service) handlers.AppHandler { + return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { + ctx := r.Context() + + if dis, ok := ctx.Value(appctx.DisableSolanaLinkingCTXKey).(bool); ok && dis { + return handlers.ValidationError("Connecting Brave Rewards to Solana is temporarily unavailable. Please try again later", nil) + } + + l := logging.Logger(ctx, "wallet") + + if o := r.Header.Get("Origin"); o != s.dappConf.AllowedOrigin { + l.Error().Err(errOriginForbidden).Str("origin", strOr(o, "empty")).Msg("error linking solana address") + return handlers.WrapError(errOriginForbidden, "request origin forbidden", http.StatusForbidden) + } + + var paymentID inputs.ID + if err := inputs.DecodeAndValidateString(ctx, &paymentID, chi.URLParam(r, "paymentID")); err != nil { + return handlers.WrapError(err, "invalid paymentID", http.StatusBadRequest) + } + + b, err := io.ReadAll(io.LimitReader(r.Body, reqBodyLimit10MB)) + if err != nil { + return handlers.WrapError(err, "error reading body", http.StatusBadRequest) + } + + var solReq linkSolanaAddrRequest + if err := json.Unmarshal(b, &solReq); err != nil { + return handlers.WrapError(err, "error decoding body", http.StatusBadRequest) + } + + if _, err := govalidator.ValidateStruct(solReq); err != nil { + return handlers.WrapValidationError(err) + } + + if err := s.LinkSolanaAddress(ctx, *paymentID.UUID(), solReq); err != nil { + l.Error().Err(err).Msg("error linking solana address") + + var solErr *walletutils.LinkSolanaAddressError + switch { + case errors.Is(err, model.ErrWalletNotWhitelisted): + return handlers.WrapError(model.ErrWalletNotWhitelisted, "rewards wallet not whitelisted", http.StatusForbidden) + case errors.Is(err, model.ErrChallengeExpired): + return handlers.WrapError(model.ErrChallengeExpired, "linking challenge expired", http.StatusUnauthorized) + case errors.Is(err, model.ErrWalletNotFound): + return handlers.WrapError(model.ErrWalletNotFound, "rewards wallet not found", http.StatusNotFound) + case errors.As(err, &solErr): + return handlers.WrapError(solErr, "invalid solana linking message", http.StatusUnauthorized) + default: + return handlers.WrapError(model.ErrInternalServer, "internal server error", http.StatusInternalServerError) + } + } + + return handlers.RenderContent(ctx, nil, w, http.StatusOK) + } +} + +type challengeRequest struct { + PaymentID uuid.UUID `json:"paymentId" valid:"required"` +} + +type challengeResponse struct { + Nonce string `json:"challengeId"` +} + +func CreateChallenge(s *Service) handlers.AppHandler { + return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { + b, err := io.ReadAll(io.LimitReader(r.Body, reqBodyLimit10MB)) + if err != nil { + return handlers.WrapError(err, "error reading body", http.StatusBadRequest) + } + + var chlReq challengeRequest + if err := json.Unmarshal(b, &chlReq); err != nil { + return handlers.WrapError(err, "error decoding body", http.StatusBadRequest) + } + + if _, err := govalidator.ValidateStruct(chlReq); err != nil { + return handlers.WrapValidationError(err) + } + + ctx := r.Context() + + chl, err := s.CreateChallenge(ctx, chlReq.PaymentID) + if err != nil { + logging.Logger(ctx, "wallet").Error().Err(err).Msg("error creating challenge") + return handlers.WrapError(model.ErrInternalServer, "error creating challenge", http.StatusInternalServerError) + } + + resp := challengeResponse{ + Nonce: chl.Nonce, + } + + return handlers.RenderContent(ctx, resp, w, http.StatusCreated) + } +} + +// GetWalletV3 returns a rewards wallet for the given paymentID. func GetWalletV3(w http.ResponseWriter, r *http.Request) *handlers.AppError { var ctx = r.Context() - // get logger from context - logger := logging.Logger(ctx, "wallet.GetWalletV3") + + l := logging.Logger(ctx, "wallet.GetWalletV3") var id = new(inputs.ID) if err := inputs.DecodeAndValidateString(ctx, id, chi.URLParam(r, "paymentID")); err != nil { - logger.Warn().Str("paymentId", err.Error()).Msg("failed to decode and validate paymentID from url") - return handlers.ValidationError( - "Error validating paymentID url parameter", - map[string]interface{}{ - "paymentId": err.Error(), - }, - ) + l.Warn().Err(err).Str("paymentID", id.String()).Msg("failed to decode and validate paymentID from url") + return handlers.ValidationError("Error validating paymentID url parameter", map[string]interface{}{ + "paymentId": err.Error(), + }) } - var ( - roDB ReadOnlyDatastore - ok bool - ) - - // get datastore from context - if roDB, ok = ctx.Value(appctx.RODatastoreCTXKey).(ReadOnlyDatastore); !ok { - logger.Error().Msg("unable to get read only datastore from context") + // TODO(clD11): this should be removed from ctx as part of wallet refactor. Note, the service would have already + // panicked at startup if the db is missing. However, as a precaution we should stop processing. + roDB, ok := ctx.Value(appctx.RODatastoreCTXKey).(ReadOnlyDatastore) + if !ok { + return handlers.WrapError(errorutils.ErrInternalServerError, "db missing from context", http.StatusInternalServerError) } - // get wallet from datastore info, err := roDB.GetWallet(ctx, *id.UUID()) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - logger.Info().Err(err).Str("id", id.String()).Msg("wallet not found") - return handlers.WrapError(err, "no such wallet", http.StatusNotFound) - } - logger.Warn().Err(err).Str("id", id.String()).Msg("unable to get wallet") + l.Error().Err(err).Str("paymentID", id.String()).Msg("error getting wallet") return handlers.WrapError(err, "error getting wallet from storage", http.StatusInternalServerError) } + if info == nil { - logger.Info().Err(err).Str("id", id.String()).Msg("wallet not found") + l.Info().Str("paymentID", id.String()).Msg("wallet not found") return handlers.WrapError(err, "no such wallet", http.StatusNotFound) } - // render the wallet return handlers.RenderContent(ctx, infoToResponseV3(info), w, http.StatusOK) } @@ -713,3 +814,10 @@ func DisconnectCustodianLinkV3(s *Service) func(w http.ResponseWriter, r *http.R return handlers.RenderContent(ctx, map[string]interface{}{}, w, http.StatusOK) } } + +func strOr(a string, b string) string { + if a == "" { + return b + } + return a +} diff --git a/services/wallet/controllers_v3_test.go b/services/wallet/controllers_v3_pvt_test.go similarity index 91% rename from services/wallet/controllers_v3_test.go rename to services/wallet/controllers_v3_pvt_test.go index 9179ead76..7f7cf7275 100644 --- a/services/wallet/controllers_v3_test.go +++ b/services/wallet/controllers_v3_pvt_test.go @@ -18,8 +18,6 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/brave-intl/bat-go/libs/clients/gemini" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" mockgemini "github.com/brave-intl/bat-go/libs/clients/gemini/mock" mockreputation "github.com/brave-intl/bat-go/libs/clients/reputation/mock" @@ -35,6 +33,8 @@ import ( "github.com/golang/mock/gomock" "github.com/jmoiron/sqlx" uuid "github.com/satori/go.uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" ) @@ -196,14 +196,6 @@ func TestLinkBitFlyerWalletV3(t *testing.T) { } var ( - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - // add the datastore to the context ctx = middleware.AddKeyID(context.WithValue(context.Background(), appctx.BitFlyerJWTKeyCTXKey, []byte(secret)), idFrom.String()) r = httptest.NewRequest( @@ -215,15 +207,13 @@ func TestLinkBitFlyerWalletV3(t *testing.T) { }`, tokenString)), ) mockReputation = mockreputation.NewMockClient(mockCtrl) - s, _ = wallet.InitService(datastore, nil, nil, nil, nil, nil, nil, nil) + s, mock = initSvcWithMockDB(t) handler = wallet.LinkBitFlyerDepositAccountV3(s) rw = httptest.NewRecorder() ) mock.ExpectExec("^insert (.+)").WithArgs("1").WillReturnResult(sqlmock.NewResult(1, 1)) - mockSQLCustodianLink(mock, "bitflyer") - // begin linking tx mock.ExpectBegin() @@ -239,6 +229,8 @@ func TestLinkBitFlyerWalletV3(t *testing.T) { var linkingIDRows = sqlmock.NewRows([]string{"linking_id"}).AddRow(linkingID) mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "bitflyer").WillReturnRows(linkingIDRows) + mockSQLCustodianLink(mock, "bitflyer") + // updates the link to the wallet_custodian record in wallets mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) @@ -256,7 +248,7 @@ func TestLinkBitFlyerWalletV3(t *testing.T) { // commit transaction mock.ExpectCommit() - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) + ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, s.Datastore) ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputation) ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") @@ -299,14 +291,8 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { accountID = uuid.NewV4() idTo = accountID - // setup db mocks - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) + s, mock = initSvcWithMockDB(t) + linkingInfo = "this is the fake jwt for linking_info" // setup mock clients @@ -324,14 +310,6 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { }`, linkingInfo, idTo)), ) - mtc = &mockMtc{} - gem = &mockGemini{ - fnGetIssuingCountry: func(acc gemini.ValidatedAccount, fallback bool) string { - return "US" - }, - } - - s, _ = wallet.InitService(datastore, nil, nil, nil, nil, nil, mtc, gem) handler = wallet.LinkGeminiDepositAccountV3(s) rw = httptest.NewRecorder() ) @@ -346,7 +324,6 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { nil, ) - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputationClient) ctx = context.WithValue(ctx, appctx.GeminiClientCTXKey, mockGeminiClient) ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") @@ -391,6 +368,8 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { // not before linked mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "gemini").WillReturnError(sql.ErrNoRows) + mockSQLCustodianLink(mock, "gemini") + var max = sqlmock.NewRows([]string{"max"}).AddRow(4) var open = sqlmock.NewRows([]string{"used"}).AddRow(0) @@ -485,6 +464,8 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { // not before linked mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "gemini").WillReturnError(sql.ErrNoRows) + mockSQLCustodianLink(mock, "gemini") + // perform again, make sure we check haslinkedprio hasPriorRows := sqlmock.NewRows([]string{"result"}). AddRow(true) @@ -542,14 +523,8 @@ func TestLinkGeminiWalletV3FirstLinking(t *testing.T) { accountID = uuid.NewV4() idTo = accountID - // setup db mocks - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) + s, mock = initSvcWithMockDB(t) + linkingInfo = "this is the fake jwt for linking_info" // setup mock clients @@ -567,14 +542,6 @@ func TestLinkGeminiWalletV3FirstLinking(t *testing.T) { }`, linkingInfo, idTo)), ) - mtc = &mockMtc{} - gem = &mockGemini{ - fnGetIssuingCountry: func(acc gemini.ValidatedAccount, fallback bool) string { - return "US" - }, - } - - s, _ = wallet.InitService(datastore, nil, nil, nil, nil, nil, mtc, gem) handler = wallet.LinkGeminiDepositAccountV3(s) rw = httptest.NewRecorder() ) @@ -589,7 +556,6 @@ func TestLinkGeminiWalletV3FirstLinking(t *testing.T) { nil, ) - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputationClient) ctx = context.WithValue(ctx, appctx.GeminiClientCTXKey, mockGeminiClient) ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") @@ -625,6 +591,8 @@ func TestLinkGeminiWalletV3FirstLinking(t *testing.T) { // not before linked mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "gemini").WillReturnError(sql.ErrNoRows) + mockSQLCustodianLink(mock, "gemini") + var max = sqlmock.NewRows([]string{"max"}).AddRow(4) var open = sqlmock.NewRows([]string{"used"}).AddRow(0) @@ -688,30 +656,11 @@ func TestLinkZebPayWalletV3_InvalidKyc(t *testing.T) { ctx = middleware.AddKeyID(context.Background(), idFrom.String()) accountID = uuid.NewV4() idTo = accountID - - // setup db mocks - db, _, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - - mtc = &mockMtc{ - fnLinkFailureZP: func(cc string) { - assert.Equal(t, "IN", cc) - }, - } - - gem = &mockGemini{} - - s, _ = wallet.InitService(datastore, nil, nil, nil, nil, nil, mtc, gem) - handler = wallet.LinkZebPayDepositAccountV3(s) - rw = httptest.NewRecorder() + s, _ = initSvcWithMockDB(t) + handler = wallet.LinkZebPayDepositAccountV3(s) + rw = httptest.NewRecorder() ) - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") ctx = context.WithValue(ctx, appctx.ZebPayLinkingKeyCTXKey, base64.StdEncoding.EncodeToString(secret)) @@ -764,38 +713,19 @@ func TestLinkZebPayWalletV3(t *testing.T) { } var ( - // setup test variables idFrom = uuid.NewV4() ctx = middleware.AddKeyID(context.Background(), idFrom.String()) accountID = uuid.NewV4() idTo = accountID - // setup db mocks - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - - // setup mock clients mockReputationClient = mockreputation.NewMockClient(mockCtrl) - mtc = &mockMtc{ - fnLinkSuccessZP: func(cc string) { - assert.Equal(t, "IN", cc) - }, - } + s, mock = initSvcWithMockDB(t) - gem = &mockGemini{} - - s, _ = wallet.InitService(datastore, nil, nil, nil, nil, nil, mtc, gem) handler = wallet.LinkZebPayDepositAccountV3(s) rw = httptest.NewRecorder() ) - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputationClient) ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") ctx = context.WithValue(ctx, appctx.ZebPayLinkingKeyCTXKey, base64.StdEncoding.EncodeToString(secret)) @@ -828,8 +758,6 @@ func TestLinkZebPayWalletV3(t *testing.T) { nil, ) - mockSQLCustodianLink(mock, "zebpay") - // begin linking tx mock.ExpectBegin() @@ -843,6 +771,8 @@ func TestLinkZebPayWalletV3(t *testing.T) { mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "zebpay").WillReturnRows(linkingIDRows) + mockSQLCustodianLink(mock, "zebpay") + // updates the link to the wallet_custodian record in wallets mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) @@ -891,14 +821,6 @@ func TestLinkGeminiWalletV3(t *testing.T) { accountID = uuid.NewV4() idTo = accountID - // setup db mocks - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) linkingInfo = "this is the fake jwt for linking_info" // setup mock clients @@ -915,20 +837,12 @@ func TestLinkGeminiWalletV3(t *testing.T) { "recipient_id": "%s" }`, linkingInfo, idTo)), ) + s, mock = initSvcWithMockDB(t) - mtc = &mockMtc{} - gem = &mockGemini{ - fnGetIssuingCountry: func(acc gemini.ValidatedAccount, fallback bool) string { - return "GB" - }, - } - - s, _ = wallet.InitService(datastore, nil, nil, nil, nil, nil, mtc, gem) handler = wallet.LinkGeminiDepositAccountV3(s) rw = httptest.NewRecorder() ) - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputationClient) ctx = context.WithValue(ctx, appctx.GeminiClientCTXKey, mockGeminiClient) ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") @@ -938,7 +852,7 @@ func TestLinkGeminiWalletV3(t *testing.T) { ValidDocuments: []gemini.ValidDocument{ { Type: "passport", - IssuingCountry: "GB", + IssuingCountry: "US", }, }, } @@ -974,6 +888,8 @@ func TestLinkGeminiWalletV3(t *testing.T) { mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "gemini").WillReturnRows(linkingIDRows) + mockSQLCustodianLink(mock, "gemini") + // updates the link to the wallet_custodian record in wallets mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) @@ -1006,7 +922,7 @@ func TestLinkGeminiWalletV3(t *testing.T) { err := json.Unmarshal(b, &l) require.NoError(t, err) - assert.Equal(t, "GB", l.GeoCountry) + assert.Equal(t, "US", l.GeoCountry) } func TestDisconnectCustodianLinkV3(t *testing.T) { @@ -1020,24 +936,13 @@ func TestDisconnectCustodianLinkV3(t *testing.T) { idFrom = uuid.NewV4() ctx = middleware.AddKeyID(context.Background(), idFrom.String()) - // setup db mocks - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - // this is our main request r = httptest.NewRequest( "DELETE", fmt.Sprintf("/v3/wallet/gemini/%s/claim", idFrom), nil) - mtc = &mockMtc{} - gem = &mockGemini{} + s, mock = initSvcWithMockDB(t) - s, _ = wallet.InitService(datastore, nil, nil, nil, nil, nil, mtc, gem) handler = wallet.DisconnectCustodianLinkV3(s) w = httptest.NewRecorder() ) @@ -1054,7 +959,6 @@ func TestDisconnectCustodianLinkV3(t *testing.T) { // commit transaction because we are done disconnecting mock.ExpectCommit() - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") r = r.WithContext(ctx) @@ -1094,6 +998,31 @@ func mockSQLCustodianLink(mock sqlmock.Sqlmock, custodian string) { WillReturnRows(clRow) } +func initSvcWithMockDB(t *testing.T) (*wallet.Service, sqlmock.Sqlmock) { + db, mock, _ := sqlmock.New() + datastore := wallet.Datastore( + &wallet.Postgres{ + Postgres: datastoreutils.Postgres{ + DB: sqlx.NewDb(db, "postgres"), + }, + }) + + mtc := &mockMtc{} + + gem := &mockGemini{ + fnGetIssuingCountry: func(acc gemini.ValidatedAccount, fallback bool) string { + return "US" + }, + } + + dappConf := wallet.DAppConfig{} + + s, err := wallet.InitService(datastore, nil, nil, nil, nil, nil, nil, nil, mtc, gem, dappConf) + require.NoError(t, err) + + return s, mock +} + type mockGemini struct { fnGetIssuingCountry func(acc gemini.ValidatedAccount, fallback bool) string fnIsRegionAllowed func(ctx context.Context, issuingCountry string, custodianRegions custodian.Regions) error diff --git a/services/wallet/controllers_v4_test.go b/services/wallet/controllers_v4_test.go index cd0b0dcdd..3efbb8072 100644 --- a/services/wallet/controllers_v4_test.go +++ b/services/wallet/controllers_v4_test.go @@ -15,6 +15,7 @@ import ( "testing" errorutils "github.com/brave-intl/bat-go/libs/errors" + "github.com/brave-intl/bat-go/libs/middleware" "github.com/brave-intl/bat-go/libs/clients" @@ -73,11 +74,11 @@ func (suite *WalletControllersV4TestSuite) TestCreateBraveWalletV4_Success() { Validate(gomock.Any(), geoCountry). Return(true, nil) - service, err := wallet.InitService(storage, nil, reputationClient, nil, locationValidator, backoff.Retry, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, reputationClient, nil, locationValidator, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: geoCountry, @@ -120,11 +121,11 @@ func (suite *WalletControllersV4TestSuite) TestCreateBraveWalletV4_GeoCountryDis Validate(gomock.Any(), gomock.Any()). Return(false, nil) - service, err := wallet.InitService(nil, nil, nil, nil, locationValidator, backoff.Retry, nil, nil) + service, err := wallet.InitService(nil, nil, nil, nil, nil, nil, locationValidator, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: "AF", @@ -172,11 +173,11 @@ func (suite *WalletControllersV4TestSuite) TestCreateBraveWalletV4_WalletAlready Validate(gomock.Any(), geoCountry). Return(true, nil) - service, err := wallet.InitService(storage, nil, nil, nil, locationValidator, nil, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, nil, nil, locationValidator, nil, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: geoCountry, @@ -243,11 +244,11 @@ func (suite *WalletControllersV4TestSuite) TestCreateBraveWalletV4_ReputationCal Validate(gomock.Any(), gomock.Any()). Return(true, nil) - service, err := wallet.InitService(storage, nil, reputationClient, nil, locationValidator, backoff.Retry, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, reputationClient, nil, locationValidator, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: "AF", @@ -293,7 +294,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_Success() { UpsertReputationSummary(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil) - service, err := wallet.InitService(storage, nil, reputationClient, nil, nil, backoff.Retry, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, reputationClient, nil, nil, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) // create rewards wallet with public key @@ -314,7 +315,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_Success() { suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: "AF", @@ -343,7 +344,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_VerificationM storage, err := wallet.NewWritablePostgres("", false, "") suite.NoError(err) - service, err := wallet.InitService(storage, nil, nil, nil, nil, backoff.Retry, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, nil, nil, nil, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) publicKey, privateKey, err := httpsignature.GenerateEd25519Key(nil) @@ -352,7 +353,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_VerificationM paymentID := uuid.NewV5(wallet.ClaimNamespace, publicKey.String()).String() router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: "AF", @@ -381,7 +382,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_PaymentIDMism storage, err := wallet.NewWritablePostgres("", false, "") suite.NoError(err) - service, err := wallet.InitService(storage, nil, nil, nil, nil, backoff.Retry, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, nil, nil, nil, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) // create rewards wallet with public key @@ -402,7 +403,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_PaymentIDMism suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: "AF", @@ -448,7 +449,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_GeoCountryAlr UpsertReputationSummary(gomock.Any(), gomock.Any(), gomock.Any()). Return(errorBundle) - service, err := wallet.InitService(storage, nil, reputationClient, nil, nil, backoff.Retry, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, reputationClient, nil, nil, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) // create rewards wallet with public key @@ -469,7 +470,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_GeoCountryAlr suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: "AF", @@ -513,7 +514,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_ReputationCal UpsertReputationSummary(gomock.Any(), gomock.Any(), gomock.Any()). Return(errReputation) - service, err := wallet.InitService(storage, nil, reputationClient, nil, nil, backoff.Retry, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, reputationClient, nil, nil, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) // create rewards wallet with public key @@ -534,7 +535,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_ReputationCal suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: "AF", @@ -564,3 +565,17 @@ func signUpdateRequest(req *http.Request, paymentID string, privateKey ed25519.P s.Headers = []string{"digest", "(request-target)"} return s.Sign(privateKey, crypto.Hash(0), req) } + +func noOpHandler() middleware.InstrumentHandlerDef { + return func(name string, h http.Handler) http.Handler { + return h + } +} + +func noOpMw() func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + }) + } +} diff --git a/services/wallet/datastore.go b/services/wallet/datastore.go index 071897a82..8461d71cb 100644 --- a/services/wallet/datastore.go +++ b/services/wallet/datastore.go @@ -8,6 +8,7 @@ import ( "net/http" "os" "strconv" + "strings" "time" "github.com/brave-intl/bat-go/libs/backoff" @@ -63,7 +64,7 @@ func init() { // Datastore holds the interface for the wallet datastore type Datastore interface { datastore.Datastore - LinkWallet(ctx context.Context, id string, providerID string, providerLinkingID uuid.UUID, depositProvider, country string) error + LinkWallet(ctx context.Context, id string, providerID string, providerLinkingID uuid.UUID, depositProvider string) error GetLinkingLimitInfo(ctx context.Context, providerLinkingID string) (map[string]LinkingInfo, error) HasPriorLinking(ctx context.Context, walletID uuid.UUID, providerLinkingID uuid.UUID) (bool, error) // GetLinkingsByProviderLinkingID gets the wallet linking info by provider linking id @@ -194,7 +195,9 @@ func (pg *Postgres) UpsertWallet(ctx context.Context, wallet *walletutils.Info) return nil } -// GetWallet by ID +// TODO(clD11): address GetWallet in wallet refactor. + +// GetWallet retrieves a wallet by its walletID, if no wallet is found then nil is returned. func (pg *Postgres) GetWallet(ctx context.Context, ID uuid.UUID) (*walletutils.Info, error) { statement := ` select @@ -553,23 +556,12 @@ var ( ) // LinkWallet links a rewards wallet to the given deposit provider. -func (pg *Postgres) LinkWallet(ctx context.Context, id string, userDepositDestination string, providerLinkingID uuid.UUID, depositProvider, country string) error { +func (pg *Postgres) LinkWallet(ctx context.Context, id string, userDepositDestination string, providerLinkingID uuid.UUID, depositProvider string) error { walletID, err := uuid.FromString(id) if err != nil { return fmt.Errorf("invalid wallet id, not uuid: %w", err) } - repClient, ok := ctx.Value(appctx.ReputationClientCTXKey).(reputation.Client) - if !ok { - return ErrNoReputationClient - } - - // TODO(clD11): We no longer need to act on the response and only require a successful call to reputation to - // continue linking. As part of the wallet refactor we should clean this up. - if _, _, err := repClient.IsLinkingReputable(ctx, walletID, country); err != nil { - return fmt.Errorf("failed to check wallet rep: %w", err) - } - ctx, tx, rollback, commit, err := getTx(ctx, pg) if err != nil { return fmt.Errorf("error getting tx: %w", err) @@ -600,9 +592,15 @@ func (pg *Postgres) LinkWallet(ctx context.Context, id string, userDepositDestin } if directVerifiedWalletEnable { + repClient, ok := ctx.Value(appctx.ReputationClientCTXKey).(reputation.Client) + if !ok { + return ErrNoReputationClient + } + op := func() (interface{}, error) { return nil, repClient.UpdateReputationSummary(ctx, walletID.String(), true) } + if _, err := backoff.Retry(ctx, op, retryPolicy, canRetry(nonRetriableErrors)); err != nil { return fmt.Errorf("failed to update verified wallet: %w", err) } @@ -616,17 +614,29 @@ func (pg *Postgres) LinkWallet(ctx context.Context, id string, userDepositDestin return nil } -// CustodianLink - representation of wallet_custodian record +// TODO(clD11): CustodianLink represent a wallet_custodian. Review during wallet refactor + +// CustodianLink representation a wallet_custodian record. type CustodianLink struct { - WalletID *uuid.UUID `json:"wallet_id" db:"wallet_id" valid:"uuidv4"` - Custodian string `json:"custodian" db:"custodian" valid:"in(uphold,brave,gemini,bitflyer)"` - CreatedAt time.Time `json:"created_at" db:"created_at" valid:"-"` - UpdatedAt *time.Time `json:"updated_at" db:"updated_at" valid:"-"` - LinkedAt time.Time `json:"linked_at" db:"linked_at" valid:"-"` - DisconnectedAt *time.Time `json:"disconnected_at" db:"disconnected_at" valid:"-"` - DepositDestination string `json:"deposit_destination" db:"deposit_destination" valid:"-"` - LinkingID *uuid.UUID `json:"linking_id" db:"linking_id" valid:"uuid"` - UnlinkedAt *time.Time `json:"unlinked_at" db:"unlinked_at" valid:"-"` + WalletID *uuid.UUID `json:"wallet_id" db:"wallet_id" valid:"uuidv4"` + Custodian string `json:"custodian" db:"custodian" valid:"in(uphold,brave,gemini,bitflyer)"` + CreatedAt time.Time `json:"created_at" db:"created_at" valid:"-"` + UpdatedAt *time.Time `json:"updated_at" db:"updated_at" valid:"-"` + LinkedAt time.Time `json:"linked_at" db:"linked_at" valid:"-"` + DisconnectedAt *time.Time `json:"disconnected_at" db:"disconnected_at" valid:"-"` + LinkingID *uuid.UUID `json:"linking_id" db:"linking_id" valid:"uuid"` + UnlinkedAt *time.Time `json:"unlinked_at" db:"unlinked_at" valid:"-"` +} + +// TODO(clD11): Wallet Refactor. These should not be nullable, fix pointers and raname fields for consistency. + +func NewSolanaCustodialLink(walletID uuid.UUID, depositDestination string) *CustodianLink { + const depositProviderSolana = "solana" + return &CustodianLink{ + WalletID: &walletID, + LinkingID: ptrFromUUID(uuid.NewV5(ClaimNamespace, depositDestination)), + Custodian: depositProviderSolana, + } } // GetWalletIDString - get string version of the WalletID @@ -716,16 +726,10 @@ func commitFn(ctx context.Context, tx *sqlx.Tx) func() error { // getTx will get or create a tx on the context, if created hands back rollback and commit functions func getTx(ctx context.Context, datastore Datastore) (context.Context, *sqlx.Tx, func(), func() error, error) { - // create a sublogger - sublogger := logger(ctx) - sublogger.Debug().Msg("getting tx from context") - // get tx tx, noContextTx := ctx.Value(appctx.DatabaseTransactionCTXKey).(*sqlx.Tx) if !noContextTx { - sublogger.Debug().Msg("no tx in context") tx, err := createTx(ctx, datastore) if err != nil || tx == nil { - sublogger.Error().Err(err).Msg("error creating tx") return ctx, nil, func() {}, func() error { return nil }, fmt.Errorf("failed to create tx: %w", err) } ctx = context.WithValue(ctx, appctx.DatabaseTransactionCTXKey, tx) @@ -861,8 +865,19 @@ func (pg *Postgres) ConnectCustodialWallet(ctx context.Context, cl *CustodianLin return fmt.Errorf("failed to get linking id from custodian record: %w", err) } - if !uuid.Equal(existingLinkingID, *new(uuid.UUID)) { - // check if the member matches the associated member + // TODO(clD11): WR. The relinking check below only considers the currently linked wallet and not + // the custodian. We can combine/refactor these checks for both linkings, custodians and limits. + + if err := validateCustodianLinking(ctx, pg, *cl.WalletID, cl.Custodian); err != nil { + if errors.Is(err, errCustodianLinkMismatch) { + return errCustodianLinkMismatch + } + return handlers.WrapError(err, "failed to check linking mismatch", http.StatusInternalServerError) + } + + // Relinking. + if !uuid.Equal(existingLinkingID, *new(uuid.UUID)) { // if not a new wallet + // check if the currently linked wallet matches the proposed linking. if !uuid.Equal(*cl.LinkingID, existingLinkingID) { return handlers.WrapError(errors.New("wallets do not match"), "mismatched provider accounts", http.StatusForbidden) } @@ -1015,20 +1030,33 @@ func (pg *Postgres) SendVerifiedWalletOutbox(ctx context.Context, client reputat return true, nil } -// helper to make logger easier +func validateCustodianLinking(ctx context.Context, storage Datastore, walletID uuid.UUID, depositProvider string) error { + c, err := storage.GetCustodianLinkByWalletID(ctx, walletID) + if err != nil && !errors.Is(err, model.ErrNoWalletCustodian) { + return err + } + + // if there are no instances of wallet custodian then it is + // considered a new linking and therefore valid. + if c == nil { + return nil + } + + if !strings.EqualFold(c.Custodian, depositProvider) { + return errCustodianLinkMismatch + } + + return nil +} + func logger(ctx context.Context) *zerolog.Logger { - // get logger return logging.Logger(ctx, "wallet") } // helper to create a tx -func createTx(ctx context.Context, datastore Datastore) (tx *sqlx.Tx, err error) { - logger(ctx).Debug(). - Msg("creating transaction") +func createTx(_ context.Context, datastore Datastore) (tx *sqlx.Tx, err error) { tx, err = datastore.RawDB().Beginx() if err != nil { - logger(ctx).Error().Err(err). - Msg("error creating transaction") return tx, fmt.Errorf("failed to create transaction: %w", err) } return tx, nil @@ -1043,3 +1071,7 @@ func waitAndLockTx(ctx context.Context, tx *sqlx.Tx, id uuid.UUID) error { } return nil } + +func ptrFromUUID(u uuid.UUID) *uuid.UUID { + return &u +} diff --git a/services/wallet/datastore_test.go b/services/wallet/datastore_test.go index 477ee9136..6cdbde3d3 100644 --- a/services/wallet/datastore_test.go +++ b/services/wallet/datastore_test.go @@ -232,7 +232,8 @@ func (suite *WalletPostgresTestSuite) TestLinkWallet_Concurrent_InsertUpdate() { for i := 0; i < runs; i++ { go func() { defer wg.Done() - err = pg.LinkWallet(ctx, walletInfo.ID, userDepositDestination, providerLinkingID, walletInfo.Provider, "") + err := pg.LinkWallet(ctx, walletInfo.ID, userDepositDestination, providerLinkingID, walletInfo.Provider) + suite.Require().NoError(err) }() } wg.Wait() @@ -272,7 +273,7 @@ func (suite *WalletPostgresTestSuite) seedWallet(pg Datastore) (string, uuid.UUI err := pg.UpsertWallet(ctx, walletInfo) suite.Require().NoError(err, "save wallet should succeed") - err = pg.LinkWallet(ctx, walletInfo.ID, userDepositDestination, providerLinkingID, "uphold", "") + err = pg.LinkWallet(ctx, walletInfo.ID, userDepositDestination, providerLinkingID, "uphold") suite.Require().NoError(err, "link wallet should succeed") } @@ -322,7 +323,8 @@ func (suite *WalletPostgresTestSuite) TestLinkWallet_Concurrent_MaxLinkCount() { for i := 0; i < len(wallets); i++ { go func(index int) { defer wg.Done() - err = pg.LinkWallet(ctx, wallets[index].ID, userDepositDestination, providerLinkingID, wallets[index].Provider, "") + // Once we reach the limit this will return an error which is expected hence we can ignore it. + _ = pg.LinkWallet(ctx, wallets[index].ID, userDepositDestination, providerLinkingID, wallets[index].Provider) }(i) } wg.Wait() diff --git a/services/wallet/instrumented_datastore.go b/services/wallet/instrumented_datastore.go index 7bcf2e1f0..18f962c4b 100644 --- a/services/wallet/instrumented_datastore.go +++ b/services/wallet/instrumented_datastore.go @@ -13,6 +13,7 @@ import ( "github.com/brave-intl/bat-go/libs/backoff" "github.com/brave-intl/bat-go/libs/clients/reputation" walletutils "github.com/brave-intl/bat-go/libs/wallet" + migrate "github.com/golang-migrate/migrate/v4" "github.com/jmoiron/sqlx" "github.com/prometheus/client_golang/prometheus" @@ -255,7 +256,7 @@ func (_d DatastoreWithPrometheus) InsertWalletTx(ctx context.Context, tx *sqlx.T } // LinkWallet implements Datastore -func (_d DatastoreWithPrometheus) LinkWallet(ctx context.Context, id string, providerID string, providerLinkingID uuid.UUID, depositProvider string, country string) (err error) { +func (_d DatastoreWithPrometheus) LinkWallet(ctx context.Context, id string, providerID string, providerLinkingID uuid.UUID, depositProvider string) (err error) { _since := time.Now() defer func() { result := "ok" @@ -265,7 +266,7 @@ func (_d DatastoreWithPrometheus) LinkWallet(ctx context.Context, id string, pro datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "LinkWallet", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.LinkWallet(ctx, id, providerID, providerLinkingID, depositProvider, country) + return _d.base.LinkWallet(ctx, id, providerID, providerLinkingID, depositProvider) } // Migrate implements Datastore diff --git a/services/wallet/model/model.go b/services/wallet/model/model.go index e4220bd8c..27681abbe 100644 --- a/services/wallet/model/model.go +++ b/services/wallet/model/model.go @@ -1,8 +1,53 @@ package model -import "errors" +import ( + "encoding/base64" + "time" -var ErrNoWalletCustodian = errors.New("model: no linked wallet custodian") + uuid "github.com/satori/go.uuid" +) + +const ( + ErrWalletNotWhitelisted Error = "model: wallet not whitelisted" + ErrNotFound Error = "model: not found" + ErrChallengeExpired Error = "model: challenge expired" + ErrNoRowsDeleted Error = "model: no rows deleted" + ErrNotInserted Error = "model: not inserted" + ErrNoWalletCustodian Error = "model: no linked wallet custodian" + ErrInternalServer Error = "model: internal server error" + ErrWalletNotFound Error = "model: wallet not found" +) + +type AllowListEntry struct { + PaymentID uuid.UUID `db:"payment_id"` + CreatedAt time.Time `db:"created_at"` +} + +type Challenge struct { + PaymentID uuid.UUID `db:"payment_id"` + CreatedAt time.Time `db:"created_at"` + Nonce string `db:"nonce"` +} + +func NewChallenge(paymentID uuid.UUID) Challenge { + return Challenge{ + PaymentID: paymentID, + CreatedAt: time.Now(), + Nonce: base64.URLEncoding.EncodeToString(uuid.NewV4().Bytes()), + } +} + +func (c *Challenge) IsValid(now time.Time) error { + if c.hasExpired(now) { + return ErrChallengeExpired + } + return nil +} + +func (c *Challenge) hasExpired(now time.Time) bool { + expiresAt := c.CreatedAt.Add(5 * time.Minute) + return expiresAt.Before(now) +} type Error string diff --git a/services/wallet/model/model_test.go b/services/wallet/model/model_test.go new file mode 100644 index 000000000..6bdf173d8 --- /dev/null +++ b/services/wallet/model/model_test.go @@ -0,0 +1,54 @@ +package model + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestChallenge_IsValid(t *testing.T) { + type tcGiven struct { + chl Challenge + now time.Time + } + + type testCase struct { + name string + given tcGiven + assertErr assert.ErrorAssertionFunc + } + + tests := []testCase{ + { + name: "expired", + given: tcGiven{ + chl: Challenge{ + CreatedAt: time.Now(), + }, + now: time.Now().Add(6 * time.Minute), + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, ErrChallengeExpired) + }, + }, + { + name: "valid", + given: tcGiven{ + chl: Challenge{ + CreatedAt: time.Now(), + }, + now: time.Now(), + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.given.chl.IsValid(tc.given.now) + tc.assertErr(t, err) + }) + } +} diff --git a/services/wallet/service.go b/services/wallet/service.go index eef9c1817..36019122c 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -16,8 +16,11 @@ import ( "github.com/brave-intl/bat-go/services/wallet/metric" "github.com/brave-intl/bat-go/services/wallet/model" + "github.com/brave-intl/bat-go/services/wallet/storage" "github.com/go-chi/chi" + "github.com/go-chi/cors" "github.com/go-jose/go-jose/v3/jwt" + "github.com/jmoiron/sqlx" "github.com/lib/pq" uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" @@ -104,6 +107,16 @@ type GeoValidator interface { Validate(ctx context.Context, geolocation string) (bool, error) } +type challengeRepo interface { + Get(ctx context.Context, dbi sqlx.QueryerContext, paymentID uuid.UUID) (model.Challenge, error) + Upsert(ctx context.Context, dbi sqlx.ExecerContext, chl model.Challenge) error + Delete(ctx context.Context, dbi sqlx.ExecerContext, paymentID uuid.UUID) error +} + +type allowListRepo interface { + GetAllowListEntry(ctx context.Context, dbi sqlx.QueryerContext, paymentID uuid.UUID) (model.AllowListEntry, error) +} + type metricSvc interface { LinkSuccessZP(cc string) LinkFailureZP(cc string) @@ -121,6 +134,8 @@ type geminiSvc interface { type Service struct { Datastore Datastore RoDatastore ReadOnlyDatastore + chlRepo challengeRepo + allowListRepo allowListRepo repClient reputation.Client geminiClient gemini.Client geoValidator GeoValidator @@ -130,20 +145,39 @@ type Service struct { custodianRegions custodian.Regions metric metricSvc gemini geminiSvc + dappConf DAppConfig } -// InitService creates a service using the passed datastore and clients configured from the environment -func InitService(datastore Datastore, roDatastore ReadOnlyDatastore, repClient reputation.Client, geminiClient gemini.Client, geoCountryValidator GeoValidator, retry backoff.RetryFunc, metric metricSvc, gemini geminiSvc) (*Service, error) { +type DAppConfig struct { + AllowedOrigin string +} + +// InitService creates a new instances of the wallet service. +func InitService( + datastore Datastore, + roDatastore ReadOnlyDatastore, + chlRepo challengeRepo, + allowList allowListRepo, + repClient reputation.Client, + geminiClient gemini.Client, + geoCountryValidator GeoValidator, + retry backoff.RetryFunc, + metric metricSvc, + gemini geminiSvc, + dappConf DAppConfig) (*Service, error) { service := &Service{ - crMu: new(sync.RWMutex), - Datastore: datastore, - RoDatastore: roDatastore, - repClient: repClient, - geminiClient: geminiClient, - geoValidator: geoCountryValidator, - retry: retry, - metric: metric, - gemini: gemini, + Datastore: datastore, + RoDatastore: roDatastore, + chlRepo: chlRepo, + allowListRepo: allowList, + repClient: repClient, + geminiClient: geminiClient, + geoValidator: geoCountryValidator, + retry: retry, + metric: metric, + gemini: gemini, + dappConf: dappConf, + crMu: new(sync.RWMutex), } return service, nil } @@ -163,16 +197,19 @@ func (service *Service) ReadableDatastore() ReadOnlyDatastore { // SetupService - create a new wallet service func SetupService(ctx context.Context) (context.Context, *Service) { - logger := logging.Logger(ctx, "wallet.SetupService") + l := logging.Logger(ctx, "wallet.SetupService") + + chlRepo := storage.NewChallenge() + alRepo := storage.NewAllowList() db, err := NewWritablePostgres(viper.GetString("datastore"), false, "wallet_db") if err != nil { - logger.Panic().Err(err).Msg("unable connect to wallet db") + l.Panic().Err(err).Msg("unable connect to wallet db") } roDB, err := NewReadOnlyPostgres(viper.GetString("ro-datastore"), false, "wallet_ro_db") if err != nil { - logger.Panic().Err(err).Msg("unable connect to wallet db") + l.Panic().Err(err).Msg("unable connect to wallet db") } ctx = context.WithValue(ctx, appctx.RODatastoreCTXKey, roDB) @@ -184,7 +221,7 @@ func SetupService(ctx context.Context) (context.Context, *Service) { // jwt key is hex encoded string decodedBitFlyerJWTKey, err := hex.DecodeString(viper.GetString("bitflyer-jwt-key")) if err != nil { - logger.Error().Err(err).Msg("invalid bitflyer jwt key") + l.Error().Err(err).Msg("invalid bitflyer jwt key") } ctx = context.WithValue(ctx, appctx.BitFlyerJWTKeyCTXKey, decodedBitFlyerJWTKey) @@ -192,7 +229,7 @@ func SetupService(ctx context.Context) (context.Context, *Service) { repClient, err := reputation.New() // it's okay to not fatally fail if this environment is local and we cant make a rep client if err != nil && os.Getenv("ENV") != "local" { - logger.Panic().Err(err).Msg("failed to initialize wallet service") + l.Panic().Err(err).Msg("failed to initialize wallet service") } ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, repClient) @@ -201,19 +238,19 @@ func SetupService(ctx context.Context) (context.Context, *Service) { if os.Getenv("GEMINI_ENABLED") == "true" { geminiClient, err = gemini.New() if err != nil { - logger.Panic().Err(err).Msg("failed to create gemini client") + l.Panic().Err(err).Msg("failed to create gemini client") } ctx = context.WithValue(ctx, appctx.GeminiClientCTXKey, geminiClient) } - cfg, err := appaws.BaseAWSConfig(ctx, logger) + cfg, err := appaws.BaseAWSConfig(ctx, l) if err != nil { - logger.Panic().Err(err).Msg("failed to initialize wallet service") + l.Panic().Err(err).Msg("failed to initialize wallet service") } awsClient, err := appaws.NewClient(cfg) if err != nil { - logger.Panic().Err(err).Msg("failed to initialize wallet service") + l.Panic().Err(err).Msg("failed to initialize wallet service") } // put the configured aws client on ctx @@ -222,33 +259,40 @@ func SetupService(ctx context.Context) (context.Context, *Service) { // get the s3 bucket and object bucket, bucketOK := ctx.Value(appctx.ParametersMergeBucketCTXKey).(string) if !bucketOK { - logger.Panic().Err(errors.New("bucket not in context")). + l.Panic().Err(errors.New("bucket not in context")). Msg("failed to initialize wallet service") } object, ok := ctx.Value(appctx.DisabledWalletGeoCountriesCTXKey).(string) if !ok { - logger.Panic().Err(errors.New("wallet geo countries disabled ctx key value not found")). + l.Panic().Err(errors.New("wallet geo countries disabled ctx key value not found")). Msg("failed to initialize wallet service") } - config := Config{ + geoCountryValidator := NewGeoCountryValidator(awsClient, Config{ bucket: bucket, object: object, - } - - geoCountryValidator := NewGeoCountryValidator(awsClient, config) + }) mtc := metric.New() gemx := newGeminix("passport", "drivers_license", "national_identity_card", "passport_card") - s, err := InitService(db, roDB, repClient, geminiClient, geoCountryValidator, backoff.Retry, mtc, gemx) + dappAO := os.Getenv("DAPP_ALLOWED_CORS_ORIGINS") + if dappAO == "" { + l.Panic().Err(errors.New("dapp allowed origins missing")).Msg("failed to initialize wallet service") + } + + dappConf := DAppConfig{ + AllowedOrigin: dappAO, + } + + s, err := InitService(db, roDB, chlRepo, alRepo, repClient, geminiClient, geoCountryValidator, backoff.Retry, mtc, gemx, dappConf) if err != nil { - logger.Panic().Err(err).Msg("failed to initialize wallet service") + l.Panic().Err(err).Msg("failed to initialize wallet service") } _, err = s.RefreshCustodianRegionsWorker(ctx) if err != nil { - logger.Error().Err(err).Msg("failed to initialize custodian regions") + l.Error().Err(err).Msg("failed to initialize custodian regions") } s.jobs = []srv.Job{ @@ -269,7 +313,7 @@ func SetupService(ctx context.Context) (context.Context, *Service) { err = cmd.SetupJobWorkers(ctx, s.Jobs()) if err != nil { - logger.Error().Err(err).Msg("error initializing job workers") + l.Error().Err(err).Msg("error initializing job workers") } return ctx, s @@ -288,7 +332,7 @@ func (service *Service) getCustodianRegions() custodian.Regions { } // RegisterRoutes - register the wallet api routes given a chi.Mux -func RegisterRoutes(ctx context.Context, s *Service, r *chi.Mux) *chi.Mux { +func RegisterRoutes(ctx context.Context, s *Service, r *chi.Mux, metricsMw middleware.InstrumentHandlerDef, dAppCorsMw func(next http.Handler) http.Handler) *chi.Mux { // setup our wallet routes r.Route("/v3/wallet", func(r chi.Router) { // rate limited to 2 per minute... @@ -319,20 +363,20 @@ func RegisterRoutes(ctx context.Context, s *Service, r *chi.Mux) *chi.Mux { "LinkGeminiDepositAccount", LinkGeminiDepositAccountV3(s))).ServeHTTP) r.Post("/zebpay/{paymentID}/connect", middleware.HTTPSignedOnly(s)(middleware.InstrumentHandlerFunc( "LinkZebPayDepositAccount", LinkZebPayDepositAccountV3(s))).ServeHTTP) + + r.Method(http.MethodPost, "/solana/{paymentID}/connect", metricsMw("LinkSolanaAddress", dAppCorsMw(LinkSolanaAddress(s)))) } - r.Get("/linking-info", middleware.SimpleTokenAuthorizedOnly( - middleware.InstrumentHandlerFunc("GetLinkingInfo", GetLinkingInfoV3(s))).ServeHTTP) + r.Get("/linking-info", middleware.SimpleTokenAuthorizedOnly(middleware.InstrumentHandlerFunc("GetLinkingInfo", GetLinkingInfoV3(s))).ServeHTTP) // get wallet routes - r.Get("/{paymentID}", middleware.InstrumentHandlerFunc( - "GetWallet", GetWalletV3)) - r.Get("/recover/{publicKey}", middleware.InstrumentHandlerFunc( - "RecoverWallet", RecoverWalletV3)) + r.Get("/{paymentID}", middleware.InstrumentHandlerFunc("GetWallet", GetWalletV3)) + r.Get("/recover/{publicKey}", middleware.InstrumentHandlerFunc("RecoverWallet", RecoverWalletV3)) // get wallet balance routes - r.Get("/uphold/{paymentID}", middleware.InstrumentHandlerFunc( - "GetUpholdWalletBalance", GetUpholdWalletBalanceV3)) + r.Get("/uphold/{paymentID}", middleware.InstrumentHandlerFunc("GetUpholdWalletBalance", GetUpholdWalletBalanceV3)) + + r.Post("/challenges", middleware.RateLimiter(ctx, 2)(metricsMw("CreateChallenge", dAppCorsMw(CreateChallenge(s)))).ServeHTTP) }) r.Route("/v4/wallets", func(r chi.Router) { @@ -352,6 +396,21 @@ func RegisterRoutes(ctx context.Context, s *Service, r *chi.Mux) *chi.Mux { return r } +// TODO(clD11): WR. Move once we address the rest_run.go and grant.go start functions. + +func NewDAppCorsMw(origin string) func(next http.Handler) http.Handler { + opts := cors.Options{ + Debug: false, + AllowedOrigins: []string{origin}, + AllowedHeaders: []string{"Accept", "Content-Type"}, + ExposedHeaders: []string{""}, + AllowedMethods: []string{http.MethodPost}, + AllowCredentials: false, + MaxAge: 300, + } + return cors.Handler(opts) +} + // SubmitAnonCardTransaction validates and submits a transaction on behalf of an anonymous card func (service *Service) SubmitAnonCardTransaction( ctx context.Context, @@ -421,20 +480,11 @@ func (service *Service) LinkBitFlyerWallet(ctx context.Context, walletID uuid.UU country = "JP" ) - err := validateCustodianLinking(ctx, service.Datastore, walletID, depositProvider) - if err != nil { - if errors.Is(err, errCustodianLinkMismatch) { - return "", errCustodianLinkMismatch - } - return "", handlers.WrapError(err, "failed to check linking mismatch", http.StatusInternalServerError) - } - // In the controller validation, we verified that the account hash and deposit id were signed by bitflyer // we also validated that this "info" signed the request to perform the linking with http signature // we assume that since we got linkingInfo signed from BF that they are KYC providerLinkingID := uuid.NewV5(ClaimNamespace, accountHash) - err = service.Datastore.LinkWallet(ctx, walletID.String(), depositID, providerLinkingID, depositProvider, country) - if err != nil { + if err := service.linkCustodialAccount(ctx, walletID.String(), depositID, providerLinkingID, depositProvider, country); err != nil { if errors.Is(err, ErrUnusualActivity) { return "", handlers.WrapError(err, "unable to link - unusual activity", http.StatusBadRequest) } @@ -468,18 +518,8 @@ func (service *Service) LinkZebPayWallet(ctx context.Context, walletID uuid.UUID return "", err } - err = validateCustodianLinking(ctx, service.Datastore, walletID, depositProvider) - if err != nil { - service.metric.LinkFailureZP(claims.CountryCode) - - if errors.Is(err, errCustodianLinkMismatch) { - return "", errCustodianLinkMismatch - } - return "", handlers.WrapError(err, "failed to check linking mismatch", http.StatusInternalServerError) - } - providerLinkingID := uuid.NewV5(ClaimNamespace, claims.AccountID) - if err := service.Datastore.LinkWallet(ctx, walletID.String(), claims.DepositID, providerLinkingID, depositProvider, claims.CountryCode); err != nil { + if err := service.linkCustodialAccount(ctx, walletID.String(), claims.DepositID, providerLinkingID, depositProvider, claims.CountryCode); err != nil { service.metric.LinkFailureZP(claims.CountryCode) if errors.Is(err, ErrUnusualActivity) { @@ -560,7 +600,7 @@ func (service *Service) LinkGeminiWallet(ctx context.Context, walletID uuid.UUID } service.metric.LinkSuccessGemini(issuingCountry) - if err := service.Datastore.LinkWallet(ctx, walletID.String(), depositID, linkingID, depositProvider, issuingCountry); err != nil { + if err := service.linkCustodialAccount(ctx, walletID.String(), depositID, linkingID, depositProvider, issuingCountry); err != nil { if errors.Is(err, ErrUnusualActivity) { return "", handlers.WrapError(err, "unable to link - unusual activity", http.StatusBadRequest) } @@ -610,14 +650,6 @@ func (service *Service) LinkUpholdWallet(ctx context.Context, wallet uphold.Wall return "", fmt.Errorf("failed to parse uphold id: %w", err) } - err = validateCustodianLinking(ctx, service.Datastore, walletID, depositProvider) - if err != nil { - if errors.Is(err, errCustodianLinkMismatch) { - return "", errCustodianLinkMismatch - } - return "", handlers.WrapError(err, "failed to check linking mismatch", http.StatusInternalServerError) - } - // verify that the user is kyc from uphold. (for all wallet provider cases) if uID, ok, c, err := wallet.IsUserKYC(ctx, transactionInfo.Destination); err != nil { // check if this gemini accountID has already been linked to this wallet, @@ -659,9 +691,7 @@ func (service *Service) LinkUpholdWallet(ctx context.Context, wallet uphold.Wall probi = transactionInfo.Probi providerLinkingID := uuid.NewV5(ClaimNamespace, userID) - // tx.Destination will be stored as UserDepositDestination in the wallet info upon linking - err = service.Datastore.LinkWallet(ctx, info.ID, transactionInfo.Destination, providerLinkingID, depositProvider, country) - if err != nil { + if err := service.linkCustodialAccount(ctx, walletID.String(), transactionInfo.Destination, providerLinkingID, depositProvider, country); err != nil { if errors.Is(err, ErrUnusualActivity) { return "", handlers.WrapError(err, "unable to link - unusual activity", http.StatusBadRequest) } @@ -689,6 +719,87 @@ func (service *Service) LinkUpholdWallet(ctx context.Context, wallet uphold.Wall return country, nil } +func (service *Service) LinkSolanaAddress(ctx context.Context, paymentID uuid.UUID, req linkSolanaAddrRequest) error { + if err := isWalletWhitelisted(ctx, service.Datastore.RawDB(), service.allowListRepo, paymentID); err != nil { + return err + } + + ctx, txn, rollback, commit, err := getTx(ctx, service.Datastore) + if err != nil { + return err + } + defer rollback() + + chl, err := service.chlRepo.Get(ctx, txn, paymentID) + if err != nil { + return err + } + + if err := chl.IsValid(time.Now()); err != nil { + return err + } + + w, err := service.Datastore.GetWallet(ctx, paymentID) + if err != nil { + return err + } + + if w == nil { + return model.ErrWalletNotFound + } + + if err := w.LinkSolanaAddress(ctx, walletutils.SolanaLinkReq{ + Pub: req.SolanaPublicKey, + Sig: req.SolanaSignature, + Msg: req.Message, + Nonce: chl.Nonce, + }); err != nil { + return err + } + + cl := NewSolanaCustodialLink(paymentID, w.UserDepositDestination) + + if err := service.Datastore.LinkWallet(ctx, w.ID, w.UserDepositDestination, *cl.LinkingID, cl.Custodian); err != nil { + return err + } + + if err := service.chlRepo.Delete(ctx, txn, chl.PaymentID); err != nil { + return err + } + + if err := commit(); err != nil { + return err + } + + return nil +} + +func (service *Service) linkCustodialAccount(ctx context.Context, wID string, userDepositDestination string, providerLinkingID uuid.UUID, depositProvider, country string) error { + walletID, err := uuid.FromString(wID) + if err != nil { + return fmt.Errorf("invalid wallet id, not uuid: %w", err) + } + + repClient, ok := ctx.Value(appctx.ReputationClientCTXKey).(reputation.Client) + if !ok { + return ErrNoReputationClient + } + + if _, _, err := repClient.IsLinkingReputable(ctx, walletID, country); err != nil { + return fmt.Errorf("failed to check wallet rep: %w", err) + } + + return service.Datastore.LinkWallet(ctx, walletID.String(), userDepositDestination, providerLinkingID, depositProvider) +} + +func (service *Service) CreateChallenge(ctx context.Context, paymentID uuid.UUID) (model.Challenge, error) { + chl := model.NewChallenge(paymentID) + if err := service.chlRepo.Upsert(ctx, service.Datastore.RawDB(), chl); err != nil { + return model.Challenge{}, fmt.Errorf("error creating challenge: %w", err) + } + return chl, nil +} + // DisconnectCustodianLink - removes the link to the custodian wallet that is active func (service *Service) DisconnectCustodianLink(ctx context.Context, custodian string, walletID uuid.UUID) error { if err := service.Datastore.DisconnectCustodialWallet(ctx, walletID); err != nil { @@ -880,22 +991,13 @@ func (c *claimsZP) validateTime(now time.Time) error { return nil } -func validateCustodianLinking(ctx context.Context, storage Datastore, walletID uuid.UUID, depositProvider string) error { - c, err := storage.GetCustodianLinkByWalletID(ctx, walletID) - if err != nil && !errors.Is(err, model.ErrNoWalletCustodian) { - return err - } - - // if there are no instances of wallet custodian then it is - // considered a new linking and therefore valid. - if c == nil { - return nil - } - - if !strings.EqualFold(c.Custodian, depositProvider) { - return errCustodianLinkMismatch +func isWalletWhitelisted(ctx context.Context, dbi sqlx.QueryerContext, alRepo allowListRepo, paymentID uuid.UUID) error { + if _, err := alRepo.GetAllowListEntry(ctx, dbi, paymentID); err != nil { + if errors.Is(err, model.ErrNotFound) { + return model.ErrWalletNotWhitelisted + } + return fmt.Errorf("error checking allow list entry: %w", err) } - return nil } diff --git a/services/wallet/storage/storage.go b/services/wallet/storage/storage.go new file mode 100644 index 000000000..a09a46683 --- /dev/null +++ b/services/wallet/storage/storage.go @@ -0,0 +1,91 @@ +package storage + +import ( + "context" + "database/sql" + "errors" + + "github.com/brave-intl/bat-go/services/wallet/model" + "github.com/jmoiron/sqlx" + uuid "github.com/satori/go.uuid" +) + +type Challenge struct{} + +func NewChallenge() *Challenge { return &Challenge{} } + +// Get retrieves a model.Challenge from the database by the given paymentID. +func (c *Challenge) Get(ctx context.Context, dbi sqlx.QueryerContext, paymentID uuid.UUID) (model.Challenge, error) { + const q = `select * from challenge where payment_id = $1` + + var result model.Challenge + if err := sqlx.GetContext(ctx, dbi, &result, q, paymentID); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return result, model.ErrNotFound + } + return result, err + } + + return result, nil +} + +// Upsert persists a model.Challenge to the database. +func (c *Challenge) Upsert(ctx context.Context, dbi sqlx.ExecerContext, chl model.Challenge) error { + const q = `insert into challenge (payment_id, created_at, nonce) values($1, $2, $3) on conflict (payment_id) do update set created_at = $2, nonce = $3` + + result, err := dbi.ExecContext(ctx, q, chl.PaymentID, chl.CreatedAt, chl.Nonce) + if err != nil { + return err + } + + row, err := result.RowsAffected() + if err != nil { + return err + } + + if row != 1 { + return model.ErrNotInserted + } + + return nil +} + +// Delete removes a model.Challenge from the database identified by the paymentID. +func (c *Challenge) Delete(ctx context.Context, dbi sqlx.ExecerContext, paymentID uuid.UUID) error { + const q = `delete from challenge where payment_id = $1` + + result, err := dbi.ExecContext(ctx, q, paymentID) + if err != nil { + return err + } + + row, err := result.RowsAffected() + if err != nil { + return err + } + + if row == 0 { + return model.ErrNoRowsDeleted + } + + return nil +} + +type AllowList struct{} + +func NewAllowList() *AllowList { return &AllowList{} } + +// GetAllowListEntry retrieves a model.AllowListEntry from the database for the given paymentID. +func (a *AllowList) GetAllowListEntry(ctx context.Context, dbi sqlx.QueryerContext, paymentID uuid.UUID) (model.AllowListEntry, error) { + const q = `select * from allow_list where payment_id = $1` + + var result model.AllowListEntry + if err := sqlx.GetContext(ctx, dbi, &result, q, paymentID); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return result, model.ErrNotFound + } + return result, err + } + + return result, nil +} diff --git a/services/wallet/storage/storage_test.go b/services/wallet/storage/storage_test.go new file mode 100644 index 000000000..fff574359 --- /dev/null +++ b/services/wallet/storage/storage_test.go @@ -0,0 +1,325 @@ +//go:build integration + +package storage + +import ( + "context" + "testing" + "time" + + "github.com/brave-intl/bat-go/libs/datastore" + "github.com/brave-intl/bat-go/services/wallet/model" + "github.com/jmoiron/sqlx" + uuid "github.com/satori/go.uuid" + should "github.com/stretchr/testify/assert" + must "github.com/stretchr/testify/require" +) + +func TestChallenge_Get(t *testing.T) { + dbi, err := setupDBI() + must.NoError(t, err) + + defer func() { + _, _ = dbi.Exec("TRUNCATE_TABLE challenge;") + }() + + type tcGiven struct { + paymentID uuid.UUID + chal model.Challenge + } + + type exp struct { + chal model.Challenge + err error + } + + type testCase struct { + name string + given tcGiven + exp exp + } + + tests := []testCase{ + { + name: "get", + given: tcGiven{ + paymentID: uuid.FromStringOrNil("f6d43f0c-db24-4d65-9f02-b000ff8ec782"), + chal: model.Challenge{ + PaymentID: uuid.FromStringOrNil("f6d43f0c-db24-4d65-9f02-b000ff8ec782"), + CreatedAt: time.Date(2024, 1, 1, 1, 1, 1, 0, time.UTC), + Nonce: "nonce-1", + }}, + exp: exp{ + chal: model.Challenge{ + PaymentID: uuid.FromStringOrNil("f6d43f0c-db24-4d65-9f02-b000ff8ec782"), + CreatedAt: time.Date(2024, 1, 1, 1, 1, 1, 0, time.UTC), + Nonce: "nonce-1", + }, + err: nil, + }, + }, + { + name: "not_found", + given: tcGiven{ + paymentID: uuid.FromStringOrNil("1b8c218f-2585-49c1-90cd-b82006eb9865"), + chal: model.Challenge{ + PaymentID: uuid.FromStringOrNil("54e4a78e-2c69-4fb2-8d72-47921bb0b374"), + CreatedAt: time.Date(2024, 1, 1, 1, 1, 1, 0, time.UTC), + Nonce: "nonce-2", + }}, + exp: exp{ + err: model.ErrNotFound, + }, + }, + } + + const q = `insert into challenge (payment_id, created_at, nonce) values($1, $2, $3)` + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + _, err1 := dbi.ExecContext(ctx, q, tc.given.chal.PaymentID, tc.given.chal.CreatedAt, tc.given.chal.Nonce) + must.NoError(t, err1) + + c := Challenge{} + actual, err2 := c.Get(ctx, dbi, tc.given.paymentID) + + should.Equal(t, tc.exp.err, err2) + should.Equal(t, tc.exp.chal.PaymentID, actual.PaymentID) + should.Equal(t, tc.exp.chal.CreatedAt, actual.CreatedAt) + should.Equal(t, tc.exp.chal.Nonce, actual.Nonce) + }) + } +} + +func TestChallenge_Upsert(t *testing.T) { + dbi, err := setupDBI() + must.NoError(t, err) + + defer func() { + _, _ = dbi.Exec("TRUNCATE_TABLE challenge;") + }() + + const q = `select * from challenge where payment_id = $1` + + chlRepo := &Challenge{} + + t.Run("insert", func(t *testing.T) { + exp := model.Challenge{ + PaymentID: uuid.FromStringOrNil("66e4751f-cd72-4bb0-aebd-66c50a2e8c45"), + CreatedAt: time.Date(2024, 1, 1, 1, 1, 1, 0, time.UTC), + Nonce: "a", + } + + err1 := chlRepo.Upsert(context.TODO(), dbi, exp) + must.NoError(t, err1) + + var actual model.Challenge + err2 := sqlx.GetContext(context.TODO(), dbi, &actual, q, exp.PaymentID) + must.NoError(t, err2) + + should.Equal(t, exp.PaymentID, actual.PaymentID) + should.Equal(t, exp.CreatedAt, actual.CreatedAt) + should.Equal(t, exp.Nonce, actual.Nonce) + }) + + t.Run("upsert", func(t *testing.T) { + ctx := context.Background() + + exp := model.Challenge{ + PaymentID: uuid.FromStringOrNil("66e4751f-cd72-4bb0-aebd-66c50a2e8c45"), + CreatedAt: time.Date(2024, 12, 1, 1, 1, 1, 0, time.UTC), + Nonce: "b", + } + + err1 := chlRepo.Upsert(ctx, dbi, exp) + must.NoError(t, err1) + + var actual model.Challenge + err2 := sqlx.GetContext(ctx, dbi, &actual, q, exp.PaymentID) + must.NoError(t, err2) + + should.Equal(t, exp.PaymentID, actual.PaymentID) + should.Equal(t, exp.CreatedAt, actual.CreatedAt) + should.Equal(t, exp.Nonce, actual.Nonce) + }) +} + +func TestChallenge_Delete(t *testing.T) { + dbi, err := setupDBI() + must.NoError(t, err) + + defer func() { + _, _ = dbi.Exec("TRUNCATE_TABLE challenge;") + }() + + type tcGiven struct { + paymentID uuid.UUID + chal model.Challenge + } + + type exp struct { + err error + } + + type testCase struct { + name string + given tcGiven + exp exp + } + + tests := []testCase{ + { + name: "delete", + given: tcGiven{ + paymentID: uuid.FromStringOrNil("66e4751f-cd72-4bb0-aebd-66c50a2e8c45"), + chal: model.Challenge{ + PaymentID: uuid.FromStringOrNil("66e4751f-cd72-4bb0-aebd-66c50a2e8c45"), + Nonce: "nonce-1", + }}, + exp: exp{ + err: nil, + }, + }, + { + name: "no_rows_deleted", + given: tcGiven{ + paymentID: uuid.FromStringOrNil("34fe675b-aebf-4209-90b6-a7ba4452087a"), + chal: model.Challenge{ + PaymentID: uuid.FromStringOrNil("f6d43f0c-db24-4d65-9f02-b000ff8ec782"), + CreatedAt: time.Date(2024, 1, 1, 1, 1, 1, 0, time.UTC), + Nonce: "nonce-2", + }}, + exp: exp{ + err: model.ErrNoRowsDeleted, + }, + }, + } + + const q = `insert into challenge (payment_id, created_at, nonce) values($1, $2, $3)` + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + _, err1 := dbi.ExecContext(ctx, q, tc.given.chal.PaymentID, tc.given.chal.CreatedAt, tc.given.chal.Nonce) + must.NoError(t, err1) + + c := Challenge{} + err2 := c.Delete(ctx, dbi, tc.given.paymentID) + should.Equal(t, tc.exp.err, err2) + }) + } +} + +func TestAllowList_GetAllowListEntry(t *testing.T) { + dbi, err := setupDBI() + must.NoError(t, err) + + defer func() { + _, _ = dbi.Exec("TRUNCATE_TABLE allow_list;") + }() + + type tcGiven struct { + paymentID uuid.UUID + allow model.AllowListEntry + } + + type exp struct { + allow model.AllowListEntry + err error + } + + type testCase struct { + name string + given tcGiven + exp exp + } + + tests := []testCase{ + { + name: "success", + given: tcGiven{ + paymentID: uuid.FromStringOrNil("6d85a314-0fa8-4594-9cb9-c9141b61a887"), + allow: model.AllowListEntry{ + PaymentID: uuid.FromStringOrNil("6d85a314-0fa8-4594-9cb9-c9141b61a887"), + CreatedAt: time.Date(2024, 1, 1, 1, 1, 1, 0, time.UTC), + }, + }, + exp: exp{ + allow: model.AllowListEntry{ + PaymentID: uuid.FromStringOrNil("6d85a314-0fa8-4594-9cb9-c9141b61a887"), + CreatedAt: time.Date(2024, 1, 1, 1, 1, 1, 0, time.UTC), + }, + err: nil, + }, + }, + { + name: "not_found", + given: tcGiven{ + paymentID: uuid.NewV4(), + }, + exp: exp{ + err: model.ErrNotFound, + }, + }, + } + + const q = `insert into allow_list (payment_id, created_at) values($1, $2)` + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + _, err1 := dbi.ExecContext(ctx, q, tc.given.allow.PaymentID, tc.given.allow.CreatedAt) + must.NoError(t, err1) + + a := &AllowList{} + actual, err2 := a.GetAllowListEntry(ctx, dbi, tc.given.paymentID) + should.Equal(t, tc.exp.err, err2) + should.Equal(t, tc.exp.allow, actual) + }) + } +} + +func setupDBI() (*sqlx.DB, error) { + pg, err := datastore.NewPostgres("", false, "") + if err != nil { + return nil, err + } + + mg, err := pg.NewMigrate() + if err != nil { + return nil, err + } + + ver, dirty, err := mg.Version() + if err != nil { + return nil, err + } + + if dirty { + if err := mg.Force(int(ver)); err != nil { + return nil, err + } + } + + if ver > 0 { + if err := mg.Down(); err != nil { + return nil, err + } + } + + if err := pg.Migrate(); err != nil { + return nil, err + } + + return pg.RawDB(), nil +} diff --git a/tools/go.mod b/tools/go.mod index 968705fe8..1b0f5c17b 100644 --- a/tools/go.mod +++ b/tools/go.mod @@ -97,6 +97,7 @@ require ( github.com/fatih/color v1.15.0 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-chi/chi v4.1.2+incompatible // indirect + github.com/go-chi/cors v1.2.1 // indirect github.com/go-jose/go-jose/v3 v3.0.1 // indirect github.com/go-openapi/analysis v0.21.4 // indirect github.com/go-openapi/errors v0.20.3 // indirect diff --git a/tools/go.sum b/tools/go.sum index d64e2a4f7..86753a598 100644 --- a/tools/go.sum +++ b/tools/go.sum @@ -565,6 +565,8 @@ github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2H github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-chi/chi v4.1.2+incompatible h1:fGFk2Gmi/YKXk0OmGfBh0WgmN3XB8lVnEyNz34tQRec= github.com/go-chi/chi v4.1.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= +github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= +github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g= github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3T0ecnM9pNujks= From dcb328fccdb76b47c70eeb45256b3443f2c1ba57 Mon Sep 17 00:00:00 2001 From: Anirudha Bose Date: Wed, 10 Jan 2024 20:52:19 +0530 Subject: [PATCH 02/18] Add bonk mapping for ratios service (#2283) --- services/ratios/mapping.go | 1 + 1 file changed, 1 insertion(+) diff --git a/services/ratios/mapping.go b/services/ratios/mapping.go index d15ff32c7..3318525d5 100644 --- a/services/ratios/mapping.go +++ b/services/ratios/mapping.go @@ -39,6 +39,7 @@ var ( "bnb": "binancecoin", "boa": "bosagora", "bob": "bobs_repair", + "bonk": "bonk", "boo": "boo", "booty": "candybooty", "box": "box-token", From 45f5c2760c06f0257ff8318d7745c9b7dfa5605f Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Wed, 10 Jan 2024 17:39:41 +0000 Subject: [PATCH 03/18] fix: add submit receipt validation error check (#2285) --- services/skus/controllers.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/services/skus/controllers.go b/services/skus/controllers.go index c1642d835..8cec0f3c4 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -14,7 +14,7 @@ import ( "github.com/go-chi/chi" "github.com/go-chi/cors" uuid "github.com/satori/go.uuid" - stripe "github.com/stripe/stripe-go/v72" + "github.com/stripe/stripe-go/v72" "github.com/stripe/stripe-go/v72/webhook" "github.com/brave-intl/bat-go/libs/clients/radom" @@ -1296,6 +1296,10 @@ func SubmitReceipt(service *Service) handlers.AppHandler { validationErrMap["request-body"] = err.Error() } + if len(validationErrMap) != 0 { + return handlers.ValidationError("Error validating request", validationErrMap) + } + // validate the receipt externalID, err := service.validateReceipt(ctx, orderID.UUID(), req) if err != nil { From 875466d5f550f8660228341be9ac81b2ccd50a4f Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Wed, 10 Jan 2024 21:27:58 +0000 Subject: [PATCH 04/18] Add self custody available field to get wallet endpoint (#2287) * feat: add self custody available field v4 get wallet * feat: add self custody available field v4 get wallet --- services/wallet/controllers_v4.go | 46 +++++++++++++++-- services/wallet/controllers_v4_test.go | 48 ++++++++++++++++++ services/wallet/outputs.go | 68 ++++++++++++++++++++++++++ services/wallet/service.go | 2 +- 4 files changed, 158 insertions(+), 6 deletions(-) diff --git a/services/wallet/controllers_v4.go b/services/wallet/controllers_v4.go index 8e876cd8f..e4a27b688 100644 --- a/services/wallet/controllers_v4.go +++ b/services/wallet/controllers_v4.go @@ -11,8 +11,10 @@ import ( errorutils "github.com/brave-intl/bat-go/libs/errors" "github.com/brave-intl/bat-go/libs/handlers" "github.com/brave-intl/bat-go/libs/httpsignature" + "github.com/brave-intl/bat-go/libs/inputs" "github.com/brave-intl/bat-go/libs/logging" "github.com/brave-intl/bat-go/libs/middleware" + "github.com/brave-intl/bat-go/services/wallet/model" "github.com/go-chi/chi" ) @@ -167,12 +169,46 @@ func UpdateWalletV4(s *Service) func(w http.ResponseWriter, r *http.Request) *ha } } -// GetWalletV4 is the same as get wallet v3, but we are now requiring http signatures for get wallet requests -func GetWalletV4(w http.ResponseWriter, r *http.Request) *handlers.AppError { - return GetWalletV3(w, r) -} - // GetUpholdWalletBalanceV4 produces an http handler for the service s which handles balance inquiries of uphold wallets func GetUpholdWalletBalanceV4(w http.ResponseWriter, r *http.Request) *handlers.AppError { return GetUpholdWalletBalanceV3(w, r) } + +func GetWalletV4(s *Service) func(w http.ResponseWriter, r *http.Request) *handlers.AppError { + return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { + var ctx = r.Context() + + l := logging.Logger(ctx, "wallet") + + var id inputs.ID + if err := inputs.DecodeAndValidateString(ctx, &id, chi.URLParam(r, "paymentID")); err != nil { + l.Warn().Err(err).Str("paymentID", id.String()).Msg("failed to decode and validate paymentID from url") + return handlers.ValidationError("Error validating paymentID url parameter", map[string]interface{}{ + "paymentId": err.Error(), + }) + } + + paymentID := *id.UUID() + + info, err := s.Datastore.GetWallet(ctx, paymentID) + if err != nil { + l.Error().Err(err).Str("paymentID", id.String()).Msg("error getting wallet") + return handlers.WrapError(err, "error getting wallet from storage", http.StatusInternalServerError) + } + + if info == nil { + l.Info().Interface("paymentID", paymentID).Msg("wallet not found") + return handlers.WrapError(err, "no such wallet", http.StatusNotFound) + } + + if _, err := s.allowListRepo.GetAllowListEntry(ctx, s.Datastore.RawDB(), paymentID); err != nil && !errors.Is(err, model.ErrNotFound) { + return handlers.WrapError(err, "error getting allow list entry from storage", http.StatusInternalServerError) + } + + solSelfCustody := !errors.Is(err, model.ErrNotFound) + + resp := infoToResponseV4(info, solSelfCustody) + + return handlers.RenderContent(ctx, resp, w, http.StatusOK) + } +} diff --git a/services/wallet/controllers_v4_test.go b/services/wallet/controllers_v4_test.go index 3efbb8072..f47eee140 100644 --- a/services/wallet/controllers_v4_test.go +++ b/services/wallet/controllers_v4_test.go @@ -7,6 +7,7 @@ import ( "context" "crypto" "crypto/ed25519" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -14,8 +15,10 @@ import ( "net/http/httptest" "testing" + appctx "github.com/brave-intl/bat-go/libs/context" errorutils "github.com/brave-intl/bat-go/libs/errors" "github.com/brave-intl/bat-go/libs/middleware" + "github.com/brave-intl/bat-go/services/wallet/storage" "github.com/brave-intl/bat-go/libs/clients" @@ -558,6 +561,51 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_ReputationCal suite.Require().Equal(http.StatusInternalServerError, rw.Code) } +func (suite *WalletControllersTestSuite) TestGetWalletV4() { + pg, _, err := wallet.NewPostgres() + suite.Require().NoError(err) + + pub, _, err := ed25519.GenerateKey(nil) + suite.Require().NoError(err) + + paymentID := uuid.NewV4() + w := &walletutils.Info{ + ID: paymentID.String(), + Provider: "brave", + PublicKey: hex.EncodeToString(pub), + AltCurrency: ptrTo(altcurrency.BAT), + } + err = pg.InsertWallet(context.TODO(), w) + suite.Require().NoError(err) + + whitelistWallet(suite.T(), pg, w.ID) + + allowList := storage.NewAllowList() + + service, _ := wallet.InitService(pg, nil, nil, allowList, nil, nil, nil, nil, nil, nil, wallet.DAppConfig{}) + + handler := handlers.AppHandler(wallet.GetWalletV4(service)) + + req, err := http.NewRequest("GET", "/v4/wallets/"+paymentID.String(), nil) + suite.Require().NoError(err, "a request should be created") + + req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) + rctx := chi.NewRouteContext() + rctx.URLParams.Add("paymentID", paymentID.String()) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + suite.Assert().Equal(http.StatusOK, rr.Code) + + var resp wallet.ResponseV4 + err = json.Unmarshal(rr.Body.Bytes(), &resp) + suite.Require().NoError(err) + + suite.Assert().Equal(true, resp.SelfCustodyAvailable["solana"]) +} + func signUpdateRequest(req *http.Request, paymentID string, privateKey ed25519.PrivateKey) error { var s httpsignature.SignatureParams s.Algorithm = httpsignature.ED25519 diff --git a/services/wallet/outputs.go b/services/wallet/outputs.go index 48dde81a6..33eb8f20a 100644 --- a/services/wallet/outputs.go +++ b/services/wallet/outputs.go @@ -163,6 +163,74 @@ func infoToResponseV3(info *walletutils.Info) ResponseV3 { return resp } +// ResponseV4 - wallet creation response +type ResponseV4 struct { + PaymentID string `json:"paymentId"` + DepositAccountProvider *DepositAccountProviderDetailsV3 `json:"depositAccountProvider,omitempty"` + WalletProvider *ProviderDetailsV3 `json:"walletProvider,omitempty"` + AltCurrency string `json:"altcurrency"` + PublicKey string `json:"publicKey"` + SelfCustodyAvailable map[string]bool `json:"selfCustodyAvailable"` +} + +func infoToResponseV4(info *walletutils.Info, selfCustody bool) ResponseV4 { + var ( + linkingID string + anonymousAddress string + ) + if info == nil { + return ResponseV4{} + } + + var altCurrency = convertAltCurrency(info.AltCurrency) + + if info.ProviderLinkingID == nil { + linkingID = "" + } else { + linkingID = info.ProviderLinkingID.String() + } + + if info.AnonymousAddress == nil { + anonymousAddress = "" + } else { + anonymousAddress = info.AnonymousAddress.String() + } + + // common to all wallets + resp := ResponseV4{ + PaymentID: info.ID, + AltCurrency: altCurrency, + PublicKey: info.PublicKey, + WalletProvider: &ProviderDetailsV3{ + Name: info.Provider, + }, + SelfCustodyAvailable: map[string]bool{ + "solana": selfCustody, + }, + } + + // setup the wallet provider (anon card uphold) + if info.Provider == "uphold" { + // this is a uphold provided wallet (anon card based) + resp.WalletProvider.ID = info.ProviderID + resp.WalletProvider.AnonymousAddress = anonymousAddress + resp.WalletProvider.LinkingID = linkingID + } + + // now setup user deposit account + if info.UserDepositAccountProvider != nil { + // this brave wallet has a linked deposit account + resp.DepositAccountProvider = &DepositAccountProviderDetailsV3{ + Name: info.UserDepositAccountProvider, + ID: &info.UserDepositDestination, + LinkingID: linkingID, + AnonymousAddress: anonymousAddress, + } + } + + return resp +} + // BalanceResponseV3 - wallet creation response type BalanceResponseV3 struct { Total float64 `json:"total"` diff --git a/services/wallet/service.go b/services/wallet/service.go index 36019122c..aea81c473 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -386,7 +386,7 @@ func RegisterRoutes(ctx context.Context, s *Service, r *chi.Mux, metricsMw middl "UpdateWalletV4", UpdateWalletV4(s))).ServeHTTP) r.Get("/{paymentID}", middleware.HTTPSignedOnly(s)(middleware.InstrumentHandlerFunc( - "GetWalletV4", GetWalletV4)).ServeHTTP) + "GetWalletV4", GetWalletV4(s))).ServeHTTP) // get wallet balance routes r.Get("/uphold/{paymentID}", middleware.HTTPSignedOnly(s)(middleware.InstrumentHandlerFunc( From 8d59e8baee1b9edc2303f702d9a5762840f4c031 Mon Sep 17 00:00:00 2001 From: Pavel Brm <5097196+pavelbrm@users.noreply.github.com> Date: Thu, 11 Jan 2024 21:37:24 +1300 Subject: [PATCH 05/18] Set proper type (#2289) --- services/skus/input.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/services/skus/input.go b/services/skus/input.go index 7ade6c47a..de6615d6c 100644 --- a/services/skus/input.go +++ b/services/skus/input.go @@ -130,7 +130,7 @@ func credentialOpaqueFromString(s string) (*VerifyCredentialOpaque, error) { const ( appleVendor Vendor = "ios" - googleVendor = "android" + googleVendor Vendor = "android" ) var errInvalidVendor = errors.New("invalid vendor") From d70fe73489674a26a27a39c702020a1d70509c55 Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Thu, 11 Jan 2024 09:54:20 +0000 Subject: [PATCH 06/18] refactor: remove extra validation check from submit receipt endpoint (#2290) --- services/skus/controllers.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/services/skus/controllers.go b/services/skus/controllers.go index 8cec0f3c4..c7f3aafa4 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -1323,10 +1323,6 @@ func SubmitReceipt(service *Service) handlers.AppHandler { } } - // if we had any validation errors, return the validation error map to the caller - if len(validationErrMap) != 0 { - return handlers.ValidationError("error validating request", validationErrMap) - } // does this external id exist already exists, err := service.ExternalIDExists(ctx, externalID) if err != nil { From 360e2a33d4afef0f789a720d20be45dcedf34e7c Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Thu, 11 Jan 2024 19:58:20 +0000 Subject: [PATCH 07/18] fix: add logging to submit receipt (#2293) --- services/skus/controllers.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/services/skus/controllers.go b/services/skus/controllers.go index c7f3aafa4..9c564fe9c 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -1289,6 +1289,9 @@ func SubmitReceipt(service *Service) handlers.AppHandler { validationErrMap["request-body"] = err.Error() } + // TODO(clD11): remove when no longer needed + logger.Info().Interface("payload_byte", payload).Str("payload_str", string(payload)).Msg("payload") + // validate the payload if err := inputs.DecodeAndValidate(context.Background(), &req, payload); err != nil { logger.Debug().Str("payload", string(payload)).Msg("Failed to decode and validate the payload") @@ -1296,6 +1299,9 @@ func SubmitReceipt(service *Service) handlers.AppHandler { validationErrMap["request-body"] = err.Error() } + // TODO(clD11): remove when no longer needed + logger.Info().Interface("req_decoded", req).Msg("req decoded") + if len(validationErrMap) != 0 { return handlers.ValidationError("Error validating request", validationErrMap) } From 79ae3a26789b5e35071196d9a75f99e6c6eebca2 Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Thu, 11 Jan 2024 20:09:59 +0000 Subject: [PATCH 08/18] refactor: add specific challenge not found error (#2291) --- services/wallet/controller_v3_test.go | 2 +- services/wallet/controllers_v3.go | 4 +++- services/wallet/model/model.go | 1 + services/wallet/storage/storage.go | 2 +- services/wallet/storage/storage_test.go | 4 ++-- 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/services/wallet/controller_v3_test.go b/services/wallet/controller_v3_test.go index e3ada3402..aae26fb42 100644 --- a/services/wallet/controller_v3_test.go +++ b/services/wallet/controller_v3_test.go @@ -538,7 +538,7 @@ func (suite *WalletControllersTestSuite) TestLinkSolanaAddress_Success() { // after a successful linking the challenge should be removed from the database. _, actualErr := chlRep.Get(context.TODO(), pg.RawDB(), paymentID) - suite.Assert().ErrorIs(actualErr, model.ErrNotFound) + suite.Assert().ErrorIs(actualErr, model.ErrChallengeNotFound) } func whitelistWallet(t *testing.T, pg wallet.Datastore, Id string) { diff --git a/services/wallet/controllers_v3.go b/services/wallet/controllers_v3.go index ad0697015..bbefc2285 100644 --- a/services/wallet/controllers_v3.go +++ b/services/wallet/controllers_v3.go @@ -498,6 +498,8 @@ func LinkSolanaAddress(s *Service) handlers.AppHandler { switch { case errors.Is(err, model.ErrWalletNotWhitelisted): return handlers.WrapError(model.ErrWalletNotWhitelisted, "rewards wallet not whitelisted", http.StatusForbidden) + case errors.Is(err, model.ErrChallengeNotFound): + return handlers.WrapError(model.ErrChallengeNotFound, "linking challenge not found", http.StatusNotFound) case errors.Is(err, model.ErrChallengeExpired): return handlers.WrapError(model.ErrChallengeExpired, "linking challenge expired", http.StatusUnauthorized) case errors.Is(err, model.ErrWalletNotFound): @@ -505,7 +507,7 @@ func LinkSolanaAddress(s *Service) handlers.AppHandler { case errors.As(err, &solErr): return handlers.WrapError(solErr, "invalid solana linking message", http.StatusUnauthorized) default: - return handlers.WrapError(model.ErrInternalServer, "internal server error", http.StatusInternalServerError) + return handlers.WrapError(err, "internal server error", http.StatusInternalServerError) } } diff --git a/services/wallet/model/model.go b/services/wallet/model/model.go index 27681abbe..ea76f1170 100644 --- a/services/wallet/model/model.go +++ b/services/wallet/model/model.go @@ -10,6 +10,7 @@ import ( const ( ErrWalletNotWhitelisted Error = "model: wallet not whitelisted" ErrNotFound Error = "model: not found" + ErrChallengeNotFound Error = "model: challenge not found" ErrChallengeExpired Error = "model: challenge expired" ErrNoRowsDeleted Error = "model: no rows deleted" ErrNotInserted Error = "model: not inserted" diff --git a/services/wallet/storage/storage.go b/services/wallet/storage/storage.go index a09a46683..c223e3752 100644 --- a/services/wallet/storage/storage.go +++ b/services/wallet/storage/storage.go @@ -21,7 +21,7 @@ func (c *Challenge) Get(ctx context.Context, dbi sqlx.QueryerContext, paymentID var result model.Challenge if err := sqlx.GetContext(ctx, dbi, &result, q, paymentID); err != nil { if errors.Is(err, sql.ErrNoRows) { - return result, model.ErrNotFound + return result, model.ErrChallengeNotFound } return result, err } diff --git a/services/wallet/storage/storage_test.go b/services/wallet/storage/storage_test.go index fff574359..d69e5473f 100644 --- a/services/wallet/storage/storage_test.go +++ b/services/wallet/storage/storage_test.go @@ -59,7 +59,7 @@ func TestChallenge_Get(t *testing.T) { }, }, { - name: "not_found", + name: "challenge_not_found", given: tcGiven{ paymentID: uuid.FromStringOrNil("1b8c218f-2585-49c1-90cd-b82006eb9865"), chal: model.Challenge{ @@ -68,7 +68,7 @@ func TestChallenge_Get(t *testing.T) { Nonce: "nonce-2", }}, exp: exp{ - err: model.ErrNotFound, + err: model.ErrChallengeNotFound, }, }, } From df58e91bfe61eac0fbd5fc9160c8e8352a73eb06 Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Mon, 15 Jan 2024 12:16:26 +0000 Subject: [PATCH 09/18] refactor: use generic internal server error on sol link endpoint (#2300) --- services/wallet/controllers_v3.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/services/wallet/controllers_v3.go b/services/wallet/controllers_v3.go index bbefc2285..d2fa1a3a0 100644 --- a/services/wallet/controllers_v3.go +++ b/services/wallet/controllers_v3.go @@ -507,7 +507,7 @@ func LinkSolanaAddress(s *Service) handlers.AppHandler { case errors.As(err, &solErr): return handlers.WrapError(solErr, "invalid solana linking message", http.StatusUnauthorized) default: - return handlers.WrapError(err, "internal server error", http.StatusInternalServerError) + return handlers.WrapError(model.ErrInternalServer, "internal server error", http.StatusInternalServerError) } } From 29be8e728bfca6076037aee22efbbefb9c5e3bd1 Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Mon, 15 Jan 2024 12:49:02 +0000 Subject: [PATCH 10/18] feat: allow multiple cors origins for dapp (#2299) * feat: allow multiple cors origins for dapp * chore: rename test filename to match src filename * refactor: improve is allowed origin test * refactor: rename dapp cors config as plural --- services/grant/cmd/grant.go | 6 +- services/wallet/cmd/rest_run.go | 7 +- services/wallet/controllers_v3.go | 17 +- services/wallet/controllers_v3_pvt_test.go | 157 +++++++++++------- ...ller_v3_test.go => controllers_v3_test.go} | 10 +- services/wallet/controllers_v4_test.go | 8 + services/wallet/datastore_test.go | 2 +- services/wallet/service.go | 12 +- 8 files changed, 142 insertions(+), 77 deletions(-) rename services/wallet/{controller_v3_test.go => controllers_v3_test.go} (98%) diff --git a/services/grant/cmd/grant.go b/services/grant/cmd/grant.go index f34dfa2f9..f0fdb45c2 100644 --- a/services/grant/cmd/grant.go +++ b/services/grant/cmd/grant.go @@ -355,11 +355,11 @@ func setupRouter(ctx context.Context, logger *zerolog.Logger) (context.Context, // this way we can have the wallet service completely separated from // grants service and easily deployable. ctx, walletService = wallet.SetupService(ctx) - origin := os.Getenv("DAPP_ALLOWED_CORS_ORIGINS") - if origin == "" { + dappAO := strings.Split(os.Getenv("DAPP_ALLOWED_CORS_ORIGINS"), ",") + if len(dappAO) == 0 { logger.Panic().Msg("dapp origin env missing") } - r = wallet.RegisterRoutes(ctx, walletService, r, middleware.InstrumentHandler, wallet.NewDAppCorsMw(origin)) + r = wallet.RegisterRoutes(ctx, walletService, r, middleware.InstrumentHandler, wallet.NewDAppCorsMw(dappAO)) promotionDB, promotionRODB, err := promotion.NewPostgres() if err != nil { diff --git a/services/wallet/cmd/rest_run.go b/services/wallet/cmd/rest_run.go index 279233909..900694e7a 100644 --- a/services/wallet/cmd/rest_run.go +++ b/services/wallet/cmd/rest_run.go @@ -5,6 +5,7 @@ import ( // pprof imports _ "net/http/pprof" "os" + "strings" "time" cmdutils "github.com/brave-intl/bat-go/cmd" @@ -29,12 +30,12 @@ func WalletRestRun(command *cobra.Command, args []string) { router := cmd.SetupRouter(ctx) - origin := os.Getenv("DAPP_ALLOWED_CORS_ORIGINS") - if origin == "" { + dappAO := strings.Split(os.Getenv("DAPP_ALLOWED_CORS_ORIGINS"), ",") + if len(dappAO) == 0 { logger.Panic().Msg("dapp origin env missing") } - wallet.RegisterRoutes(ctx, service, router, middleware.InstrumentHandler, wallet.NewDAppCorsMw(origin)) + wallet.RegisterRoutes(ctx, service, router, middleware.InstrumentHandler, wallet.NewDAppCorsMw(dappAO)) // add profiling flag to enable profiling routes if viper.GetString("pprof-enabled") != "" { diff --git a/services/wallet/controllers_v3.go b/services/wallet/controllers_v3.go index d2fa1a3a0..8bbae9a5d 100644 --- a/services/wallet/controllers_v3.go +++ b/services/wallet/controllers_v3.go @@ -467,7 +467,8 @@ func LinkSolanaAddress(s *Service) handlers.AppHandler { l := logging.Logger(ctx, "wallet") - if o := r.Header.Get("Origin"); o != s.dappConf.AllowedOrigin { + o := r.Header.Get("Origin") + if !isAllowedOrigin(o, s.dappConf.AllowedOrigins) { l.Error().Err(errOriginForbidden).Str("origin", strOr(o, "empty")).Msg("error linking solana address") return handlers.WrapError(errOriginForbidden, "request origin forbidden", http.StatusForbidden) } @@ -817,6 +818,20 @@ func DisconnectCustodianLinkV3(s *Service) func(w http.ResponseWriter, r *http.R } } +func isAllowedOrigin(origin string, allowedOrigins []string) bool { + if origin == "" { + return false + } + + for i := range allowedOrigins { + if allowedOrigins[i] == origin { + return true + } + } + + return false +} + func strOr(a string, b string) string { if a == "" { return b diff --git a/services/wallet/controllers_v3_pvt_test.go b/services/wallet/controllers_v3_pvt_test.go index 7f7cf7275..952e80ffb 100644 --- a/services/wallet/controllers_v3_pvt_test.go +++ b/services/wallet/controllers_v3_pvt_test.go @@ -1,4 +1,4 @@ -package wallet_test +package wallet import ( "bytes" @@ -28,7 +28,6 @@ import ( "github.com/brave-intl/bat-go/libs/httpsignature" "github.com/brave-intl/bat-go/libs/logging" "github.com/brave-intl/bat-go/libs/middleware" - "github.com/brave-intl/bat-go/services/wallet" "github.com/go-chi/chi" "github.com/golang/mock/gomock" "github.com/jmoiron/sqlx" @@ -44,15 +43,15 @@ func TestCreateBraveWalletV3(t *testing.T) { defer mockCtrl.Finish() var ( db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ + datastore = Datastore( + &Postgres{ Postgres: datastoreutils.Postgres{ DB: sqlx.NewDb(db, "postgres"), }, }) // add the datastore to the context ctx = context.Background() - handler = wallet.CreateBraveWalletV3 + handler = CreateBraveWalletV3 r = httptest.NewRequest("POST", "/v3/wallet/brave", nil) ) // no logger, setup @@ -66,10 +65,10 @@ func TestCreateBraveWalletV3(t *testing.T) { // setup keypair publicKey, privKey, err := httpsignature.GenerateEd25519Key(nil) - must(t, "failed to generate keypair", err) + require.NoError(t, err) err = signRequest(r, publicKey, privKey) - must(t, "failed to sign request", err) + require.NoError(t, err) r = r.WithContext(ctx) @@ -85,15 +84,15 @@ func TestCreateUpholdWalletV3(t *testing.T) { defer mockCtrl.Finish() var ( db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ + datastore = Datastore( + &Postgres{ Postgres: datastoreutils.Postgres{ DB: sqlx.NewDb(db, "postgres"), }, }) // add the datastore to the context ctx = context.Background() - handler = wallet.CreateUpholdWalletV3 + handler = CreateUpholdWalletV3 r = httptest.NewRequest("POST", "/v3/wallet/uphold", bytes.NewBufferString(`{ "signedCreationRequest": "eyJib2R5Ijp7ImRlbm9taW5hdGlvbiI6eyJhbW91bnQiOiIwIiwiY3VycmVuY3kiOiJCQVQifSwiZGVzdGluYXRpb24iOiJhNmRmZjJiYS1kMGQxLTQxYzQtOGU1Ni1hMjYwNWJjYWY0YWYifSwiaGVhZGVycyI6eyJkaWdlc3QiOiJTSEEtMjU2PWR2RTAzVHdpRmFSR0c0MUxLSkR4aUk2a3c5M0h0cTNsclB3VllldE5VY1E9Iiwic2lnbmF0dXJlIjoia2V5SWQ9XCJwcmltYXJ5XCIsYWxnb3JpdGhtPVwiZWQyNTUxOVwiLGhlYWRlcnM9XCJkaWdlc3RcIixzaWduYXR1cmU9XCJkcXBQdERESXE0djNiS1V5eHB6Q3Vyd01nSzRmTWk1MUJjakRLc2pTak90K1h1MElZZlBTMWxEZ01aRkhiaWJqcGh0MVd3V3l5enFad3lVNW0yN1FDUT09XCIifSwib2N0ZXRzIjoie1wiZGVub21pbmF0aW9uXCI6e1wiYW1vdW50XCI6XCIwXCIsXCJjdXJyZW5jeVwiOlwiQkFUXCJ9LFwiZGVzdGluYXRpb25cIjpcImE2ZGZmMmJhLWQwZDEtNDFjNC04ZTU2LWEyNjA1YmNhZjRhZlwifSJ9"}`)) ) @@ -118,14 +117,14 @@ func TestCreateUpholdWalletV3(t *testing.T) { func TestGetWalletV3(t *testing.T) { var ( db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ + datastore = Datastore( + &Postgres{ Postgres: datastoreutils.Postgres{ DB: sqlx.NewDb(db, "postgres"), }, }) - roDatastore = wallet.ReadOnlyDatastore( - &wallet.Postgres{ + roDatastore = ReadOnlyDatastore( + &Postgres{ Postgres: datastoreutils.Postgres{ DB: sqlx.NewDb(db, "postgres"), }, @@ -134,7 +133,7 @@ func TestGetWalletV3(t *testing.T) { ctx = context.Background() id = uuid.NewV4() r = httptest.NewRequest("GET", fmt.Sprintf("/v3/wallet/%s", id), nil) - handler = wallet.GetWalletV3 + handler = GetWalletV3 rw = httptest.NewRecorder() rows = sqlmock.NewRows([]string{"id", "provider", "provider_id", "public_key", "provider_linking_id", "anonymous_address"}). AddRow(id, "brave", "", "12345", id, id) @@ -157,7 +156,7 @@ func TestGetWalletV3(t *testing.T) { } func TestLinkBitFlyerWalletV3(t *testing.T) { - wallet.VerifiedWalletEnable = true + VerifiedWalletEnable = true mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() @@ -182,7 +181,7 @@ func TestLinkBitFlyerWalletV3(t *testing.T) { externalAccountID := hex.EncodeToString(h.Sum(nil)) - linkingInfo := wallet.BitFlyerLinkingInfo{ + linkingInfo := BitFlyerLinkingInfo{ DepositID: idTo.String(), RequestID: "1", AccountHash: accountHash.String(), @@ -208,7 +207,7 @@ func TestLinkBitFlyerWalletV3(t *testing.T) { ) mockReputation = mockreputation.NewMockClient(mockCtrl) s, mock = initSvcWithMockDB(t) - handler = wallet.LinkBitFlyerDepositAccountV3(s) + handler = LinkBitFlyerDepositAccountV3(s) rw = httptest.NewRecorder() ) @@ -218,7 +217,7 @@ func TestLinkBitFlyerWalletV3(t *testing.T) { mock.ExpectBegin() // make sure old linking id matches new one for same custodian - linkingID := uuid.NewV5(wallet.ClaimNamespace, accountHash.String()) + linkingID := uuid.NewV5(ClaimNamespace, accountHash.String()) // acquire lock for linkingID mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). @@ -238,7 +237,7 @@ func TestLinkBitFlyerWalletV3(t *testing.T) { AddRow(time.Now(), time.Now()) // insert into wallet custodian - mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "bitflyer", uuid.NewV5(wallet.ClaimNamespace, accountHash.String())).WillReturnRows(clRows) + mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "bitflyer", uuid.NewV5(ClaimNamespace, accountHash.String())).WillReturnRows(clRows) // updates the link to the wallet_custodian record in wallets mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "bitflyer", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) @@ -271,7 +270,7 @@ func TestLinkBitFlyerWalletV3(t *testing.T) { b := rw.Body.Bytes() require.Equal(t, http.StatusOK, rw.Code, string(b)) - var l wallet.LinkDepositAccountResponse + var l LinkDepositAccountResponse err = json.Unmarshal(b, &l) require.NoError(t, err) @@ -279,7 +278,7 @@ func TestLinkBitFlyerWalletV3(t *testing.T) { } func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { - wallet.VerifiedWalletEnable = true + VerifiedWalletEnable = true mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() @@ -310,7 +309,7 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { }`, linkingInfo, idTo)), ) - handler = wallet.LinkGeminiDepositAccountV3(s) + handler = LinkGeminiDepositAccountV3(s) rw = httptest.NewRecorder() ) @@ -359,7 +358,7 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { mock.ExpectBegin() // make sure old linking id matches new one for same custodian - linkingID := uuid.NewV5(wallet.ClaimNamespace, idTo.String()) + linkingID := uuid.NewV5(ClaimNamespace, idTo.String()) // acquire lock for linkingID mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). @@ -391,7 +390,7 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { AddRow(time.Now(), time.Now()) // insert into wallet custodian - mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(wallet.ClaimNamespace, accountID.String())).WillReturnRows(clRows) + mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(ClaimNamespace, accountID.String())).WillReturnRows(clRows) // updates the link to the wallet_custodian record in wallets mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "gemini", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) @@ -410,7 +409,7 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { b := rw.Body.Bytes() require.Equal(t, http.StatusOK, rw.Code, string(b)) - var l wallet.LinkDepositAccountResponse + var l LinkDepositAccountResponse err := json.Unmarshal(b, &l) require.NoError(t, err) @@ -421,7 +420,7 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { "DELETE", fmt.Sprintf("/v3/wallet/gemini/%s/claim", idFrom), nil) - handler = wallet.DisconnectCustodianLinkV3(s) + handler = DisconnectCustodianLinkV3(s) rw = httptest.NewRecorder() // create transaction @@ -443,7 +442,7 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { router.ServeHTTP(rw, r) if resp := rw.Result(); resp.StatusCode != http.StatusOK { - must(t, "invalid response", fmt.Errorf("expected %d, got %d", http.StatusOK, resp.StatusCode)) + t.Errorf("invalid response expected %d, got %d", http.StatusOK, resp.StatusCode) } // ban the country now @@ -469,7 +468,7 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { // perform again, make sure we check haslinkedprio hasPriorRows := sqlmock.NewRows([]string{"result"}). AddRow(true) - mock.ExpectQuery("^select exists(select 1 from wallet_custodian (.+)").WithArgs(uuid.NewV5(wallet.ClaimNamespace, accountID.String()), idFrom).WillReturnRows(hasPriorRows) + mock.ExpectQuery("^select exists(select 1 from wallet_custodian (.+)").WithArgs(uuid.NewV5(ClaimNamespace, accountID.String()), idFrom).WillReturnRows(hasPriorRows) max = sqlmock.NewRows([]string{"max"}).AddRow(4) open = sqlmock.NewRows([]string{"used"}).AddRow(0) @@ -492,7 +491,7 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { AddRow(time.Now(), time.Now()) // insert into wallet custodian - mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(wallet.ClaimNamespace, accountID.String())).WillReturnRows(clRows) + mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(ClaimNamespace, accountID.String())).WillReturnRows(clRows) // updates the link to the wallet_custodian record in wallets mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "gemini", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) @@ -511,7 +510,7 @@ func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { } func TestLinkGeminiWalletV3FirstLinking(t *testing.T) { - wallet.VerifiedWalletEnable = true + VerifiedWalletEnable = true mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() @@ -542,7 +541,7 @@ func TestLinkGeminiWalletV3FirstLinking(t *testing.T) { }`, linkingInfo, idTo)), ) - handler = wallet.LinkGeminiDepositAccountV3(s) + handler = LinkGeminiDepositAccountV3(s) rw = httptest.NewRecorder() ) @@ -582,7 +581,7 @@ func TestLinkGeminiWalletV3FirstLinking(t *testing.T) { mock.ExpectBegin() // make sure old linking id matches new one for same custodian - linkingID := uuid.NewV5(wallet.ClaimNamespace, idTo.String()) + linkingID := uuid.NewV5(ClaimNamespace, idTo.String()) // acquire lock for linkingID mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). @@ -614,7 +613,7 @@ func TestLinkGeminiWalletV3FirstLinking(t *testing.T) { AddRow(time.Now(), time.Now()) // insert into wallet custodian - mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(wallet.ClaimNamespace, accountID.String())).WillReturnRows(clRows) + mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(ClaimNamespace, accountID.String())).WillReturnRows(clRows) // updates the link to the wallet_custodian record in wallets mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "gemini", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) @@ -633,7 +632,7 @@ func TestLinkGeminiWalletV3FirstLinking(t *testing.T) { b := rw.Body.Bytes() require.Equal(t, http.StatusOK, rw.Code, string(b)) - var l wallet.LinkDepositAccountResponse + var l LinkDepositAccountResponse err := json.Unmarshal(b, &l) require.NoError(t, err) } @@ -657,7 +656,7 @@ func TestLinkZebPayWalletV3_InvalidKyc(t *testing.T) { accountID = uuid.NewV4() idTo = accountID s, _ = initSvcWithMockDB(t) - handler = wallet.LinkZebPayDepositAccountV3(s) + handler = LinkZebPayDepositAccountV3(s) rw = httptest.NewRecorder() ) @@ -694,13 +693,13 @@ func TestLinkZebPayWalletV3_InvalidKyc(t *testing.T) { b := rw.Body.Bytes() require.Equal(t, http.StatusForbidden, rw.Code, string(b)) - var l wallet.LinkDepositAccountResponse + var l LinkDepositAccountResponse err = json.Unmarshal(b, &l) require.NoError(t, err) } func TestLinkZebPayWalletV3(t *testing.T) { - wallet.VerifiedWalletEnable = true + VerifiedWalletEnable = true mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() @@ -722,7 +721,7 @@ func TestLinkZebPayWalletV3(t *testing.T) { s, mock = initSvcWithMockDB(t) - handler = wallet.LinkZebPayDepositAccountV3(s) + handler = LinkZebPayDepositAccountV3(s) rw = httptest.NewRecorder() ) @@ -762,7 +761,7 @@ func TestLinkZebPayWalletV3(t *testing.T) { mock.ExpectBegin() // make sure old linking id matches new one for same custodian - linkingID := uuid.NewV5(wallet.ClaimNamespace, idTo.String()) + linkingID := uuid.NewV5(ClaimNamespace, idTo.String()) var linkingIDRows = sqlmock.NewRows([]string{"linking_id"}).AddRow(linkingID) // acquire lock for linkingID @@ -782,7 +781,7 @@ func TestLinkZebPayWalletV3(t *testing.T) { AddRow(time.Now(), time.Now()) // insert into wallet custodian - mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "zebpay", uuid.NewV5(wallet.ClaimNamespace, accountID.String())).WillReturnRows(clRows) + mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "zebpay", uuid.NewV5(ClaimNamespace, accountID.String())).WillReturnRows(clRows) // updates the link to the wallet_custodian record in wallets mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "zebpay", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) @@ -801,7 +800,7 @@ func TestLinkZebPayWalletV3(t *testing.T) { b := rw.Body.Bytes() require.Equal(t, http.StatusOK, rw.Code, string(b)) - var l wallet.LinkDepositAccountResponse + var l LinkDepositAccountResponse err = json.Unmarshal(b, &l) require.NoError(t, err) @@ -809,7 +808,7 @@ func TestLinkZebPayWalletV3(t *testing.T) { } func TestLinkGeminiWalletV3(t *testing.T) { - wallet.VerifiedWalletEnable = true + VerifiedWalletEnable = true mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() @@ -839,7 +838,7 @@ func TestLinkGeminiWalletV3(t *testing.T) { ) s, mock = initSvcWithMockDB(t) - handler = wallet.LinkGeminiDepositAccountV3(s) + handler = LinkGeminiDepositAccountV3(s) rw = httptest.NewRecorder() ) @@ -879,7 +878,7 @@ func TestLinkGeminiWalletV3(t *testing.T) { mock.ExpectBegin() // make sure old linking id matches new one for same custodian - linkingID := uuid.NewV5(wallet.ClaimNamespace, idTo.String()) + linkingID := uuid.NewV5(ClaimNamespace, idTo.String()) var linkingIDRows = sqlmock.NewRows([]string{"linking_id"}).AddRow(linkingID) // acquire lock for linkingID @@ -899,7 +898,7 @@ func TestLinkGeminiWalletV3(t *testing.T) { AddRow(time.Now(), time.Now()) // insert into wallet custodian - mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(wallet.ClaimNamespace, accountID.String())).WillReturnRows(clRows) + mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(ClaimNamespace, accountID.String())).WillReturnRows(clRows) // updates the link to the wallet_custodian record in wallets mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "gemini", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) @@ -918,7 +917,7 @@ func TestLinkGeminiWalletV3(t *testing.T) { b := rw.Body.Bytes() require.Equal(t, http.StatusOK, rw.Code, string(b)) - var l wallet.LinkDepositAccountResponse + var l LinkDepositAccountResponse err := json.Unmarshal(b, &l) require.NoError(t, err) @@ -926,7 +925,7 @@ func TestLinkGeminiWalletV3(t *testing.T) { } func TestDisconnectCustodianLinkV3(t *testing.T) { - wallet.VerifiedWalletEnable = true + VerifiedWalletEnable = true mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() @@ -943,7 +942,7 @@ func TestDisconnectCustodianLinkV3(t *testing.T) { s, mock = initSvcWithMockDB(t) - handler = wallet.DisconnectCustodianLinkV3(s) + handler = DisconnectCustodianLinkV3(s) w = httptest.NewRecorder() ) @@ -968,13 +967,55 @@ func TestDisconnectCustodianLinkV3(t *testing.T) { router.ServeHTTP(w, r) if resp := w.Result(); resp.StatusCode != http.StatusOK { - must(t, "invalid response", fmt.Errorf("expected %d, got %d", http.StatusOK, resp.StatusCode)) + t.Errorf("invalid response expected %d, got %d", http.StatusOK, resp.StatusCode) } } -func must(t *testing.T, msg string, err error) { - if err != nil { - t.Errorf("%s: %s\n", msg, err) +func TestIsAllowedOrigin(t *testing.T) { + type tcGiven struct { + origin string + allowedOrigins []string + } + + type testCase struct { + name string + given tcGiven + exp bool + } + + tests := []testCase{ + { + name: "allowed", + given: tcGiven{ + origin: "test", + allowedOrigins: []string{"random-1", "random-2", "random-3", "test"}, + }, + exp: true, + }, + { + name: "empty_origin", + given: tcGiven{ + allowedOrigins: []string{"random-1", "random-2", "random-3", "test"}, + }, + exp: false, + }, + { + name: "origin_not_in_allowed_origins", + given: tcGiven{ + origin: "test", + allowedOrigins: []string{"random-1", "random-2", "random-3", "random-4"}, + }, + exp: false, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := isAllowedOrigin(tc.given.origin, tc.given.allowedOrigins) + assert.Equal(t, tc.exp, actual) + }) } } @@ -998,10 +1039,10 @@ func mockSQLCustodianLink(mock sqlmock.Sqlmock, custodian string) { WillReturnRows(clRow) } -func initSvcWithMockDB(t *testing.T) (*wallet.Service, sqlmock.Sqlmock) { +func initSvcWithMockDB(t *testing.T) (*Service, sqlmock.Sqlmock) { db, mock, _ := sqlmock.New() - datastore := wallet.Datastore( - &wallet.Postgres{ + datastore := Datastore( + &Postgres{ Postgres: datastoreutils.Postgres{ DB: sqlx.NewDb(db, "postgres"), }, @@ -1015,9 +1056,9 @@ func initSvcWithMockDB(t *testing.T) (*wallet.Service, sqlmock.Sqlmock) { }, } - dappConf := wallet.DAppConfig{} + dappConf := DAppConfig{} - s, err := wallet.InitService(datastore, nil, nil, nil, nil, nil, nil, nil, mtc, gem, dappConf) + s, err := InitService(datastore, nil, nil, nil, nil, nil, nil, nil, mtc, gem, dappConf) require.NoError(t, err) return s, mock diff --git a/services/wallet/controller_v3_test.go b/services/wallet/controllers_v3_test.go similarity index 98% rename from services/wallet/controller_v3_test.go rename to services/wallet/controllers_v3_test.go index aae26fb42..3b2c21457 100644 --- a/services/wallet/controller_v3_test.go +++ b/services/wallet/controllers_v3_test.go @@ -439,7 +439,7 @@ func (suite *WalletControllersTestSuite) TestChallenges_Success() { chlRep := storage.NewChallenge() dac := wallet.DAppConfig{ - AllowedOrigin: "https://my-dapp.com", + AllowedOrigins: []string{"https://my-dapp.com", "https://my-dapp-2.com"}, } s, err := wallet.InitService(pg, nil, chlRep, nil, nil, nil, nil, nil, nil, nil, dac) @@ -497,7 +497,7 @@ func (suite *WalletControllersTestSuite) TestLinkSolanaAddress_Success() { suite.Require().NoError(err) dac := wallet.DAppConfig{ - AllowedOrigin: "https://my-dapp.com", + AllowedOrigins: []string{"https://my-dapp.com", "https://my-dapp-2.com"}, } s, err := wallet.InitService(pg, nil, chlRep, allowList, nil, nil, nil, nil, nil, nil, dac) @@ -521,14 +521,14 @@ func (suite *WalletControllersTestSuite) TestLinkSolanaAddress_Success() { suite.Require().NoError(err) r := httptest.NewRequest(http.MethodPost, "/v3/wallet/solana/"+w.ID+"/connect", bytes.NewBuffer(b)) - r.Header.Set("origin", "https://my-dapp.com") + r.Header.Set("origin", "https://my-dapp-2.com") rw := httptest.NewRecorder() svr := &http.Server{Addr: ":8080", Handler: setupRouter(s)} svr.Handler.ServeHTTP(rw, r) suite.Require().Equal(http.StatusOK, rw.Code) - suite.Require().Equal("https://my-dapp.com", rw.Header().Get("Access-Control-Allow-Origin")) + suite.Require().Equal("https://my-dapp-2.com", rw.Header().Get("Access-Control-Allow-Origin")) // assert actual, err := pg.GetWallet(context.TODO(), paymentID) @@ -685,7 +685,7 @@ func setupRouter(service *wallet.Service) *chi.Mux { mw := func(name string, h http.Handler) http.Handler { return h } - s := "https://my-dapp.com" + s := []string{"https://my-dapp.com", "https://my-dapp-2.com"} r := chi.NewRouter() r.Mount("/v3", wallet.RegisterRoutes(context.TODO(), service, r, mw, wallet.NewDAppCorsMw(s))) return r diff --git a/services/wallet/controllers_v4_test.go b/services/wallet/controllers_v4_test.go index f47eee140..900968a28 100644 --- a/services/wallet/controllers_v4_test.go +++ b/services/wallet/controllers_v4_test.go @@ -627,3 +627,11 @@ func noOpMw() func(next http.Handler) http.Handler { }) } } + +func signRequest(req *http.Request, publicKey httpsignature.Ed25519PubKey, privateKey ed25519.PrivateKey) error { + var s httpsignature.SignatureParams + s.Algorithm = httpsignature.ED25519 + s.KeyID = hex.EncodeToString(publicKey) + s.Headers = []string{"digest", "(request-target)"} + return s.Sign(privateKey, crypto.Hash(0), req) +} diff --git a/services/wallet/datastore_test.go b/services/wallet/datastore_test.go index 6cdbde3d3..93984ed41 100644 --- a/services/wallet/datastore_test.go +++ b/services/wallet/datastore_test.go @@ -58,7 +58,7 @@ func (suite *WalletPostgresTestSuite) TearDownTest() { } func (suite *WalletPostgresTestSuite) CleanDB() { - tables := []string{"claim_creds", "claims", "wallets", "issuers", "promotions"} + tables := []string{"claim_creds", "claims", "wallets", "issuers", "promotions", "verified_wallet_outbox"} pg, _, err := NewPostgres() suite.Require().NoError(err, "Failed to get postgres conn") diff --git a/services/wallet/service.go b/services/wallet/service.go index aea81c473..08e54bf43 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -149,7 +149,7 @@ type Service struct { } type DAppConfig struct { - AllowedOrigin string + AllowedOrigins []string } // InitService creates a new instances of the wallet service. @@ -276,13 +276,13 @@ func SetupService(ctx context.Context) (context.Context, *Service) { mtc := metric.New() gemx := newGeminix("passport", "drivers_license", "national_identity_card", "passport_card") - dappAO := os.Getenv("DAPP_ALLOWED_CORS_ORIGINS") - if dappAO == "" { + dappAO := strings.Split(os.Getenv("DAPP_ALLOWED_CORS_ORIGINS"), ",") + if len(dappAO) == 0 { l.Panic().Err(errors.New("dapp allowed origins missing")).Msg("failed to initialize wallet service") } dappConf := DAppConfig{ - AllowedOrigin: dappAO, + AllowedOrigins: dappAO, } s, err := InitService(db, roDB, chlRepo, alRepo, repClient, geminiClient, geoCountryValidator, backoff.Retry, mtc, gemx, dappConf) @@ -398,10 +398,10 @@ func RegisterRoutes(ctx context.Context, s *Service, r *chi.Mux, metricsMw middl // TODO(clD11): WR. Move once we address the rest_run.go and grant.go start functions. -func NewDAppCorsMw(origin string) func(next http.Handler) http.Handler { +func NewDAppCorsMw(origins []string) func(next http.Handler) http.Handler { opts := cors.Options{ Debug: false, - AllowedOrigins: []string{origin}, + AllowedOrigins: origins, AllowedHeaders: []string{"Accept", "Content-Type"}, ExposedHeaders: []string{""}, AllowedMethods: []string{http.MethodPost}, From 418c59ac867c2f6c5b0a59b226c6f84e70631825 Mon Sep 17 00:00:00 2001 From: Pavel Brm <5097196+pavelbrm@users.noreply.github.com> Date: Thu, 18 Jan 2024 14:37:53 +1300 Subject: [PATCH 11/18] fix: use proper handler method in local env (#2304) --- services/grant/cmd/grant.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/services/grant/cmd/grant.go b/services/grant/cmd/grant.go index f0fdb45c2..dba2dbbd6 100644 --- a/services/grant/cmd/grant.go +++ b/services/grant/cmd/grant.go @@ -497,7 +497,7 @@ func setupRouter(ctx context.Context, logger *zerolog.Logger) (context.Context, "/", middleware.InstrumentHandler( "CreateOrderNew", - corsMwrPost(handlers.AppHandler(orderh.Create)), + corsMwrPost(handlers.AppHandler(orderh.CreateNew)), ), ) } else { From d0f21ddd7baebc07c956b72233a098adf8976764 Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Thu, 18 Jan 2024 08:49:16 +0000 Subject: [PATCH 12/18] feat: add option endpoints for challenge and connect solana (#2302) --- services/wallet/controllers_v3_test.go | 40 ++++++++++++++++++++++++++ services/wallet/service.go | 8 ++++++ 2 files changed, 48 insertions(+) diff --git a/services/wallet/controllers_v3_test.go b/services/wallet/controllers_v3_test.go index 3b2c21457..d9414458d 100644 --- a/services/wallet/controllers_v3_test.go +++ b/services/wallet/controllers_v3_test.go @@ -465,6 +465,25 @@ func (suite *WalletControllersTestSuite) TestChallenges_Success() { suite.Assert().Equal(chl.Nonce, resp.Nonce) } +func (suite *WalletControllersTestSuite) TestChallenges_Options() { + req := httptest.NewRequest(http.MethodOptions, "/v3/wallet/challenges", nil) + req.Header.Add("Access-Control-Request-Method", http.MethodPost) + req.Header.Add("Access-Control-Request-Headers", "Content-Type") + req.Header.Set("origin", "https://my-dapp.com") + + rw := httptest.NewRecorder() + + s := wallet.Service{} + + svr := &http.Server{Addr: ":8080", Handler: setupRouter(&s)} + svr.Handler.ServeHTTP(rw, req) + + suite.Require().Equal(http.StatusOK, rw.Code) + suite.Require().Equal("https://my-dapp.com", rw.Header().Get("Access-Control-Allow-Origin")) + suite.Require().Equal(http.MethodPost, rw.Header().Get("Access-Control-Allow-Methods")) + suite.Require().Equal("Content-Type", rw.Header().Get("Access-Control-Allow-Headers")) +} + func (suite *WalletControllersTestSuite) TestLinkSolanaAddress_Success() { viper.Set("enable-link-drain-flag", "true") @@ -541,6 +560,27 @@ func (suite *WalletControllersTestSuite) TestLinkSolanaAddress_Success() { suite.Assert().ErrorIs(actualErr, model.ErrChallengeNotFound) } +func (suite *WalletControllersTestSuite) TestLinkSolanaAddress_Options() { + viper.Set("enable-link-drain-flag", "true") + + req := httptest.NewRequest(http.MethodOptions, "/v3/wallet/solana/ae51dce3-08e9-4beb-8a70-c51d064bb7d1/connect", nil) + req.Header.Add("Access-Control-Request-Method", http.MethodPost) + req.Header.Add("Access-Control-Request-Headers", "Content-Type") + req.Header.Set("origin", "https://my-dapp.com") + + rw := httptest.NewRecorder() + + s := wallet.Service{} + + svr := &http.Server{Addr: ":8080", Handler: setupRouter(&s)} + svr.Handler.ServeHTTP(rw, req) + + suite.Require().Equal(http.StatusOK, rw.Code) + suite.Require().Equal("https://my-dapp.com", rw.Header().Get("Access-Control-Allow-Origin")) + suite.Require().Equal(http.MethodPost, rw.Header().Get("Access-Control-Allow-Methods")) + suite.Require().Equal("Content-Type", rw.Header().Get("Access-Control-Allow-Headers")) +} + func whitelistWallet(t *testing.T, pg wallet.Datastore, Id string) { const q = `insert into allow_list (payment_id, created_at) values($1, $2)` _, err := pg.RawDB().Exec(q, Id, time.Now()) diff --git a/services/wallet/service.go b/services/wallet/service.go index 08e54bf43..627c3eef5 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -365,6 +365,7 @@ func RegisterRoutes(ctx context.Context, s *Service, r *chi.Mux, metricsMw middl "LinkZebPayDepositAccount", LinkZebPayDepositAccountV3(s))).ServeHTTP) r.Method(http.MethodPost, "/solana/{paymentID}/connect", metricsMw("LinkSolanaAddress", dAppCorsMw(LinkSolanaAddress(s)))) + r.Method(http.MethodOptions, "/solana/{paymentID}/connect", metricsMw("LinkSolanaAddressOptions", dAppCorsMw(noOpHandler()))) } r.Get("/linking-info", middleware.SimpleTokenAuthorizedOnly(middleware.InstrumentHandlerFunc("GetLinkingInfo", GetLinkingInfoV3(s))).ServeHTTP) @@ -377,6 +378,7 @@ func RegisterRoutes(ctx context.Context, s *Service, r *chi.Mux, metricsMw middl r.Get("/uphold/{paymentID}", middleware.InstrumentHandlerFunc("GetUpholdWalletBalance", GetUpholdWalletBalanceV3)) r.Post("/challenges", middleware.RateLimiter(ctx, 2)(metricsMw("CreateChallenge", dAppCorsMw(CreateChallenge(s)))).ServeHTTP) + r.Options("/challenges", middleware.RateLimiter(ctx, 2)(metricsMw("CreateChallengeOptions", dAppCorsMw(noOpHandler()))).ServeHTTP) }) r.Route("/v4/wallets", func(r chi.Router) { @@ -411,6 +413,12 @@ func NewDAppCorsMw(origins []string) func(next http.Handler) http.Handler { return cors.Handler(opts) } +func noOpHandler() http.Handler { + return http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + return + }) +} + // SubmitAnonCardTransaction validates and submits a transaction on behalf of an anonymous card func (service *Service) SubmitAnonCardTransaction( ctx context.Context, From d83f0e57afc34be0861957ba8994888b2b5bfad0 Mon Sep 17 00:00:00 2001 From: eV <8796196+evq@users.noreply.github.com> Date: Thu, 18 Jan 2024 23:21:45 +0000 Subject: [PATCH 13/18] fix: fix legacy creds by item id (#2308) * fix legacy creds by item id * try fix lint * fix: tidy up * refactor: return 500 instead of 404 on type assertion failure --------- Co-authored-by: PavelBrm --- services/skus/controllers.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/services/skus/controllers.go b/services/skus/controllers.go index 9c564fe9c..ea9273ff3 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -742,7 +742,8 @@ func getOrderCredsByID(svc *Service, legacyMode bool) handlers.AppHandler { reqID = *reqIDRaw.UUID() } - creds, status, err := svc.GetItemCredentials(ctx, *orderID.UUID(), *itemID.UUID(), reqID) + itemIDv := *itemID.UUID() + creds, status, err := svc.GetItemCredentials(ctx, *orderID.UUID(), itemIDv, reqID) if err != nil { if !errors.Is(err, errSetRetryAfter) { return handlers.WrapError(err, "Error getting credentials", status) @@ -757,6 +758,21 @@ func getOrderCredsByID(svc *Service, legacyMode bool) handlers.AppHandler { w.Header().Set("Retry-After", strconv.FormatInt(avg, 10)) } + if legacyMode { + suCreds, ok := creds.([]OrderCreds) + if !ok { + return handlers.WrapError(err, "Error getting credentials", http.StatusInternalServerError) + } + + for i := range suCreds { + if uuid.Equal(suCreds[i].ID, itemIDv) { + return handlers.RenderContent(ctx, suCreds[i], w, status) + } + } + + return handlers.WrapError(err, "Error getting credentials", http.StatusNotFound) + } + if creds == nil { return handlers.RenderContent(ctx, map[string]interface{}{}, w, status) } From a935a0c90bda7e2433d2735958fdf267118d5760 Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Fri, 19 Jan 2024 08:53:41 +0000 Subject: [PATCH 14/18] feat: remove expired challenges from the database (#2307) * feat: remove expired challenges from the database * refactor: use distinct loop variables for readability in test challenge --- services/wallet/service.go | 24 +++++ services/wallet/storage/storage.go | 23 +++++ services/wallet/storage/storage_test.go | 112 ++++++++++++++++++++++++ 3 files changed, 159 insertions(+) diff --git a/services/wallet/service.go b/services/wallet/service.go index 627c3eef5..992e72730 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -295,12 +295,19 @@ func SetupService(ctx context.Context) (context.Context, *Service) { l.Error().Err(err).Msg("failed to initialize custodian regions") } + decJob := deleteExpiredChallengeTask{exec: db.RawDB(), deleter: chlRepo, deleteAfterMin: 5} + s.jobs = []srv.Job{ { Func: s.RefreshCustodianRegionsWorker, Cadence: 15 * time.Minute, Workers: 1, }, + { + Func: decJob.deleteExpiredChallenges, + Cadence: 10 * time.Minute, + Workers: 1, + }, } if VerifiedWalletEnable { @@ -1050,3 +1057,20 @@ func parseZebPayClaims(ctx context.Context, verificationToken string) (claimsZP, return claims, nil } + +type deleter interface { + DeleteAfter(ctx context.Context, dbi sqlx.ExecerContext, interval time.Duration) error +} + +type deleteExpiredChallengeTask struct { + exec sqlx.ExecerContext + deleter deleter + deleteAfterMin time.Duration +} + +func (d *deleteExpiredChallengeTask) deleteExpiredChallenges(ctx context.Context) (bool, error) { + if err := d.deleter.DeleteAfter(ctx, d.exec, d.deleteAfterMin); err != nil && !errors.Is(err, model.ErrNoRowsDeleted) { + return false, fmt.Errorf("error deleting expired challenges: %w", err) + } + return true, nil +} diff --git a/services/wallet/storage/storage.go b/services/wallet/storage/storage.go index c223e3752..dc6353a83 100644 --- a/services/wallet/storage/storage.go +++ b/services/wallet/storage/storage.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "time" "github.com/brave-intl/bat-go/services/wallet/model" "github.com/jmoiron/sqlx" @@ -71,6 +72,28 @@ func (c *Challenge) Delete(ctx context.Context, dbi sqlx.ExecerContext, paymentI return nil } +// DeleteAfter removes model.Challenge's from the database where their created at plus the specified interval it +// less than the current time. The interval should be specified in minutes. +func (c *Challenge) DeleteAfter(ctx context.Context, dbi sqlx.ExecerContext, interval time.Duration) error { + const q = `delete from challenge where created_at + interval '1 min' * $1 < now()` + + result, err := dbi.ExecContext(ctx, q, interval) + if err != nil { + return err + } + + row, err := result.RowsAffected() + if err != nil { + return err + } + + if row == 0 { + return model.ErrNoRowsDeleted + } + + return nil +} + type AllowList struct{} func NewAllowList() *AllowList { return &AllowList{} } diff --git a/services/wallet/storage/storage_test.go b/services/wallet/storage/storage_test.go index d69e5473f..270579f4a 100644 --- a/services/wallet/storage/storage_test.go +++ b/services/wallet/storage/storage_test.go @@ -217,6 +217,118 @@ func TestChallenge_Delete(t *testing.T) { } } +func TestChallenge_DeleteAfter(t *testing.T) { + dbi, err := setupDBI() + must.NoError(t, err) + + defer func() { + _, _ = dbi.Exec("TRUNCATE_TABLE challenge;") + }() + + type tcGiven struct { + interval time.Duration + challenges []model.Challenge + } + + type exp struct { + errDel error + errGet error + } + + type testCase struct { + name string + given tcGiven + exp exp + } + + tests := []testCase{ + { + name: "delete_single", + given: tcGiven{ + interval: 1, + challenges: []model.Challenge{ + { + PaymentID: uuid.NewV4(), + CreatedAt: time.Now().Add(-6 * time.Minute), + Nonce: "nonce-1", + }, + }, + }, + exp: exp{ + errDel: nil, + errGet: model.ErrChallengeNotFound, + }, + }, + { + name: "delete_multiple", + given: tcGiven{ + interval: 1, + challenges: []model.Challenge{ + { + PaymentID: uuid.NewV4(), + CreatedAt: time.Now().Add(-6 * time.Minute), + Nonce: "nonce-1", + }, + { + PaymentID: uuid.NewV4(), + CreatedAt: time.Now().Add(-10 * time.Minute), + Nonce: "nonce-2", + }, + }, + }, + exp: exp{ + errDel: nil, + errGet: model.ErrChallengeNotFound, + }, + }, + { + name: "delete_none", + given: tcGiven{ + interval: 1, + challenges: []model.Challenge{ + { + PaymentID: uuid.NewV4(), + CreatedAt: time.Now(), + Nonce: "nonce-1", + }, + { + PaymentID: uuid.NewV4(), + CreatedAt: time.Now(), + Nonce: "nonce-2", + }, + }, + }, + exp: exp{ + errDel: model.ErrNoRowsDeleted, + errGet: nil, + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + c := Challenge{} + + for j := range tc.given.challenges { + err := c.Upsert(ctx, dbi, tc.given.challenges[j]) + must.NoError(t, err) + } + + err := c.DeleteAfter(ctx, dbi, tc.given.interval) + must.Equal(t, tc.exp.errDel, err) + + for j := range tc.given.challenges { + _, err := c.Get(ctx, dbi, tc.given.challenges[j].PaymentID) + must.Equal(t, tc.exp.errGet, err) + } + }) + } +} + func TestAllowList_GetAllowListEntry(t *testing.T) { dbi, err := setupDBI() must.NoError(t, err) From 83b8cd22c5a9b58d9b3df5bc31486d68c5ff619b Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Mon, 22 Jan 2024 21:13:15 +0000 Subject: [PATCH 15/18] feat: increase rate limit on get v4 wallets (#2311) --- services/wallet/service.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/services/wallet/service.go b/services/wallet/service.go index 992e72730..5741d1a25 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -389,17 +389,17 @@ func RegisterRoutes(ctx context.Context, s *Service, r *chi.Mux, metricsMw middl }) r.Route("/v4/wallets", func(r chi.Router) { - r.Use(middleware.RateLimiter(ctx, 2)) - r.Post("/", middleware.InstrumentHandlerFunc("CreateWalletV4", CreateWalletV4(s))) - r.Patch("/{paymentID}", middleware.HTTPSignedOnly(s)(middleware.InstrumentHandlerFunc( - "UpdateWalletV4", UpdateWalletV4(s))).ServeHTTP) - r.Get("/{paymentID}", - middleware.HTTPSignedOnly(s)(middleware.InstrumentHandlerFunc( - "GetWalletV4", GetWalletV4(s))).ServeHTTP) - // get wallet balance routes - r.Get("/uphold/{paymentID}", - middleware.HTTPSignedOnly(s)(middleware.InstrumentHandlerFunc( - "GetUpholdWalletBalanceV4", GetUpholdWalletBalanceV4)).ServeHTTP) + r.Post("/", middleware.RateLimiter(ctx, 2)( + middleware.InstrumentHandlerFunc("CreateWalletV4", CreateWalletV4(s))).ServeHTTP) + + r.Patch("/{paymentID}", middleware.RateLimiter(ctx, 2)(middleware.HTTPSignedOnly(s)( + middleware.InstrumentHandlerFunc("UpdateWalletV4", UpdateWalletV4(s)))).ServeHTTP) + + r.Get("/{paymentID}", middleware.RateLimiter(ctx, 7)(middleware.HTTPSignedOnly(s)( + middleware.InstrumentHandlerFunc("GetWalletV4", GetWalletV4(s)))).ServeHTTP) + + r.Get("/uphold/{paymentID}", middleware.RateLimiter(ctx, 2)(middleware.HTTPSignedOnly(s)( + middleware.InstrumentHandlerFunc("GetUpholdWalletBalanceV4", GetUpholdWalletBalanceV4))).ServeHTTP) }) return r From c3de1bcfe20f2dea87e3bcbdf36b5a1455a93367 Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Mon, 22 Jan 2024 22:18:22 +0000 Subject: [PATCH 16/18] fix: add is allowed check to allow list entry (#2310) --- services/wallet/controllers_v4.go | 7 +-- services/wallet/controllers_v4_test.go | 43 ++++++++++++++++++ services/wallet/model/model.go | 4 ++ services/wallet/model/model_test.go | 63 ++++++++++++++++++++++++++ 4 files changed, 114 insertions(+), 3 deletions(-) diff --git a/services/wallet/controllers_v4.go b/services/wallet/controllers_v4.go index e4a27b688..af6fbd89d 100644 --- a/services/wallet/controllers_v4.go +++ b/services/wallet/controllers_v4.go @@ -201,13 +201,14 @@ func GetWalletV4(s *Service) func(w http.ResponseWriter, r *http.Request) *handl return handlers.WrapError(err, "no such wallet", http.StatusNotFound) } - if _, err := s.allowListRepo.GetAllowListEntry(ctx, s.Datastore.RawDB(), paymentID); err != nil && !errors.Is(err, model.ErrNotFound) { + allow, err := s.allowListRepo.GetAllowListEntry(ctx, s.Datastore.RawDB(), paymentID) + if err != nil && !errors.Is(err, model.ErrNotFound) { return handlers.WrapError(err, "error getting allow list entry from storage", http.StatusInternalServerError) } - solSelfCustody := !errors.Is(err, model.ErrNotFound) + isSelfCustAvail := allow.IsAllowed(paymentID) - resp := infoToResponseV4(info, solSelfCustody) + resp := infoToResponseV4(info, isSelfCustAvail) return handlers.RenderContent(ctx, resp, w, http.StatusOK) } diff --git a/services/wallet/controllers_v4_test.go b/services/wallet/controllers_v4_test.go index 900968a28..37474abab 100644 --- a/services/wallet/controllers_v4_test.go +++ b/services/wallet/controllers_v4_test.go @@ -606,6 +606,49 @@ func (suite *WalletControllersTestSuite) TestGetWalletV4() { suite.Assert().Equal(true, resp.SelfCustodyAvailable["solana"]) } +func (suite *WalletControllersTestSuite) TestGetWalletV4_Not_Whitelisted() { + pg, _, err := wallet.NewPostgres() + suite.Require().NoError(err) + + pub, _, err := ed25519.GenerateKey(nil) + suite.Require().NoError(err) + + paymentID := uuid.NewV4() + w := &walletutils.Info{ + ID: paymentID.String(), + Provider: "brave", + PublicKey: hex.EncodeToString(pub), + AltCurrency: ptrTo(altcurrency.BAT), + } + err = pg.InsertWallet(context.TODO(), w) + suite.Require().NoError(err) + + allowList := storage.NewAllowList() + + service, _ := wallet.InitService(pg, nil, nil, allowList, nil, nil, nil, nil, nil, nil, wallet.DAppConfig{}) + + handler := handlers.AppHandler(wallet.GetWalletV4(service)) + + req, err := http.NewRequest("GET", "/v4/wallets/"+paymentID.String(), nil) + suite.Require().NoError(err, "a request should be created") + + req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) + rctx := chi.NewRouteContext() + rctx.URLParams.Add("paymentID", paymentID.String()) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + suite.Assert().Equal(http.StatusOK, rr.Code) + + var resp wallet.ResponseV4 + err = json.Unmarshal(rr.Body.Bytes(), &resp) + suite.Require().NoError(err) + + suite.Assert().Equal(false, resp.SelfCustodyAvailable["solana"]) +} + func signUpdateRequest(req *http.Request, paymentID string, privateKey ed25519.PrivateKey) error { var s httpsignature.SignatureParams s.Algorithm = httpsignature.ED25519 diff --git a/services/wallet/model/model.go b/services/wallet/model/model.go index ea76f1170..ea3907d01 100644 --- a/services/wallet/model/model.go +++ b/services/wallet/model/model.go @@ -24,6 +24,10 @@ type AllowListEntry struct { CreatedAt time.Time `db:"created_at"` } +func (a AllowListEntry) IsAllowed(paymentID uuid.UUID) bool { + return !uuid.Equal(a.PaymentID, uuid.Nil) && uuid.Equal(a.PaymentID, paymentID) +} + type Challenge struct { PaymentID uuid.UUID `db:"payment_id"` CreatedAt time.Time `db:"created_at"` diff --git a/services/wallet/model/model_test.go b/services/wallet/model/model_test.go index 6bdf173d8..55313a6d7 100644 --- a/services/wallet/model/model_test.go +++ b/services/wallet/model/model_test.go @@ -4,9 +4,72 @@ import ( "testing" "time" + uuid "github.com/satori/go.uuid" "github.com/stretchr/testify/assert" ) +func TestAllowListEntry_IsAllowed(t *testing.T) { + type tcGiven struct { + allow AllowListEntry + paymentID uuid.UUID + } + + type exp struct { + isAllowed bool + } + + type testCase struct { + name string + given tcGiven + exp exp + } + + tests := []testCase{ + { + name: "default_allow_list_entry", + given: tcGiven{ + paymentID: uuid.FromStringOrNil("dc5a802a-87e9-47a5-9d9c-af4c8f171bf3"), + }, + exp: exp{isAllowed: false}, + }, + { + name: "payment_id_nil", + given: tcGiven{ + paymentID: uuid.Nil, + }, + exp: exp{isAllowed: false}, + }, + { + name: "payment_ids_not_equal", + given: tcGiven{ + allow: AllowListEntry{ + PaymentID: uuid.FromStringOrNil("d1359406-42f1-4364-99b7-77840e8594e8"), + }, + paymentID: uuid.FromStringOrNil("dc5a802a-87e9-47a5-9d9c-af4c8f171bf3"), + }, + exp: exp{isAllowed: false}, + }, + { + name: "payment_ids_are_equal", + given: tcGiven{ + allow: AllowListEntry{ + PaymentID: uuid.FromStringOrNil("356a634a-dbae-4f95-b276-f3f0f0a53509"), + }, + paymentID: uuid.FromStringOrNil("356a634a-dbae-4f95-b276-f3f0f0a53509"), + }, + exp: exp{isAllowed: true}, + }, + } + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := tc.given.allow.IsAllowed(tc.given.paymentID) + assert.Equal(t, tc.exp.isAllowed, actual) + }) + } +} + func TestChallenge_IsValid(t *testing.T) { type tcGiven struct { chl Challenge From 9f0b7aa487086dd4b0769f437a2ae5114a29a520 Mon Sep 17 00:00:00 2001 From: Pavel Brm <5097196+pavelbrm@users.noreply.github.com> Date: Tue, 23 Jan 2024 11:52:22 +1300 Subject: [PATCH 17/18] Refactor: Improve Receipt Validation (#2306) * refactor: refactor and improve receipt validation * fix: delete unused var * refactor: self-review --- services/go.mod | 2 +- services/skus/controllers.go | 126 +++++++++--------- services/skus/controllers_test.go | 27 +++- services/skus/receipt.go | 213 ++++++++++++++++-------------- services/skus/service.go | 45 ++++--- 5 files changed, 228 insertions(+), 185 deletions(-) diff --git a/services/go.mod b/services/go.mod index 7eabc1f32..f6200d5fc 100644 --- a/services/go.mod +++ b/services/go.mod @@ -46,6 +46,7 @@ require ( github.com/stripe/stripe-go/v72 v72.122.0 golang.org/x/crypto v0.15.0 golang.org/x/exp v0.0.0-20230223210539-50820d90acfd + google.golang.org/api v0.134.0 gopkg.in/macaroon.v2 v2.1.0 gopkg.in/square/go-jose.v2 v2.6.0 ) @@ -131,7 +132,6 @@ require ( golang.org/x/sync v0.3.0 // indirect golang.org/x/sys v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect - google.golang.org/api v0.134.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230803162519-f966b187b2e5 // indirect google.golang.org/grpc v1.58.3 // indirect diff --git a/services/skus/controllers.go b/services/skus/controllers.go index ea9273ff3..a661b7aff 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -1279,97 +1279,95 @@ func HandleStripeWebhook(service *Service) handlers.AppHandler { } } -// SubmitReceipt submit a vendor verifiable receipt that proves order is paid -func SubmitReceipt(service *Service) handlers.AppHandler { +// SubmitReceipt handles receipt submission requests. +func SubmitReceipt(svc *Service) handlers.AppHandler { return handlers.AppHandler(func(w http.ResponseWriter, r *http.Request) *handlers.AppError { + ctx := r.Context() - var ( - ctx = r.Context() - req SubmitReceiptRequestV1 // the body of the request - orderID = new(inputs.ID) // the order id - validationErrMap = map[string]interface{}{} // for tracking our validation errors - ) + l := logging.Logger(ctx, "skus").With().Str("func", "SubmitReceipt").Logger() - logger := logging.Logger(ctx, "skus").With().Str("func", "SubmitReceipt").Logger() + orderID := &inputs.ID{} + if err := inputs.DecodeAndValidateString(ctx, orderID, chi.URLParam(r, "orderID")); err != nil { + l.Warn().Err(err).Msg("failed to decode orderID") - // validate the order id - if err := inputs.DecodeAndValidateString(context.Background(), orderID, chi.URLParam(r, "orderID")); err != nil { - logger.Warn().Err(err).Msg("Failed to decode/validate order id from url") - validationErrMap["orderID"] = err.Error() + return handlers.ValidationError("Error validating request", map[string]interface{}{"orderID": err.Error()}) } - // read the payload - payload, err := requestutils.Read(r.Context(), r.Body) + payload, err := requestutils.Read(ctx, r.Body) if err != nil { - logger.Warn().Err(err).Msg("Failed to read the payload") - validationErrMap["request-body"] = err.Error() - } - - // TODO(clD11): remove when no longer needed - logger.Info().Interface("payload_byte", payload).Str("payload_str", string(payload)).Msg("payload") + l.Warn().Err(err).Msg("failed to read body") - // validate the payload - if err := inputs.DecodeAndValidate(context.Background(), &req, payload); err != nil { - logger.Debug().Str("payload", string(payload)).Msg("Failed to decode and validate the payload") - logger.Warn().Err(err).Msg("Failed to decode and validate the payload") - validationErrMap["request-body"] = err.Error() + return handlers.ValidationError("Error validating request", map[string]interface{}{"request-body": err.Error()}) } - // TODO(clD11): remove when no longer needed - logger.Info().Interface("req_decoded", req).Msg("req decoded") + // TODO(clD11): remove when no longer needed. + payloadS := string(payload) + l.Info().Interface("payload_byte", payload).Str("payload_str", payloadS).Msg("payload") - if len(validationErrMap) != 0 { - return handlers.ValidationError("Error validating request", validationErrMap) + req := SubmitReceiptRequestV1{} + if err := inputs.DecodeAndValidate(ctx, &req, payload); err != nil { + l.Debug().Str("payload", payloadS).Msg("failed to decode payload") + l.Warn().Err(err).Msg("failed to decode payload") + + return handlers.ValidationError("Error validating request", map[string]interface{}{"request-body": err.Error()}) } - // validate the receipt - externalID, err := service.validateReceipt(ctx, orderID.UUID(), req) + // TODO(clD11): remove when no longer needed. + l.Info().Interface("req_decoded", req).Msg("req decoded") + + extID, err := svc.validateReceipt(ctx, req) if err != nil { if errors.Is(err, errNotFound) { return handlers.WrapError(err, "order not found", http.StatusNotFound) } - logger.Warn().Err(err).Msg("Failed to validate the receipt with vendor") - validationErrMap["receiptErrors"] = err.Error() - // return codified errors for application - if errors.Is(err, errPurchaseFailed) { - return handlers.CodedValidationError(err.Error(), purchaseFailedErrCode, validationErrMap) - } else if errors.Is(err, errPurchasePending) { - return handlers.CodedValidationError(err.Error(), purchasePendingErrCode, validationErrMap) - } else if errors.Is(err, errPurchaseDeferred) { - return handlers.CodedValidationError(err.Error(), purchaseDeferredErrCode, validationErrMap) - } else if errors.Is(err, errPurchaseStatusUnknown) { - return handlers.CodedValidationError(err.Error(), purchaseStatusUnknownErrCode, validationErrMap) - } else { - // unknown error - return handlers.CodedValidationError("error validating receipt", purchaseValidationErrCode, validationErrMap) + + l.Warn().Err(err).Msg("failed to validate receipt with vendor") + + errStr := err.Error() + verrs := map[string]interface{}{"receiptErrors": errStr} + + switch { + case errors.Is(err, errPurchaseFailed): + return handlers.CodedValidationError(errStr, purchaseFailedErrCode, verrs) + case errors.Is(err, errPurchasePending): + return handlers.CodedValidationError(errStr, purchasePendingErrCode, verrs) + case errors.Is(err, errPurchaseDeferred): + return handlers.CodedValidationError(errStr, purchaseDeferredErrCode, verrs) + case errors.Is(err, errPurchaseStatusUnknown): + return handlers.CodedValidationError(errStr, purchaseStatusUnknownErrCode, verrs) + default: + return handlers.CodedValidationError("error validating receipt", purchaseValidationErrCode, verrs) } } - // does this external id exist already - exists, err := service.ExternalIDExists(ctx, externalID) - if err != nil { - logger.Warn().Err(err).Msg("failed to lookup external id existance") - return handlers.WrapError(err, "failed to lookup external id", http.StatusInternalServerError) + { + exists, err := svc.ExternalIDExists(ctx, extID) + if err != nil { + l.Warn().Err(err).Msg("failed to lookup external id") + + return handlers.WrapError(err, "failed to lookup external id", http.StatusInternalServerError) + } + + if exists { + return handlers.WrapError(err, "receipt has already been submitted", http.StatusBadRequest) + } } - if exists { - return handlers.WrapError(err, "receipt has already been submitted", http.StatusBadRequest) + vnd := req.Type.String() + mdata := datastore.Metadata{ + "vendor": vnd, + "externalID": extID, + paymentProcessor: vnd, } - // set order paid and include the vendor and external id to metadata - if err := service.UpdateOrderStatusPaidWithMetadata(ctx, orderID.UUID(), datastore.Metadata{ - "vendor": req.Type.String(), - "externalID": externalID, - paymentProcessor: req.Type.String(), - }); err != nil { - logger.Warn().Err(err).Msg("Failed to update the order with appropriate metadata") + if err := svc.UpdateOrderStatusPaidWithMetadata(ctx, orderID.UUID(), mdata); err != nil { + l.Warn().Err(err).Msg("failed to update order with vendor metadata") return handlers.WrapError(err, "failed to store status of order", http.StatusInternalServerError) } - return handlers.RenderContent(r.Context(), SubmitReceiptResponseV1{ - ExternalID: externalID, - Vendor: req.Type.String(), - }, w, http.StatusOK) + result := SubmitReceiptResponseV1{ExternalID: extID, Vendor: vnd} + + return handlers.RenderContent(ctx, result, w, http.StatusOK) }) } diff --git a/services/skus/controllers_test.go b/services/skus/controllers_test.go index 1cb5fa7a3..eb37ebfaa 100644 --- a/services/skus/controllers_test.go +++ b/services/skus/controllers_test.go @@ -373,10 +373,8 @@ func (suite *ControllersTestSuite) TestAndroidWebhook() { err := suite.storage.AppendOrderMetadata(context.Background(), &order.ID, "externalID", "my external id") suite.Require().NoError(err) - // overwrite the receipt validation function for this test - receiptValidationFns = map[Vendor]func(context.Context, interface{}) (string, error){ - appleVendor: validateIOSReceipt, - googleVendor: func(ctx context.Context, v interface{}) (string, error) { + suite.service.vendorReceiptValid = &mockVendorReceiptValidator{ + fnValidateGoogle: func(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) { return "my external id", nil }, } @@ -1890,3 +1888,24 @@ func newCORSOptsEnv() cors.Options { return NewCORSOpts(origins, dbg) } + +type mockVendorReceiptValidator struct { + fnValidateApple func(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) + fnValidateGoogle func(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) +} + +func (v *mockVendorReceiptValidator) validateApple(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) { + if v.fnValidateApple == nil { + return "apple_defaul", nil + } + + return v.fnValidateApple(ctx, receipt) +} + +func (v *mockVendorReceiptValidator) validateGoogle(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) { + if v.fnValidateGoogle == nil { + return "google_default", nil + } + + return v.fnValidateGoogle(ctx, receipt) +} diff --git a/services/skus/receipt.go b/services/skus/receipt.go index 127c02ab0..89a9e2f0b 100644 --- a/services/skus/receipt.go +++ b/services/skus/receipt.go @@ -15,8 +15,12 @@ import ( "github.com/awa/go-iap/appstore" "github.com/awa/go-iap/playstore" + "google.golang.org/api/androidpublisher/v3" + appctx "github.com/brave-intl/bat-go/libs/context" "github.com/brave-intl/bat-go/libs/logging" + + "github.com/brave-intl/bat-go/services/skus/model" ) const ( @@ -24,22 +28,26 @@ const ( androidPaymentStatePaid androidPaymentStateTrial androidPaymentStatePendingDeferred +) +const ( androidCancelReasonUser int64 = 0 androidCancelReasonSystem int64 = 1 androidCancelReasonReplaced int64 = 2 androidCancelReasonDeveloper int64 = 3 + + purchasePendingErrCode = "purchase_pending" + purchaseDeferredErrCode = "purchase_deferred" + purchaseStatusUnknownErrCode = "purchase_status_unknown" + purchaseFailedErrCode = "purchase_failed" + purchaseValidationErrCode = "validation_failed" ) -var ( - receiptValidationFns = map[Vendor]func(context.Context, interface{}) (string, error){ - appleVendor: validateIOSReceipt, - googleVendor: validateAndroidReceipt, - } - iosClient *appstore.Client - androidClient *playstore.Client - errClientMisconfigured = errors.New("misconfigured client") +const ( + errNoInAppTx model.Error = "no in app info in response" +) +var ( errPurchaseUserCanceled = errors.New("purchase is canceled by user") errPurchaseSystemCanceled = errors.New("purchase is canceled by google playstore") errPurchaseReplacedCanceled = errors.New("purchase is canceled and replaced") @@ -51,12 +59,6 @@ var ( errPurchaseFailed = errors.New("purchase failed") errPurchaseExpired = errors.New("purchase expired") - - purchasePendingErrCode = "purchase_pending" - purchaseDeferredErrCode = "purchase_deferred" - purchaseStatusUnknownErrCode = "purchase_status_unknown" - purchaseFailedErrCode = "purchase_failed" - purchaseValidationErrCode = "validation_failed" ) type dumpTransport struct{} @@ -81,109 +83,118 @@ func (dt *dumpTransport) RoundTrip(r *http.Request) (*http.Response, error) { return resp, rtErr } -func initClients(ctx context.Context) { +type appStoreVerifier interface { + Verify(ctx context.Context, req appstore.IAPRequest, result interface{}) error +} - var logClient = &http.Client{ - Transport: &dumpTransport{}, - } +type playStoreVerifier interface { + VerifySubscription(ctx context.Context, pkgName, subID, token string) (*androidpublisher.SubscriptionPurchase, error) +} - logger := logging.Logger(ctx, "skus").With().Str("func", "initClients").Logger() - iosClient = appstore.New() +type receiptVerifier struct { + appStoreCl appStoreVerifier + playStoreCl playStoreVerifier +} - if jsonKey, ok := ctx.Value(appctx.PlaystoreJSONKeyCTXKey).([]byte); ok { - var err error - androidClient, err = playstore.NewWithClient(jsonKey, logClient) +func newReceiptVerifier(cl *http.Client, playKey []byte) (*receiptVerifier, error) { + result := &receiptVerifier{ + appStoreCl: appstore.NewWithClient(cl), + } + + if playKey != nil && len(playKey) != 0 { + gpCl, err := playstore.NewWithClient(playKey, cl) if err != nil { - logger.Error().Err(err).Msg("failed to initialize android client") + return nil, err } + + result.playStoreCl = gpCl } + + return result, nil } -// validateIOSReceipt - validate apple receipt with their apis -func validateIOSReceipt(ctx context.Context, receipt interface{}) (string, error) { - logger := logging.Logger(ctx, "skus").With().Str("func", "validateIOSReceipt").Logger() +// validateApple validates Apple App Store receipt. +func (v *receiptVerifier) validateApple(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) { + l := logging.Logger(ctx, "skus").With().Str("func", "validateReceiptApple").Logger() - // get the shared key from the context sharedKey, sharedKeyOK := ctx.Value(appctx.AppleReceiptSharedKeyCTXKey).(string) - if iosClient != nil { - // handle v1 receipt type - if v, ok := receipt.(SubmitReceiptRequestV1); ok { - req := appstore.IAPRequest{ - ReceiptData: v.Blob, - ExcludeOldTransactions: true, - } - if sharedKeyOK && len(sharedKey) > 0 { - req.Password = sharedKey - } - resp := &appstore.IAPResponse{} - if err := iosClient.Verify(ctx, req, resp); err != nil { - logger.Error().Err(err).Msg("failed to verify receipt") - return "", fmt.Errorf("failed to verify receipt: %w", err) - } - logger.Debug().Msg(fmt.Sprintf("%+v", resp)) - // get the transaction id back - if len(resp.Receipt.InApp) < 1 { - logger.Error().Msg("failed to verify receipt, no in app info") - return "", fmt.Errorf("failed to verify receipt, no in app info in response") - } - return resp.Receipt.InApp[0].OriginalTransactionID, nil - } + req := appstore.IAPRequest{ + ReceiptData: receipt.Blob, + ExcludeOldTransactions: true, } - logger.Error().Msg("client is not configured") - return "", errClientMisconfigured + + if sharedKeyOK && len(sharedKey) > 0 { + req.Password = sharedKey + } + + resp := &appstore.IAPResponse{} + if err := v.appStoreCl.Verify(ctx, req, resp); err != nil { + l.Error().Err(err).Msg("failed to verify receipt") + + return "", fmt.Errorf("failed to verify receipt: %w", err) + } + + l.Debug().Msg(fmt.Sprintf("%+v", resp)) + + if len(resp.Receipt.InApp) < 1 { + l.Error().Msg("failed to verify receipt: no in app info") + return "", errNoInAppTx + } + + return resp.Receipt.InApp[0].OriginalTransactionID, nil } -// validateAndroidReceipt - validate android receipt with their apis -func validateAndroidReceipt(ctx context.Context, receipt interface{}) (string, error) { - logger := logging.Logger(ctx, "skus").With().Str("func", "validateAndroidReceipt").Logger() - if androidClient != nil { - if v, ok := receipt.(SubmitReceiptRequestV1); ok { - logger.Debug().Str("receipt", fmt.Sprintf("%+v", v)).Msg("about to verify subscription") - // handle v1 receipt type - resp, err := androidClient.VerifySubscription(ctx, v.Package, v.SubscriptionID, v.Blob) - if err != nil { - logger.Error().Err(err).Msg("failed to verify subscription") - return "", errPurchaseFailed - } - - // is order expired? - if time.Unix(0, resp.ExpiryTimeMillis*int64(time.Millisecond)).Before(time.Now()) { - return "", errPurchaseExpired - } - - logger.Debug().Msgf("resp: %+v", resp) - if resp.PaymentState != nil { - // check that the order was paid - switch *resp.PaymentState { - case androidPaymentStatePaid, androidPaymentStateTrial: - break - case androidPaymentStatePending: - // is there a cancel reason? - switch resp.CancelReason { - case androidCancelReasonUser: - return "", errPurchaseUserCanceled - case androidCancelReasonSystem: - return "", errPurchaseSystemCanceled - case androidCancelReasonReplaced: - return "", errPurchaseReplacedCanceled - case androidCancelReasonDeveloper: - return "", errPurchaseDeveloperCanceled - } - return "", errPurchasePending - case androidPaymentStatePendingDeferred: - return "", errPurchaseDeferred - default: - return "", errPurchaseStatusUnknown - } - return v.Blob, nil - } - logger.Error().Err(err).Msg("failed to verify subscription: no payment state") - return "", errPurchaseFailed +// validateGoogle validates Google Store receipt. +func (v *receiptVerifier) validateGoogle(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) { + l := logging.Logger(ctx, "skus").With().Str("func", "validateReceiptGoogle").Logger() + + l.Debug().Str("receipt", fmt.Sprintf("%+v", receipt)).Msg("about to verify subscription") + + resp, err := v.playStoreCl.VerifySubscription(ctx, receipt.Package, receipt.SubscriptionID, receipt.Blob) + if err != nil { + l.Error().Err(err).Msg("failed to verify subscription") + return "", errPurchaseFailed + } + + // Check order expiration. + // There seems to be a mistake here? + // Unix expects nanoseconds as the second param, but ms are passed. + if time.Unix(0, resp.ExpiryTimeMillis*int64(time.Millisecond)).Before(time.Now()) { + return "", errPurchaseExpired + } + + l.Debug().Msgf("resp: %+v", resp) + + if resp.PaymentState == nil { + l.Error().Err(err).Msg("failed to verify subscription: no payment state") + return "", errPurchaseFailed + } + + // Check that the order was paid. + switch *resp.PaymentState { + case androidPaymentStatePaid, androidPaymentStateTrial: + return receipt.Blob, nil + + case androidPaymentStatePending: + // Checl for cancel reason. + switch resp.CancelReason { + case androidCancelReasonUser: + return "", errPurchaseUserCanceled + case androidCancelReasonSystem: + return "", errPurchaseSystemCanceled + case androidCancelReasonReplaced: + return "", errPurchaseReplacedCanceled + case androidCancelReasonDeveloper: + return "", errPurchaseDeveloperCanceled } + return "", errPurchasePending + + case androidPaymentStatePendingDeferred: + return "", errPurchaseDeferred + default: + return "", errPurchaseStatusUnknown } - logger.Error().Msg("client is not configured") - return "", errClientMisconfigured } // get the public key from the jws header diff --git a/services/skus/service.go b/services/skus/service.go index 9625cc459..a93607bbc 100644 --- a/services/skus/service.go +++ b/services/skus/service.go @@ -91,6 +91,11 @@ type orderStoreSvc interface { Get(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) } +type vendorReceiptValidator interface { + validateApple(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) + validateGoogle(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) +} + // Service contains datastore type Service struct { orderRepo orderStoreSvc @@ -113,6 +118,8 @@ type Service struct { retry backoff.RetryFunc radomClient *radom.InstrumentedClient radomSellerAddress string + + vendorReceiptValid vendorReceiptValidator } // PauseWorker - pause worker until time specified @@ -159,8 +166,6 @@ func (s *Service) InitKafka(ctx context.Context) error { // InitService creates a service using the passed datastore and clients configured from the environment. func InitService(ctx context.Context, datastore Datastore, walletService *wallet.Service, orderRepo orderStoreSvc, issuerRepo issuerStore) (*Service, error) { sublogger := logging.Logger(ctx, "payments").With().Str("func", "InitService").Logger() - // setup the in app purchase clients - initClients(ctx) // setup stripe if exists in context and enabled scClient := &client.API{} @@ -233,6 +238,17 @@ func InitService(ctx context.Context, datastore Datastore, walletService *wallet } } + cl := &http.Client{ + Transport: &dumpTransport{}, + Timeout: 30 * time.Second, + } + + playKey, _ := ctx.Value(appctx.PlaystoreJSONKeyCTXKey).([]byte) + rcptValidator, err := newReceiptVerifier(cl, playKey) + if err != nil { + return nil, err + } + service := &Service{ orderRepo: orderRepo, issuerRepo: issuerRepo, @@ -247,6 +263,7 @@ func InitService(ctx context.Context, datastore Datastore, walletService *wallet retry: backoff.Retry, radomClient: radomClient, radomSellerAddress: radomSellerAddress, + vendorReceiptValid: rcptValidator, } // setup runnable jobs @@ -1559,13 +1576,12 @@ func (s *Service) verifyDeveloperNotification(ctx context.Context, dn *Developer } // have order, now validate the receipt from the notification - _, err = s.validateReceipt(ctx, &o.ID, SubmitReceiptRequestV1{ + if _, err := s.vendorReceiptValid.validateGoogle(ctx, SubmitReceiptRequestV1{ Type: "android", Blob: dn.SubscriptionNotification.PurchaseToken, Package: dn.PackageName, SubscriptionID: dn.SubscriptionNotification.SubscriptionID, - }) - if err != nil { + }); err != nil { return fmt.Errorf("failed to validate purchase token: %w", err) } @@ -1597,17 +1613,16 @@ func (s *Service) verifyDeveloperNotification(ctx context.Context, dn *Developer return nil } -// validateReceipt - perform receipt validation -func (s *Service) validateReceipt(ctx context.Context, orderID *uuid.UUID, receipt interface{}) (string, error) { - // based on the vendor call the vendor specific apis to check the status of the receipt, - if v, ok := receipt.(SubmitReceiptRequestV1); ok { - // and get back the external id - if fn, ok := receiptValidationFns[v.Type]; ok { - return fn(ctx, receipt) - } +// validateReceipt validates receipt. +func (s *Service) validateReceipt(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) { + switch receipt.Type { + case appleVendor: + return s.vendorReceiptValid.validateApple(ctx, receipt) + case googleVendor: + return s.vendorReceiptValid.validateGoogle(ctx, receipt) + default: + return "", errorutils.ErrNotImplemented } - - return "", errorutils.ErrNotImplemented } // UpdateOrderStatusPaidWithMetadata - update the order status with metadata From 2d2552aade535585334b66c3e5cc1dc83cfe7231 Mon Sep 17 00:00:00 2001 From: clD11 <23483715+clD11@users.noreply.github.com> Date: Thu, 25 Jan 2024 21:44:30 +0000 Subject: [PATCH 18/18] feat: add solana to country code allow block list (#2315) * feat: add solana to country code allow block list * refactor: rename test regions decode * fix: add solana to parameters v1 schema test --- libs/custodian/regions.go | 6 ++-- libs/custodian/regions_test.go | 51 +++++++++++++++++++++++++++++++++- schema/rewards/ParametersV1 | 2 +- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/libs/custodian/regions.go b/libs/custodian/regions.go index 8fb7dc0fb..130ab58af 100644 --- a/libs/custodian/regions.go +++ b/libs/custodian/regions.go @@ -117,7 +117,6 @@ func contains(countries, allowblock []string) bool { for _, ab := range allowblock { for _, country := range countries { if strings.EqualFold(ab, country) { - fmt.Println("contains ", ab, country) return true } } @@ -141,6 +140,7 @@ type Regions struct { Gemini GeoAllowBlockMap `json:"gemini" valid:"-"` Bitflyer GeoAllowBlockMap `json:"bitflyer" valid:"-"` Zebpay GeoAllowBlockMap `json:"zebpay" valid:"-"` + Solana GeoAllowBlockMap `json:"solana" valid:"-"` } // HandleErrors - handle any errors in input @@ -149,12 +149,12 @@ func (cr *Regions) HandleErrors(err error) *handlers.AppError { } // Decode - implement decodable -func (cr *Regions) Decode(ctx context.Context, input []byte) error { +func (cr *Regions) Decode(_ context.Context, input []byte) error { return json.Unmarshal(input, cr) } // Validate - implement validatable -func (cr *Regions) Validate(ctx context.Context) error { +func (cr *Regions) Validate(_ context.Context) error { isValid, err := govalidator.ValidateStruct(cr) if err != nil { return err diff --git a/libs/custodian/regions_test.go b/libs/custodian/regions_test.go index 0b2889cce..1975b0c47 100644 --- a/libs/custodian/regions_test.go +++ b/libs/custodian/regions_test.go @@ -1,6 +1,12 @@ package custodian -import "testing" +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) func TestVerdictAllowList(t *testing.T) { gabm := GeoAllowBlockMap{ @@ -27,3 +33,46 @@ func TestVerdictBlockList(t *testing.T) { t.Error("should have been false, US in block list") } } + +func TestRegions_Decode(t *testing.T) { + type tcGiven struct { + input []byte + } + + type exp struct { + allow []string + block []string + } + + type testCase struct { + name string + given tcGiven + exp exp + } + + testCases := []testCase{ + { + name: "solana", + given: tcGiven{ + input: []byte(`{"solana":{"allow":["AA"],"block":["AB"]}}`), + }, + exp: exp{ + allow: []string{"AA"}, + block: []string{"AB"}, + }, + }, + } + + for i := range testCases { + tc := testCases[i] + + t.Run(tc.name, func(t *testing.T) { + regions := Regions{} + err := regions.Decode(context.Background(), tc.given.input) + require.NoError(t, err) + + assert.Equal(t, tc.exp.allow, regions.Solana.Allow) + assert.Equal(t, tc.exp.block, regions.Solana.Block) + }) + } +} diff --git a/schema/rewards/ParametersV1 b/schema/rewards/ParametersV1 index 9cb9f5be2..437aafdc8 100644 --- a/schema/rewards/ParametersV1 +++ b/schema/rewards/ParametersV1 @@ -1 +1 @@ -{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/ParametersV1","definitions":{"AutoContribute":{"properties":{"choices":{"items":{"type":"number"},"type":"array"},"defaultChoice":{"type":"number"}},"additionalProperties":false,"type":"object"},"GeoAllowBlockMap":{"required":["allow","block"],"properties":{"allow":{"items":{"type":"string"},"type":"array"},"block":{"items":{"type":"string"},"type":"array"}},"additionalProperties":false,"type":"object"},"ParametersV1":{"required":["payoutStatus","custodianRegions","vbatExpired"],"properties":{"payoutStatus":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/PayoutStatus"},"custodianRegions":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/Regions"},"batRate":{"type":"number"},"autocontribute":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/AutoContribute"},"tips":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/Tips"},"vbatExpired":{"type":"boolean"},"vbatDeadline":{"type":"string","format":"date-time"}},"additionalProperties":false,"type":"object"},"PayoutStatus":{"required":["unverified","uphold","gemini","bitflyer","zebpay","payoutDate"],"properties":{"unverified":{"type":"string"},"uphold":{"type":"string"},"gemini":{"type":"string"},"bitflyer":{"type":"string"},"zebpay":{"type":"string"},"payoutDate":{"type":"string"}},"additionalProperties":false,"type":"object"},"Regions":{"required":["uphold","gemini","bitflyer","zebpay"],"properties":{"uphold":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/GeoAllowBlockMap"},"gemini":{"$ref":"#/definitions/GeoAllowBlockMap"},"bitflyer":{"$ref":"#/definitions/GeoAllowBlockMap"},"zebpay":{"$ref":"#/definitions/GeoAllowBlockMap"}},"additionalProperties":false,"type":"object"},"Tips":{"properties":{"defaultTipChoices":{"items":{"type":"number"},"type":"array"},"defaultMonthlyChoices":{"items":{"type":"number"},"type":"array"}},"additionalProperties":false,"type":"object"}}} \ No newline at end of file +{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/ParametersV1","definitions":{"AutoContribute":{"properties":{"choices":{"items":{"type":"number"},"type":"array"},"defaultChoice":{"type":"number"}},"additionalProperties":false,"type":"object"},"GeoAllowBlockMap":{"required":["allow","block"],"properties":{"allow":{"items":{"type":"string"},"type":"array"},"block":{"items":{"type":"string"},"type":"array"}},"additionalProperties":false,"type":"object"},"ParametersV1":{"required":["payoutStatus","custodianRegions","vbatExpired"],"properties":{"payoutStatus":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/PayoutStatus"},"custodianRegions":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/Regions"},"batRate":{"type":"number"},"autocontribute":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/AutoContribute"},"tips":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/Tips"},"vbatExpired":{"type":"boolean"},"vbatDeadline":{"type":"string","format":"date-time"}},"additionalProperties":false,"type":"object"},"PayoutStatus":{"required":["unverified","uphold","gemini","bitflyer","zebpay","payoutDate"],"properties":{"unverified":{"type":"string"},"uphold":{"type":"string"},"gemini":{"type":"string"},"bitflyer":{"type":"string"},"zebpay":{"type":"string"},"payoutDate":{"type":"string"}},"additionalProperties":false,"type":"object"},"Regions":{"required":["uphold","gemini","bitflyer","zebpay","solana"],"properties":{"uphold":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/GeoAllowBlockMap"},"gemini":{"$ref":"#/definitions/GeoAllowBlockMap"},"bitflyer":{"$ref":"#/definitions/GeoAllowBlockMap"},"zebpay":{"$ref":"#/definitions/GeoAllowBlockMap"},"solana":{"$ref":"#/definitions/GeoAllowBlockMap"}},"additionalProperties":false,"type":"object"},"Tips":{"properties":{"defaultTipChoices":{"items":{"type":"number"},"type":"array"},"defaultMonthlyChoices":{"items":{"type":"number"},"type":"array"}},"additionalProperties":false,"type":"object"}}} \ No newline at end of file