diff --git a/.gitignore b/.gitignore index fbf196191..19725210b 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,6 @@ share-0.gpg fetch-quote.json main/main + +go.work +go.work.sum diff --git a/docker-compose.dev-refresh.yml b/docker-compose.dev-refresh.yml index ca5c07ea3..6f18e8cf6 100644 --- a/docker-compose.dev-refresh.yml +++ b/docker-compose.dev-refresh.yml @@ -56,3 +56,4 @@ services: - UPHOLD_ACCESS_TOKEN - "RATIOS_SERVICE=https://ratios.rewards.bravesoftware.com" - RATIOS_TOKEN + - "DAPP_ALLOWED_CORS_ORIGINS=https://my-dapp.com" diff --git a/docker-compose.yml b/docker-compose.yml index e93b4b75a..2cfc5c41f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -80,6 +80,7 @@ services: - TEST_RUN - TOKEN_LIST - UPHOLD_ACCESS_TOKEN + - "DAPP_ALLOWED_CORS_ORIGINS=https://my-dapp.com" volumes: - ./test/secrets:/etc/kafka/secrets - ./migrations:/src/migrations @@ -167,6 +168,7 @@ services: - TEST_RUN - TOKEN_LIST - UPHOLD_ACCESS_TOKEN + - "DAPP_ALLOWED_CORS_ORIGINS=https://my-dapp.com" volumes: - ./test/secrets:/etc/kafka/secrets - ./migrations:/src/migrations diff --git a/libs/context/keys.go b/libs/context/keys.go index 9418f611b..0104850a8 100644 --- a/libs/context/keys.go +++ b/libs/context/keys.go @@ -128,6 +128,8 @@ const ( DisableGeminiLinkingCTXKey CTXKey = "disable_gemini_linking" // DisableBitflyerLinkingCTXKey - this informs if bitflyer linking is enabled DisableBitflyerLinkingCTXKey CTXKey = "disable_bitflyer_linking" + // DisableSolanaLinkingCTXKey - this informs if solana linking is enabled + DisableSolanaLinkingCTXKey CTXKey = "disable_solana_linking" // RadomWebhookSecretCTXKey - the webhook secret key for radom integration RadomWebhookSecretCTXKey CTXKey = "radom_webhook_secret" diff --git a/libs/custodian/regions.go b/libs/custodian/regions.go index 8fb7dc0fb..130ab58af 100644 --- a/libs/custodian/regions.go +++ b/libs/custodian/regions.go @@ -117,7 +117,6 @@ func contains(countries, allowblock []string) bool { for _, ab := range allowblock { for _, country := range countries { if strings.EqualFold(ab, country) { - fmt.Println("contains ", ab, country) return true } } @@ -141,6 +140,7 @@ type Regions struct { Gemini GeoAllowBlockMap `json:"gemini" valid:"-"` Bitflyer GeoAllowBlockMap `json:"bitflyer" valid:"-"` Zebpay GeoAllowBlockMap `json:"zebpay" valid:"-"` + Solana GeoAllowBlockMap `json:"solana" valid:"-"` } // HandleErrors - handle any errors in input @@ -149,12 +149,12 @@ func (cr *Regions) HandleErrors(err error) *handlers.AppError { } // Decode - implement decodable -func (cr *Regions) Decode(ctx context.Context, input []byte) error { +func (cr *Regions) Decode(_ context.Context, input []byte) error { return json.Unmarshal(input, cr) } // Validate - implement validatable -func (cr *Regions) Validate(ctx context.Context) error { +func (cr *Regions) Validate(_ context.Context) error { isValid, err := govalidator.ValidateStruct(cr) if err != nil { return err diff --git a/libs/custodian/regions_test.go b/libs/custodian/regions_test.go index 0b2889cce..1975b0c47 100644 --- a/libs/custodian/regions_test.go +++ b/libs/custodian/regions_test.go @@ -1,6 +1,12 @@ package custodian -import "testing" +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) func TestVerdictAllowList(t *testing.T) { gabm := GeoAllowBlockMap{ @@ -27,3 +33,46 @@ func TestVerdictBlockList(t *testing.T) { t.Error("should have been false, US in block list") } } + +func TestRegions_Decode(t *testing.T) { + type tcGiven struct { + input []byte + } + + type exp struct { + allow []string + block []string + } + + type testCase struct { + name string + given tcGiven + exp exp + } + + testCases := []testCase{ + { + name: "solana", + given: tcGiven{ + input: []byte(`{"solana":{"allow":["AA"],"block":["AB"]}}`), + }, + exp: exp{ + allow: []string{"AA"}, + block: []string{"AB"}, + }, + }, + } + + for i := range testCases { + tc := testCases[i] + + t.Run(tc.name, func(t *testing.T) { + regions := Regions{} + err := regions.Decode(context.Background(), tc.given.input) + require.NoError(t, err) + + assert.Equal(t, tc.exp.allow, regions.Solana.Allow) + assert.Equal(t, tc.exp.block, regions.Solana.Block) + }) + } +} diff --git a/libs/datastore/postgres.go b/libs/datastore/postgres.go index 1fde6af11..d9d8ac9b0 100644 --- a/libs/datastore/postgres.go +++ b/libs/datastore/postgres.go @@ -13,8 +13,8 @@ import ( appctx "github.com/brave-intl/bat-go/libs/context" "github.com/brave-intl/bat-go/libs/logging" "github.com/brave-intl/bat-go/libs/metrics" - sentry "github.com/getsentry/sentry-go" - migrate "github.com/golang-migrate/migrate/v4" + "github.com/getsentry/sentry-go" + "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/postgres" "github.com/jmoiron/sqlx" "github.com/prometheus/client_golang/prometheus" @@ -41,7 +41,7 @@ var ( } dbs = map[string]*sqlx.DB{} // CurrentMigrationVersion holds the default migration version - CurrentMigrationVersion = uint(63) + CurrentMigrationVersion = uint(66) // MigrationTracks holds the migration version for a given track (eyeshade, promotion, wallet) MigrationTracks = map[string]uint{ "eyeshade": 20, diff --git a/libs/ptr/ptr.go b/libs/ptr/ptr.go index 6c1667b17..46027145f 100644 --- a/libs/ptr/ptr.go +++ b/libs/ptr/ptr.go @@ -2,15 +2,8 @@ package ptr import ( "time" - - uuid "github.com/satori/go.uuid" ) -// FromUUID returns pointer to uuid -func FromUUID(u uuid.UUID) *uuid.UUID { - return &u -} - // FromString returns pointer to string func FromString(s string) *string { return &s diff --git a/libs/wallet/wallet.go b/libs/wallet/wallet.go index 1c3d89aac..40575bdbd 100644 --- a/libs/wallet/wallet.go +++ b/libs/wallet/wallet.go @@ -3,10 +3,15 @@ package wallet import ( "context" + "crypto/ed25519" + "encoding/base64" + "encoding/hex" "fmt" + "strings" "time" "github.com/brave-intl/bat-go/libs/altcurrency" + "github.com/btcsuite/btcutil/base58" uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" ) @@ -89,3 +94,202 @@ type Wallet interface { // ListTransactions for this wallet, limit number of transactions returned ListTransactions(ctx context.Context, limit int, startDate time.Time) ([]TransactionInfo, error) } + +type SolanaLinkReq struct { + Pub string + Sig string + Msg string + Nonce string +} + +func (w *Info) LinkSolanaAddress(_ context.Context, s SolanaLinkReq) error { + if err := w.isLinkReqValid(s); err != nil { + return newLinkSolanaAddressError(err) + } + + w.UserDepositDestination = s.Pub + + return nil +} + +func (w *Info) isLinkReqValid(s SolanaLinkReq) error { + if err := verifySolanaSignature(s.Pub, s.Msg, s.Sig); err != nil { + return err + } + + p := newSolMsgParser(w.ID, s.Pub, s.Nonce) + rm, err := p.parse(s.Msg) + if err != nil { + return fmt.Errorf("parsing error: %w", err) + } + + if err := verifyRewardsSignature(w.PublicKey, rm.msg, rm.sig); err != nil { + return err + } + + return nil +} + +const publicKeyLength = 32 + +const ( + errBadPublicKeyLength Error = "wallet: bad public key length" + errInvalidSolanaSignature Error = "wallet: invalid solana signature for message and public key" +) + +func verifySolanaSignature(pub, msg, sig string) error { + b := base58.Decode(pub) + if len(b) != publicKeyLength { + return fmt.Errorf("error verifying solana signature: %w", errBadPublicKeyLength) + } + pubKey := ed25519.PublicKey(b) + + decSig, err := base64.URLEncoding.DecodeString(sig) + if err != nil { + return fmt.Errorf("error decoding solana signature: %w", err) + } + + if !ed25519.Verify(pubKey, []byte(msg), decSig) { + return errInvalidSolanaSignature + } + + return nil +} + +type rewMsg struct { + msg string + sig string +} + +const ( + errInvalidPartsLineBreak Error = "wallet: invalid number of lines" + errInvalidPartsColon Error = "wallet: invalid parts colon" + errInvalidParts Error = "wallet: invalid number of parts" + errInvalidPaymentID Error = "wallet: payment id does not match" + errInvalidSolanaPubKey Error = "wallet: solana public key does not match" + errInvalidRewardsMessage Error = "wallet: invalid rewards message" +) + +type solMsgParser struct { + paymentID string + solPub string + nonce string +} + +func newSolMsgParser(paymentID, solPub, nonce string) solMsgParser { + return solMsgParser{ + paymentID: paymentID, + solPub: solPub, + nonce: nonce, + } +} + +// parse parses a linking message and returns the rewards part. +// +// For a message to parse successfully it must be a valid format and successfully match the +// configured parser parameters paymentID, solPub and nonce. +// +// The message format should be three lines with a colon delimiter. For example, +// : +// : +// :.. +// +// The final part is referred to as the rewards message and has the +// format .. we only need to compare +// the .. when parsing. +// +// The returned rewMsg will contain both the message and signature. For example, +// +// rewMsg { +// msg: ., +// sig: , +// } +func (s *solMsgParser) parse(msg string) (rewMsg, error) { + lines := strings.Split(msg, "\n") + if len(lines) != 3 { + return rewMsg{}, errInvalidPartsLineBreak + } + + parts := make([]string, 0, 3) + for i := range lines { + p := strings.Split(lines[i], ":") + if len(p) != 2 { + return rewMsg{}, errInvalidPartsColon + } + parts = append(parts, p[1]) + } + + if len(parts) != 3 { + return rewMsg{}, errInvalidParts + } + + if parts[0] != s.paymentID { + return rewMsg{}, errInvalidPaymentID + } + + if parts[1] != s.solPub { + return rewMsg{}, errInvalidSolanaPubKey + } + + // Compare the final part minus the signature. + exp := s.paymentID + "." + s.nonce + "." + for i := range exp { + if parts[2][i] != exp[i] { + return rewMsg{}, errInvalidRewardsMessage + } + } + + rm := rewMsg{ + msg: parts[2][:len(exp)-1], // -1 removes the trailing . + sig: parts[2][len(exp):], + } + + return rm, nil +} + +const errInvalidRewardsSignature Error = "wallet: invalid rewards signature for message and public key" + +func verifyRewardsSignature(pub, msg, sig string) error { + b, err := hex.DecodeString(pub) + if err != nil { + return fmt.Errorf("error decoding rewards public key: %w", err) + } + + if len(b) != publicKeyLength { + return fmt.Errorf("error verifying rewards signature: %w", errBadPublicKeyLength) + } + pubKey := ed25519.PublicKey(b) + + decSig, err := base64.URLEncoding.DecodeString(sig) + if err != nil { + return fmt.Errorf("error decoding rewards signature: %w", err) + } + + if !ed25519.Verify(pubKey, []byte(msg), decSig) { + return errInvalidRewardsSignature + } + + return nil +} + +type LinkSolanaAddressError struct { + err error +} + +func newLinkSolanaAddressError(err error) error { + return &LinkSolanaAddressError{err: err} +} + +func (e *LinkSolanaAddressError) Error() string { + return e.err.Error() +} + +func (e *LinkSolanaAddressError) Unwrap() error { + return e.err +} + +type Error string + +func (e Error) Error() string { + return string(e) +} diff --git a/libs/wallet/wallet_test.go b/libs/wallet/wallet_test.go new file mode 100644 index 000000000..cf0f33daa --- /dev/null +++ b/libs/wallet/wallet_test.go @@ -0,0 +1,335 @@ +package wallet + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestInfo_LinkSolanaAddress(t *testing.T) { + type tcGiven struct { + w Info + s SolanaLinkReq + } + + type testCase struct { + name string + given tcGiven + assertErr assert.ErrorAssertionFunc + } + + tests := []testCase{ + { + name: "success", + given: tcGiven{ + w: Info{ + ID: "5e99b1ae-6e91-481f-9021-7fb3c97327d4", + PublicKey: "f4205e18ea138efc59f18dfbb5c7c2a24b5724395a36e86e84c84b9ec9ccb1ec", + }, + s: SolanaLinkReq{ + Pub: "GysxUmKWkazQSUm3DCG7o5tPmsuNqFgQ2iG9imkzxzyG", + Sig: "X2nhxq-95ZR5QWk9R1m-Rqh8QVndDy2yL2NY5PSx0G-EyzX3xm7JKpPhILxZfc_cWwLtaPk6xRQBManPCKE6BQ==", + Msg: newMsg(msgParts{ + paymentID: "5e99b1ae-6e91-481f-9021-7fb3c97327d4", + solPub: "GysxUmKWkazQSUm3DCG7o5tPmsuNqFgQ2iG9imkzxzyG", + nonce: "86d6f240-df9b-4167-a66e-5df6da80ac24", + rewSig: "szPeTsDRUOFLS1y-k85yfvcI40OOWwTEn2mQ3cGcnZLPkDCXD1qJKYJNkNgdY5j5BA7pvj8AzEy8riKtdeRaAQ==", + }), + Nonce: "86d6f240-df9b-4167-a66e-5df6da80ac24", + }, + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) + }, + }, + { + name: "invalid_linking", + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + var expected *LinkSolanaAddressError + return assert.ErrorAs(t, err, &expected) + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + wallet := tc.given.w + err := wallet.LinkSolanaAddress(context.TODO(), tc.given.s) + tc.assertErr(t, err) + }) + } +} + +func TestVerifySolanaSignature(t *testing.T) { + type tcGiven struct { + solPub string + msg string + solSig string + } + + type testCase struct { + name string + given tcGiven + assertErr assert.ErrorAssertionFunc + } + + tests := []testCase{ + { + name: "invalid_public_key_length", + given: tcGiven{solPub: "123456789"}, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, errBadPublicKeyLength) + }, + }, + { + name: "signature_has_illegal_character", + given: tcGiven{solPub: "32rbMEtgTphzVnHuSsuHEv3hKpm92UsgMerjDjZr72T1", solSig: "+"}, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorContains(t, err, "error decoding solana signature") + }, + }, + { + name: "invalid_signature", + given: tcGiven{ + solPub: "5zjTAqk1xbeYFzkMCnY3H52SLQpuM1GUFUDAwfhJs1wg", + msg: "invalid_message", + solSig: "zc2boTImAAhzraUplAlUy2L6hNF6l-DYGfOqq_4UfrDsJEBg26jaHIAXJF2i3tifCZxrvmu3ahqIdnm2kOwyBQ==", + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, errInvalidSolanaSignature) + }, + }, + { + name: "valid_signature", + given: tcGiven{ + solPub: "5zjTAqk1xbeYFzkMCnY3H52SLQpuM1GUFUDAwfhJs1wg", + msg: "test", + solSig: "zc2boTImAAhzraUplAlUy2L6hNF6l-DYGfOqq_4UfrDsJEBg26jaHIAXJF2i3tifCZxrvmu3ahqIdnm2kOwyBQ==", + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) + }, + }, + } + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + err := verifySolanaSignature(tc.given.solPub, tc.given.msg, tc.given.solSig) + tc.assertErr(t, err) + }) + } +} + +func TestSolMsgParser_parse(t *testing.T) { + type tcGiven struct { + msgParser solMsgParser + msg string + } + + type exp struct { + rewMsg rewMsg + err error + } + + type testCase struct { + name string + given tcGiven + exp exp + } + + tests := []testCase{ + { + name: "invalid_number_of_line_breaks", + given: tcGiven{ + msgParser: solMsgParser{paymentID: "payment-id", solPub: "solana-address", nonce: "nonce"}, + msg: "some-text:payment-id-some-text:solana-address\nsome-text:payment-id.nonce.abcd", + }, + exp: exp{ + err: errInvalidPartsLineBreak, + }, + }, + { + name: "invalid_parts_colon", + given: tcGiven{ + msgParser: solMsgParser{paymentID: "payment-id", solPub: "solana-address", nonce: "nonce"}, + msg: "payment-id\nsome-text:solana-address\nsome-text:payment-id.nonce.abcd", + }, + exp: exp{ + err: errInvalidPartsColon, + }, + }, + { + name: "invalid_payment_id", + given: tcGiven{ + msgParser: solMsgParser{paymentID: "payment-id", solPub: "solana-address", nonce: "nonce"}, + msg: "some-text:another-id\nsome-text:solana-address\nsome-text:payment-id.nonce.abcd", + }, + exp: exp{ + err: errInvalidPaymentID, + }, + }, + { + name: "invalid_solana_public_key", + given: tcGiven{ + msgParser: solMsgParser{paymentID: "payment-id", solPub: "solana-address", nonce: "nonce"}, + msg: "some-text:payment-id\nsome-text:another-solana-address\nsome-text:payment-id.nonce.abcd", + }, + exp: exp{ + err: errInvalidSolanaPubKey, + }, + }, + { + name: "invalid_rewards_message", + given: tcGiven{ + msgParser: solMsgParser{paymentID: "payment-id", solPub: "solana-address", nonce: "nonce"}, + msg: "some-text:payment-id\nsome-text:solana-address\nsome-text:another-payment-id.nonce.abcd", + }, + exp: exp{ + err: errInvalidRewardsMessage, + }, + }, + { + name: "no_rewards_message_signature", + given: tcGiven{ + msgParser: solMsgParser{paymentID: "payment-id", solPub: "solana-address", nonce: "nonce"}, + msg: "some-text:payment-id\nsome-text:solana-address\nsome-text:payment-id.nonce.", + }, + exp: exp{ + rewMsg: rewMsg{ + msg: "payment-id.nonce", + }, + err: nil, + }, + }, + { + name: "success", + given: tcGiven{ + msgParser: solMsgParser{paymentID: "payment-id", solPub: "solana-address", nonce: "nonce"}, + msg: "some-text:payment-id\nsome-text:solana-address\nsome-text:payment-id.nonce.abcd", + }, + exp: exp{ + rewMsg: rewMsg{ + msg: "payment-id.nonce", + sig: "abcd", + }, + err: nil, + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + msgParser := tc.given.msgParser + actual, err := msgParser.parse(tc.given.msg) + assert.Equal(t, tc.exp.err, err) + assert.Equal(t, tc.exp.rewMsg, actual) + }) + } +} + +func TestVerifyRewardsSignature(t *testing.T) { + type tcGiven struct { + pub string + sig string + msg string + } + + type testCase struct { + name string + given tcGiven + assertErr assert.ErrorAssertionFunc + } + + tests := []testCase{ + { + name: "error_decoding_public_key", + given: tcGiven{pub: "invalid_key"}, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorContains(t, err, "error decoding rewards public key") + }, + }, + { + name: "invalid_public_key_length", + given: tcGiven{pub: hex.EncodeToString([]byte("key"))}, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, errBadPublicKeyLength) + }, + }, + { + name: "signature_has_illegal_character", + given: tcGiven{ + pub: "ac1e69da621a99cf29de8ac1b0ffc8ece154b98e99a0ebec2bfdf2af04b8ac53", + sig: "!", + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorContains(t, err, "error decoding rewards signature") + }, + }, + { + name: "invalid_signature", + given: tcGiven{ + pub: "e0e9196cfb3c98f8912c011ff46193167b7df72a166c595408c6ca6c690bb707", + msg: "invalid_message", + sig: "gJJptSk0lGBjpJOx7Mq_AwVtNkW5tg4esgbtYesQXLfabDZP4K_bFxpEn40TIBRISQho9oLzGfOnzWH88ntdAg=="}, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, errInvalidRewardsSignature) + }, + }, + { + name: "valid_signature", + given: tcGiven{ + pub: "e0e9196cfb3c98f8912c011ff46193167b7df72a166c595408c6ca6c690bb707", + msg: "test", + sig: "gJJptSk0lGBjpJOx7Mq_AwVtNkW5tg4esgbtYesQXLfabDZP4K_bFxpEn40TIBRISQho9oLzGfOnzWH88ntdAg==", + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + err := verifyRewardsSignature(tc.given.pub, tc.given.msg, tc.given.sig) + tc.assertErr(t, err) + }) + } +} + +func TestNewLinkSolanaAddressError(t *testing.T) { + err1 := errors.New("error text") + err2 := newLinkSolanaAddressError(err1) + err3 := fmt.Errorf("%w", err2) + + t.Run("error_text", func(t *testing.T) { + assert.ErrorContains(t, err2, err1.Error()) + }) + + t.Run("error_unwrap", func(t *testing.T) { + var target *LinkSolanaAddressError + assert.ErrorAs(t, err3, &target) + }) +} + +type msgParts struct { + paymentID string + solPub string + nonce string + rewSig string +} + +func newMsg(parts msgParts) string { + const msgTmpl = ":%s\n:%s\n:%s.%s.%s" + return fmt.Sprintf(msgTmpl, parts.paymentID, parts.solPub, parts.paymentID, parts.nonce, parts.rewSig) +} diff --git a/migrations/0064_challenges.down.sql b/migrations/0064_challenges.down.sql new file mode 100644 index 000000000..3efc2210d --- /dev/null +++ b/migrations/0064_challenges.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS challenge; diff --git a/migrations/0064_challenges.up.sql b/migrations/0064_challenges.up.sql new file mode 100644 index 000000000..d23b958cf --- /dev/null +++ b/migrations/0064_challenges.up.sql @@ -0,0 +1,6 @@ +CREATE TABLE challenge ( + payment_id uuid PRIMARY KEY, + created_at timestamp WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + nonce text NOT NULL, + CONSTRAINT challenge_nonce UNIQUE (nonce) +); diff --git a/migrations/0065_allow_list.down.sql b/migrations/0065_allow_list.down.sql new file mode 100644 index 000000000..7c8c09e99 --- /dev/null +++ b/migrations/0065_allow_list.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS allow_list; diff --git a/migrations/0065_allow_list.up.sql b/migrations/0065_allow_list.up.sql new file mode 100644 index 000000000..8b83ddf70 --- /dev/null +++ b/migrations/0065_allow_list.up.sql @@ -0,0 +1,4 @@ +CREATE TABLE allow_list ( + payment_id uuid PRIMARY KEY, + created_at timestamp WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP +); diff --git a/migrations/0066_wallet_custodian_check_custodian_add_solana.down.sql b/migrations/0066_wallet_custodian_check_custodian_add_solana.down.sql new file mode 100644 index 000000000..7f3e0126c --- /dev/null +++ b/migrations/0066_wallet_custodian_check_custodian_add_solana.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE wallet_custodian DROP CONSTRAINT IF EXISTS check_custodian, + ADD CONSTRAINT check_custodian CHECK (custodian IN ('brave', 'uphold', 'bitflyer', 'gemini', 'zebpay')); diff --git a/migrations/0066_wallet_custodian_check_custodian_add_solana.up.sql b/migrations/0066_wallet_custodian_check_custodian_add_solana.up.sql new file mode 100644 index 000000000..06c40dd1d --- /dev/null +++ b/migrations/0066_wallet_custodian_check_custodian_add_solana.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE wallet_custodian DROP CONSTRAINT IF EXISTS check_custodian, + ADD CONSTRAINT check_custodian CHECK (custodian IN ('brave', 'uphold', 'bitflyer', 'gemini', 'zebpay', 'solana')); diff --git a/schema/rewards/ParametersV1 b/schema/rewards/ParametersV1 index 9cb9f5be2..437aafdc8 100644 --- a/schema/rewards/ParametersV1 +++ b/schema/rewards/ParametersV1 @@ -1 +1 @@ -{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/ParametersV1","definitions":{"AutoContribute":{"properties":{"choices":{"items":{"type":"number"},"type":"array"},"defaultChoice":{"type":"number"}},"additionalProperties":false,"type":"object"},"GeoAllowBlockMap":{"required":["allow","block"],"properties":{"allow":{"items":{"type":"string"},"type":"array"},"block":{"items":{"type":"string"},"type":"array"}},"additionalProperties":false,"type":"object"},"ParametersV1":{"required":["payoutStatus","custodianRegions","vbatExpired"],"properties":{"payoutStatus":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/PayoutStatus"},"custodianRegions":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/Regions"},"batRate":{"type":"number"},"autocontribute":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/AutoContribute"},"tips":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/Tips"},"vbatExpired":{"type":"boolean"},"vbatDeadline":{"type":"string","format":"date-time"}},"additionalProperties":false,"type":"object"},"PayoutStatus":{"required":["unverified","uphold","gemini","bitflyer","zebpay","payoutDate"],"properties":{"unverified":{"type":"string"},"uphold":{"type":"string"},"gemini":{"type":"string"},"bitflyer":{"type":"string"},"zebpay":{"type":"string"},"payoutDate":{"type":"string"}},"additionalProperties":false,"type":"object"},"Regions":{"required":["uphold","gemini","bitflyer","zebpay"],"properties":{"uphold":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/GeoAllowBlockMap"},"gemini":{"$ref":"#/definitions/GeoAllowBlockMap"},"bitflyer":{"$ref":"#/definitions/GeoAllowBlockMap"},"zebpay":{"$ref":"#/definitions/GeoAllowBlockMap"}},"additionalProperties":false,"type":"object"},"Tips":{"properties":{"defaultTipChoices":{"items":{"type":"number"},"type":"array"},"defaultMonthlyChoices":{"items":{"type":"number"},"type":"array"}},"additionalProperties":false,"type":"object"}}} \ No newline at end of file +{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/ParametersV1","definitions":{"AutoContribute":{"properties":{"choices":{"items":{"type":"number"},"type":"array"},"defaultChoice":{"type":"number"}},"additionalProperties":false,"type":"object"},"GeoAllowBlockMap":{"required":["allow","block"],"properties":{"allow":{"items":{"type":"string"},"type":"array"},"block":{"items":{"type":"string"},"type":"array"}},"additionalProperties":false,"type":"object"},"ParametersV1":{"required":["payoutStatus","custodianRegions","vbatExpired"],"properties":{"payoutStatus":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/PayoutStatus"},"custodianRegions":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/Regions"},"batRate":{"type":"number"},"autocontribute":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/AutoContribute"},"tips":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/Tips"},"vbatExpired":{"type":"boolean"},"vbatDeadline":{"type":"string","format":"date-time"}},"additionalProperties":false,"type":"object"},"PayoutStatus":{"required":["unverified","uphold","gemini","bitflyer","zebpay","payoutDate"],"properties":{"unverified":{"type":"string"},"uphold":{"type":"string"},"gemini":{"type":"string"},"bitflyer":{"type":"string"},"zebpay":{"type":"string"},"payoutDate":{"type":"string"}},"additionalProperties":false,"type":"object"},"Regions":{"required":["uphold","gemini","bitflyer","zebpay","solana"],"properties":{"uphold":{"$schema":"http://json-schema.org/draft-04/schema#","$ref":"#/definitions/GeoAllowBlockMap"},"gemini":{"$ref":"#/definitions/GeoAllowBlockMap"},"bitflyer":{"$ref":"#/definitions/GeoAllowBlockMap"},"zebpay":{"$ref":"#/definitions/GeoAllowBlockMap"},"solana":{"$ref":"#/definitions/GeoAllowBlockMap"}},"additionalProperties":false,"type":"object"},"Tips":{"properties":{"defaultTipChoices":{"items":{"type":"number"},"type":"array"},"defaultMonthlyChoices":{"items":{"type":"number"},"type":"array"}},"additionalProperties":false,"type":"object"}}} \ No newline at end of file diff --git a/services/go.mod b/services/go.mod index 7a8cec66a..f6200d5fc 100644 --- a/services/go.mod +++ b/services/go.mod @@ -20,6 +20,7 @@ require ( github.com/brave-intl/bat-go v1.0.2 github.com/brave-intl/bat-go/libs v1.0.2 github.com/brave-intl/bat-go/tools v1.0.2 + github.com/btcsuite/btcutil v1.0.2 github.com/getsentry/sentry-go v0.14.0 github.com/go-chi/chi v4.1.2+incompatible github.com/go-chi/cors v1.2.1 @@ -45,6 +46,7 @@ require ( github.com/stripe/stripe-go/v72 v72.122.0 golang.org/x/crypto v0.15.0 golang.org/x/exp v0.0.0-20230223210539-50820d90acfd + google.golang.org/api v0.134.0 gopkg.in/macaroon.v2 v2.1.0 gopkg.in/square/go-jose.v2 v2.6.0 ) @@ -72,7 +74,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.6 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.18.7 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/btcsuite/btcutil v1.0.2 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect @@ -131,7 +132,6 @@ require ( golang.org/x/sync v0.3.0 // indirect golang.org/x/sys v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect - google.golang.org/api v0.134.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230803162519-f966b187b2e5 // indirect google.golang.org/grpc v1.58.3 // indirect diff --git a/services/grant/cmd/grant.go b/services/grant/cmd/grant.go index 56cc51470..dba2dbbd6 100644 --- a/services/grant/cmd/grant.go +++ b/services/grant/cmd/grant.go @@ -88,6 +88,11 @@ func init() { Bind("disable-zebpay-linking"). Env("DISABLE_ZEBPAY_LINKING") + flagBuilder.Flag().Bool("disable-solana-linking", false, + "disable address linking for solana"). + Bind("disable-solana-linking"). + Env("DISABLE_SOLANA_LINKING") + flagBuilder.Flag().StringSlice("brave-transfer-promotion-ids", []string{""}, "brave vg deposit destination promotion id"). Bind("brave-transfer-promotion-ids"). @@ -350,7 +355,11 @@ func setupRouter(ctx context.Context, logger *zerolog.Logger) (context.Context, // this way we can have the wallet service completely separated from // grants service and easily deployable. ctx, walletService = wallet.SetupService(ctx) - r = wallet.RegisterRoutes(ctx, walletService, r) + dappAO := strings.Split(os.Getenv("DAPP_ALLOWED_CORS_ORIGINS"), ",") + if len(dappAO) == 0 { + logger.Panic().Msg("dapp origin env missing") + } + r = wallet.RegisterRoutes(ctx, walletService, r, middleware.InstrumentHandler, wallet.NewDAppCorsMw(dappAO)) promotionDB, promotionRODB, err := promotion.NewPostgres() if err != nil { @@ -488,7 +497,7 @@ func setupRouter(ctx context.Context, logger *zerolog.Logger) (context.Context, "/", middleware.InstrumentHandler( "CreateOrderNew", - corsMwrPost(handlers.AppHandler(orderh.Create)), + corsMwrPost(handlers.AppHandler(orderh.CreateNew)), ), ) } else { @@ -615,6 +624,7 @@ func GrantServer( ctx = context.WithValue(ctx, appctx.DisableUpholdLinkingCTXKey, viper.GetBool("disable-uphold-linking")) ctx = context.WithValue(ctx, appctx.DisableGeminiLinkingCTXKey, viper.GetBool("disable-gemini-linking")) ctx = context.WithValue(ctx, appctx.DisableBitflyerLinkingCTXKey, viper.GetBool("disable-bitflyer-linking")) + ctx = context.WithValue(ctx, appctx.DisableSolanaLinkingCTXKey, viper.GetBool("disable-solana-linking")) // stripe variables ctx = context.WithValue(ctx, appctx.StripeEnabledCTXKey, viper.GetBool("stripe-enabled")) @@ -710,6 +720,7 @@ func newSrvStatusFromCtx(ctx context.Context) map[string]any { g, _ := ctx.Value(appctx.DisableGeminiLinkingCTXKey).(bool) bf, _ := ctx.Value(appctx.DisableBitflyerLinkingCTXKey).(bool) zp, _ := ctx.Value(appctx.DisableZebPayLinkingCTXKey).(bool) + s, _ := ctx.Value(appctx.DisableSolanaLinkingCTXKey).(bool) result := map[string]interface{}{ "wallet": map[string]bool{ @@ -717,6 +728,7 @@ func newSrvStatusFromCtx(ctx context.Context) map[string]any { "gemini": !g, "bitflyer": !bf, "zebpay": !zp, + "solana": !s, }, } diff --git a/services/grant/cmd/grant_test.go b/services/grant/cmd/grant_test.go index 481840159..714ce8388 100644 --- a/services/grant/cmd/grant_test.go +++ b/services/grant/cmd/grant_test.go @@ -16,6 +16,7 @@ func TestNewSrvStatusFromCtx(t *testing.T) { ctx = context.WithValue(ctx, appctx.DisableGeminiLinkingCTXKey, true) ctx = context.WithValue(ctx, appctx.DisableBitflyerLinkingCTXKey, true) ctx = context.WithValue(ctx, appctx.DisableZebPayLinkingCTXKey, true) + ctx = context.WithValue(ctx, appctx.DisableSolanaLinkingCTXKey, true) act := newSrvStatusFromCtx(ctx) exp := map[string]interface{}{ @@ -24,6 +25,7 @@ func TestNewSrvStatusFromCtx(t *testing.T) { "gemini": false, "bitflyer": false, "zebpay": false, + "solana": false, }, } diff --git a/services/ratios/mapping.go b/services/ratios/mapping.go index d15ff32c7..3318525d5 100644 --- a/services/ratios/mapping.go +++ b/services/ratios/mapping.go @@ -39,6 +39,7 @@ var ( "bnb": "binancecoin", "boa": "bosagora", "bob": "bobs_repair", + "bonk": "bonk", "boo": "boo", "booty": "candybooty", "box": "box-token", diff --git a/services/skus/controllers.go b/services/skus/controllers.go index 612f6c034..a1efb1efd 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -14,7 +14,7 @@ import ( "github.com/go-chi/chi" "github.com/go-chi/cors" uuid "github.com/satori/go.uuid" - stripe "github.com/stripe/stripe-go/v72" + "github.com/stripe/stripe-go/v72" "github.com/stripe/stripe-go/v72/webhook" "github.com/brave-intl/bat-go/libs/clients/radom" @@ -784,7 +784,8 @@ func getOrderCredsByID(svc *Service, legacyMode bool) handlers.AppHandler { reqID = *reqIDRaw.UUID() } - creds, status, err := svc.GetItemCredentials(ctx, *orderID.UUID(), *itemID.UUID(), reqID) + itemIDv := *itemID.UUID() + creds, status, err := svc.GetItemCredentials(ctx, *orderID.UUID(), itemIDv, reqID) if err != nil { if !errors.Is(err, errSetRetryAfter) { return handlers.WrapError(err, "Error getting credentials", status) @@ -799,6 +800,21 @@ func getOrderCredsByID(svc *Service, legacyMode bool) handlers.AppHandler { w.Header().Set("Retry-After", strconv.FormatInt(avg, 10)) } + if legacyMode { + suCreds, ok := creds.([]OrderCreds) + if !ok { + return handlers.WrapError(err, "Error getting credentials", http.StatusInternalServerError) + } + + for i := range suCreds { + if uuid.Equal(suCreds[i].ID, itemIDv) { + return handlers.RenderContent(ctx, suCreds[i], w, status) + } + } + + return handlers.WrapError(err, "Error getting credentials", http.StatusNotFound) + } + if creds == nil { return handlers.RenderContent(ctx, map[string]interface{}{}, w, status) } @@ -1305,91 +1321,95 @@ func HandleStripeWebhook(service *Service) handlers.AppHandler { } } -// SubmitReceipt submit a vendor verifiable receipt that proves order is paid -func SubmitReceipt(service *Service) handlers.AppHandler { +// SubmitReceipt handles receipt submission requests. +func SubmitReceipt(svc *Service) handlers.AppHandler { return handlers.AppHandler(func(w http.ResponseWriter, r *http.Request) *handlers.AppError { + ctx := r.Context() - var ( - ctx = r.Context() - req SubmitReceiptRequestV1 // the body of the request - orderID = new(inputs.ID) // the order id - validationErrMap = map[string]interface{}{} // for tracking our validation errors - ) + l := logging.Logger(ctx, "skus").With().Str("func", "SubmitReceipt").Logger() - logger := logging.Logger(ctx, "skus").With().Str("func", "SubmitReceipt").Logger() + orderID := &inputs.ID{} + if err := inputs.DecodeAndValidateString(ctx, orderID, chi.URLParam(r, "orderID")); err != nil { + l.Warn().Err(err).Msg("failed to decode orderID") - // validate the order id - if err := inputs.DecodeAndValidateString(context.Background(), orderID, chi.URLParam(r, "orderID")); err != nil { - logger.Warn().Err(err).Msg("Failed to decode/validate order id from url") - validationErrMap["orderID"] = err.Error() + return handlers.ValidationError("Error validating request", map[string]interface{}{"orderID": err.Error()}) } - // read the payload - payload, err := requestutils.Read(r.Context(), r.Body) + payload, err := requestutils.Read(ctx, r.Body) if err != nil { - logger.Warn().Err(err).Msg("Failed to read the payload") - validationErrMap["request-body"] = err.Error() + l.Warn().Err(err).Msg("failed to read body") + + return handlers.ValidationError("Error validating request", map[string]interface{}{"request-body": err.Error()}) } - // validate the payload - if err := inputs.DecodeAndValidate(context.Background(), &req, payload); err != nil { - logger.Debug().Str("payload", string(payload)).Msg("Failed to decode and validate the payload") - logger.Warn().Err(err).Msg("Failed to decode and validate the payload") - validationErrMap["request-body"] = err.Error() + // TODO(clD11): remove when no longer needed. + payloadS := string(payload) + l.Info().Interface("payload_byte", payload).Str("payload_str", payloadS).Msg("payload") + + req := SubmitReceiptRequestV1{} + if err := inputs.DecodeAndValidate(ctx, &req, payload); err != nil { + l.Debug().Str("payload", payloadS).Msg("failed to decode payload") + l.Warn().Err(err).Msg("failed to decode payload") + + return handlers.ValidationError("Error validating request", map[string]interface{}{"request-body": err.Error()}) } - // validate the receipt - externalID, err := service.validateReceipt(ctx, orderID.UUID(), req) + // TODO(clD11): remove when no longer needed. + l.Info().Interface("req_decoded", req).Msg("req decoded") + + extID, err := svc.validateReceipt(ctx, req) if err != nil { if errors.Is(err, errNotFound) { return handlers.WrapError(err, "order not found", http.StatusNotFound) } - logger.Warn().Err(err).Msg("Failed to validate the receipt with vendor") - validationErrMap["receiptErrors"] = err.Error() - // return codified errors for application - if errors.Is(err, errPurchaseFailed) { - return handlers.CodedValidationError(err.Error(), purchaseFailedErrCode, validationErrMap) - } else if errors.Is(err, errPurchasePending) { - return handlers.CodedValidationError(err.Error(), purchasePendingErrCode, validationErrMap) - } else if errors.Is(err, errPurchaseDeferred) { - return handlers.CodedValidationError(err.Error(), purchaseDeferredErrCode, validationErrMap) - } else if errors.Is(err, errPurchaseStatusUnknown) { - return handlers.CodedValidationError(err.Error(), purchaseStatusUnknownErrCode, validationErrMap) - } else { - // unknown error - return handlers.CodedValidationError("error validating receipt", purchaseValidationErrCode, validationErrMap) + + l.Warn().Err(err).Msg("failed to validate receipt with vendor") + + errStr := err.Error() + verrs := map[string]interface{}{"receiptErrors": errStr} + + switch { + case errors.Is(err, errPurchaseFailed): + return handlers.CodedValidationError(errStr, purchaseFailedErrCode, verrs) + case errors.Is(err, errPurchasePending): + return handlers.CodedValidationError(errStr, purchasePendingErrCode, verrs) + case errors.Is(err, errPurchaseDeferred): + return handlers.CodedValidationError(errStr, purchaseDeferredErrCode, verrs) + case errors.Is(err, errPurchaseStatusUnknown): + return handlers.CodedValidationError(errStr, purchaseStatusUnknownErrCode, verrs) + default: + return handlers.CodedValidationError("error validating receipt", purchaseValidationErrCode, verrs) } } - // if we had any validation errors, return the validation error map to the caller - if len(validationErrMap) != 0 { - return handlers.ValidationError("error validating request", validationErrMap) - } - // does this external id exist already - exists, err := service.ExternalIDExists(ctx, externalID) - if err != nil { - logger.Warn().Err(err).Msg("failed to lookup external id existance") - return handlers.WrapError(err, "failed to lookup external id", http.StatusInternalServerError) + { + exists, err := svc.ExternalIDExists(ctx, extID) + if err != nil { + l.Warn().Err(err).Msg("failed to lookup external id") + + return handlers.WrapError(err, "failed to lookup external id", http.StatusInternalServerError) + } + + if exists { + return handlers.WrapError(err, "receipt has already been submitted", http.StatusBadRequest) + } } - if exists { - return handlers.WrapError(err, "receipt has already been submitted", http.StatusBadRequest) + vnd := req.Type.String() + mdata := datastore.Metadata{ + "vendor": vnd, + "externalID": extID, + paymentProcessor: vnd, } - // set order paid and include the vendor and external id to metadata - if err := service.UpdateOrderStatusPaidWithMetadata(ctx, orderID.UUID(), datastore.Metadata{ - "vendor": req.Type.String(), - "externalID": externalID, - paymentProcessor: req.Type.String(), - }); err != nil { - logger.Warn().Err(err).Msg("Failed to update the order with appropriate metadata") + if err := svc.UpdateOrderStatusPaidWithMetadata(ctx, orderID.UUID(), mdata); err != nil { + l.Warn().Err(err).Msg("failed to update order with vendor metadata") return handlers.WrapError(err, "failed to store status of order", http.StatusInternalServerError) } - return handlers.RenderContent(r.Context(), SubmitReceiptResponseV1{ - ExternalID: externalID, - Vendor: req.Type.String(), - }, w, http.StatusOK) + result := SubmitReceiptResponseV1{ExternalID: extID, Vendor: vnd} + + return handlers.RenderContent(ctx, result, w, http.StatusOK) }) } diff --git a/services/skus/controllers_test.go b/services/skus/controllers_test.go index 1cb5fa7a3..eb37ebfaa 100644 --- a/services/skus/controllers_test.go +++ b/services/skus/controllers_test.go @@ -373,10 +373,8 @@ func (suite *ControllersTestSuite) TestAndroidWebhook() { err := suite.storage.AppendOrderMetadata(context.Background(), &order.ID, "externalID", "my external id") suite.Require().NoError(err) - // overwrite the receipt validation function for this test - receiptValidationFns = map[Vendor]func(context.Context, interface{}) (string, error){ - appleVendor: validateIOSReceipt, - googleVendor: func(ctx context.Context, v interface{}) (string, error) { + suite.service.vendorReceiptValid = &mockVendorReceiptValidator{ + fnValidateGoogle: func(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) { return "my external id", nil }, } @@ -1890,3 +1888,24 @@ func newCORSOptsEnv() cors.Options { return NewCORSOpts(origins, dbg) } + +type mockVendorReceiptValidator struct { + fnValidateApple func(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) + fnValidateGoogle func(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) +} + +func (v *mockVendorReceiptValidator) validateApple(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) { + if v.fnValidateApple == nil { + return "apple_defaul", nil + } + + return v.fnValidateApple(ctx, receipt) +} + +func (v *mockVendorReceiptValidator) validateGoogle(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) { + if v.fnValidateGoogle == nil { + return "google_default", nil + } + + return v.fnValidateGoogle(ctx, receipt) +} diff --git a/services/skus/input.go b/services/skus/input.go index 7ade6c47a..de6615d6c 100644 --- a/services/skus/input.go +++ b/services/skus/input.go @@ -130,7 +130,7 @@ func credentialOpaqueFromString(s string) (*VerifyCredentialOpaque, error) { const ( appleVendor Vendor = "ios" - googleVendor = "android" + googleVendor Vendor = "android" ) var errInvalidVendor = errors.New("invalid vendor") diff --git a/services/skus/receipt.go b/services/skus/receipt.go index 127c02ab0..89a9e2f0b 100644 --- a/services/skus/receipt.go +++ b/services/skus/receipt.go @@ -15,8 +15,12 @@ import ( "github.com/awa/go-iap/appstore" "github.com/awa/go-iap/playstore" + "google.golang.org/api/androidpublisher/v3" + appctx "github.com/brave-intl/bat-go/libs/context" "github.com/brave-intl/bat-go/libs/logging" + + "github.com/brave-intl/bat-go/services/skus/model" ) const ( @@ -24,22 +28,26 @@ const ( androidPaymentStatePaid androidPaymentStateTrial androidPaymentStatePendingDeferred +) +const ( androidCancelReasonUser int64 = 0 androidCancelReasonSystem int64 = 1 androidCancelReasonReplaced int64 = 2 androidCancelReasonDeveloper int64 = 3 + + purchasePendingErrCode = "purchase_pending" + purchaseDeferredErrCode = "purchase_deferred" + purchaseStatusUnknownErrCode = "purchase_status_unknown" + purchaseFailedErrCode = "purchase_failed" + purchaseValidationErrCode = "validation_failed" ) -var ( - receiptValidationFns = map[Vendor]func(context.Context, interface{}) (string, error){ - appleVendor: validateIOSReceipt, - googleVendor: validateAndroidReceipt, - } - iosClient *appstore.Client - androidClient *playstore.Client - errClientMisconfigured = errors.New("misconfigured client") +const ( + errNoInAppTx model.Error = "no in app info in response" +) +var ( errPurchaseUserCanceled = errors.New("purchase is canceled by user") errPurchaseSystemCanceled = errors.New("purchase is canceled by google playstore") errPurchaseReplacedCanceled = errors.New("purchase is canceled and replaced") @@ -51,12 +59,6 @@ var ( errPurchaseFailed = errors.New("purchase failed") errPurchaseExpired = errors.New("purchase expired") - - purchasePendingErrCode = "purchase_pending" - purchaseDeferredErrCode = "purchase_deferred" - purchaseStatusUnknownErrCode = "purchase_status_unknown" - purchaseFailedErrCode = "purchase_failed" - purchaseValidationErrCode = "validation_failed" ) type dumpTransport struct{} @@ -81,109 +83,118 @@ func (dt *dumpTransport) RoundTrip(r *http.Request) (*http.Response, error) { return resp, rtErr } -func initClients(ctx context.Context) { +type appStoreVerifier interface { + Verify(ctx context.Context, req appstore.IAPRequest, result interface{}) error +} - var logClient = &http.Client{ - Transport: &dumpTransport{}, - } +type playStoreVerifier interface { + VerifySubscription(ctx context.Context, pkgName, subID, token string) (*androidpublisher.SubscriptionPurchase, error) +} - logger := logging.Logger(ctx, "skus").With().Str("func", "initClients").Logger() - iosClient = appstore.New() +type receiptVerifier struct { + appStoreCl appStoreVerifier + playStoreCl playStoreVerifier +} - if jsonKey, ok := ctx.Value(appctx.PlaystoreJSONKeyCTXKey).([]byte); ok { - var err error - androidClient, err = playstore.NewWithClient(jsonKey, logClient) +func newReceiptVerifier(cl *http.Client, playKey []byte) (*receiptVerifier, error) { + result := &receiptVerifier{ + appStoreCl: appstore.NewWithClient(cl), + } + + if playKey != nil && len(playKey) != 0 { + gpCl, err := playstore.NewWithClient(playKey, cl) if err != nil { - logger.Error().Err(err).Msg("failed to initialize android client") + return nil, err } + + result.playStoreCl = gpCl } + + return result, nil } -// validateIOSReceipt - validate apple receipt with their apis -func validateIOSReceipt(ctx context.Context, receipt interface{}) (string, error) { - logger := logging.Logger(ctx, "skus").With().Str("func", "validateIOSReceipt").Logger() +// validateApple validates Apple App Store receipt. +func (v *receiptVerifier) validateApple(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) { + l := logging.Logger(ctx, "skus").With().Str("func", "validateReceiptApple").Logger() - // get the shared key from the context sharedKey, sharedKeyOK := ctx.Value(appctx.AppleReceiptSharedKeyCTXKey).(string) - if iosClient != nil { - // handle v1 receipt type - if v, ok := receipt.(SubmitReceiptRequestV1); ok { - req := appstore.IAPRequest{ - ReceiptData: v.Blob, - ExcludeOldTransactions: true, - } - if sharedKeyOK && len(sharedKey) > 0 { - req.Password = sharedKey - } - resp := &appstore.IAPResponse{} - if err := iosClient.Verify(ctx, req, resp); err != nil { - logger.Error().Err(err).Msg("failed to verify receipt") - return "", fmt.Errorf("failed to verify receipt: %w", err) - } - logger.Debug().Msg(fmt.Sprintf("%+v", resp)) - // get the transaction id back - if len(resp.Receipt.InApp) < 1 { - logger.Error().Msg("failed to verify receipt, no in app info") - return "", fmt.Errorf("failed to verify receipt, no in app info in response") - } - return resp.Receipt.InApp[0].OriginalTransactionID, nil - } + req := appstore.IAPRequest{ + ReceiptData: receipt.Blob, + ExcludeOldTransactions: true, } - logger.Error().Msg("client is not configured") - return "", errClientMisconfigured + + if sharedKeyOK && len(sharedKey) > 0 { + req.Password = sharedKey + } + + resp := &appstore.IAPResponse{} + if err := v.appStoreCl.Verify(ctx, req, resp); err != nil { + l.Error().Err(err).Msg("failed to verify receipt") + + return "", fmt.Errorf("failed to verify receipt: %w", err) + } + + l.Debug().Msg(fmt.Sprintf("%+v", resp)) + + if len(resp.Receipt.InApp) < 1 { + l.Error().Msg("failed to verify receipt: no in app info") + return "", errNoInAppTx + } + + return resp.Receipt.InApp[0].OriginalTransactionID, nil } -// validateAndroidReceipt - validate android receipt with their apis -func validateAndroidReceipt(ctx context.Context, receipt interface{}) (string, error) { - logger := logging.Logger(ctx, "skus").With().Str("func", "validateAndroidReceipt").Logger() - if androidClient != nil { - if v, ok := receipt.(SubmitReceiptRequestV1); ok { - logger.Debug().Str("receipt", fmt.Sprintf("%+v", v)).Msg("about to verify subscription") - // handle v1 receipt type - resp, err := androidClient.VerifySubscription(ctx, v.Package, v.SubscriptionID, v.Blob) - if err != nil { - logger.Error().Err(err).Msg("failed to verify subscription") - return "", errPurchaseFailed - } - - // is order expired? - if time.Unix(0, resp.ExpiryTimeMillis*int64(time.Millisecond)).Before(time.Now()) { - return "", errPurchaseExpired - } - - logger.Debug().Msgf("resp: %+v", resp) - if resp.PaymentState != nil { - // check that the order was paid - switch *resp.PaymentState { - case androidPaymentStatePaid, androidPaymentStateTrial: - break - case androidPaymentStatePending: - // is there a cancel reason? - switch resp.CancelReason { - case androidCancelReasonUser: - return "", errPurchaseUserCanceled - case androidCancelReasonSystem: - return "", errPurchaseSystemCanceled - case androidCancelReasonReplaced: - return "", errPurchaseReplacedCanceled - case androidCancelReasonDeveloper: - return "", errPurchaseDeveloperCanceled - } - return "", errPurchasePending - case androidPaymentStatePendingDeferred: - return "", errPurchaseDeferred - default: - return "", errPurchaseStatusUnknown - } - return v.Blob, nil - } - logger.Error().Err(err).Msg("failed to verify subscription: no payment state") - return "", errPurchaseFailed +// validateGoogle validates Google Store receipt. +func (v *receiptVerifier) validateGoogle(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) { + l := logging.Logger(ctx, "skus").With().Str("func", "validateReceiptGoogle").Logger() + + l.Debug().Str("receipt", fmt.Sprintf("%+v", receipt)).Msg("about to verify subscription") + + resp, err := v.playStoreCl.VerifySubscription(ctx, receipt.Package, receipt.SubscriptionID, receipt.Blob) + if err != nil { + l.Error().Err(err).Msg("failed to verify subscription") + return "", errPurchaseFailed + } + + // Check order expiration. + // There seems to be a mistake here? + // Unix expects nanoseconds as the second param, but ms are passed. + if time.Unix(0, resp.ExpiryTimeMillis*int64(time.Millisecond)).Before(time.Now()) { + return "", errPurchaseExpired + } + + l.Debug().Msgf("resp: %+v", resp) + + if resp.PaymentState == nil { + l.Error().Err(err).Msg("failed to verify subscription: no payment state") + return "", errPurchaseFailed + } + + // Check that the order was paid. + switch *resp.PaymentState { + case androidPaymentStatePaid, androidPaymentStateTrial: + return receipt.Blob, nil + + case androidPaymentStatePending: + // Checl for cancel reason. + switch resp.CancelReason { + case androidCancelReasonUser: + return "", errPurchaseUserCanceled + case androidCancelReasonSystem: + return "", errPurchaseSystemCanceled + case androidCancelReasonReplaced: + return "", errPurchaseReplacedCanceled + case androidCancelReasonDeveloper: + return "", errPurchaseDeveloperCanceled } + return "", errPurchasePending + + case androidPaymentStatePendingDeferred: + return "", errPurchaseDeferred + default: + return "", errPurchaseStatusUnknown } - logger.Error().Msg("client is not configured") - return "", errClientMisconfigured } // get the public key from the jws header diff --git a/services/skus/service.go b/services/skus/service.go index 9625cc459..a93607bbc 100644 --- a/services/skus/service.go +++ b/services/skus/service.go @@ -91,6 +91,11 @@ type orderStoreSvc interface { Get(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) } +type vendorReceiptValidator interface { + validateApple(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) + validateGoogle(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) +} + // Service contains datastore type Service struct { orderRepo orderStoreSvc @@ -113,6 +118,8 @@ type Service struct { retry backoff.RetryFunc radomClient *radom.InstrumentedClient radomSellerAddress string + + vendorReceiptValid vendorReceiptValidator } // PauseWorker - pause worker until time specified @@ -159,8 +166,6 @@ func (s *Service) InitKafka(ctx context.Context) error { // InitService creates a service using the passed datastore and clients configured from the environment. func InitService(ctx context.Context, datastore Datastore, walletService *wallet.Service, orderRepo orderStoreSvc, issuerRepo issuerStore) (*Service, error) { sublogger := logging.Logger(ctx, "payments").With().Str("func", "InitService").Logger() - // setup the in app purchase clients - initClients(ctx) // setup stripe if exists in context and enabled scClient := &client.API{} @@ -233,6 +238,17 @@ func InitService(ctx context.Context, datastore Datastore, walletService *wallet } } + cl := &http.Client{ + Transport: &dumpTransport{}, + Timeout: 30 * time.Second, + } + + playKey, _ := ctx.Value(appctx.PlaystoreJSONKeyCTXKey).([]byte) + rcptValidator, err := newReceiptVerifier(cl, playKey) + if err != nil { + return nil, err + } + service := &Service{ orderRepo: orderRepo, issuerRepo: issuerRepo, @@ -247,6 +263,7 @@ func InitService(ctx context.Context, datastore Datastore, walletService *wallet retry: backoff.Retry, radomClient: radomClient, radomSellerAddress: radomSellerAddress, + vendorReceiptValid: rcptValidator, } // setup runnable jobs @@ -1559,13 +1576,12 @@ func (s *Service) verifyDeveloperNotification(ctx context.Context, dn *Developer } // have order, now validate the receipt from the notification - _, err = s.validateReceipt(ctx, &o.ID, SubmitReceiptRequestV1{ + if _, err := s.vendorReceiptValid.validateGoogle(ctx, SubmitReceiptRequestV1{ Type: "android", Blob: dn.SubscriptionNotification.PurchaseToken, Package: dn.PackageName, SubscriptionID: dn.SubscriptionNotification.SubscriptionID, - }) - if err != nil { + }); err != nil { return fmt.Errorf("failed to validate purchase token: %w", err) } @@ -1597,17 +1613,16 @@ func (s *Service) verifyDeveloperNotification(ctx context.Context, dn *Developer return nil } -// validateReceipt - perform receipt validation -func (s *Service) validateReceipt(ctx context.Context, orderID *uuid.UUID, receipt interface{}) (string, error) { - // based on the vendor call the vendor specific apis to check the status of the receipt, - if v, ok := receipt.(SubmitReceiptRequestV1); ok { - // and get back the external id - if fn, ok := receiptValidationFns[v.Type]; ok { - return fn(ctx, receipt) - } +// validateReceipt validates receipt. +func (s *Service) validateReceipt(ctx context.Context, receipt SubmitReceiptRequestV1) (string, error) { + switch receipt.Type { + case appleVendor: + return s.vendorReceiptValid.validateApple(ctx, receipt) + case googleVendor: + return s.vendorReceiptValid.validateGoogle(ctx, receipt) + default: + return "", errorutils.ErrNotImplemented } - - return "", errorutils.ErrNotImplemented } // UpdateOrderStatusPaidWithMetadata - update the order status with metadata diff --git a/services/wallet/cmd/rest_run.go b/services/wallet/cmd/rest_run.go index 940fc85bd..900694e7a 100644 --- a/services/wallet/cmd/rest_run.go +++ b/services/wallet/cmd/rest_run.go @@ -2,16 +2,18 @@ package cmd import ( "net/http" - "time" - // pprof imports _ "net/http/pprof" + "os" + "strings" + "time" cmdutils "github.com/brave-intl/bat-go/cmd" appctx "github.com/brave-intl/bat-go/libs/context" + "github.com/brave-intl/bat-go/libs/middleware" "github.com/brave-intl/bat-go/services/cmd" "github.com/brave-intl/bat-go/services/wallet" - sentry "github.com/getsentry/sentry-go" + "github.com/getsentry/sentry-go" "github.com/go-chi/chi" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -22,12 +24,19 @@ import ( // wallets rest microservice. func WalletRestRun(command *cobra.Command, args []string) { ctx, service := wallet.SetupService(command.Context()) - router := cmd.SetupRouter(ctx) - wallet.RegisterRoutes(ctx, service, router) logger, err := appctx.GetLogger(ctx) cmdutils.Must(err) + router := cmd.SetupRouter(ctx) + + dappAO := strings.Split(os.Getenv("DAPP_ALLOWED_CORS_ORIGINS"), ",") + if len(dappAO) == 0 { + logger.Panic().Msg("dapp origin env missing") + } + + wallet.RegisterRoutes(ctx, service, router, middleware.InstrumentHandler, wallet.NewDAppCorsMw(dappAO)) + // add profiling flag to enable profiling routes if viper.GetString("pprof-enabled") != "" { // pprof attaches routes to default serve mux diff --git a/services/wallet/controllers_v3.go b/services/wallet/controllers_v3.go index a62da46d8..8bbae9a5d 100644 --- a/services/wallet/controllers_v3.go +++ b/services/wallet/controllers_v3.go @@ -5,10 +5,13 @@ import ( "crypto/ed25519" "database/sql" "encoding/hex" + "encoding/json" "errors" + "io" "net/http" "strings" + "github.com/asaskevich/govalidator" "github.com/brave-intl/bat-go/libs/altcurrency" appctx "github.com/brave-intl/bat-go/libs/context" errorutils "github.com/brave-intl/bat-go/libs/errors" @@ -19,10 +22,15 @@ import ( "github.com/brave-intl/bat-go/libs/middleware" walletutils "github.com/brave-intl/bat-go/libs/wallet" "github.com/brave-intl/bat-go/libs/wallet/provider/uphold" + "github.com/brave-intl/bat-go/services/wallet/model" "github.com/go-chi/chi" uuid "github.com/satori/go.uuid" ) +const ( + reqBodyLimit10MB = 10 << 20 +) + // LinkDepositAccountResponse is the response returned by the linking endpoints. type LinkDepositAccountResponse struct { GeoCountry string `json:"geoCountry"` @@ -441,49 +449,145 @@ func LinkUpholdDepositAccountV3(s *Service) func(w http.ResponseWriter, r *http. } } -// GetWalletV3 - produces an http handler for the service s which handles getting of brave wallets +const errOriginForbidden model.Error = "request origin forbidden" + +type linkSolanaAddrRequest struct { + SolanaPublicKey string `json:"solanaPublicKey" valid:"length(32|44)"` + Message string `json:"message" valid:"required"` + SolanaSignature string `json:"solanaSignature" valid:"required"` +} + +func LinkSolanaAddress(s *Service) handlers.AppHandler { + return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { + ctx := r.Context() + + if dis, ok := ctx.Value(appctx.DisableSolanaLinkingCTXKey).(bool); ok && dis { + return handlers.ValidationError("Connecting Brave Rewards to Solana is temporarily unavailable. Please try again later", nil) + } + + l := logging.Logger(ctx, "wallet") + + o := r.Header.Get("Origin") + if !isAllowedOrigin(o, s.dappConf.AllowedOrigins) { + l.Error().Err(errOriginForbidden).Str("origin", strOr(o, "empty")).Msg("error linking solana address") + return handlers.WrapError(errOriginForbidden, "request origin forbidden", http.StatusForbidden) + } + + var paymentID inputs.ID + if err := inputs.DecodeAndValidateString(ctx, &paymentID, chi.URLParam(r, "paymentID")); err != nil { + return handlers.WrapError(err, "invalid paymentID", http.StatusBadRequest) + } + + b, err := io.ReadAll(io.LimitReader(r.Body, reqBodyLimit10MB)) + if err != nil { + return handlers.WrapError(err, "error reading body", http.StatusBadRequest) + } + + var solReq linkSolanaAddrRequest + if err := json.Unmarshal(b, &solReq); err != nil { + return handlers.WrapError(err, "error decoding body", http.StatusBadRequest) + } + + if _, err := govalidator.ValidateStruct(solReq); err != nil { + return handlers.WrapValidationError(err) + } + + if err := s.LinkSolanaAddress(ctx, *paymentID.UUID(), solReq); err != nil { + l.Error().Err(err).Msg("error linking solana address") + + var solErr *walletutils.LinkSolanaAddressError + switch { + case errors.Is(err, model.ErrWalletNotWhitelisted): + return handlers.WrapError(model.ErrWalletNotWhitelisted, "rewards wallet not whitelisted", http.StatusForbidden) + case errors.Is(err, model.ErrChallengeNotFound): + return handlers.WrapError(model.ErrChallengeNotFound, "linking challenge not found", http.StatusNotFound) + case errors.Is(err, model.ErrChallengeExpired): + return handlers.WrapError(model.ErrChallengeExpired, "linking challenge expired", http.StatusUnauthorized) + case errors.Is(err, model.ErrWalletNotFound): + return handlers.WrapError(model.ErrWalletNotFound, "rewards wallet not found", http.StatusNotFound) + case errors.As(err, &solErr): + return handlers.WrapError(solErr, "invalid solana linking message", http.StatusUnauthorized) + default: + return handlers.WrapError(model.ErrInternalServer, "internal server error", http.StatusInternalServerError) + } + } + + return handlers.RenderContent(ctx, nil, w, http.StatusOK) + } +} + +type challengeRequest struct { + PaymentID uuid.UUID `json:"paymentId" valid:"required"` +} + +type challengeResponse struct { + Nonce string `json:"challengeId"` +} + +func CreateChallenge(s *Service) handlers.AppHandler { + return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { + b, err := io.ReadAll(io.LimitReader(r.Body, reqBodyLimit10MB)) + if err != nil { + return handlers.WrapError(err, "error reading body", http.StatusBadRequest) + } + + var chlReq challengeRequest + if err := json.Unmarshal(b, &chlReq); err != nil { + return handlers.WrapError(err, "error decoding body", http.StatusBadRequest) + } + + if _, err := govalidator.ValidateStruct(chlReq); err != nil { + return handlers.WrapValidationError(err) + } + + ctx := r.Context() + + chl, err := s.CreateChallenge(ctx, chlReq.PaymentID) + if err != nil { + logging.Logger(ctx, "wallet").Error().Err(err).Msg("error creating challenge") + return handlers.WrapError(model.ErrInternalServer, "error creating challenge", http.StatusInternalServerError) + } + + resp := challengeResponse{ + Nonce: chl.Nonce, + } + + return handlers.RenderContent(ctx, resp, w, http.StatusCreated) + } +} + +// GetWalletV3 returns a rewards wallet for the given paymentID. func GetWalletV3(w http.ResponseWriter, r *http.Request) *handlers.AppError { var ctx = r.Context() - // get logger from context - logger := logging.Logger(ctx, "wallet.GetWalletV3") + + l := logging.Logger(ctx, "wallet.GetWalletV3") var id = new(inputs.ID) if err := inputs.DecodeAndValidateString(ctx, id, chi.URLParam(r, "paymentID")); err != nil { - logger.Warn().Str("paymentId", err.Error()).Msg("failed to decode and validate paymentID from url") - return handlers.ValidationError( - "Error validating paymentID url parameter", - map[string]interface{}{ - "paymentId": err.Error(), - }, - ) + l.Warn().Err(err).Str("paymentID", id.String()).Msg("failed to decode and validate paymentID from url") + return handlers.ValidationError("Error validating paymentID url parameter", map[string]interface{}{ + "paymentId": err.Error(), + }) } - var ( - roDB ReadOnlyDatastore - ok bool - ) - - // get datastore from context - if roDB, ok = ctx.Value(appctx.RODatastoreCTXKey).(ReadOnlyDatastore); !ok { - logger.Error().Msg("unable to get read only datastore from context") + // TODO(clD11): this should be removed from ctx as part of wallet refactor. Note, the service would have already + // panicked at startup if the db is missing. However, as a precaution we should stop processing. + roDB, ok := ctx.Value(appctx.RODatastoreCTXKey).(ReadOnlyDatastore) + if !ok { + return handlers.WrapError(errorutils.ErrInternalServerError, "db missing from context", http.StatusInternalServerError) } - // get wallet from datastore info, err := roDB.GetWallet(ctx, *id.UUID()) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - logger.Info().Err(err).Str("id", id.String()).Msg("wallet not found") - return handlers.WrapError(err, "no such wallet", http.StatusNotFound) - } - logger.Warn().Err(err).Str("id", id.String()).Msg("unable to get wallet") + l.Error().Err(err).Str("paymentID", id.String()).Msg("error getting wallet") return handlers.WrapError(err, "error getting wallet from storage", http.StatusInternalServerError) } + if info == nil { - logger.Info().Err(err).Str("id", id.String()).Msg("wallet not found") + l.Info().Str("paymentID", id.String()).Msg("wallet not found") return handlers.WrapError(err, "no such wallet", http.StatusNotFound) } - // render the wallet return handlers.RenderContent(ctx, infoToResponseV3(info), w, http.StatusOK) } @@ -713,3 +817,24 @@ func DisconnectCustodianLinkV3(s *Service) func(w http.ResponseWriter, r *http.R return handlers.RenderContent(ctx, map[string]interface{}{}, w, http.StatusOK) } } + +func isAllowedOrigin(origin string, allowedOrigins []string) bool { + if origin == "" { + return false + } + + for i := range allowedOrigins { + if allowedOrigins[i] == origin { + return true + } + } + + return false +} + +func strOr(a string, b string) string { + if a == "" { + return b + } + return a +} diff --git a/services/wallet/controllers_v3_pvt_test.go b/services/wallet/controllers_v3_pvt_test.go new file mode 100644 index 000000000..952e80ffb --- /dev/null +++ b/services/wallet/controllers_v3_pvt_test.go @@ -0,0 +1,1105 @@ +package wallet + +import ( + "bytes" + "context" + "crypto" + "crypto/ed25519" + "crypto/sha256" + "database/sql" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/brave-intl/bat-go/libs/clients/gemini" + + mockgemini "github.com/brave-intl/bat-go/libs/clients/gemini/mock" + mockreputation "github.com/brave-intl/bat-go/libs/clients/reputation/mock" + appctx "github.com/brave-intl/bat-go/libs/context" + "github.com/brave-intl/bat-go/libs/custodian" + datastoreutils "github.com/brave-intl/bat-go/libs/datastore" + "github.com/brave-intl/bat-go/libs/handlers" + "github.com/brave-intl/bat-go/libs/httpsignature" + "github.com/brave-intl/bat-go/libs/logging" + "github.com/brave-intl/bat-go/libs/middleware" + "github.com/go-chi/chi" + "github.com/golang/mock/gomock" + "github.com/jmoiron/sqlx" + uuid "github.com/satori/go.uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" +) + +func TestCreateBraveWalletV3(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + var ( + db, mock, _ = sqlmock.New() + datastore = Datastore( + &Postgres{ + Postgres: datastoreutils.Postgres{ + DB: sqlx.NewDb(db, "postgres"), + }, + }) + // add the datastore to the context + ctx = context.Background() + handler = CreateBraveWalletV3 + r = httptest.NewRequest("POST", "/v3/wallet/brave", nil) + ) + // no logger, setup + ctx, _ = logging.SetupLogger(ctx) + + // setup sqlmock + mock.ExpectExec("^INSERT INTO wallets (.+)").WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(result{}) + + ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) + ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") + + // setup keypair + publicKey, privKey, err := httpsignature.GenerateEd25519Key(nil) + require.NoError(t, err) + + err = signRequest(r, publicKey, privKey) + require.NoError(t, err) + + r = r.WithContext(ctx) + + var rw = httptest.NewRecorder() + handlers.AppHandler(handler).ServeHTTP(rw, r) + + b := rw.Body.Bytes() + require.Equal(t, http.StatusCreated, rw.Code, string(b)) +} + +func TestCreateUpholdWalletV3(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + var ( + db, mock, _ = sqlmock.New() + datastore = Datastore( + &Postgres{ + Postgres: datastoreutils.Postgres{ + DB: sqlx.NewDb(db, "postgres"), + }, + }) + // add the datastore to the context + ctx = context.Background() + handler = CreateUpholdWalletV3 + r = httptest.NewRequest("POST", "/v3/wallet/uphold", bytes.NewBufferString(`{ + "signedCreationRequest": "eyJib2R5Ijp7ImRlbm9taW5hdGlvbiI6eyJhbW91bnQiOiIwIiwiY3VycmVuY3kiOiJCQVQifSwiZGVzdGluYXRpb24iOiJhNmRmZjJiYS1kMGQxLTQxYzQtOGU1Ni1hMjYwNWJjYWY0YWYifSwiaGVhZGVycyI6eyJkaWdlc3QiOiJTSEEtMjU2PWR2RTAzVHdpRmFSR0c0MUxLSkR4aUk2a3c5M0h0cTNsclB3VllldE5VY1E9Iiwic2lnbmF0dXJlIjoia2V5SWQ9XCJwcmltYXJ5XCIsYWxnb3JpdGhtPVwiZWQyNTUxOVwiLGhlYWRlcnM9XCJkaWdlc3RcIixzaWduYXR1cmU9XCJkcXBQdERESXE0djNiS1V5eHB6Q3Vyd01nSzRmTWk1MUJjakRLc2pTak90K1h1MElZZlBTMWxEZ01aRkhiaWJqcGh0MVd3V3l5enFad3lVNW0yN1FDUT09XCIifSwib2N0ZXRzIjoie1wiZGVub21pbmF0aW9uXCI6e1wiYW1vdW50XCI6XCIwXCIsXCJjdXJyZW5jeVwiOlwiQkFUXCJ9LFwiZGVzdGluYXRpb25cIjpcImE2ZGZmMmJhLWQwZDEtNDFjNC04ZTU2LWEyNjA1YmNhZjRhZlwifSJ9"}`)) + ) + // no logger, setup + ctx, _ = logging.SetupLogger(ctx) + + // setup sqlmock + mock.ExpectExec("^INSERT INTO wallets (.+)").WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(result{}) + + ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) + ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") + + r = r.WithContext(ctx) + + var rw = httptest.NewRecorder() + handlers.AppHandler(handler).ServeHTTP(rw, r) + + b := rw.Body.Bytes() + require.Equal(t, http.StatusBadRequest, rw.Code, string(b)) +} + +func TestGetWalletV3(t *testing.T) { + var ( + db, mock, _ = sqlmock.New() + datastore = Datastore( + &Postgres{ + Postgres: datastoreutils.Postgres{ + DB: sqlx.NewDb(db, "postgres"), + }, + }) + roDatastore = ReadOnlyDatastore( + &Postgres{ + Postgres: datastoreutils.Postgres{ + DB: sqlx.NewDb(db, "postgres"), + }, + }) + // add the datastore to the context + ctx = context.Background() + id = uuid.NewV4() + r = httptest.NewRequest("GET", fmt.Sprintf("/v3/wallet/%s", id), nil) + handler = GetWalletV3 + rw = httptest.NewRecorder() + rows = sqlmock.NewRows([]string{"id", "provider", "provider_id", "public_key", "provider_linking_id", "anonymous_address"}). + AddRow(id, "brave", "", "12345", id, id) + ) + + mock.ExpectQuery("^select (.+)").WithArgs(id).WillReturnRows(rows) + + ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) + ctx = context.WithValue(ctx, appctx.RODatastoreCTXKey, roDatastore) + ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") + + r = r.WithContext(ctx) + + router := chi.NewRouter() + router.Get("/v3/wallet/{paymentID}", handlers.AppHandler(handler).ServeHTTP) + router.ServeHTTP(rw, r) + + b := rw.Body.Bytes() + require.Equal(t, http.StatusOK, rw.Code, string(b)) +} + +func TestLinkBitFlyerWalletV3(t *testing.T) { + VerifiedWalletEnable = true + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + // setup jwt token for the test + var secret = []byte("a jwt secret") + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: secret}, (&jose.SignerOptions{}).WithType("JWT")) + if err != nil { + panic(err) + } + + var ( + idFrom = uuid.NewV4() + idTo = uuid.NewV4() + accountHash = uuid.NewV4() + timestamp = time.Now() + ) + + h := sha256.New() + if _, err := h.Write([]byte(idFrom.String())); err != nil { + panic(err) + } + + externalAccountID := hex.EncodeToString(h.Sum(nil)) + + linkingInfo := BitFlyerLinkingInfo{ + DepositID: idTo.String(), + RequestID: "1", + AccountHash: accountHash.String(), + ExternalAccountID: externalAccountID, + Timestamp: timestamp, + } + + tokenString, err := jwt.Signed(sig).Claims(linkingInfo).CompactSerialize() + if err != nil { + panic(err) + } + + var ( + // add the datastore to the context + ctx = middleware.AddKeyID(context.WithValue(context.Background(), appctx.BitFlyerJWTKeyCTXKey, []byte(secret)), idFrom.String()) + r = httptest.NewRequest( + "POST", + fmt.Sprintf("/v3/wallet/bitflyer/%s/claim", idFrom), + bytes.NewBufferString(fmt.Sprintf(` + { + "linkingInfo": "%s" + }`, tokenString)), + ) + mockReputation = mockreputation.NewMockClient(mockCtrl) + s, mock = initSvcWithMockDB(t) + handler = LinkBitFlyerDepositAccountV3(s) + rw = httptest.NewRecorder() + ) + + mock.ExpectExec("^insert (.+)").WithArgs("1").WillReturnResult(sqlmock.NewResult(1, 1)) + + // begin linking tx + mock.ExpectBegin() + + // make sure old linking id matches new one for same custodian + linkingID := uuid.NewV5(ClaimNamespace, accountHash.String()) + + // acquire lock for linkingID + mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // this wallet has been linked prior, with the same linking id that the request is with + // SHOULD SKIP THE linking limit checks + var linkingIDRows = sqlmock.NewRows([]string{"linking_id"}).AddRow(linkingID) + mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "bitflyer").WillReturnRows(linkingIDRows) + + mockSQLCustodianLink(mock, "bitflyer") + + // updates the link to the wallet_custodian record in wallets + mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + clRows := sqlmock.NewRows([]string{"created_at", "linked_at"}). + AddRow(time.Now(), time.Now()) + + // insert into wallet custodian + mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "bitflyer", uuid.NewV5(ClaimNamespace, accountHash.String())).WillReturnRows(clRows) + + // updates the link to the wallet_custodian record in wallets + mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "bitflyer", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + mock.ExpectExec("^insert into (.+)").WithArgs(idFrom, true).WillReturnResult(sqlmock.NewResult(1, 1)) + + // commit transaction + mock.ExpectCommit() + + ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, s.Datastore) + ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputation) + ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") + + mockReputation.EXPECT().IsLinkingReputable( + gomock.Any(), // ctx + gomock.Any(), // wallet id + gomock.Any(), // country + ).Return( + true, + []int{}, + nil, + ) + + r = r.WithContext(ctx) + + router := chi.NewRouter() + router.Post("/v3/wallet/bitflyer/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) + router.ServeHTTP(rw, r) + + b := rw.Body.Bytes() + require.Equal(t, http.StatusOK, rw.Code, string(b)) + + var l LinkDepositAccountResponse + err = json.Unmarshal(b, &l) + require.NoError(t, err) + + assert.Equal(t, "JP", l.GeoCountry) +} + +func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { + VerifiedWalletEnable = true + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + var ( + // setup test variables + idFrom = uuid.NewV4() + ctx = middleware.AddKeyID(context.Background(), idFrom.String()) + accountID = uuid.NewV4() + idTo = accountID + + s, mock = initSvcWithMockDB(t) + + linkingInfo = "this is the fake jwt for linking_info" + + // setup mock clients + mockReputationClient = mockreputation.NewMockClient(mockCtrl) + mockGeminiClient = mockgemini.NewMockClient(mockCtrl) + + // this is our main request + r = httptest.NewRequest( + "POST", + fmt.Sprintf("/v3/wallet/gemini/%s/claim", idFrom), + bytes.NewBufferString(fmt.Sprintf(` + { + "linking_info": "%s", + "recipient_id": "%s" + }`, linkingInfo, idTo)), + ) + + handler = LinkGeminiDepositAccountV3(s) + rw = httptest.NewRecorder() + ) + + mockReputationClient.EXPECT().IsLinkingReputable( + gomock.Any(), // ctx + gomock.Any(), // wallet id + gomock.Any(), // country + ).Return( + true, + []int{}, + nil, + ) + + ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputationClient) + ctx = context.WithValue(ctx, appctx.GeminiClientCTXKey, mockGeminiClient) + ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") + // turn on region check + ctx = context.WithValue(ctx, appctx.UseCustodianRegionsCTXKey, true) + // configure allow region + custodianRegions := custodian.Regions{ + Gemini: custodian.GeoAllowBlockMap{ + Allow: []string{"US"}, + }, + } + ctx = context.WithValue(ctx, appctx.CustodianRegionsCTXKey, custodianRegions) + + validateAccountRes := gemini.ValidatedAccount{ + ID: accountID.String(), + ValidDocuments: []gemini.ValidDocument{ + { + Type: "passport", + IssuingCountry: "US", + }, + }, + } + + mockGeminiClient.EXPECT().FetchValidateAccount( + gomock.Any(), + gomock.Any(), + gomock.Any(), + ).Return(validateAccountRes, nil) + + mockSQLCustodianLink(mock, "gemini") + + // begin linking tx + mock.ExpectBegin() + + // make sure old linking id matches new one for same custodian + linkingID := uuid.NewV5(ClaimNamespace, idTo.String()) + + // acquire lock for linkingID + mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // not before linked + mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "gemini").WillReturnError(sql.ErrNoRows) + + mockSQLCustodianLink(mock, "gemini") + + var max = sqlmock.NewRows([]string{"max"}).AddRow(4) + var open = sqlmock.NewRows([]string{"used"}).AddRow(0) + + var custLinks = sqlmock.NewRows([]string{"custodian", "linking_id"}).AddRow("gemini", linkingID.String()) + + // linking limit checks + mock.ExpectQuery("^select wc1.custodian, wc1.linking_id from wallet_custodian (.+)").WithArgs(linkingID).WillReturnRows(custLinks) + mock.ExpectQuery("^select (.+)").WithArgs(linkingID, 4).WillReturnRows(max) + mock.ExpectQuery("^select (.+)").WithArgs(linkingID).WillReturnRows(open) + mock.ExpectQuery("^select (.+)").WithArgs(linkingID).WillReturnRows(sqlmock.NewRows([]string{"wallet_id"}).AddRow(uuid.NewV4().String())) + // get last un linking + var lastUnlink = sqlmock.NewRows([]string{"last_unlinking"}).AddRow(time.Now()) + mock.ExpectQuery("^select max(.+)").WithArgs(linkingID).WillReturnRows(lastUnlink) + + // updates the link to the wallet_custodian record in wallets + mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + clRows := sqlmock.NewRows([]string{"created_at", "linked_at"}). + AddRow(time.Now(), time.Now()) + + // insert into wallet custodian + mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(ClaimNamespace, accountID.String())).WillReturnRows(clRows) + + // updates the link to the wallet_custodian record in wallets + mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "gemini", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + mock.ExpectExec("^insert into (.+)").WithArgs(idFrom, true).WillReturnResult(sqlmock.NewResult(1, 1)) + + // commit transaction + mock.ExpectCommit() + + r = r.WithContext(ctx) + + router := chi.NewRouter() + router.Post("/v3/wallet/gemini/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) + router.ServeHTTP(rw, r) + + b := rw.Body.Bytes() + require.Equal(t, http.StatusOK, rw.Code, string(b)) + + var l LinkDepositAccountResponse + err := json.Unmarshal(b, &l) + require.NoError(t, err) + + assert.Equal(t, "US", l.GeoCountry) + + // delete linking + r = httptest.NewRequest( + "DELETE", + fmt.Sprintf("/v3/wallet/gemini/%s/claim", idFrom), nil) + + handler = DisconnectCustodianLinkV3(s) + rw = httptest.NewRecorder() + + // create transaction + mock.ExpectBegin() + + // removes the link to the user_deposit_destination record in wallets + mock.ExpectExec("^update wallets (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + // updates the disconnected date on the record, and returns no error and one changed row + mock.ExpectExec("^update wallet_custodian(.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + // commit transaction because we are done disconnecting + mock.ExpectCommit() + + r = r.WithContext(ctx) + + router = chi.NewRouter() + router.Delete("/v3/wallet/{custodian}/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) + router.ServeHTTP(rw, r) + + if resp := rw.Result(); resp.StatusCode != http.StatusOK { + t.Errorf("invalid response expected %d, got %d", http.StatusOK, resp.StatusCode) + } + + // ban the country now + custodianRegions = custodian.Regions{ + Gemini: custodian.GeoAllowBlockMap{ + Allow: []string{}, + }, + } + ctx = context.WithValue(ctx, appctx.CustodianRegionsCTXKey, custodianRegions) + + // begin linking tx + mock.ExpectBegin() + + // acquire lock for linkingID + mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // not before linked + mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "gemini").WillReturnError(sql.ErrNoRows) + + mockSQLCustodianLink(mock, "gemini") + + // perform again, make sure we check haslinkedprio + hasPriorRows := sqlmock.NewRows([]string{"result"}). + AddRow(true) + mock.ExpectQuery("^select exists(select 1 from wallet_custodian (.+)").WithArgs(uuid.NewV5(ClaimNamespace, accountID.String()), idFrom).WillReturnRows(hasPriorRows) + + max = sqlmock.NewRows([]string{"max"}).AddRow(4) + open = sqlmock.NewRows([]string{"used"}).AddRow(0) + + custLinks = sqlmock.NewRows([]string{"custodian", "linking_id"}).AddRow("gemini", linkingID.String()) + + // linking limit checks + mock.ExpectQuery("^select wc1.custodian, wc1.linking_id from wallet_custodian (.+)").WithArgs(linkingID).WillReturnRows(custLinks) + mock.ExpectQuery("^select (.+)").WithArgs(linkingID, 4).WillReturnRows(max) + mock.ExpectQuery("^select (.+)").WithArgs(linkingID).WillReturnRows(open) + mock.ExpectQuery("^select (.+)").WithArgs(linkingID).WillReturnRows(sqlmock.NewRows([]string{"wallet_id"}).AddRow(uuid.NewV4().String())) + // get last un linking + lastUnlink = sqlmock.NewRows([]string{"last_unlinking"}).AddRow(time.Now()) + mock.ExpectQuery("^select max(.+)").WithArgs(linkingID).WillReturnRows(lastUnlink) + + // updates the link to the wallet_custodian record in wallets + mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + clRows = sqlmock.NewRows([]string{"created_at", "linked_at"}). + AddRow(time.Now(), time.Now()) + + // insert into wallet custodian + mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(ClaimNamespace, accountID.String())).WillReturnRows(clRows) + + // updates the link to the wallet_custodian record in wallets + mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "gemini", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + // commit transaction + mock.ExpectCommit() + + r = r.WithContext(ctx) + + router = chi.NewRouter() + router.Post("/v3/wallet/gemini/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) + router.ServeHTTP(rw, r) + + b = rw.Body.Bytes() + require.Equal(t, http.StatusOK, rw.Code, string(b)) +} + +func TestLinkGeminiWalletV3FirstLinking(t *testing.T) { + VerifiedWalletEnable = true + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + var ( + // setup test variables + idFrom = uuid.NewV4() + ctx = middleware.AddKeyID(context.Background(), idFrom.String()) + accountID = uuid.NewV4() + idTo = accountID + + s, mock = initSvcWithMockDB(t) + + linkingInfo = "this is the fake jwt for linking_info" + + // setup mock clients + mockReputationClient = mockreputation.NewMockClient(mockCtrl) + mockGeminiClient = mockgemini.NewMockClient(mockCtrl) + + // this is our main request + r = httptest.NewRequest( + "POST", + fmt.Sprintf("/v3/wallet/gemini/%s/claim", idFrom), + bytes.NewBufferString(fmt.Sprintf(` + { + "linking_info": "%s", + "recipient_id": "%s" + }`, linkingInfo, idTo)), + ) + + handler = LinkGeminiDepositAccountV3(s) + rw = httptest.NewRecorder() + ) + + mockReputationClient.EXPECT().IsLinkingReputable( + gomock.Any(), // ctx + gomock.Any(), // wallet id + gomock.Any(), // country + ).Return( + true, + []int{}, + nil, + ) + + ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputationClient) + ctx = context.WithValue(ctx, appctx.GeminiClientCTXKey, mockGeminiClient) + ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") + + validateAccountRes := gemini.ValidatedAccount{ + ID: accountID.String(), + ValidDocuments: []gemini.ValidDocument{ + { + Type: "passport", + IssuingCountry: "US", + }, + }, + } + + mockGeminiClient.EXPECT().FetchValidateAccount( + gomock.Any(), + gomock.Any(), + gomock.Any(), + ).Return(validateAccountRes, nil) + + mockSQLCustodianLink(mock, "gemini") + + // begin linking tx + mock.ExpectBegin() + + // make sure old linking id matches new one for same custodian + linkingID := uuid.NewV5(ClaimNamespace, idTo.String()) + + // acquire lock for linkingID + mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // not before linked + mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "gemini").WillReturnError(sql.ErrNoRows) + + mockSQLCustodianLink(mock, "gemini") + + var max = sqlmock.NewRows([]string{"max"}).AddRow(4) + var open = sqlmock.NewRows([]string{"used"}).AddRow(0) + + var custLinks = sqlmock.NewRows([]string{"custodian", "linking_id"}).AddRow("gemini", linkingID.String()) + + // linking limit checks + mock.ExpectQuery("^select wc1.custodian, wc1.linking_id from wallet_custodian (.+)").WithArgs(linkingID).WillReturnRows(custLinks) + mock.ExpectQuery("^select (.+)").WithArgs(linkingID, 4).WillReturnRows(max) + mock.ExpectQuery("^select (.+)").WithArgs(linkingID).WillReturnRows(open) + mock.ExpectQuery("^select (.+)").WithArgs(linkingID).WillReturnRows(sqlmock.NewRows([]string{"wallet_id"}).AddRow(uuid.NewV4().String())) + // get last un linking + var lastUnlink = sqlmock.NewRows([]string{"last_unlinking"}).AddRow(time.Now()) + mock.ExpectQuery("^select max(.+)").WithArgs(linkingID).WillReturnRows(lastUnlink) + + // updates the link to the wallet_custodian record in wallets + mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + clRows := sqlmock.NewRows([]string{"created_at", "linked_at"}). + AddRow(time.Now(), time.Now()) + + // insert into wallet custodian + mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(ClaimNamespace, accountID.String())).WillReturnRows(clRows) + + // updates the link to the wallet_custodian record in wallets + mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "gemini", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + mock.ExpectExec("^insert into (.+)").WithArgs(idFrom, true).WillReturnResult(sqlmock.NewResult(1, 1)) + + // commit transaction + mock.ExpectCommit() + + r = r.WithContext(ctx) + + router := chi.NewRouter() + router.Post("/v3/wallet/gemini/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) + router.ServeHTTP(rw, r) + + b := rw.Body.Bytes() + require.Equal(t, http.StatusOK, rw.Code, string(b)) + + var l LinkDepositAccountResponse + err := json.Unmarshal(b, &l) + require.NoError(t, err) +} + +func TestLinkZebPayWalletV3_InvalidKyc(t *testing.T) { + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + // setup jwt token for the test + var secret = []byte("a jwt secret") + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: secret}, (&jose.SignerOptions{}).WithType("JWT")) + if err != nil { + panic(err) + } + + var ( + // setup test variables + idFrom = uuid.NewV4() + ctx = middleware.AddKeyID(context.Background(), idFrom.String()) + accountID = uuid.NewV4() + idTo = accountID + s, _ = initSvcWithMockDB(t) + handler = LinkZebPayDepositAccountV3(s) + rw = httptest.NewRecorder() + ) + + ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") + ctx = context.WithValue(ctx, appctx.ZebPayLinkingKeyCTXKey, base64.StdEncoding.EncodeToString(secret)) + + linkingInfo, err := jwt.Signed(sig).Claims(map[string]interface{}{ + "accountId": accountID, + "depositId": idTo, + "countryCode": "IN", + "iat": time.Now().Unix(), + "exp": time.Now().Add(5 * time.Second).Unix(), + }).CompactSerialize() + if err != nil { + panic(err) + } + + // this is our main request + r := httptest.NewRequest( + "POST", + fmt.Sprintf("/v3/wallet/zebpay/%s/claim", idFrom), + bytes.NewBufferString(fmt.Sprintf( + `{"linking_info": "%s"}`, + linkingInfo, + )), + ) + + r = r.WithContext(ctx) + + router := chi.NewRouter() + router.Post("/v3/wallet/zebpay/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) + router.ServeHTTP(rw, r) + + b := rw.Body.Bytes() + require.Equal(t, http.StatusForbidden, rw.Code, string(b)) + + var l LinkDepositAccountResponse + err = json.Unmarshal(b, &l) + require.NoError(t, err) +} + +func TestLinkZebPayWalletV3(t *testing.T) { + VerifiedWalletEnable = true + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + // setup jwt token for the test + var secret = []byte("a jwt secret") + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: secret}, (&jose.SignerOptions{}).WithType("JWT")) + if err != nil { + panic(err) + } + + var ( + idFrom = uuid.NewV4() + ctx = middleware.AddKeyID(context.Background(), idFrom.String()) + accountID = uuid.NewV4() + idTo = accountID + + mockReputationClient = mockreputation.NewMockClient(mockCtrl) + + s, mock = initSvcWithMockDB(t) + + handler = LinkZebPayDepositAccountV3(s) + rw = httptest.NewRecorder() + ) + + ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputationClient) + ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") + ctx = context.WithValue(ctx, appctx.ZebPayLinkingKeyCTXKey, base64.StdEncoding.EncodeToString(secret)) + + linkingInfo, err := jwt.Signed(sig).Claims(map[string]interface{}{ + "accountId": accountID, "depositId": idTo, "iat": time.Now().Unix(), "exp": time.Now().Add(5 * time.Second).Unix(), + "isValid": true, "countryCode": "IN", + }).CompactSerialize() + if err != nil { + panic(err) + } + + // this is our main request + r := httptest.NewRequest( + "POST", + fmt.Sprintf("/v3/wallet/zebpay/%s/claim", idFrom), + bytes.NewBufferString(fmt.Sprintf( + `{"linking_info": "%s"}`, + linkingInfo, + )), + ) + + mockReputationClient.EXPECT().IsLinkingReputable( + gomock.Any(), // ctx + gomock.Any(), // wallet id + gomock.Any(), // country + ).Return( + true, + []int{}, + nil, + ) + + // begin linking tx + mock.ExpectBegin() + + // make sure old linking id matches new one for same custodian + linkingID := uuid.NewV5(ClaimNamespace, idTo.String()) + var linkingIDRows = sqlmock.NewRows([]string{"linking_id"}).AddRow(linkingID) + + // acquire lock for linkingID + mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "zebpay").WillReturnRows(linkingIDRows) + + mockSQLCustodianLink(mock, "zebpay") + + // updates the link to the wallet_custodian record in wallets + mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + // this wallet has been linked prior, with the same linking id that the request is with + // SHOULD SKIP THE linking limit checks + clRows := sqlmock.NewRows([]string{"created_at", "linked_at"}). + AddRow(time.Now(), time.Now()) + + // insert into wallet custodian + mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "zebpay", uuid.NewV5(ClaimNamespace, accountID.String())).WillReturnRows(clRows) + + // updates the link to the wallet_custodian record in wallets + mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "zebpay", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + mock.ExpectExec("^insert into (.+)").WithArgs(idFrom, true).WillReturnResult(sqlmock.NewResult(1, 1)) + + // commit transaction + mock.ExpectCommit() + + r = r.WithContext(ctx) + + router := chi.NewRouter() + router.Post("/v3/wallet/zebpay/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) + router.ServeHTTP(rw, r) + + b := rw.Body.Bytes() + require.Equal(t, http.StatusOK, rw.Code, string(b)) + + var l LinkDepositAccountResponse + err = json.Unmarshal(b, &l) + require.NoError(t, err) + + assert.Equal(t, "IN", l.GeoCountry) +} + +func TestLinkGeminiWalletV3(t *testing.T) { + VerifiedWalletEnable = true + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + var ( + // setup test variables + idFrom = uuid.NewV4() + ctx = middleware.AddKeyID(context.Background(), idFrom.String()) + accountID = uuid.NewV4() + idTo = accountID + + linkingInfo = "this is the fake jwt for linking_info" + + // setup mock clients + mockReputationClient = mockreputation.NewMockClient(mockCtrl) + mockGeminiClient = mockgemini.NewMockClient(mockCtrl) + + // this is our main request + r = httptest.NewRequest( + "POST", + fmt.Sprintf("/v3/wallet/gemini/%s/claim", idFrom), + bytes.NewBufferString(fmt.Sprintf(` + { + "linking_info": "%s", + "recipient_id": "%s" + }`, linkingInfo, idTo)), + ) + s, mock = initSvcWithMockDB(t) + + handler = LinkGeminiDepositAccountV3(s) + rw = httptest.NewRecorder() + ) + + ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputationClient) + ctx = context.WithValue(ctx, appctx.GeminiClientCTXKey, mockGeminiClient) + ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") + + validateAccountRes := gemini.ValidatedAccount{ + ID: accountID.String(), + ValidDocuments: []gemini.ValidDocument{ + { + Type: "passport", + IssuingCountry: "US", + }, + }, + } + + mockGeminiClient.EXPECT().FetchValidateAccount( + gomock.Any(), + gomock.Any(), + gomock.Any(), + ).Return(validateAccountRes, nil) + + mockReputationClient.EXPECT().IsLinkingReputable( + gomock.Any(), // ctx + gomock.Any(), // wallet id + gomock.Any(), // country + ).Return( + true, + []int{}, + nil, + ) + + mockSQLCustodianLink(mock, "gemini") + + // begin linking tx + mock.ExpectBegin() + + // make sure old linking id matches new one for same custodian + linkingID := uuid.NewV5(ClaimNamespace, idTo.String()) + var linkingIDRows = sqlmock.NewRows([]string{"linking_id"}).AddRow(linkingID) + + // acquire lock for linkingID + mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "gemini").WillReturnRows(linkingIDRows) + + mockSQLCustodianLink(mock, "gemini") + + // updates the link to the wallet_custodian record in wallets + mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + // this wallet has been linked prior, with the same linking id that the request is with + // SHOULD SKIP THE linking limit checks + clRows := sqlmock.NewRows([]string{"created_at", "linked_at"}). + AddRow(time.Now(), time.Now()) + + // insert into wallet custodian + mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(ClaimNamespace, accountID.String())).WillReturnRows(clRows) + + // updates the link to the wallet_custodian record in wallets + mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "gemini", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + mock.ExpectExec("^insert into (.+)").WithArgs(idFrom, true).WillReturnResult(sqlmock.NewResult(1, 1)) + + // commit transaction + mock.ExpectCommit() + + r = r.WithContext(ctx) + + router := chi.NewRouter() + router.Post("/v3/wallet/gemini/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) + router.ServeHTTP(rw, r) + + b := rw.Body.Bytes() + require.Equal(t, http.StatusOK, rw.Code, string(b)) + + var l LinkDepositAccountResponse + err := json.Unmarshal(b, &l) + require.NoError(t, err) + + assert.Equal(t, "US", l.GeoCountry) +} + +func TestDisconnectCustodianLinkV3(t *testing.T) { + VerifiedWalletEnable = true + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + var ( + // setup test variables + idFrom = uuid.NewV4() + ctx = middleware.AddKeyID(context.Background(), idFrom.String()) + + // this is our main request + r = httptest.NewRequest( + "DELETE", + fmt.Sprintf("/v3/wallet/gemini/%s/claim", idFrom), nil) + + s, mock = initSvcWithMockDB(t) + + handler = DisconnectCustodianLinkV3(s) + w = httptest.NewRecorder() + ) + + // create transaction + mock.ExpectBegin() + + // removes the link to the user_deposit_destination record in wallets + mock.ExpectExec("^update wallets (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + // updates the disconnected date on the record, and returns no error and one changed row + mock.ExpectExec("^update wallet_custodian(.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + + // commit transaction because we are done disconnecting + mock.ExpectCommit() + + ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") + + r = r.WithContext(ctx) + + router := chi.NewRouter() + router.Delete("/v3/wallet/{custodian}/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) + router.ServeHTTP(w, r) + + if resp := w.Result(); resp.StatusCode != http.StatusOK { + t.Errorf("invalid response expected %d, got %d", http.StatusOK, resp.StatusCode) + } +} + +func TestIsAllowedOrigin(t *testing.T) { + type tcGiven struct { + origin string + allowedOrigins []string + } + + type testCase struct { + name string + given tcGiven + exp bool + } + + tests := []testCase{ + { + name: "allowed", + given: tcGiven{ + origin: "test", + allowedOrigins: []string{"random-1", "random-2", "random-3", "test"}, + }, + exp: true, + }, + { + name: "empty_origin", + given: tcGiven{ + allowedOrigins: []string{"random-1", "random-2", "random-3", "test"}, + }, + exp: false, + }, + { + name: "origin_not_in_allowed_origins", + given: tcGiven{ + origin: "test", + allowedOrigins: []string{"random-1", "random-2", "random-3", "random-4"}, + }, + exp: false, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := isAllowedOrigin(tc.given.origin, tc.given.allowedOrigins) + assert.Equal(t, tc.exp, actual) + }) + } +} + +func signRequest(req *http.Request, publicKey httpsignature.Ed25519PubKey, privateKey ed25519.PrivateKey) error { + var s httpsignature.SignatureParams + s.Algorithm = httpsignature.ED25519 + s.KeyID = hex.EncodeToString(publicKey) + s.Headers = []string{"digest", "(request-target)"} + return s.Sign(privateKey, crypto.Hash(0), req) +} + +type result struct{} + +func (r result) LastInsertId() (int64, error) { return 1, nil } +func (r result) RowsAffected() (int64, error) { return 1, nil } + +func mockSQLCustodianLink(mock sqlmock.Sqlmock, custodian string) { + clRow := sqlmock.NewRows([]string{"wallet_id", "custodian", "linking_id", "created_at", "disconnected_at", "linked_at"}). + AddRow(uuid.NewV4().String(), custodian, uuid.NewV4().String(), time.Now(), time.Now(), time.Now()) + mock.ExpectQuery("^select(.+) from wallet_custodian(.+)"). + WillReturnRows(clRow) +} + +func initSvcWithMockDB(t *testing.T) (*Service, sqlmock.Sqlmock) { + db, mock, _ := sqlmock.New() + datastore := Datastore( + &Postgres{ + Postgres: datastoreutils.Postgres{ + DB: sqlx.NewDb(db, "postgres"), + }, + }) + + mtc := &mockMtc{} + + gem := &mockGemini{ + fnGetIssuingCountry: func(acc gemini.ValidatedAccount, fallback bool) string { + return "US" + }, + } + + dappConf := DAppConfig{} + + s, err := InitService(datastore, nil, nil, nil, nil, nil, nil, nil, mtc, gem, dappConf) + require.NoError(t, err) + + return s, mock +} + +type mockGemini struct { + fnGetIssuingCountry func(acc gemini.ValidatedAccount, fallback bool) string + fnIsRegionAllowed func(ctx context.Context, issuingCountry string, custodianRegions custodian.Regions) error +} + +func (m *mockGemini) GetIssuingCountry(acc gemini.ValidatedAccount, fallback bool) string { + if m.fnGetIssuingCountry == nil { + return "" + } + return m.fnGetIssuingCountry(acc, fallback) +} + +func (m *mockGemini) IsRegionAvailable(ctx context.Context, issuingCountry string, custodianRegions custodian.Regions) error { + if m.fnIsRegionAllowed == nil { + return nil + } + return m.fnIsRegionAllowed(ctx, issuingCountry, custodianRegions) +} + +type mockMtc struct { + fnLinkSuccessZP func(cc string) + fnLinkFailureZP func(cc string) +} + +func (m *mockMtc) LinkSuccessZP(cc string) { + if m.fnLinkSuccessZP != nil { + m.fnLinkSuccessZP(cc) + } +} + +func (m *mockMtc) LinkFailureZP(cc string) { + if m.fnLinkFailureZP != nil { + m.fnLinkFailureZP(cc) + } +} + +func (m *mockMtc) LinkFailureGemini(_ string) {} +func (m *mockMtc) LinkSuccessGemini(_ string) {} +func (m *mockMtc) CountDocTypeByIssuingCntry(_ []gemini.ValidDocument) {} diff --git a/services/wallet/controllers_v3_test.go b/services/wallet/controllers_v3_test.go index 9179ead76..d9414458d 100644 --- a/services/wallet/controllers_v3_test.go +++ b/services/wallet/controllers_v3_test.go @@ -1,3 +1,5 @@ +//go:build integration + package wallet_test import ( @@ -5,8 +7,7 @@ import ( "context" "crypto" "crypto/ed25519" - "crypto/sha256" - "database/sql" + "crypto/rand" "encoding/base64" "encoding/hex" "encoding/json" @@ -16,1120 +17,720 @@ import ( "testing" "time" - "github.com/DATA-DOG/go-sqlmock" - "github.com/brave-intl/bat-go/libs/clients/gemini" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - mockgemini "github.com/brave-intl/bat-go/libs/clients/gemini/mock" - mockreputation "github.com/brave-intl/bat-go/libs/clients/reputation/mock" + "github.com/brave-intl/bat-go/libs/altcurrency" + mock_reputation "github.com/brave-intl/bat-go/libs/clients/reputation/mock" appctx "github.com/brave-intl/bat-go/libs/context" - "github.com/brave-intl/bat-go/libs/custodian" - datastoreutils "github.com/brave-intl/bat-go/libs/datastore" "github.com/brave-intl/bat-go/libs/handlers" "github.com/brave-intl/bat-go/libs/httpsignature" - "github.com/brave-intl/bat-go/libs/logging" - "github.com/brave-intl/bat-go/libs/middleware" + walletutils "github.com/brave-intl/bat-go/libs/wallet" + "github.com/brave-intl/bat-go/libs/wallet/provider/uphold" "github.com/brave-intl/bat-go/services/wallet" + "github.com/brave-intl/bat-go/services/wallet/model" + "github.com/brave-intl/bat-go/services/wallet/storage" + "github.com/btcsuite/btcutil/base58" "github.com/go-chi/chi" "github.com/golang/mock/gomock" - "github.com/jmoiron/sqlx" uuid "github.com/satori/go.uuid" - "gopkg.in/square/go-jose.v2" - "gopkg.in/square/go-jose.v2/jwt" + "github.com/shopspring/decimal" + "github.com/spf13/viper" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" ) -func TestCreateBraveWalletV3(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - var ( - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - // add the datastore to the context - ctx = context.Background() - handler = wallet.CreateBraveWalletV3 - r = httptest.NewRequest("POST", "/v3/wallet/brave", nil) - ) - // no logger, setup - ctx, _ = logging.SetupLogger(ctx) - - // setup sqlmock - mock.ExpectExec("^INSERT INTO wallets (.+)").WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(result{}) - - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) - ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") - - // setup keypair - publicKey, privKey, err := httpsignature.GenerateEd25519Key(nil) - must(t, "failed to generate keypair", err) - - err = signRequest(r, publicKey, privKey) - must(t, "failed to sign request", err) - - r = r.WithContext(ctx) - - var rw = httptest.NewRecorder() - handlers.AppHandler(handler).ServeHTTP(rw, r) - - b := rw.Body.Bytes() - require.Equal(t, http.StatusCreated, rw.Code, string(b)) +type WalletControllersTestSuite struct { + suite.Suite } -func TestCreateUpholdWalletV3(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - var ( - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - // add the datastore to the context - ctx = context.Background() - handler = wallet.CreateUpholdWalletV3 - r = httptest.NewRequest("POST", "/v3/wallet/uphold", bytes.NewBufferString(`{ - "signedCreationRequest": "eyJib2R5Ijp7ImRlbm9taW5hdGlvbiI6eyJhbW91bnQiOiIwIiwiY3VycmVuY3kiOiJCQVQifSwiZGVzdGluYXRpb24iOiJhNmRmZjJiYS1kMGQxLTQxYzQtOGU1Ni1hMjYwNWJjYWY0YWYifSwiaGVhZGVycyI6eyJkaWdlc3QiOiJTSEEtMjU2PWR2RTAzVHdpRmFSR0c0MUxLSkR4aUk2a3c5M0h0cTNsclB3VllldE5VY1E9Iiwic2lnbmF0dXJlIjoia2V5SWQ9XCJwcmltYXJ5XCIsYWxnb3JpdGhtPVwiZWQyNTUxOVwiLGhlYWRlcnM9XCJkaWdlc3RcIixzaWduYXR1cmU9XCJkcXBQdERESXE0djNiS1V5eHB6Q3Vyd01nSzRmTWk1MUJjakRLc2pTak90K1h1MElZZlBTMWxEZ01aRkhiaWJqcGh0MVd3V3l5enFad3lVNW0yN1FDUT09XCIifSwib2N0ZXRzIjoie1wiZGVub21pbmF0aW9uXCI6e1wiYW1vdW50XCI6XCIwXCIsXCJjdXJyZW5jeVwiOlwiQkFUXCJ9LFwiZGVzdGluYXRpb25cIjpcImE2ZGZmMmJhLWQwZDEtNDFjNC04ZTU2LWEyNjA1YmNhZjRhZlwifSJ9"}`)) - ) - // no logger, setup - ctx, _ = logging.SetupLogger(ctx) - - // setup sqlmock - mock.ExpectExec("^INSERT INTO wallets (.+)").WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(result{}) - - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) - ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") - - r = r.WithContext(ctx) - - var rw = httptest.NewRecorder() - handlers.AppHandler(handler).ServeHTTP(rw, r) - - b := rw.Body.Bytes() - require.Equal(t, http.StatusBadRequest, rw.Code, string(b)) +func TestWalletControllersTestSuite(t *testing.T) { + suite.Run(t, new(WalletControllersTestSuite)) } -func TestGetWalletV3(t *testing.T) { - var ( - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - roDatastore = wallet.ReadOnlyDatastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - // add the datastore to the context - ctx = context.Background() - id = uuid.NewV4() - r = httptest.NewRequest("GET", fmt.Sprintf("/v3/wallet/%s", id), nil) - handler = wallet.GetWalletV3 - rw = httptest.NewRecorder() - rows = sqlmock.NewRows([]string{"id", "provider", "provider_id", "public_key", "provider_linking_id", "anonymous_address"}). - AddRow(id, "brave", "", "12345", id, id) - ) +func (suite *WalletControllersTestSuite) SetupSuite() { + pg, _, err := wallet.NewPostgres() + suite.Require().NoError(err, "Failed to get postgres conn") - mock.ExpectQuery("^select (.+)").WithArgs(id).WillReturnRows(rows) + m, err := pg.NewMigrate() + suite.Require().NoError(err, "Failed to create migrate instance") - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) - ctx = context.WithValue(ctx, appctx.RODatastoreCTXKey, roDatastore) - ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") - - r = r.WithContext(ctx) - - router := chi.NewRouter() - router.Get("/v3/wallet/{paymentID}", handlers.AppHandler(handler).ServeHTTP) - router.ServeHTTP(rw, r) + ver, dirty, _ := m.Version() + if dirty { + suite.Require().NoError(m.Force(int(ver))) + } + if ver > 0 { + suite.Require().NoError(m.Down(), "Failed to migrate down cleanly") + } - b := rw.Body.Bytes() - require.Equal(t, http.StatusOK, rw.Code, string(b)) + suite.Require().NoError(pg.Migrate(), "Failed to fully migrate") } -func TestLinkBitFlyerWalletV3(t *testing.T) { - wallet.VerifiedWalletEnable = true - - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - // setup jwt token for the test - var secret = []byte("a jwt secret") - sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: secret}, (&jose.SignerOptions{}).WithType("JWT")) - if err != nil { - panic(err) - } - - var ( - idFrom = uuid.NewV4() - idTo = uuid.NewV4() - accountHash = uuid.NewV4() - timestamp = time.Now() - ) +func (suite *WalletControllersTestSuite) SetupTest() { + suite.CleanDB() +} - h := sha256.New() - if _, err := h.Write([]byte(idFrom.String())); err != nil { - panic(err) - } +func (suite *WalletControllersTestSuite) TearDownTest() { + suite.CleanDB() +} - externalAccountID := hex.EncodeToString(h.Sum(nil)) +func (suite *WalletControllersTestSuite) CleanDB() { + tables := []string{"claim_creds", "claims", "wallets", "issuers", "promotions", "wallet_custodian"} - linkingInfo := wallet.BitFlyerLinkingInfo{ - DepositID: idTo.String(), - RequestID: "1", - AccountHash: accountHash.String(), - ExternalAccountID: externalAccountID, - Timestamp: timestamp, - } + pg, _, err := wallet.NewPostgres() + suite.Require().NoError(err, "Failed to get postgres conn") - tokenString, err := jwt.Signed(sig).Claims(linkingInfo).CompactSerialize() - if err != nil { - panic(err) + for _, table := range tables { + _, err = pg.RawDB().Exec("delete from " + table) + suite.Require().NoError(err, "Failed to get clean table") } +} - var ( - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - - // add the datastore to the context - ctx = middleware.AddKeyID(context.WithValue(context.Background(), appctx.BitFlyerJWTKeyCTXKey, []byte(secret)), idFrom.String()) - r = httptest.NewRequest( - "POST", - fmt.Sprintf("/v3/wallet/bitflyer/%s/claim", idFrom), - bytes.NewBufferString(fmt.Sprintf(` - { - "linkingInfo": "%s" - }`, tokenString)), - ) - mockReputation = mockreputation.NewMockClient(mockCtrl) - s, _ = wallet.InitService(datastore, nil, nil, nil, nil, nil, nil, nil) - handler = wallet.LinkBitFlyerDepositAccountV3(s) - rw = httptest.NewRecorder() - ) - - mock.ExpectExec("^insert (.+)").WithArgs("1").WillReturnResult(sqlmock.NewResult(1, 1)) +func (suite *WalletControllersTestSuite) FundWallet(w *uphold.Wallet, probi decimal.Decimal) decimal.Decimal { + ctx := context.Background() + balanceBefore, err := w.GetBalance(ctx, true) + total, err := uphold.FundWallet(ctx, w, probi) + suite.Require().NoError(err, "an error should not be generated from funding the wallet") + suite.Require().True(total.GreaterThan(balanceBefore.TotalProbi), "submit with confirm should result in an increased balance") + return total +} - mockSQLCustodianLink(mock, "bitflyer") +func (suite *WalletControllersTestSuite) CheckBalance(w *uphold.Wallet, expect decimal.Decimal) { + balances, err := w.GetBalance(context.Background(), true) + suite.Require().NoError(err, "an error should not be generated from checking the wallet balance") + totalProbi := altcurrency.BAT.FromProbi(balances.TotalProbi) + errMessage := fmt.Sprintf("got an unexpected balance. expected: %s, got %s", expect.String(), totalProbi.String()) + suite.Require().True(expect.Equal(totalProbi), errMessage) +} - // begin linking tx - mock.ExpectBegin() +func (suite *WalletControllersTestSuite) TestBalanceV3() { + pg, _, err := wallet.NewPostgres() + suite.Require().NoError(err, "Failed to get postgres connection") - // make sure old linking id matches new one for same custodian - linkingID := uuid.NewV5(wallet.ClaimNamespace, accountHash.String()) + mockCtrl := gomock.NewController(suite.T()) + defer mockCtrl.Finish() - // acquire lock for linkingID - mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). - WillReturnResult(sqlmock.NewResult(1, 1)) + service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil, nil, nil, wallet.DAppConfig{}) - // this wallet has been linked prior, with the same linking id that the request is with - // SHOULD SKIP THE linking limit checks - var linkingIDRows = sqlmock.NewRows([]string{"linking_id"}).AddRow(linkingID) - mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "bitflyer").WillReturnRows(linkingIDRows) + w1 := suite.NewWallet(service, "uphold") - // updates the link to the wallet_custodian record in wallets - mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + bat1 := decimal.NewFromFloat(0.000000001) - clRows := sqlmock.NewRows([]string{"created_at", "linked_at"}). - AddRow(time.Now(), time.Now()) + suite.FundWallet(w1, bat1) - // insert into wallet custodian - mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "bitflyer", uuid.NewV5(wallet.ClaimNamespace, accountHash.String())).WillReturnRows(clRows) + // check there is 1 bat in w1 + suite.CheckBalance(w1, bat1) - // updates the link to the wallet_custodian record in wallets - mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "bitflyer", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + // call the balance endpoint and check that you get back a total of 1 + handler := wallet.GetUpholdWalletBalanceV3 - mock.ExpectExec("^insert into (.+)").WithArgs(idFrom, true).WillReturnResult(sqlmock.NewResult(1, 1)) + req, err := http.NewRequest("GET", "/v3/wallet/uphold/{paymentID}", nil) + suite.Require().NoError(err, "wallet claim request could not be created") - // commit transaction - mock.ExpectCommit() + rctx := chi.NewRouteContext() + rctx.URLParams.Add("paymentID", w1.ID) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + req = req.WithContext(context.WithValue(req.Context(), appctx.RODatastoreCTXKey, pg)) + req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) - ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputation) - ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") + rr := httptest.NewRecorder() + handlers.AppHandler(handler).ServeHTTP(rr, req) + suite.Require().Equal(http.StatusOK, rr.Code, fmt.Sprintf("status is expected to match %d: %s", http.StatusOK, rr.Body.String())) - mockReputation.EXPECT().IsLinkingReputable( - gomock.Any(), // ctx - gomock.Any(), // wallet id - gomock.Any(), // country - ).Return( - true, - []int{}, - nil, - ) + var balance wallet.BalanceResponseV3 + err = json.Unmarshal(rr.Body.Bytes(), &balance) + suite.Require().NoError(err, "failed to unmarshal balance result") - r = r.WithContext(ctx) + suite.Require().Equal(balance.Total, float64(0.000000001), fmt.Sprintf("balance is expected to match %f: %f", balance.Total, float64(1))) - router := chi.NewRouter() - router.Post("/v3/wallet/bitflyer/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) - router.ServeHTTP(rw, r) + _, err = pg.RawDB().Exec(`update wallets set provider_id = '' where id = $1`, w1.ID) + suite.Require().NoError(err, "wallet provider_id could not be set as empty string") - b := rw.Body.Bytes() - require.Equal(t, http.StatusOK, rw.Code, string(b)) + req, err = http.NewRequest("GET", "/v3/wallet/uphold/{paymentID}", nil) + suite.Require().NoError(err, "wallet claim request could not be created") - var l wallet.LinkDepositAccountResponse - err = json.Unmarshal(b, &l) - require.NoError(t, err) + rctx = chi.NewRouteContext() + rctx.URLParams.Add("paymentID", w1.ID) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + req = req.WithContext(context.WithValue(req.Context(), appctx.RODatastoreCTXKey, pg)) + req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) - assert.Equal(t, "JP", l.GeoCountry) + rr = httptest.NewRecorder() + handlers.AppHandler(handler).ServeHTTP(rr, req) + expectingForbidden := fmt.Sprintf("status is expected to match %d: %s", http.StatusForbidden, rr.Body.String()) + suite.Require().Equal(http.StatusForbidden, rr.Code, expectingForbidden) } -func TestLinkGeminiWalletV3RelinkBadRegion(t *testing.T) { - wallet.VerifiedWalletEnable = true - - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - var ( - // setup test variables - idFrom = uuid.NewV4() - ctx = middleware.AddKeyID(context.Background(), idFrom.String()) - accountID = uuid.NewV4() - idTo = accountID - - // setup db mocks - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - linkingInfo = "this is the fake jwt for linking_info" - - // setup mock clients - mockReputationClient = mockreputation.NewMockClient(mockCtrl) - mockGeminiClient = mockgemini.NewMockClient(mockCtrl) - - // this is our main request - r = httptest.NewRequest( - "POST", - fmt.Sprintf("/v3/wallet/gemini/%s/claim", idFrom), - bytes.NewBufferString(fmt.Sprintf(` - { - "linking_info": "%s", - "recipient_id": "%s" - }`, linkingInfo, idTo)), - ) - - mtc = &mockMtc{} - gem = &mockGemini{ - fnGetIssuingCountry: func(acc gemini.ValidatedAccount, fallback bool) string { - return "US" - }, - } - - s, _ = wallet.InitService(datastore, nil, nil, nil, nil, nil, mtc, gem) - handler = wallet.LinkGeminiDepositAccountV3(s) - rw = httptest.NewRecorder() - ) - - mockReputationClient.EXPECT().IsLinkingReputable( - gomock.Any(), // ctx - gomock.Any(), // wallet id - gomock.Any(), // country - ).Return( - true, - []int{}, - nil, - ) - - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) - ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputationClient) - ctx = context.WithValue(ctx, appctx.GeminiClientCTXKey, mockGeminiClient) - ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") - // turn on region check - ctx = context.WithValue(ctx, appctx.UseCustodianRegionsCTXKey, true) - // configure allow region - custodianRegions := custodian.Regions{ - Gemini: custodian.GeoAllowBlockMap{ - Allow: []string{"US"}, - }, - } - ctx = context.WithValue(ctx, appctx.CustodianRegionsCTXKey, custodianRegions) - - validateAccountRes := gemini.ValidatedAccount{ - ID: accountID.String(), - ValidDocuments: []gemini.ValidDocument{ - { - Type: "passport", - IssuingCountry: "US", - }, - }, - } - - mockGeminiClient.EXPECT().FetchValidateAccount( - gomock.Any(), - gomock.Any(), - gomock.Any(), - ).Return(validateAccountRes, nil) +func (suite *WalletControllersTestSuite) TestLinkWalletV3() { + ctx := context.Background() - mockSQLCustodianLink(mock, "gemini") + pg, _, err := wallet.NewPostgres() + suite.Require().NoError(err, "Failed to get postgres connection") - // begin linking tx - mock.ExpectBegin() + mockCtrl := gomock.NewController(suite.T()) + defer mockCtrl.Finish() - // make sure old linking id matches new one for same custodian - linkingID := uuid.NewV5(wallet.ClaimNamespace, idTo.String()) + service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil, nil, nil, wallet.DAppConfig{}) - // acquire lock for linkingID - mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). - WillReturnResult(sqlmock.NewResult(1, 1)) + w1 := suite.NewWallet(service, "uphold") + w2 := suite.NewWallet(service, "uphold") + w3 := suite.NewWallet(service, "uphold") + w4 := suite.NewWallet(service, "uphold") + bat1 := decimal.NewFromFloat(0.000000001) + bat2 := decimal.NewFromFloat(0.000000002) - // not before linked - mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "gemini").WillReturnError(sql.ErrNoRows) + suite.FundWallet(w1, bat1) + suite.FundWallet(w2, bat1) + suite.FundWallet(w3, bat1) + suite.FundWallet(w4, bat1) - var max = sqlmock.NewRows([]string{"max"}).AddRow(4) - var open = sqlmock.NewRows([]string{"used"}).AddRow(0) + anonCard1ID, err := w1.CreateCardAddress(ctx, "anonymous") + suite.Require().NoError(err, "create anon card must not fail") + anonCard1UUID := uuid.Must(uuid.FromString(anonCard1ID)) - var custLinks = sqlmock.NewRows([]string{"custodian", "linking_id"}).AddRow("gemini", linkingID.String()) + anonCard2ID, err := w2.CreateCardAddress(ctx, "anonymous") + suite.Require().NoError(err, "create anon card must not fail") + anonCard2UUID := uuid.Must(uuid.FromString(anonCard2ID)) - // linking limit checks - mock.ExpectQuery("^select wc1.custodian, wc1.linking_id from wallet_custodian (.+)").WithArgs(linkingID).WillReturnRows(custLinks) - mock.ExpectQuery("^select (.+)").WithArgs(linkingID, 4).WillReturnRows(max) - mock.ExpectQuery("^select (.+)").WithArgs(linkingID).WillReturnRows(open) - mock.ExpectQuery("^select (.+)").WithArgs(linkingID).WillReturnRows(sqlmock.NewRows([]string{"wallet_id"}).AddRow(uuid.NewV4().String())) - // get last un linking - var lastUnlink = sqlmock.NewRows([]string{"last_unlinking"}).AddRow(time.Now()) - mock.ExpectQuery("^select max(.+)").WithArgs(linkingID).WillReturnRows(lastUnlink) + anonCard3ID, err := w3.CreateCardAddress(ctx, "anonymous") + suite.Require().NoError(err, "create anon card must not fail") + anonCard3UUID := uuid.Must(uuid.FromString(anonCard3ID)) - // updates the link to the wallet_custodian record in wallets - mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + w1ProviderID := w1.GetWalletInfo().ProviderID + w2ProviderID := w2.GetWalletInfo().ProviderID + w3ProviderID := w3.GetWalletInfo().ProviderID - clRows := sqlmock.NewRows([]string{"created_at", "linked_at"}). - AddRow(time.Now(), time.Now()) + zero := decimal.NewFromFloat(0) - // insert into wallet custodian - mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(wallet.ClaimNamespace, accountID.String())).WillReturnRows(clRows) + suite.CheckBalance(w1, bat1) + suite.claimCardV3(service, w1, w3ProviderID, http.StatusOK, bat1, &anonCard3UUID) + suite.CheckBalance(w1, zero) - // updates the link to the wallet_custodian record in wallets - mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "gemini", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + suite.CheckBalance(w2, bat1) + suite.claimCardV3(service, w2, w1ProviderID, http.StatusOK, zero, &anonCard1UUID) + suite.CheckBalance(w2, bat1) - mock.ExpectExec("^insert into (.+)").WithArgs(idFrom, true).WillReturnResult(sqlmock.NewResult(1, 1)) + suite.CheckBalance(w2, bat1) + suite.claimCardV3(service, w2, w1ProviderID, http.StatusOK, bat1, &anonCard3UUID) + suite.CheckBalance(w2, zero) - // commit transaction - mock.ExpectCommit() + suite.CheckBalance(w3, bat2) + suite.claimCardV3(service, w3, w2ProviderID, http.StatusOK, bat1, &anonCard3UUID) + suite.CheckBalance(w3, bat1) - r = r.WithContext(ctx) + suite.CheckBalance(w3, bat1) + suite.claimCardV3(service, w3, w1ProviderID, http.StatusOK, zero, &anonCard2UUID) + suite.CheckBalance(w3, bat1) +} - router := chi.NewRouter() - router.Post("/v3/wallet/gemini/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) - router.ServeHTTP(rw, r) +func (suite *WalletControllersTestSuite) claimCardV3( + service *wallet.Service, + w *uphold.Wallet, + destination string, + status int, + amount decimal.Decimal, + anonymousAddress *uuid.UUID, +) (*walletutils.Info, string) { + signedCreationRequest, err := w.PrepareTransaction(*w.AltCurrency, altcurrency.BAT.ToProbi(amount), destination, "", "", nil) + + suite.Require().NoError(err, "transaction must be signed client side") + + // V3 Payload + reqBody := wallet.LinkUpholdDepositAccountRequest{ + SignedLinkingRequest: signedCreationRequest, + } - b := rw.Body.Bytes() - require.Equal(t, http.StatusOK, rw.Code, string(b)) + if anonymousAddress != nil { + reqBody.AnonymousAddress = anonymousAddress.String() + } - var l wallet.LinkDepositAccountResponse - err := json.Unmarshal(b, &l) - require.NoError(t, err) + body, err := json.Marshal(&reqBody) + suite.Require().NoError(err, "unable to marshal claim body") - assert.Equal(t, "US", l.GeoCountry) + info := w.GetWalletInfo() - // delete linking - r = httptest.NewRequest( - "DELETE", - fmt.Sprintf("/v3/wallet/gemini/%s/claim", idFrom), nil) + // V3 Handler - handler = wallet.DisconnectCustodianLinkV3(s) - rw = httptest.NewRecorder() + handler := wallet.LinkUpholdDepositAccountV3(service) - // create transaction - mock.ExpectBegin() + req, err := http.NewRequest("POST", "/v3/wallet/{paymentID}/claim", bytes.NewBuffer(body)) + suite.Require().NoError(err, "wallet claim request could not be created") - // removes the link to the user_deposit_destination record in wallets - mock.ExpectExec("^update wallets (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + ctrl := gomock.NewController(suite.T()) + defer ctrl.Finish() - // updates the disconnected date on the record, and returns no error and one changed row - mock.ExpectExec("^update wallet_custodian(.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + repClient := mock_reputation.NewMockClient(ctrl) + repClient.EXPECT().IsLinkingReputable(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - // commit transaction because we are done disconnecting - mock.ExpectCommit() + rctx := chi.NewRouteContext() + rctx.URLParams.Add("paymentID", info.ID) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) + req = req.WithContext(context.WithValue(req.Context(), appctx.ReputationClientCTXKey, repClient)) - r = r.WithContext(ctx) + rr := httptest.NewRecorder() + handlers.AppHandler(handler).ServeHTTP(rr, req) + suite.Require().Equal(status, rr.Code, fmt.Sprintf("status is expected to match %d: %s", status, rr.Body.String())) + linked, err := service.Datastore.GetWallet(req.Context(), uuid.Must(uuid.FromString(w.ID))) + suite.Require().NoError(err, "retrieving the wallet did not cause an error") + return linked, rr.Body.String() +} - router = chi.NewRouter() - router.Delete("/v3/wallet/{custodian}/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) - router.ServeHTTP(rw, r) +func (suite *WalletControllersTestSuite) createBody(tx string) string { + reqBody, _ := json.Marshal(wallet.UpholdCreationRequest{ + SignedCreationRequest: tx, + }) + return string(reqBody) +} - if resp := rw.Result(); resp.StatusCode != http.StatusOK { - must(t, "invalid response", fmt.Errorf("expected %d, got %d", http.StatusOK, resp.StatusCode)) +func (suite *WalletControllersTestSuite) NewWallet(service *wallet.Service, provider string) *uphold.Wallet { + publicKey, privKey, err := httpsignature.GenerateEd25519Key(nil) + publicKeyString := hex.EncodeToString(publicKey) + + bat := altcurrency.BAT + info := walletutils.Info{ + ID: uuid.NewV4().String(), + PublicKey: publicKeyString, + Provider: provider, + AltCurrency: &bat, } - - // ban the country now - custodianRegions = custodian.Regions{ - Gemini: custodian.GeoAllowBlockMap{ - Allow: []string{}, - }, + w := &uphold.Wallet{ + Info: info, + PrivKey: privKey, + PubKey: publicKey, } - ctx = context.WithValue(ctx, appctx.CustodianRegionsCTXKey, custodianRegions) - - // begin linking tx - mock.ExpectBegin() - - // acquire lock for linkingID - mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). - WillReturnResult(sqlmock.NewResult(1, 1)) - // not before linked - mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "gemini").WillReturnError(sql.ErrNoRows) + reg, err := w.PrepareRegistration("Brave Browser Test Link") + suite.Require().NoError(err, "unable to prepare transaction") - // perform again, make sure we check haslinkedprio - hasPriorRows := sqlmock.NewRows([]string{"result"}). - AddRow(true) - mock.ExpectQuery("^select exists(select 1 from wallet_custodian (.+)").WithArgs(uuid.NewV5(wallet.ClaimNamespace, accountID.String()), idFrom).WillReturnRows(hasPriorRows) - - max = sqlmock.NewRows([]string{"max"}).AddRow(4) - open = sqlmock.NewRows([]string{"used"}).AddRow(0) + createResp := suite.createUpholdWalletV3( + service, + suite.createBody(reg), + http.StatusCreated, + publicKey, + privKey, + true, + ) - custLinks = sqlmock.NewRows([]string{"custodian", "linking_id"}).AddRow("gemini", linkingID.String()) + var returnedInfo wallet.ResponseV3 + err = json.Unmarshal([]byte(createResp), &returnedInfo) + suite.Require().NoError(err, "unable to create wallet") + convertedInfo := wallet.ResponseV3ToInfo(returnedInfo) + w.Info = *convertedInfo + return w +} - // linking limit checks - mock.ExpectQuery("^select wc1.custodian, wc1.linking_id from wallet_custodian (.+)").WithArgs(linkingID).WillReturnRows(custLinks) - mock.ExpectQuery("^select (.+)").WithArgs(linkingID, 4).WillReturnRows(max) - mock.ExpectQuery("^select (.+)").WithArgs(linkingID).WillReturnRows(open) - mock.ExpectQuery("^select (.+)").WithArgs(linkingID).WillReturnRows(sqlmock.NewRows([]string{"wallet_id"}).AddRow(uuid.NewV4().String())) - // get last un linking - lastUnlink = sqlmock.NewRows([]string{"last_unlinking"}).AddRow(time.Now()) - mock.ExpectQuery("^select max(.+)").WithArgs(linkingID).WillReturnRows(lastUnlink) +func (suite *WalletControllersTestSuite) TestCreateBraveWalletV3() { + pg, _, err := wallet.NewPostgres() + suite.Require().NoError(err, "Failed to get postgres connection") - // updates the link to the wallet_custodian record in wallets - mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil, nil, nil, wallet.DAppConfig{}) - clRows = sqlmock.NewRows([]string{"created_at", "linked_at"}). - AddRow(time.Now(), time.Now()) + publicKey, privKey, err := httpsignature.GenerateEd25519Key(nil) - // insert into wallet custodian - mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(wallet.ClaimNamespace, accountID.String())).WillReturnRows(clRows) + // assume 400 is already covered + // fail because of lacking signature presence + notSignedResponse := suite.createUpholdWalletV3( + service, + `{}`, + http.StatusBadRequest, + publicKey, + privKey, + false, + ) - // updates the link to the wallet_custodian record in wallets - mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "gemini", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + suite.Assert().JSONEq(`{"code":400, "data":{"validationErrors":{"decoding":"failed decoding: failed to decode signed creation request: unexpected end of JSON input", "signedCreationRequest":"value is required", "validation":"failed validation: missing signed creation request"}}, "message":"Error validating uphold create wallet request validation errors"}`, notSignedResponse, "field is not valid") - // commit transaction - mock.ExpectCommit() + createResp := suite.createBraveWalletV3( + service, + ``, + http.StatusCreated, + publicKey, + privKey, + true, + ) - r = r.WithContext(ctx) + var created wallet.ResponseV3 + err = json.Unmarshal([]byte(createResp), &created) + suite.Require().NoError(err, "unable to unmarshal response") - router = chi.NewRouter() - router.Post("/v3/wallet/gemini/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) - router.ServeHTTP(rw, r) + getResp := suite.getWallet(service, uuid.Must(uuid.FromString(created.PaymentID)), http.StatusOK) - b = rw.Body.Bytes() - require.Equal(t, http.StatusOK, rw.Code, string(b)) + var gotten wallet.ResponseV3 + err = json.Unmarshal([]byte(getResp), &gotten) + suite.Require().NoError(err, "unable to unmarshal response") + // does not return wallet provider + suite.Require().Equal(created, gotten, "the get and create return the same structure") } -func TestLinkGeminiWalletV3FirstLinking(t *testing.T) { - wallet.VerifiedWalletEnable = true +func (suite *WalletControllersTestSuite) TestCreateUpholdWalletV3() { + pg, _, err := wallet.NewPostgres() + suite.Require().NoError(err, "Failed to get postgres connection") - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() + service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil, nil, nil, wallet.DAppConfig{}) - var ( - // setup test variables - idFrom = uuid.NewV4() - ctx = middleware.AddKeyID(context.Background(), idFrom.String()) - accountID = uuid.NewV4() - idTo = accountID - - // setup db mocks - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - linkingInfo = "this is the fake jwt for linking_info" - - // setup mock clients - mockReputationClient = mockreputation.NewMockClient(mockCtrl) - mockGeminiClient = mockgemini.NewMockClient(mockCtrl) - - // this is our main request - r = httptest.NewRequest( - "POST", - fmt.Sprintf("/v3/wallet/gemini/%s/claim", idFrom), - bytes.NewBufferString(fmt.Sprintf(` - { - "linking_info": "%s", - "recipient_id": "%s" - }`, linkingInfo, idTo)), - ) + publicKey, privKey, err := httpsignature.GenerateEd25519Key(nil) - mtc = &mockMtc{} - gem = &mockGemini{ - fnGetIssuingCountry: func(acc gemini.ValidatedAccount, fallback bool) string { - return "US" - }, + badJSONBodyParse := suite.createUpholdWalletV3( + service, + ``, + http.StatusBadRequest, + publicKey, + privKey, + true, + ) + suite.Assert().JSONEq(`{ + "code":400, + "data": { + "validationErrors":{ + "decoding":"failed decoding: failed to decode json: EOF", + "signedCreationRequest":"value is required", + "validation":"failed validation: missing signed creation request" } - - s, _ = wallet.InitService(datastore, nil, nil, nil, nil, nil, mtc, gem) - handler = wallet.LinkGeminiDepositAccountV3(s) - rw = httptest.NewRecorder() + }, + "message":"Error validating uphold create wallet request validation errors" + }`, badJSONBodyParse, "should fail when parsing json") + + badFieldResponse := suite.createUpholdWalletV3( + service, + `{"signedCreationRequest":""}`, + http.StatusBadRequest, + publicKey, + privKey, + true, ) - mockReputationClient.EXPECT().IsLinkingReputable( - gomock.Any(), // ctx - gomock.Any(), // wallet id - gomock.Any(), // country - ).Return( - true, - []int{}, - nil, + suite.Assert().JSONEq(`{ + "code":400, "data":{"validationErrors":{"decoding":"failed decoding: failed to decode signed creation request: unexpected end of JSON input", "signedCreationRequest":"value is required", "validation":"failed validation: missing signed creation request"}}, "message":"Error validating uphold create wallet request validation errors"}`, badFieldResponse, "field is not valid") + + // assume 403 is already covered + // fail because of lacking signature presence + notSignedResponse := suite.createUpholdWalletV3( + service, + `{}`, + http.StatusBadRequest, + publicKey, + privKey, + false, ) - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) - ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputationClient) - ctx = context.WithValue(ctx, appctx.GeminiClientCTXKey, mockGeminiClient) - ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") - - validateAccountRes := gemini.ValidatedAccount{ - ID: accountID.String(), - ValidDocuments: []gemini.ValidDocument{ - { - Type: "passport", - IssuingCountry: "US", - }, - }, - } + suite.Assert().JSONEq(`{ +"code":400, "data":{"validationErrors":{"decoding":"failed decoding: failed to decode signed creation request: unexpected end of JSON input", "signedCreationRequest":"value is required", "validation":"failed validation: missing signed creation request"}}, "message":"Error validating uphold create wallet request validation errors" + }`, notSignedResponse, "field is not valid") +} - mockGeminiClient.EXPECT().FetchValidateAccount( - gomock.Any(), - gomock.Any(), - gomock.Any(), - ).Return(validateAccountRes, nil) +func (suite *WalletControllersTestSuite) TestChallenges_Success() { + paymentID := uuid.NewV4() - mockSQLCustodianLink(mock, "gemini") + body := struct { + PaymentID uuid.UUID `json:"paymentId"` + }{ + PaymentID: paymentID, + } - // begin linking tx - mock.ExpectBegin() + b, err := json.Marshal(body) + suite.Require().NoError(err) - // make sure old linking id matches new one for same custodian - linkingID := uuid.NewV5(wallet.ClaimNamespace, idTo.String()) + r := httptest.NewRequest(http.MethodPost, "/v3/wallet/challenges", bytes.NewBuffer(b)) + r.Header.Set("origin", "https://my-dapp.com") - // acquire lock for linkingID - mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). - WillReturnResult(sqlmock.NewResult(1, 1)) + rw := httptest.NewRecorder() - // not before linked - mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "gemini").WillReturnError(sql.ErrNoRows) + pg, _, err := wallet.NewPostgres() + suite.Require().NoError(err) - var max = sqlmock.NewRows([]string{"max"}).AddRow(4) - var open = sqlmock.NewRows([]string{"used"}).AddRow(0) + chlRep := storage.NewChallenge() - var custLinks = sqlmock.NewRows([]string{"custodian", "linking_id"}).AddRow("gemini", linkingID.String()) + dac := wallet.DAppConfig{ + AllowedOrigins: []string{"https://my-dapp.com", "https://my-dapp-2.com"}, + } - // linking limit checks - mock.ExpectQuery("^select wc1.custodian, wc1.linking_id from wallet_custodian (.+)").WithArgs(linkingID).WillReturnRows(custLinks) - mock.ExpectQuery("^select (.+)").WithArgs(linkingID, 4).WillReturnRows(max) - mock.ExpectQuery("^select (.+)").WithArgs(linkingID).WillReturnRows(open) - mock.ExpectQuery("^select (.+)").WithArgs(linkingID).WillReturnRows(sqlmock.NewRows([]string{"wallet_id"}).AddRow(uuid.NewV4().String())) - // get last un linking - var lastUnlink = sqlmock.NewRows([]string{"last_unlinking"}).AddRow(time.Now()) - mock.ExpectQuery("^select max(.+)").WithArgs(linkingID).WillReturnRows(lastUnlink) + s, err := wallet.InitService(pg, nil, chlRep, nil, nil, nil, nil, nil, nil, nil, dac) + suite.Require().NoError(err) - // updates the link to the wallet_custodian record in wallets - mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + svr := &http.Server{Addr: ":8080", Handler: setupRouter(s)} + svr.Handler.ServeHTTP(rw, r) + suite.Require().Equal(http.StatusCreated, rw.Code) + suite.Require().Equal("https://my-dapp.com", rw.Header().Get("Access-Control-Allow-Origin")) - clRows := sqlmock.NewRows([]string{"created_at", "linked_at"}). - AddRow(time.Now(), time.Now()) + type chlResp struct { + Nonce string `json:"challengeId"` + } - // insert into wallet custodian - mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(wallet.ClaimNamespace, accountID.String())).WillReturnRows(clRows) + var resp chlResp + err = json.Unmarshal(rw.Body.Bytes(), &resp) + suite.Require().NoError(err) - // updates the link to the wallet_custodian record in wallets - mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "gemini", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + chlRepo := storage.NewChallenge() + chl, err := chlRepo.Get(context.TODO(), pg.RawDB(), paymentID) + suite.Require().NoError(err) - mock.ExpectExec("^insert into (.+)").WithArgs(idFrom, true).WillReturnResult(sqlmock.NewResult(1, 1)) + suite.Assert().Equal(chl.Nonce, resp.Nonce) +} - // commit transaction - mock.ExpectCommit() +func (suite *WalletControllersTestSuite) TestChallenges_Options() { + req := httptest.NewRequest(http.MethodOptions, "/v3/wallet/challenges", nil) + req.Header.Add("Access-Control-Request-Method", http.MethodPost) + req.Header.Add("Access-Control-Request-Headers", "Content-Type") + req.Header.Set("origin", "https://my-dapp.com") - r = r.WithContext(ctx) + rw := httptest.NewRecorder() - router := chi.NewRouter() - router.Post("/v3/wallet/gemini/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) - router.ServeHTTP(rw, r) + s := wallet.Service{} - b := rw.Body.Bytes() - require.Equal(t, http.StatusOK, rw.Code, string(b)) + svr := &http.Server{Addr: ":8080", Handler: setupRouter(&s)} + svr.Handler.ServeHTTP(rw, req) - var l wallet.LinkDepositAccountResponse - err := json.Unmarshal(b, &l) - require.NoError(t, err) + suite.Require().Equal(http.StatusOK, rw.Code) + suite.Require().Equal("https://my-dapp.com", rw.Header().Get("Access-Control-Allow-Origin")) + suite.Require().Equal(http.MethodPost, rw.Header().Get("Access-Control-Allow-Methods")) + suite.Require().Equal("Content-Type", rw.Header().Get("Access-Control-Allow-Headers")) } -func TestLinkZebPayWalletV3_InvalidKyc(t *testing.T) { - - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - // setup jwt token for the test - var secret = []byte("a jwt secret") - sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: secret}, (&jose.SignerOptions{}).WithType("JWT")) - if err != nil { - panic(err) - } +func (suite *WalletControllersTestSuite) TestLinkSolanaAddress_Success() { + viper.Set("enable-link-drain-flag", "true") - var ( - // setup test variables - idFrom = uuid.NewV4() - ctx = middleware.AddKeyID(context.Background(), idFrom.String()) - accountID = uuid.NewV4() - idTo = accountID - - // setup db mocks - db, _, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - - mtc = &mockMtc{ - fnLinkFailureZP: func(cc string) { - assert.Equal(t, "IN", cc) - }, - } + pg, _, err := wallet.NewPostgres() + suite.Require().NoError(err) - gem = &mockGemini{} + chlRep := storage.NewChallenge() + allowList := storage.NewAllowList() - s, _ = wallet.InitService(datastore, nil, nil, nil, nil, nil, mtc, gem) - handler = wallet.LinkZebPayDepositAccountV3(s) - rw = httptest.NewRecorder() - ) + // create the wallet + pub, priv, err := ed25519.GenerateKey(nil) + suite.Require().NoError(err) - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) - ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") - ctx = context.WithValue(ctx, appctx.ZebPayLinkingKeyCTXKey, base64.StdEncoding.EncodeToString(secret)) - - linkingInfo, err := jwt.Signed(sig).Claims(map[string]interface{}{ - "accountId": accountID, - "depositId": idTo, - "countryCode": "IN", - "iat": time.Now().Unix(), - "exp": time.Now().Add(5 * time.Second).Unix(), - }).CompactSerialize() - if err != nil { - panic(err) + paymentID := uuid.NewV4() + w := &walletutils.Info{ + ID: paymentID.String(), + Provider: "brave", + PublicKey: hex.EncodeToString(pub), + AltCurrency: ptrTo(altcurrency.BAT), } + err = pg.InsertWallet(context.TODO(), w) + suite.Require().NoError(err) - // this is our main request - r := httptest.NewRequest( - "POST", - fmt.Sprintf("/v3/wallet/zebpay/%s/claim", idFrom), - bytes.NewBufferString(fmt.Sprintf( - `{"linking_info": "%s"}`, - linkingInfo, - )), - ) - - r = r.WithContext(ctx) - - router := chi.NewRouter() - router.Post("/v3/wallet/zebpay/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) - router.ServeHTTP(rw, r) + whitelistWallet(suite.T(), pg, w.ID) - b := rw.Body.Bytes() - require.Equal(t, http.StatusForbidden, rw.Code, string(b)) + // create nonce + chl := model.NewChallenge(paymentID) - var l wallet.LinkDepositAccountResponse - err = json.Unmarshal(b, &l) - require.NoError(t, err) -} - -func TestLinkZebPayWalletV3(t *testing.T) { - wallet.VerifiedWalletEnable = true - - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() + err = chlRep.Upsert(context.TODO(), pg.RawDB(), chl) + suite.Require().NoError(err) - // setup jwt token for the test - var secret = []byte("a jwt secret") - sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: secret}, (&jose.SignerOptions{}).WithType("JWT")) - if err != nil { - panic(err) + dac := wallet.DAppConfig{ + AllowedOrigins: []string{"https://my-dapp.com", "https://my-dapp-2.com"}, } - var ( - // setup test variables - idFrom = uuid.NewV4() - ctx = middleware.AddKeyID(context.Background(), idFrom.String()) - accountID = uuid.NewV4() - idTo = accountID - - // setup db mocks - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - - // setup mock clients - mockReputationClient = mockreputation.NewMockClient(mockCtrl) - - mtc = &mockMtc{ - fnLinkSuccessZP: func(cc string) { - assert.Equal(t, "IN", cc) - }, - } - - gem = &mockGemini{} - - s, _ = wallet.InitService(datastore, nil, nil, nil, nil, nil, mtc, gem) - handler = wallet.LinkZebPayDepositAccountV3(s) - rw = httptest.NewRecorder() - ) - - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) - ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputationClient) - ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") - ctx = context.WithValue(ctx, appctx.ZebPayLinkingKeyCTXKey, base64.StdEncoding.EncodeToString(secret)) - - linkingInfo, err := jwt.Signed(sig).Claims(map[string]interface{}{ - "accountId": accountID, "depositId": idTo, "iat": time.Now().Unix(), "exp": time.Now().Add(5 * time.Second).Unix(), - "isValid": true, "countryCode": "IN", - }).CompactSerialize() - if err != nil { - panic(err) + s, err := wallet.InitService(pg, nil, chlRep, allowList, nil, nil, nil, nil, nil, nil, dac) + suite.Require().NoError(err) + + // create linking message + solPub, msg, solSig := createAndSignMessage(suite.T(), w.ID, priv, chl.Nonce) + + // make request + body := struct { + SolanaPublicKey string `json:"solanaPublicKey"` + Message string `json:"message"` + SolanaSignature string `json:"solanaSignature"` + }{ + SolanaPublicKey: solPub, + Message: msg, + SolanaSignature: solSig, } - // this is our main request - r := httptest.NewRequest( - "POST", - fmt.Sprintf("/v3/wallet/zebpay/%s/claim", idFrom), - bytes.NewBufferString(fmt.Sprintf( - `{"linking_info": "%s"}`, - linkingInfo, - )), - ) - - mockReputationClient.EXPECT().IsLinkingReputable( - gomock.Any(), // ctx - gomock.Any(), // wallet id - gomock.Any(), // country - ).Return( - true, - []int{}, - nil, - ) - - mockSQLCustodianLink(mock, "zebpay") - - // begin linking tx - mock.ExpectBegin() + b, err := json.Marshal(body) + suite.Require().NoError(err) - // make sure old linking id matches new one for same custodian - linkingID := uuid.NewV5(wallet.ClaimNamespace, idTo.String()) - var linkingIDRows = sqlmock.NewRows([]string{"linking_id"}).AddRow(linkingID) + r := httptest.NewRequest(http.MethodPost, "/v3/wallet/solana/"+w.ID+"/connect", bytes.NewBuffer(b)) + r.Header.Set("origin", "https://my-dapp-2.com") - // acquire lock for linkingID - mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). - WillReturnResult(sqlmock.NewResult(1, 1)) + rw := httptest.NewRecorder() - mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "zebpay").WillReturnRows(linkingIDRows) + svr := &http.Server{Addr: ":8080", Handler: setupRouter(s)} + svr.Handler.ServeHTTP(rw, r) + suite.Require().Equal(http.StatusOK, rw.Code) + suite.Require().Equal("https://my-dapp-2.com", rw.Header().Get("Access-Control-Allow-Origin")) - // updates the link to the wallet_custodian record in wallets - mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + // assert + actual, err := pg.GetWallet(context.TODO(), paymentID) + suite.Require().NoError(err) - // this wallet has been linked prior, with the same linking id that the request is with - // SHOULD SKIP THE linking limit checks - clRows := sqlmock.NewRows([]string{"created_at", "linked_at"}). - AddRow(time.Now(), time.Now()) + suite.Require().Equal(solPub, actual.UserDepositDestination) - // insert into wallet custodian - mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "zebpay", uuid.NewV5(wallet.ClaimNamespace, accountID.String())).WillReturnRows(clRows) + // after a successful linking the challenge should be removed from the database. + _, actualErr := chlRep.Get(context.TODO(), pg.RawDB(), paymentID) + suite.Assert().ErrorIs(actualErr, model.ErrChallengeNotFound) +} - // updates the link to the wallet_custodian record in wallets - mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "zebpay", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) +func (suite *WalletControllersTestSuite) TestLinkSolanaAddress_Options() { + viper.Set("enable-link-drain-flag", "true") - mock.ExpectExec("^insert into (.+)").WithArgs(idFrom, true).WillReturnResult(sqlmock.NewResult(1, 1)) + req := httptest.NewRequest(http.MethodOptions, "/v3/wallet/solana/ae51dce3-08e9-4beb-8a70-c51d064bb7d1/connect", nil) + req.Header.Add("Access-Control-Request-Method", http.MethodPost) + req.Header.Add("Access-Control-Request-Headers", "Content-Type") + req.Header.Set("origin", "https://my-dapp.com") - // commit transaction - mock.ExpectCommit() + rw := httptest.NewRecorder() - r = r.WithContext(ctx) + s := wallet.Service{} - router := chi.NewRouter() - router.Post("/v3/wallet/zebpay/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) - router.ServeHTTP(rw, r) + svr := &http.Server{Addr: ":8080", Handler: setupRouter(&s)} + svr.Handler.ServeHTTP(rw, req) - b := rw.Body.Bytes() - require.Equal(t, http.StatusOK, rw.Code, string(b)) + suite.Require().Equal(http.StatusOK, rw.Code) + suite.Require().Equal("https://my-dapp.com", rw.Header().Get("Access-Control-Allow-Origin")) + suite.Require().Equal(http.MethodPost, rw.Header().Get("Access-Control-Allow-Methods")) + suite.Require().Equal("Content-Type", rw.Header().Get("Access-Control-Allow-Headers")) +} - var l wallet.LinkDepositAccountResponse - err = json.Unmarshal(b, &l) +func whitelistWallet(t *testing.T, pg wallet.Datastore, Id string) { + const q = `insert into allow_list (payment_id, created_at) values($1, $2)` + _, err := pg.RawDB().Exec(q, Id, time.Now()) require.NoError(t, err) - - assert.Equal(t, "IN", l.GeoCountry) } -func TestLinkGeminiWalletV3(t *testing.T) { - wallet.VerifiedWalletEnable = true - - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - var ( - // setup test variables - idFrom = uuid.NewV4() - ctx = middleware.AddKeyID(context.Background(), idFrom.String()) - accountID = uuid.NewV4() - idTo = accountID - - // setup db mocks - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - linkingInfo = "this is the fake jwt for linking_info" - - // setup mock clients - mockReputationClient = mockreputation.NewMockClient(mockCtrl) - mockGeminiClient = mockgemini.NewMockClient(mockCtrl) - - // this is our main request - r = httptest.NewRequest( - "POST", - fmt.Sprintf("/v3/wallet/gemini/%s/claim", idFrom), - bytes.NewBufferString(fmt.Sprintf(` - { - "linking_info": "%s", - "recipient_id": "%s" - }`, linkingInfo, idTo)), - ) - - mtc = &mockMtc{} - gem = &mockGemini{ - fnGetIssuingCountry: func(acc gemini.ValidatedAccount, fallback bool) string { - return "GB" - }, - } - - s, _ = wallet.InitService(datastore, nil, nil, nil, nil, nil, mtc, gem) - handler = wallet.LinkGeminiDepositAccountV3(s) - rw = httptest.NewRecorder() - ) - - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) - ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, mockReputationClient) - ctx = context.WithValue(ctx, appctx.GeminiClientCTXKey, mockGeminiClient) - ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") - - validateAccountRes := gemini.ValidatedAccount{ - ID: accountID.String(), - ValidDocuments: []gemini.ValidDocument{ - { - Type: "passport", - IssuingCountry: "GB", - }, - }, - } - - mockGeminiClient.EXPECT().FetchValidateAccount( - gomock.Any(), - gomock.Any(), - gomock.Any(), - ).Return(validateAccountRes, nil) - - mockReputationClient.EXPECT().IsLinkingReputable( - gomock.Any(), // ctx - gomock.Any(), // wallet id - gomock.Any(), // country - ).Return( - true, - []int{}, - nil, - ) - - mockSQLCustodianLink(mock, "gemini") - - // begin linking tx - mock.ExpectBegin() - - // make sure old linking id matches new one for same custodian - linkingID := uuid.NewV5(wallet.ClaimNamespace, idTo.String()) - var linkingIDRows = sqlmock.NewRows([]string{"linking_id"}).AddRow(linkingID) +func createAndSignMessage(t *testing.T, paymentID string, rewardsPrivKey ed25519.PrivateKey, nonce string) (solPub, msg, solSig string) { + // Create and sign the rewards message. + // The message has the format = . + rewardsMsg := paymentID + "." + nonce + sig, err := rewardsPrivKey.Sign(rand.Reader, []byte(rewardsMsg), crypto.Hash(0)) + require.NoError(t, err) - // acquire lock for linkingID - mock.ExpectExec("^SELECT pg_advisory_xact_lock\\(hashtext(.+)\\)").WithArgs(linkingID.String()). - WillReturnResult(sqlmock.NewResult(1, 1)) + rewardsSig := base64.URLEncoding.EncodeToString(sig) + rewardsPart := rewardsMsg + "." + rewardsSig - mock.ExpectQuery("^select linking_id from (.+)").WithArgs(idFrom, "gemini").WillReturnRows(linkingIDRows) + // Create the linking message and sign with the Solana key. + pub, priv, err := ed25519.GenerateKey(nil) + require.NoError(t, err) - // updates the link to the wallet_custodian record in wallets - mock.ExpectExec("^update wallet_custodian (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + solPub = base58.Encode(pub) - // this wallet has been linked prior, with the same linking id that the request is with - // SHOULD SKIP THE linking limit checks - clRows := sqlmock.NewRows([]string{"created_at", "linked_at"}). - AddRow(time.Now(), time.Now()) + const msgTmpl = ":%s\n:%s\n:%s" + msg = fmt.Sprintf(msgTmpl, paymentID, solPub, rewardsPart) - // insert into wallet custodian - mock.ExpectQuery("^insert into wallet_custodian (.+)").WithArgs(idFrom, "gemini", uuid.NewV5(wallet.ClaimNamespace, accountID.String())).WillReturnRows(clRows) + sig, err = priv.Sign(rand.Reader, []byte(msg), crypto.Hash(0)) + require.NoError(t, err) - // updates the link to the wallet_custodian record in wallets - mock.ExpectExec("^update wallets (.+)").WithArgs(idTo, linkingID, "gemini", idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + solSig = base64.URLEncoding.EncodeToString(sig) - mock.ExpectExec("^insert into (.+)").WithArgs(idFrom, true).WillReturnResult(sqlmock.NewResult(1, 1)) + return +} - // commit transaction - mock.ExpectCommit() +func (suite *WalletControllersTestSuite) getWallet(service *wallet.Service, paymentId uuid.UUID, code int) string { + handler := handlers.AppHandler(wallet.GetWalletV3) - r = r.WithContext(ctx) + req, err := http.NewRequest("GET", "/v3/wallet/"+paymentId.String(), nil) + suite.Require().NoError(err, "a request should be created") - router := chi.NewRouter() - router.Post("/v3/wallet/gemini/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) - router.ServeHTTP(rw, r) + req = req.WithContext(context.WithValue(req.Context(), appctx.DatastoreCTXKey, service.Datastore)) + req = req.WithContext(context.WithValue(req.Context(), appctx.RODatastoreCTXKey, service.Datastore)) + req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) + rctx := chi.NewRouteContext() + rctx.URLParams.Add("paymentID", paymentId.String()) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) - b := rw.Body.Bytes() - require.Equal(t, http.StatusOK, rw.Code, string(b)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) - var l wallet.LinkDepositAccountResponse - err := json.Unmarshal(b, &l) - require.NoError(t, err) + suite.Require().Equal(code, rr.Code, "known status code should be sent: "+rr.Body.String()) - assert.Equal(t, "GB", l.GeoCountry) + return rr.Body.String() } -func TestDisconnectCustodianLinkV3(t *testing.T) { - wallet.VerifiedWalletEnable = true - - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - var ( - // setup test variables - idFrom = uuid.NewV4() - ctx = middleware.AddKeyID(context.Background(), idFrom.String()) - - // setup db mocks - db, mock, _ = sqlmock.New() - datastore = wallet.Datastore( - &wallet.Postgres{ - Postgres: datastoreutils.Postgres{ - DB: sqlx.NewDb(db, "postgres"), - }, - }) - - // this is our main request - r = httptest.NewRequest( - "DELETE", - fmt.Sprintf("/v3/wallet/gemini/%s/claim", idFrom), nil) - - mtc = &mockMtc{} - gem = &mockGemini{} - - s, _ = wallet.InitService(datastore, nil, nil, nil, nil, nil, mtc, gem) - handler = wallet.DisconnectCustodianLinkV3(s) - w = httptest.NewRecorder() - ) - - // create transaction - mock.ExpectBegin() - - // removes the link to the user_deposit_destination record in wallets - mock.ExpectExec("^update wallets (.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) +func (suite *WalletControllersTestSuite) createBraveWalletV3( + service *wallet.Service, + body string, + code int, + publicKey httpsignature.Ed25519PubKey, + privateKey ed25519.PrivateKey, + shouldSign bool, +) string { + + handler := handlers.AppHandler(wallet.CreateBraveWalletV3) + + bodyBuffer := bytes.NewBuffer([]byte(body)) + req, err := http.NewRequest("POST", "/v3/wallet/brave", bodyBuffer) + suite.Require().NoError(err, "a request should be created") + + // setup context + req = req.WithContext(context.WithValue(context.Background(), appctx.DatastoreCTXKey, service.Datastore)) + req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) + + if shouldSign { + suite.SignRequest( + req, + publicKey, + privateKey, + ) + } - // updates the disconnected date on the record, and returns no error and one changed row - mock.ExpectExec("^update wallet_custodian(.+)").WithArgs(idFrom).WillReturnResult(sqlmock.NewResult(1, 1)) + rctx := chi.NewRouteContext() + joined := context.WithValue(req.Context(), chi.RouteCtxKey, rctx) + req = req.WithContext(joined) - // commit transaction because we are done disconnecting - mock.ExpectCommit() + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + suite.Require().Equal(code, rr.Code, "known status code should be sent: "+rr.Body.String()) - ctx = context.WithValue(ctx, appctx.DatastoreCTXKey, datastore) - ctx = context.WithValue(ctx, appctx.NoUnlinkPriorToDurationCTXKey, "-P1D") + return rr.Body.String() +} - r = r.WithContext(ctx) +func (suite *WalletControllersTestSuite) createUpholdWalletV3( + service *wallet.Service, + body string, + code int, + publicKey httpsignature.Ed25519PubKey, + privateKey ed25519.PrivateKey, + shouldSign bool, +) string { + + handler := handlers.AppHandler(wallet.CreateUpholdWalletV3) + + bodyBuffer := bytes.NewBuffer([]byte(body)) + req, err := http.NewRequest("POST", "/v3/wallet", bodyBuffer) + suite.Require().NoError(err, "a request should be created") + + // setup context + req = req.WithContext(context.WithValue(context.Background(), appctx.DatastoreCTXKey, service.Datastore)) + req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) + + if shouldSign { + suite.SignRequest( + req, + publicKey, + privateKey, + ) + } - router := chi.NewRouter() - router.Delete("/v3/wallet/{custodian}/{paymentID}/claim", handlers.AppHandler(handler).ServeHTTP) - router.ServeHTTP(w, r) + rctx := chi.NewRouteContext() + joined := context.WithValue(req.Context(), chi.RouteCtxKey, rctx) + req = req.WithContext(joined) - if resp := w.Result(); resp.StatusCode != http.StatusOK { - must(t, "invalid response", fmt.Errorf("expected %d, got %d", http.StatusOK, resp.StatusCode)) - } -} + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + suite.Require().Equal(code, rr.Code, "known status code should be sent: "+rr.Body.String()) -func must(t *testing.T, msg string, err error) { - if err != nil { - t.Errorf("%s: %s\n", msg, err) - } + return rr.Body.String() } -func signRequest(req *http.Request, publicKey httpsignature.Ed25519PubKey, privateKey ed25519.PrivateKey) error { +func (suite *WalletControllersTestSuite) SignRequest(req *http.Request, publicKey httpsignature.Ed25519PubKey, privateKey ed25519.PrivateKey) { var s httpsignature.SignatureParams s.Algorithm = httpsignature.ED25519 s.KeyID = hex.EncodeToString(publicKey) s.Headers = []string{"digest", "(request-target)"} - return s.Sign(privateKey, crypto.Hash(0), req) -} - -type result struct{} - -func (r result) LastInsertId() (int64, error) { return 1, nil } -func (r result) RowsAffected() (int64, error) { return 1, nil } - -func mockSQLCustodianLink(mock sqlmock.Sqlmock, custodian string) { - clRow := sqlmock.NewRows([]string{"wallet_id", "custodian", "linking_id", "created_at", "disconnected_at", "linked_at"}). - AddRow(uuid.NewV4().String(), custodian, uuid.NewV4().String(), time.Now(), time.Now(), time.Now()) - mock.ExpectQuery("^select(.+) from wallet_custodian(.+)"). - WillReturnRows(clRow) -} -type mockGemini struct { - fnGetIssuingCountry func(acc gemini.ValidatedAccount, fallback bool) string - fnIsRegionAllowed func(ctx context.Context, issuingCountry string, custodianRegions custodian.Regions) error + err := s.Sign(privateKey, crypto.Hash(0), req) + suite.Require().NoError(err) } -func (m *mockGemini) GetIssuingCountry(acc gemini.ValidatedAccount, fallback bool) string { - if m.fnGetIssuingCountry == nil { - return "" +func setupRouter(service *wallet.Service) *chi.Mux { + mw := func(name string, h http.Handler) http.Handler { + return h } - return m.fnGetIssuingCountry(acc, fallback) + s := []string{"https://my-dapp.com", "https://my-dapp-2.com"} + r := chi.NewRouter() + r.Mount("/v3", wallet.RegisterRoutes(context.TODO(), service, r, mw, wallet.NewDAppCorsMw(s))) + return r } -func (m *mockGemini) IsRegionAvailable(ctx context.Context, issuingCountry string, custodianRegions custodian.Regions) error { - if m.fnIsRegionAllowed == nil { - return nil - } - return m.fnIsRegionAllowed(ctx, issuingCountry, custodianRegions) -} - -type mockMtc struct { - fnLinkSuccessZP func(cc string) - fnLinkFailureZP func(cc string) -} - -func (m *mockMtc) LinkSuccessZP(cc string) { - if m.fnLinkSuccessZP != nil { - m.fnLinkSuccessZP(cc) - } +func ptrTo[T any](v T) *T { + return &v } - -func (m *mockMtc) LinkFailureZP(cc string) { - if m.fnLinkFailureZP != nil { - m.fnLinkFailureZP(cc) - } -} - -func (m *mockMtc) LinkFailureGemini(_ string) {} -func (m *mockMtc) LinkSuccessGemini(_ string) {} -func (m *mockMtc) CountDocTypeByIssuingCntry(_ []gemini.ValidDocument) {} diff --git a/services/wallet/controllers_v4.go b/services/wallet/controllers_v4.go index 8e876cd8f..af6fbd89d 100644 --- a/services/wallet/controllers_v4.go +++ b/services/wallet/controllers_v4.go @@ -11,8 +11,10 @@ import ( errorutils "github.com/brave-intl/bat-go/libs/errors" "github.com/brave-intl/bat-go/libs/handlers" "github.com/brave-intl/bat-go/libs/httpsignature" + "github.com/brave-intl/bat-go/libs/inputs" "github.com/brave-intl/bat-go/libs/logging" "github.com/brave-intl/bat-go/libs/middleware" + "github.com/brave-intl/bat-go/services/wallet/model" "github.com/go-chi/chi" ) @@ -167,12 +169,47 @@ func UpdateWalletV4(s *Service) func(w http.ResponseWriter, r *http.Request) *ha } } -// GetWalletV4 is the same as get wallet v3, but we are now requiring http signatures for get wallet requests -func GetWalletV4(w http.ResponseWriter, r *http.Request) *handlers.AppError { - return GetWalletV3(w, r) -} - // GetUpholdWalletBalanceV4 produces an http handler for the service s which handles balance inquiries of uphold wallets func GetUpholdWalletBalanceV4(w http.ResponseWriter, r *http.Request) *handlers.AppError { return GetUpholdWalletBalanceV3(w, r) } + +func GetWalletV4(s *Service) func(w http.ResponseWriter, r *http.Request) *handlers.AppError { + return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { + var ctx = r.Context() + + l := logging.Logger(ctx, "wallet") + + var id inputs.ID + if err := inputs.DecodeAndValidateString(ctx, &id, chi.URLParam(r, "paymentID")); err != nil { + l.Warn().Err(err).Str("paymentID", id.String()).Msg("failed to decode and validate paymentID from url") + return handlers.ValidationError("Error validating paymentID url parameter", map[string]interface{}{ + "paymentId": err.Error(), + }) + } + + paymentID := *id.UUID() + + info, err := s.Datastore.GetWallet(ctx, paymentID) + if err != nil { + l.Error().Err(err).Str("paymentID", id.String()).Msg("error getting wallet") + return handlers.WrapError(err, "error getting wallet from storage", http.StatusInternalServerError) + } + + if info == nil { + l.Info().Interface("paymentID", paymentID).Msg("wallet not found") + return handlers.WrapError(err, "no such wallet", http.StatusNotFound) + } + + allow, err := s.allowListRepo.GetAllowListEntry(ctx, s.Datastore.RawDB(), paymentID) + if err != nil && !errors.Is(err, model.ErrNotFound) { + return handlers.WrapError(err, "error getting allow list entry from storage", http.StatusInternalServerError) + } + + isSelfCustAvail := allow.IsAllowed(paymentID) + + resp := infoToResponseV4(info, isSelfCustAvail) + + return handlers.RenderContent(ctx, resp, w, http.StatusOK) + } +} diff --git a/services/wallet/controllers_v4_test.go b/services/wallet/controllers_v4_test.go index cd0b0dcdd..37474abab 100644 --- a/services/wallet/controllers_v4_test.go +++ b/services/wallet/controllers_v4_test.go @@ -7,6 +7,7 @@ import ( "context" "crypto" "crypto/ed25519" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -14,7 +15,10 @@ import ( "net/http/httptest" "testing" + appctx "github.com/brave-intl/bat-go/libs/context" errorutils "github.com/brave-intl/bat-go/libs/errors" + "github.com/brave-intl/bat-go/libs/middleware" + "github.com/brave-intl/bat-go/services/wallet/storage" "github.com/brave-intl/bat-go/libs/clients" @@ -73,11 +77,11 @@ func (suite *WalletControllersV4TestSuite) TestCreateBraveWalletV4_Success() { Validate(gomock.Any(), geoCountry). Return(true, nil) - service, err := wallet.InitService(storage, nil, reputationClient, nil, locationValidator, backoff.Retry, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, reputationClient, nil, locationValidator, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: geoCountry, @@ -120,11 +124,11 @@ func (suite *WalletControllersV4TestSuite) TestCreateBraveWalletV4_GeoCountryDis Validate(gomock.Any(), gomock.Any()). Return(false, nil) - service, err := wallet.InitService(nil, nil, nil, nil, locationValidator, backoff.Retry, nil, nil) + service, err := wallet.InitService(nil, nil, nil, nil, nil, nil, locationValidator, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: "AF", @@ -172,11 +176,11 @@ func (suite *WalletControllersV4TestSuite) TestCreateBraveWalletV4_WalletAlready Validate(gomock.Any(), geoCountry). Return(true, nil) - service, err := wallet.InitService(storage, nil, nil, nil, locationValidator, nil, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, nil, nil, locationValidator, nil, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: geoCountry, @@ -243,11 +247,11 @@ func (suite *WalletControllersV4TestSuite) TestCreateBraveWalletV4_ReputationCal Validate(gomock.Any(), gomock.Any()). Return(true, nil) - service, err := wallet.InitService(storage, nil, reputationClient, nil, locationValidator, backoff.Retry, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, reputationClient, nil, locationValidator, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: "AF", @@ -293,7 +297,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_Success() { UpsertReputationSummary(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil) - service, err := wallet.InitService(storage, nil, reputationClient, nil, nil, backoff.Retry, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, reputationClient, nil, nil, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) // create rewards wallet with public key @@ -314,7 +318,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_Success() { suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: "AF", @@ -343,7 +347,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_VerificationM storage, err := wallet.NewWritablePostgres("", false, "") suite.NoError(err) - service, err := wallet.InitService(storage, nil, nil, nil, nil, backoff.Retry, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, nil, nil, nil, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) publicKey, privateKey, err := httpsignature.GenerateEd25519Key(nil) @@ -352,7 +356,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_VerificationM paymentID := uuid.NewV5(wallet.ClaimNamespace, publicKey.String()).String() router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: "AF", @@ -381,7 +385,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_PaymentIDMism storage, err := wallet.NewWritablePostgres("", false, "") suite.NoError(err) - service, err := wallet.InitService(storage, nil, nil, nil, nil, backoff.Retry, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, nil, nil, nil, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) // create rewards wallet with public key @@ -402,7 +406,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_PaymentIDMism suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: "AF", @@ -448,7 +452,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_GeoCountryAlr UpsertReputationSummary(gomock.Any(), gomock.Any(), gomock.Any()). Return(errorBundle) - service, err := wallet.InitService(storage, nil, reputationClient, nil, nil, backoff.Retry, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, reputationClient, nil, nil, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) // create rewards wallet with public key @@ -469,7 +473,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_GeoCountryAlr suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: "AF", @@ -513,7 +517,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_ReputationCal UpsertReputationSummary(gomock.Any(), gomock.Any(), gomock.Any()). Return(errReputation) - service, err := wallet.InitService(storage, nil, reputationClient, nil, nil, backoff.Retry, nil, nil) + service, err := wallet.InitService(storage, nil, nil, nil, reputationClient, nil, nil, backoff.Retry, nil, nil, wallet.DAppConfig{}) suite.Require().NoError(err) // create rewards wallet with public key @@ -534,7 +538,7 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_ReputationCal suite.Require().NoError(err) router := chi.NewRouter() - wallet.RegisterRoutes(ctx, service, router) + wallet.RegisterRoutes(ctx, service, router, noOpHandler(), noOpMw()) data := wallet.V4Request{ GeoCountry: "AF", @@ -557,6 +561,94 @@ func (suite *WalletControllersV4TestSuite) TestUpdateBraveWalletV4_ReputationCal suite.Require().Equal(http.StatusInternalServerError, rw.Code) } +func (suite *WalletControllersTestSuite) TestGetWalletV4() { + pg, _, err := wallet.NewPostgres() + suite.Require().NoError(err) + + pub, _, err := ed25519.GenerateKey(nil) + suite.Require().NoError(err) + + paymentID := uuid.NewV4() + w := &walletutils.Info{ + ID: paymentID.String(), + Provider: "brave", + PublicKey: hex.EncodeToString(pub), + AltCurrency: ptrTo(altcurrency.BAT), + } + err = pg.InsertWallet(context.TODO(), w) + suite.Require().NoError(err) + + whitelistWallet(suite.T(), pg, w.ID) + + allowList := storage.NewAllowList() + + service, _ := wallet.InitService(pg, nil, nil, allowList, nil, nil, nil, nil, nil, nil, wallet.DAppConfig{}) + + handler := handlers.AppHandler(wallet.GetWalletV4(service)) + + req, err := http.NewRequest("GET", "/v4/wallets/"+paymentID.String(), nil) + suite.Require().NoError(err, "a request should be created") + + req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) + rctx := chi.NewRouteContext() + rctx.URLParams.Add("paymentID", paymentID.String()) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + suite.Assert().Equal(http.StatusOK, rr.Code) + + var resp wallet.ResponseV4 + err = json.Unmarshal(rr.Body.Bytes(), &resp) + suite.Require().NoError(err) + + suite.Assert().Equal(true, resp.SelfCustodyAvailable["solana"]) +} + +func (suite *WalletControllersTestSuite) TestGetWalletV4_Not_Whitelisted() { + pg, _, err := wallet.NewPostgres() + suite.Require().NoError(err) + + pub, _, err := ed25519.GenerateKey(nil) + suite.Require().NoError(err) + + paymentID := uuid.NewV4() + w := &walletutils.Info{ + ID: paymentID.String(), + Provider: "brave", + PublicKey: hex.EncodeToString(pub), + AltCurrency: ptrTo(altcurrency.BAT), + } + err = pg.InsertWallet(context.TODO(), w) + suite.Require().NoError(err) + + allowList := storage.NewAllowList() + + service, _ := wallet.InitService(pg, nil, nil, allowList, nil, nil, nil, nil, nil, nil, wallet.DAppConfig{}) + + handler := handlers.AppHandler(wallet.GetWalletV4(service)) + + req, err := http.NewRequest("GET", "/v4/wallets/"+paymentID.String(), nil) + suite.Require().NoError(err, "a request should be created") + + req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) + rctx := chi.NewRouteContext() + rctx.URLParams.Add("paymentID", paymentID.String()) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + suite.Assert().Equal(http.StatusOK, rr.Code) + + var resp wallet.ResponseV4 + err = json.Unmarshal(rr.Body.Bytes(), &resp) + suite.Require().NoError(err) + + suite.Assert().Equal(false, resp.SelfCustodyAvailable["solana"]) +} + func signUpdateRequest(req *http.Request, paymentID string, privateKey ed25519.PrivateKey) error { var s httpsignature.SignatureParams s.Algorithm = httpsignature.ED25519 @@ -564,3 +656,25 @@ func signUpdateRequest(req *http.Request, paymentID string, privateKey ed25519.P s.Headers = []string{"digest", "(request-target)"} return s.Sign(privateKey, crypto.Hash(0), req) } + +func noOpHandler() middleware.InstrumentHandlerDef { + return func(name string, h http.Handler) http.Handler { + return h + } +} + +func noOpMw() func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + }) + } +} + +func signRequest(req *http.Request, publicKey httpsignature.Ed25519PubKey, privateKey ed25519.PrivateKey) error { + var s httpsignature.SignatureParams + s.Algorithm = httpsignature.ED25519 + s.KeyID = hex.EncodeToString(publicKey) + s.Headers = []string{"digest", "(request-target)"} + return s.Sign(privateKey, crypto.Hash(0), req) +} diff --git a/services/wallet/datastore.go b/services/wallet/datastore.go index 071897a82..8461d71cb 100644 --- a/services/wallet/datastore.go +++ b/services/wallet/datastore.go @@ -8,6 +8,7 @@ import ( "net/http" "os" "strconv" + "strings" "time" "github.com/brave-intl/bat-go/libs/backoff" @@ -63,7 +64,7 @@ func init() { // Datastore holds the interface for the wallet datastore type Datastore interface { datastore.Datastore - LinkWallet(ctx context.Context, id string, providerID string, providerLinkingID uuid.UUID, depositProvider, country string) error + LinkWallet(ctx context.Context, id string, providerID string, providerLinkingID uuid.UUID, depositProvider string) error GetLinkingLimitInfo(ctx context.Context, providerLinkingID string) (map[string]LinkingInfo, error) HasPriorLinking(ctx context.Context, walletID uuid.UUID, providerLinkingID uuid.UUID) (bool, error) // GetLinkingsByProviderLinkingID gets the wallet linking info by provider linking id @@ -194,7 +195,9 @@ func (pg *Postgres) UpsertWallet(ctx context.Context, wallet *walletutils.Info) return nil } -// GetWallet by ID +// TODO(clD11): address GetWallet in wallet refactor. + +// GetWallet retrieves a wallet by its walletID, if no wallet is found then nil is returned. func (pg *Postgres) GetWallet(ctx context.Context, ID uuid.UUID) (*walletutils.Info, error) { statement := ` select @@ -553,23 +556,12 @@ var ( ) // LinkWallet links a rewards wallet to the given deposit provider. -func (pg *Postgres) LinkWallet(ctx context.Context, id string, userDepositDestination string, providerLinkingID uuid.UUID, depositProvider, country string) error { +func (pg *Postgres) LinkWallet(ctx context.Context, id string, userDepositDestination string, providerLinkingID uuid.UUID, depositProvider string) error { walletID, err := uuid.FromString(id) if err != nil { return fmt.Errorf("invalid wallet id, not uuid: %w", err) } - repClient, ok := ctx.Value(appctx.ReputationClientCTXKey).(reputation.Client) - if !ok { - return ErrNoReputationClient - } - - // TODO(clD11): We no longer need to act on the response and only require a successful call to reputation to - // continue linking. As part of the wallet refactor we should clean this up. - if _, _, err := repClient.IsLinkingReputable(ctx, walletID, country); err != nil { - return fmt.Errorf("failed to check wallet rep: %w", err) - } - ctx, tx, rollback, commit, err := getTx(ctx, pg) if err != nil { return fmt.Errorf("error getting tx: %w", err) @@ -600,9 +592,15 @@ func (pg *Postgres) LinkWallet(ctx context.Context, id string, userDepositDestin } if directVerifiedWalletEnable { + repClient, ok := ctx.Value(appctx.ReputationClientCTXKey).(reputation.Client) + if !ok { + return ErrNoReputationClient + } + op := func() (interface{}, error) { return nil, repClient.UpdateReputationSummary(ctx, walletID.String(), true) } + if _, err := backoff.Retry(ctx, op, retryPolicy, canRetry(nonRetriableErrors)); err != nil { return fmt.Errorf("failed to update verified wallet: %w", err) } @@ -616,17 +614,29 @@ func (pg *Postgres) LinkWallet(ctx context.Context, id string, userDepositDestin return nil } -// CustodianLink - representation of wallet_custodian record +// TODO(clD11): CustodianLink represent a wallet_custodian. Review during wallet refactor + +// CustodianLink representation a wallet_custodian record. type CustodianLink struct { - WalletID *uuid.UUID `json:"wallet_id" db:"wallet_id" valid:"uuidv4"` - Custodian string `json:"custodian" db:"custodian" valid:"in(uphold,brave,gemini,bitflyer)"` - CreatedAt time.Time `json:"created_at" db:"created_at" valid:"-"` - UpdatedAt *time.Time `json:"updated_at" db:"updated_at" valid:"-"` - LinkedAt time.Time `json:"linked_at" db:"linked_at" valid:"-"` - DisconnectedAt *time.Time `json:"disconnected_at" db:"disconnected_at" valid:"-"` - DepositDestination string `json:"deposit_destination" db:"deposit_destination" valid:"-"` - LinkingID *uuid.UUID `json:"linking_id" db:"linking_id" valid:"uuid"` - UnlinkedAt *time.Time `json:"unlinked_at" db:"unlinked_at" valid:"-"` + WalletID *uuid.UUID `json:"wallet_id" db:"wallet_id" valid:"uuidv4"` + Custodian string `json:"custodian" db:"custodian" valid:"in(uphold,brave,gemini,bitflyer)"` + CreatedAt time.Time `json:"created_at" db:"created_at" valid:"-"` + UpdatedAt *time.Time `json:"updated_at" db:"updated_at" valid:"-"` + LinkedAt time.Time `json:"linked_at" db:"linked_at" valid:"-"` + DisconnectedAt *time.Time `json:"disconnected_at" db:"disconnected_at" valid:"-"` + LinkingID *uuid.UUID `json:"linking_id" db:"linking_id" valid:"uuid"` + UnlinkedAt *time.Time `json:"unlinked_at" db:"unlinked_at" valid:"-"` +} + +// TODO(clD11): Wallet Refactor. These should not be nullable, fix pointers and raname fields for consistency. + +func NewSolanaCustodialLink(walletID uuid.UUID, depositDestination string) *CustodianLink { + const depositProviderSolana = "solana" + return &CustodianLink{ + WalletID: &walletID, + LinkingID: ptrFromUUID(uuid.NewV5(ClaimNamespace, depositDestination)), + Custodian: depositProviderSolana, + } } // GetWalletIDString - get string version of the WalletID @@ -716,16 +726,10 @@ func commitFn(ctx context.Context, tx *sqlx.Tx) func() error { // getTx will get or create a tx on the context, if created hands back rollback and commit functions func getTx(ctx context.Context, datastore Datastore) (context.Context, *sqlx.Tx, func(), func() error, error) { - // create a sublogger - sublogger := logger(ctx) - sublogger.Debug().Msg("getting tx from context") - // get tx tx, noContextTx := ctx.Value(appctx.DatabaseTransactionCTXKey).(*sqlx.Tx) if !noContextTx { - sublogger.Debug().Msg("no tx in context") tx, err := createTx(ctx, datastore) if err != nil || tx == nil { - sublogger.Error().Err(err).Msg("error creating tx") return ctx, nil, func() {}, func() error { return nil }, fmt.Errorf("failed to create tx: %w", err) } ctx = context.WithValue(ctx, appctx.DatabaseTransactionCTXKey, tx) @@ -861,8 +865,19 @@ func (pg *Postgres) ConnectCustodialWallet(ctx context.Context, cl *CustodianLin return fmt.Errorf("failed to get linking id from custodian record: %w", err) } - if !uuid.Equal(existingLinkingID, *new(uuid.UUID)) { - // check if the member matches the associated member + // TODO(clD11): WR. The relinking check below only considers the currently linked wallet and not + // the custodian. We can combine/refactor these checks for both linkings, custodians and limits. + + if err := validateCustodianLinking(ctx, pg, *cl.WalletID, cl.Custodian); err != nil { + if errors.Is(err, errCustodianLinkMismatch) { + return errCustodianLinkMismatch + } + return handlers.WrapError(err, "failed to check linking mismatch", http.StatusInternalServerError) + } + + // Relinking. + if !uuid.Equal(existingLinkingID, *new(uuid.UUID)) { // if not a new wallet + // check if the currently linked wallet matches the proposed linking. if !uuid.Equal(*cl.LinkingID, existingLinkingID) { return handlers.WrapError(errors.New("wallets do not match"), "mismatched provider accounts", http.StatusForbidden) } @@ -1015,20 +1030,33 @@ func (pg *Postgres) SendVerifiedWalletOutbox(ctx context.Context, client reputat return true, nil } -// helper to make logger easier +func validateCustodianLinking(ctx context.Context, storage Datastore, walletID uuid.UUID, depositProvider string) error { + c, err := storage.GetCustodianLinkByWalletID(ctx, walletID) + if err != nil && !errors.Is(err, model.ErrNoWalletCustodian) { + return err + } + + // if there are no instances of wallet custodian then it is + // considered a new linking and therefore valid. + if c == nil { + return nil + } + + if !strings.EqualFold(c.Custodian, depositProvider) { + return errCustodianLinkMismatch + } + + return nil +} + func logger(ctx context.Context) *zerolog.Logger { - // get logger return logging.Logger(ctx, "wallet") } // helper to create a tx -func createTx(ctx context.Context, datastore Datastore) (tx *sqlx.Tx, err error) { - logger(ctx).Debug(). - Msg("creating transaction") +func createTx(_ context.Context, datastore Datastore) (tx *sqlx.Tx, err error) { tx, err = datastore.RawDB().Beginx() if err != nil { - logger(ctx).Error().Err(err). - Msg("error creating transaction") return tx, fmt.Errorf("failed to create transaction: %w", err) } return tx, nil @@ -1043,3 +1071,7 @@ func waitAndLockTx(ctx context.Context, tx *sqlx.Tx, id uuid.UUID) error { } return nil } + +func ptrFromUUID(u uuid.UUID) *uuid.UUID { + return &u +} diff --git a/services/wallet/datastore_test.go b/services/wallet/datastore_test.go index 477ee9136..93984ed41 100644 --- a/services/wallet/datastore_test.go +++ b/services/wallet/datastore_test.go @@ -58,7 +58,7 @@ func (suite *WalletPostgresTestSuite) TearDownTest() { } func (suite *WalletPostgresTestSuite) CleanDB() { - tables := []string{"claim_creds", "claims", "wallets", "issuers", "promotions"} + tables := []string{"claim_creds", "claims", "wallets", "issuers", "promotions", "verified_wallet_outbox"} pg, _, err := NewPostgres() suite.Require().NoError(err, "Failed to get postgres conn") @@ -232,7 +232,8 @@ func (suite *WalletPostgresTestSuite) TestLinkWallet_Concurrent_InsertUpdate() { for i := 0; i < runs; i++ { go func() { defer wg.Done() - err = pg.LinkWallet(ctx, walletInfo.ID, userDepositDestination, providerLinkingID, walletInfo.Provider, "") + err := pg.LinkWallet(ctx, walletInfo.ID, userDepositDestination, providerLinkingID, walletInfo.Provider) + suite.Require().NoError(err) }() } wg.Wait() @@ -272,7 +273,7 @@ func (suite *WalletPostgresTestSuite) seedWallet(pg Datastore) (string, uuid.UUI err := pg.UpsertWallet(ctx, walletInfo) suite.Require().NoError(err, "save wallet should succeed") - err = pg.LinkWallet(ctx, walletInfo.ID, userDepositDestination, providerLinkingID, "uphold", "") + err = pg.LinkWallet(ctx, walletInfo.ID, userDepositDestination, providerLinkingID, "uphold") suite.Require().NoError(err, "link wallet should succeed") } @@ -322,7 +323,8 @@ func (suite *WalletPostgresTestSuite) TestLinkWallet_Concurrent_MaxLinkCount() { for i := 0; i < len(wallets); i++ { go func(index int) { defer wg.Done() - err = pg.LinkWallet(ctx, wallets[index].ID, userDepositDestination, providerLinkingID, wallets[index].Provider, "") + // Once we reach the limit this will return an error which is expected hence we can ignore it. + _ = pg.LinkWallet(ctx, wallets[index].ID, userDepositDestination, providerLinkingID, wallets[index].Provider) }(i) } wg.Wait() diff --git a/services/wallet/instrumented_datastore.go b/services/wallet/instrumented_datastore.go index 7bcf2e1f0..18f962c4b 100644 --- a/services/wallet/instrumented_datastore.go +++ b/services/wallet/instrumented_datastore.go @@ -13,6 +13,7 @@ import ( "github.com/brave-intl/bat-go/libs/backoff" "github.com/brave-intl/bat-go/libs/clients/reputation" walletutils "github.com/brave-intl/bat-go/libs/wallet" + migrate "github.com/golang-migrate/migrate/v4" "github.com/jmoiron/sqlx" "github.com/prometheus/client_golang/prometheus" @@ -255,7 +256,7 @@ func (_d DatastoreWithPrometheus) InsertWalletTx(ctx context.Context, tx *sqlx.T } // LinkWallet implements Datastore -func (_d DatastoreWithPrometheus) LinkWallet(ctx context.Context, id string, providerID string, providerLinkingID uuid.UUID, depositProvider string, country string) (err error) { +func (_d DatastoreWithPrometheus) LinkWallet(ctx context.Context, id string, providerID string, providerLinkingID uuid.UUID, depositProvider string) (err error) { _since := time.Now() defer func() { result := "ok" @@ -265,7 +266,7 @@ func (_d DatastoreWithPrometheus) LinkWallet(ctx context.Context, id string, pro datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "LinkWallet", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.LinkWallet(ctx, id, providerID, providerLinkingID, depositProvider, country) + return _d.base.LinkWallet(ctx, id, providerID, providerLinkingID, depositProvider) } // Migrate implements Datastore diff --git a/services/wallet/keystore_test.go b/services/wallet/keystore_test.go deleted file mode 100644 index eef23498c..000000000 --- a/services/wallet/keystore_test.go +++ /dev/null @@ -1,530 +0,0 @@ -//go:build integration - -package wallet_test - -import ( - "bytes" - "context" - "crypto" - "crypto/ed25519" - "encoding/hex" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/brave-intl/bat-go/libs/altcurrency" - mock_reputation "github.com/brave-intl/bat-go/libs/clients/reputation/mock" - appctx "github.com/brave-intl/bat-go/libs/context" - "github.com/brave-intl/bat-go/libs/handlers" - "github.com/brave-intl/bat-go/libs/httpsignature" - walletutils "github.com/brave-intl/bat-go/libs/wallet" - "github.com/brave-intl/bat-go/libs/wallet/provider/uphold" - "github.com/brave-intl/bat-go/services/wallet" - "github.com/go-chi/chi" - "github.com/golang/mock/gomock" - uuid "github.com/satori/go.uuid" - "github.com/shopspring/decimal" - "github.com/stretchr/testify/suite" -) - -type WalletControllersTestSuite struct { - suite.Suite -} - -func TestWalletControllersTestSuite(t *testing.T) { - suite.Run(t, new(WalletControllersTestSuite)) -} - -func (suite *WalletControllersTestSuite) SetupSuite() { - pg, _, err := wallet.NewPostgres() - suite.Require().NoError(err, "Failed to get postgres conn") - - m, err := pg.NewMigrate() - suite.Require().NoError(err, "Failed to create migrate instance") - - ver, dirty, _ := m.Version() - if dirty { - suite.Require().NoError(m.Force(int(ver))) - } - if ver > 0 { - suite.Require().NoError(m.Down(), "Failed to migrate down cleanly") - } - - suite.Require().NoError(pg.Migrate(), "Failed to fully migrate") -} - -func (suite *WalletControllersTestSuite) SetupTest() { - suite.CleanDB() -} - -func (suite *WalletControllersTestSuite) TearDownTest() { - suite.CleanDB() -} - -func (suite *WalletControllersTestSuite) CleanDB() { - tables := []string{"claim_creds", "claims", "wallets", "issuers", "promotions", "wallet_custodian"} - - pg, _, err := wallet.NewPostgres() - suite.Require().NoError(err, "Failed to get postgres conn") - - for _, table := range tables { - _, err = pg.RawDB().Exec("delete from " + table) - suite.Require().NoError(err, "Failed to get clean table") - } -} - -func noUUID() *uuid.UUID { - return nil -} - -func (suite *WalletControllersTestSuite) FundWallet(w *uphold.Wallet, probi decimal.Decimal) decimal.Decimal { - ctx := context.Background() - balanceBefore, err := w.GetBalance(ctx, true) - total, err := uphold.FundWallet(ctx, w, probi) - suite.Require().NoError(err, "an error should not be generated from funding the wallet") - suite.Require().True(total.GreaterThan(balanceBefore.TotalProbi), "submit with confirm should result in an increased balance") - return total -} - -func (suite *WalletControllersTestSuite) CheckBalance(w *uphold.Wallet, expect decimal.Decimal) { - balances, err := w.GetBalance(context.Background(), true) - suite.Require().NoError(err, "an error should not be generated from checking the wallet balance") - totalProbi := altcurrency.BAT.FromProbi(balances.TotalProbi) - errMessage := fmt.Sprintf("got an unexpected balance. expected: %s, got %s", expect.String(), totalProbi.String()) - suite.Require().True(expect.Equal(totalProbi), errMessage) -} - -func (suite *WalletControllersTestSuite) TestBalanceV3() { - pg, _, err := wallet.NewPostgres() - suite.Require().NoError(err, "Failed to get postgres connection") - - mockCtrl := gomock.NewController(suite.T()) - defer mockCtrl.Finish() - - service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil) - - w1 := suite.NewWallet(service, "uphold") - - bat1 := decimal.NewFromFloat(0.000000001) - - suite.FundWallet(w1, bat1) - - // check there is 1 bat in w1 - suite.CheckBalance(w1, bat1) - - // call the balance endpoint and check that you get back a total of 1 - handler := wallet.GetUpholdWalletBalanceV3 - - req, err := http.NewRequest("GET", "/v3/wallet/uphold/{paymentID}", nil) - suite.Require().NoError(err, "wallet claim request could not be created") - - rctx := chi.NewRouteContext() - rctx.URLParams.Add("paymentID", w1.ID) - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) - req = req.WithContext(context.WithValue(req.Context(), appctx.RODatastoreCTXKey, pg)) - req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) - - rr := httptest.NewRecorder() - handlers.AppHandler(handler).ServeHTTP(rr, req) - suite.Require().Equal(http.StatusOK, rr.Code, fmt.Sprintf("status is expected to match %d: %s", http.StatusOK, rr.Body.String())) - - var balance wallet.BalanceResponseV3 - err = json.Unmarshal(rr.Body.Bytes(), &balance) - suite.Require().NoError(err, "failed to unmarshal balance result") - - suite.Require().Equal(balance.Total, float64(0.000000001), fmt.Sprintf("balance is expected to match %f: %f", balance.Total, float64(1))) - - _, err = pg.RawDB().Exec(`update wallets set provider_id = '' where id = $1`, w1.ID) - suite.Require().NoError(err, "wallet provider_id could not be set as empty string") - - req, err = http.NewRequest("GET", "/v3/wallet/uphold/{paymentID}", nil) - suite.Require().NoError(err, "wallet claim request could not be created") - - rctx = chi.NewRouteContext() - rctx.URLParams.Add("paymentID", w1.ID) - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) - req = req.WithContext(context.WithValue(req.Context(), appctx.RODatastoreCTXKey, pg)) - req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) - - rr = httptest.NewRecorder() - handlers.AppHandler(handler).ServeHTTP(rr, req) - expectingForbidden := fmt.Sprintf("status is expected to match %d: %s", http.StatusForbidden, rr.Body.String()) - suite.Require().Equal(http.StatusForbidden, rr.Code, expectingForbidden) -} - -func (suite *WalletControllersTestSuite) TestLinkWalletV3() { - ctx := context.Background() - - pg, _, err := wallet.NewPostgres() - suite.Require().NoError(err, "Failed to get postgres connection") - - mockCtrl := gomock.NewController(suite.T()) - defer mockCtrl.Finish() - - service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil) - - w1 := suite.NewWallet(service, "uphold") - w2 := suite.NewWallet(service, "uphold") - w3 := suite.NewWallet(service, "uphold") - w4 := suite.NewWallet(service, "uphold") - bat1 := decimal.NewFromFloat(0.000000001) - bat2 := decimal.NewFromFloat(0.000000002) - - suite.FundWallet(w1, bat1) - suite.FundWallet(w2, bat1) - suite.FundWallet(w3, bat1) - suite.FundWallet(w4, bat1) - - anonCard1ID, err := w1.CreateCardAddress(ctx, "anonymous") - suite.Require().NoError(err, "create anon card must not fail") - anonCard1UUID := uuid.Must(uuid.FromString(anonCard1ID)) - - anonCard2ID, err := w2.CreateCardAddress(ctx, "anonymous") - suite.Require().NoError(err, "create anon card must not fail") - anonCard2UUID := uuid.Must(uuid.FromString(anonCard2ID)) - - anonCard3ID, err := w3.CreateCardAddress(ctx, "anonymous") - suite.Require().NoError(err, "create anon card must not fail") - anonCard3UUID := uuid.Must(uuid.FromString(anonCard3ID)) - - w1ProviderID := w1.GetWalletInfo().ProviderID - w2ProviderID := w2.GetWalletInfo().ProviderID - w3ProviderID := w3.GetWalletInfo().ProviderID - - zero := decimal.NewFromFloat(0) - - suite.CheckBalance(w1, bat1) - suite.claimCardV3(service, w1, w3ProviderID, http.StatusOK, bat1, &anonCard3UUID) - suite.CheckBalance(w1, zero) - - suite.CheckBalance(w2, bat1) - suite.claimCardV3(service, w2, w1ProviderID, http.StatusOK, zero, &anonCard1UUID) - suite.CheckBalance(w2, bat1) - - suite.CheckBalance(w2, bat1) - suite.claimCardV3(service, w2, w1ProviderID, http.StatusOK, bat1, &anonCard3UUID) - suite.CheckBalance(w2, zero) - - suite.CheckBalance(w3, bat2) - suite.claimCardV3(service, w3, w2ProviderID, http.StatusOK, bat1, &anonCard3UUID) - suite.CheckBalance(w3, bat1) - - suite.CheckBalance(w3, bat1) - suite.claimCardV3(service, w3, w1ProviderID, http.StatusOK, zero, &anonCard2UUID) - suite.CheckBalance(w3, bat1) -} - -func (suite *WalletControllersTestSuite) claimCardV3( - service *wallet.Service, - w *uphold.Wallet, - destination string, - status int, - amount decimal.Decimal, - anonymousAddress *uuid.UUID, -) (*walletutils.Info, string) { - signedCreationRequest, err := w.PrepareTransaction(*w.AltCurrency, altcurrency.BAT.ToProbi(amount), destination, "", "", nil) - - suite.Require().NoError(err, "transaction must be signed client side") - - // V3 Payload - reqBody := wallet.LinkUpholdDepositAccountRequest{ - SignedLinkingRequest: signedCreationRequest, - } - - if anonymousAddress != nil { - reqBody.AnonymousAddress = anonymousAddress.String() - } - - body, err := json.Marshal(&reqBody) - suite.Require().NoError(err, "unable to marshal claim body") - - info := w.GetWalletInfo() - - // V3 Handler - - handler := wallet.LinkUpholdDepositAccountV3(service) - - req, err := http.NewRequest("POST", "/v3/wallet/{paymentID}/claim", bytes.NewBuffer(body)) - suite.Require().NoError(err, "wallet claim request could not be created") - - ctrl := gomock.NewController(suite.T()) - defer ctrl.Finish() - - repClient := mock_reputation.NewMockClient(ctrl) - repClient.EXPECT().IsLinkingReputable(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - - rctx := chi.NewRouteContext() - rctx.URLParams.Add("paymentID", info.ID) - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) - req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) - req = req.WithContext(context.WithValue(req.Context(), appctx.ReputationClientCTXKey, repClient)) - - rr := httptest.NewRecorder() - handlers.AppHandler(handler).ServeHTTP(rr, req) - suite.Require().Equal(status, rr.Code, fmt.Sprintf("status is expected to match %d: %s", status, rr.Body.String())) - linked, err := service.Datastore.GetWallet(req.Context(), uuid.Must(uuid.FromString(w.ID))) - suite.Require().NoError(err, "retrieving the wallet did not cause an error") - return linked, rr.Body.String() -} - -func (suite *WalletControllersTestSuite) createBody( - tx string, -) string { - reqBody, _ := json.Marshal(wallet.UpholdCreationRequest{ - SignedCreationRequest: tx, - }) - return string(reqBody) -} - -func (suite *WalletControllersTestSuite) NewWallet(service *wallet.Service, provider string) *uphold.Wallet { - publicKey, privKey, err := httpsignature.GenerateEd25519Key(nil) - publicKeyString := hex.EncodeToString(publicKey) - - bat := altcurrency.BAT - info := walletutils.Info{ - ID: uuid.NewV4().String(), - PublicKey: publicKeyString, - Provider: provider, - AltCurrency: &bat, - } - w := &uphold.Wallet{ - Info: info, - PrivKey: privKey, - PubKey: publicKey, - } - - reg, err := w.PrepareRegistration("Brave Browser Test Link") - suite.Require().NoError(err, "unable to prepare transaction") - - createResp := suite.createUpholdWalletV3( - service, - suite.createBody(reg), - http.StatusCreated, - publicKey, - privKey, - true, - ) - - var returnedInfo wallet.ResponseV3 - err = json.Unmarshal([]byte(createResp), &returnedInfo) - suite.Require().NoError(err, "unable to create wallet") - convertedInfo := wallet.ResponseV3ToInfo(returnedInfo) - w.Info = *convertedInfo - return w -} - -func (suite *WalletControllersTestSuite) TestCreateBraveWalletV3() { - pg, _, err := wallet.NewPostgres() - suite.Require().NoError(err, "Failed to get postgres connection") - - service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil) - - publicKey, privKey, err := httpsignature.GenerateEd25519Key(nil) - - // assume 400 is already covered - // fail because of lacking signature presence - notSignedResponse := suite.createUpholdWalletV3( - service, - `{}`, - http.StatusBadRequest, - publicKey, - privKey, - false, - ) - - suite.Assert().JSONEq(`{"code":400, "data":{"validationErrors":{"decoding":"failed decoding: failed to decode signed creation request: unexpected end of JSON input", "signedCreationRequest":"value is required", "validation":"failed validation: missing signed creation request"}}, "message":"Error validating uphold create wallet request validation errors"}`, notSignedResponse, "field is not valid") - - createResp := suite.createBraveWalletV3( - service, - ``, - http.StatusCreated, - publicKey, - privKey, - true, - ) - - var created wallet.ResponseV3 - err = json.Unmarshal([]byte(createResp), &created) - suite.Require().NoError(err, "unable to unmarshal response") - - getResp := suite.getWallet(service, uuid.Must(uuid.FromString(created.PaymentID)), http.StatusOK) - - var gotten wallet.ResponseV3 - err = json.Unmarshal([]byte(getResp), &gotten) - suite.Require().NoError(err, "unable to unmarshal response") - // does not return wallet provider - suite.Require().Equal(created, gotten, "the get and create return the same structure") -} - -func (suite *WalletControllersTestSuite) TestCreateUpholdWalletV3() { - pg, _, err := wallet.NewPostgres() - suite.Require().NoError(err, "Failed to get postgres connection") - - service, _ := wallet.InitService(pg, nil, nil, nil, nil, nil, nil, nil) - - publicKey, privKey, err := httpsignature.GenerateEd25519Key(nil) - - badJSONBodyParse := suite.createUpholdWalletV3( - service, - ``, - http.StatusBadRequest, - publicKey, - privKey, - true, - ) - suite.Assert().JSONEq(`{ - "code":400, - "data": { - "validationErrors":{ - "decoding":"failed decoding: failed to decode json: EOF", - "signedCreationRequest":"value is required", - "validation":"failed validation: missing signed creation request" - } - }, - "message":"Error validating uphold create wallet request validation errors" - }`, badJSONBodyParse, "should fail when parsing json") - - badFieldResponse := suite.createUpholdWalletV3( - service, - `{"signedCreationRequest":""}`, - http.StatusBadRequest, - publicKey, - privKey, - true, - ) - - suite.Assert().JSONEq(`{ - "code":400, "data":{"validationErrors":{"decoding":"failed decoding: failed to decode signed creation request: unexpected end of JSON input", "signedCreationRequest":"value is required", "validation":"failed validation: missing signed creation request"}}, "message":"Error validating uphold create wallet request validation errors"}`, badFieldResponse, "field is not valid") - - // assume 403 is already covered - // fail because of lacking signature presence - notSignedResponse := suite.createUpholdWalletV3( - service, - `{}`, - http.StatusBadRequest, - publicKey, - privKey, - false, - ) - - suite.Assert().JSONEq(`{ -"code":400, "data":{"validationErrors":{"decoding":"failed decoding: failed to decode signed creation request: unexpected end of JSON input", "signedCreationRequest":"value is required", "validation":"failed validation: missing signed creation request"}}, "message":"Error validating uphold create wallet request validation errors" - }`, notSignedResponse, "field is not valid") -} - -func (suite *WalletControllersTestSuite) getWallet( - service *wallet.Service, - paymentId uuid.UUID, - code int, -) string { - handler := handlers.AppHandler(wallet.GetWalletV3) - - req, err := http.NewRequest("GET", "/v3/wallet/"+paymentId.String(), nil) - suite.Require().NoError(err, "a request should be created") - - req = req.WithContext(context.WithValue(req.Context(), appctx.DatastoreCTXKey, service.Datastore)) - req = req.WithContext(context.WithValue(req.Context(), appctx.RODatastoreCTXKey, service.Datastore)) - req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) - rctx := chi.NewRouteContext() - rctx.URLParams.Add("paymentID", paymentId.String()) - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - suite.Require().Equal(code, rr.Code, "known status code should be sent: "+rr.Body.String()) - - return rr.Body.String() -} - -func (suite *WalletControllersTestSuite) createBraveWalletV3( - service *wallet.Service, - body string, - code int, - publicKey httpsignature.Ed25519PubKey, - privateKey ed25519.PrivateKey, - shouldSign bool, -) string { - - handler := handlers.AppHandler(wallet.CreateBraveWalletV3) - - bodyBuffer := bytes.NewBuffer([]byte(body)) - req, err := http.NewRequest("POST", "/v3/wallet/brave", bodyBuffer) - suite.Require().NoError(err, "a request should be created") - - // setup context - req = req.WithContext(context.WithValue(context.Background(), appctx.DatastoreCTXKey, service.Datastore)) - req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) - - if shouldSign { - suite.SignRequest( - req, - publicKey, - privateKey, - ) - } - - rctx := chi.NewRouteContext() - joined := context.WithValue(req.Context(), chi.RouteCtxKey, rctx) - req = req.WithContext(joined) - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - suite.Require().Equal(code, rr.Code, "known status code should be sent: "+rr.Body.String()) - - return rr.Body.String() -} - -func (suite *WalletControllersTestSuite) createUpholdWalletV3( - service *wallet.Service, - body string, - code int, - publicKey httpsignature.Ed25519PubKey, - privateKey ed25519.PrivateKey, - shouldSign bool, -) string { - - handler := handlers.AppHandler(wallet.CreateUpholdWalletV3) - - bodyBuffer := bytes.NewBuffer([]byte(body)) - req, err := http.NewRequest("POST", "/v3/wallet", bodyBuffer) - suite.Require().NoError(err, "a request should be created") - - // setup context - req = req.WithContext(context.WithValue(context.Background(), appctx.DatastoreCTXKey, service.Datastore)) - req = req.WithContext(context.WithValue(req.Context(), appctx.NoUnlinkPriorToDurationCTXKey, "-P1D")) - - if shouldSign { - suite.SignRequest( - req, - publicKey, - privateKey, - ) - } - - rctx := chi.NewRouteContext() - joined := context.WithValue(req.Context(), chi.RouteCtxKey, rctx) - req = req.WithContext(joined) - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - suite.Require().Equal(code, rr.Code, "known status code should be sent: "+rr.Body.String()) - - return rr.Body.String() -} - -func (suite *WalletControllersTestSuite) SignRequest( - req *http.Request, - publicKey httpsignature.Ed25519PubKey, - privateKey ed25519.PrivateKey, -) { - var s httpsignature.SignatureParams - s.Algorithm = httpsignature.ED25519 - s.KeyID = hex.EncodeToString(publicKey) - s.Headers = []string{"digest", "(request-target)"} - - err := s.Sign(privateKey, crypto.Hash(0), req) - suite.Require().NoError(err) -} diff --git a/services/wallet/model/model.go b/services/wallet/model/model.go index e4220bd8c..ea3907d01 100644 --- a/services/wallet/model/model.go +++ b/services/wallet/model/model.go @@ -1,8 +1,58 @@ package model -import "errors" +import ( + "encoding/base64" + "time" -var ErrNoWalletCustodian = errors.New("model: no linked wallet custodian") + uuid "github.com/satori/go.uuid" +) + +const ( + ErrWalletNotWhitelisted Error = "model: wallet not whitelisted" + ErrNotFound Error = "model: not found" + ErrChallengeNotFound Error = "model: challenge not found" + ErrChallengeExpired Error = "model: challenge expired" + ErrNoRowsDeleted Error = "model: no rows deleted" + ErrNotInserted Error = "model: not inserted" + ErrNoWalletCustodian Error = "model: no linked wallet custodian" + ErrInternalServer Error = "model: internal server error" + ErrWalletNotFound Error = "model: wallet not found" +) + +type AllowListEntry struct { + PaymentID uuid.UUID `db:"payment_id"` + CreatedAt time.Time `db:"created_at"` +} + +func (a AllowListEntry) IsAllowed(paymentID uuid.UUID) bool { + return !uuid.Equal(a.PaymentID, uuid.Nil) && uuid.Equal(a.PaymentID, paymentID) +} + +type Challenge struct { + PaymentID uuid.UUID `db:"payment_id"` + CreatedAt time.Time `db:"created_at"` + Nonce string `db:"nonce"` +} + +func NewChallenge(paymentID uuid.UUID) Challenge { + return Challenge{ + PaymentID: paymentID, + CreatedAt: time.Now(), + Nonce: base64.URLEncoding.EncodeToString(uuid.NewV4().Bytes()), + } +} + +func (c *Challenge) IsValid(now time.Time) error { + if c.hasExpired(now) { + return ErrChallengeExpired + } + return nil +} + +func (c *Challenge) hasExpired(now time.Time) bool { + expiresAt := c.CreatedAt.Add(5 * time.Minute) + return expiresAt.Before(now) +} type Error string diff --git a/services/wallet/model/model_test.go b/services/wallet/model/model_test.go new file mode 100644 index 000000000..55313a6d7 --- /dev/null +++ b/services/wallet/model/model_test.go @@ -0,0 +1,117 @@ +package model + +import ( + "testing" + "time" + + uuid "github.com/satori/go.uuid" + "github.com/stretchr/testify/assert" +) + +func TestAllowListEntry_IsAllowed(t *testing.T) { + type tcGiven struct { + allow AllowListEntry + paymentID uuid.UUID + } + + type exp struct { + isAllowed bool + } + + type testCase struct { + name string + given tcGiven + exp exp + } + + tests := []testCase{ + { + name: "default_allow_list_entry", + given: tcGiven{ + paymentID: uuid.FromStringOrNil("dc5a802a-87e9-47a5-9d9c-af4c8f171bf3"), + }, + exp: exp{isAllowed: false}, + }, + { + name: "payment_id_nil", + given: tcGiven{ + paymentID: uuid.Nil, + }, + exp: exp{isAllowed: false}, + }, + { + name: "payment_ids_not_equal", + given: tcGiven{ + allow: AllowListEntry{ + PaymentID: uuid.FromStringOrNil("d1359406-42f1-4364-99b7-77840e8594e8"), + }, + paymentID: uuid.FromStringOrNil("dc5a802a-87e9-47a5-9d9c-af4c8f171bf3"), + }, + exp: exp{isAllowed: false}, + }, + { + name: "payment_ids_are_equal", + given: tcGiven{ + allow: AllowListEntry{ + PaymentID: uuid.FromStringOrNil("356a634a-dbae-4f95-b276-f3f0f0a53509"), + }, + paymentID: uuid.FromStringOrNil("356a634a-dbae-4f95-b276-f3f0f0a53509"), + }, + exp: exp{isAllowed: true}, + }, + } + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := tc.given.allow.IsAllowed(tc.given.paymentID) + assert.Equal(t, tc.exp.isAllowed, actual) + }) + } +} + +func TestChallenge_IsValid(t *testing.T) { + type tcGiven struct { + chl Challenge + now time.Time + } + + type testCase struct { + name string + given tcGiven + assertErr assert.ErrorAssertionFunc + } + + tests := []testCase{ + { + name: "expired", + given: tcGiven{ + chl: Challenge{ + CreatedAt: time.Now(), + }, + now: time.Now().Add(6 * time.Minute), + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, ErrChallengeExpired) + }, + }, + { + name: "valid", + given: tcGiven{ + chl: Challenge{ + CreatedAt: time.Now(), + }, + now: time.Now(), + }, + assertErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.NoError(t, err) + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.given.chl.IsValid(tc.given.now) + tc.assertErr(t, err) + }) + } +} diff --git a/services/wallet/outputs.go b/services/wallet/outputs.go index 48dde81a6..33eb8f20a 100644 --- a/services/wallet/outputs.go +++ b/services/wallet/outputs.go @@ -163,6 +163,74 @@ func infoToResponseV3(info *walletutils.Info) ResponseV3 { return resp } +// ResponseV4 - wallet creation response +type ResponseV4 struct { + PaymentID string `json:"paymentId"` + DepositAccountProvider *DepositAccountProviderDetailsV3 `json:"depositAccountProvider,omitempty"` + WalletProvider *ProviderDetailsV3 `json:"walletProvider,omitempty"` + AltCurrency string `json:"altcurrency"` + PublicKey string `json:"publicKey"` + SelfCustodyAvailable map[string]bool `json:"selfCustodyAvailable"` +} + +func infoToResponseV4(info *walletutils.Info, selfCustody bool) ResponseV4 { + var ( + linkingID string + anonymousAddress string + ) + if info == nil { + return ResponseV4{} + } + + var altCurrency = convertAltCurrency(info.AltCurrency) + + if info.ProviderLinkingID == nil { + linkingID = "" + } else { + linkingID = info.ProviderLinkingID.String() + } + + if info.AnonymousAddress == nil { + anonymousAddress = "" + } else { + anonymousAddress = info.AnonymousAddress.String() + } + + // common to all wallets + resp := ResponseV4{ + PaymentID: info.ID, + AltCurrency: altCurrency, + PublicKey: info.PublicKey, + WalletProvider: &ProviderDetailsV3{ + Name: info.Provider, + }, + SelfCustodyAvailable: map[string]bool{ + "solana": selfCustody, + }, + } + + // setup the wallet provider (anon card uphold) + if info.Provider == "uphold" { + // this is a uphold provided wallet (anon card based) + resp.WalletProvider.ID = info.ProviderID + resp.WalletProvider.AnonymousAddress = anonymousAddress + resp.WalletProvider.LinkingID = linkingID + } + + // now setup user deposit account + if info.UserDepositAccountProvider != nil { + // this brave wallet has a linked deposit account + resp.DepositAccountProvider = &DepositAccountProviderDetailsV3{ + Name: info.UserDepositAccountProvider, + ID: &info.UserDepositDestination, + LinkingID: linkingID, + AnonymousAddress: anonymousAddress, + } + } + + return resp +} + // BalanceResponseV3 - wallet creation response type BalanceResponseV3 struct { Total float64 `json:"total"` diff --git a/services/wallet/service.go b/services/wallet/service.go index eef9c1817..5741d1a25 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -16,8 +16,11 @@ import ( "github.com/brave-intl/bat-go/services/wallet/metric" "github.com/brave-intl/bat-go/services/wallet/model" + "github.com/brave-intl/bat-go/services/wallet/storage" "github.com/go-chi/chi" + "github.com/go-chi/cors" "github.com/go-jose/go-jose/v3/jwt" + "github.com/jmoiron/sqlx" "github.com/lib/pq" uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" @@ -104,6 +107,16 @@ type GeoValidator interface { Validate(ctx context.Context, geolocation string) (bool, error) } +type challengeRepo interface { + Get(ctx context.Context, dbi sqlx.QueryerContext, paymentID uuid.UUID) (model.Challenge, error) + Upsert(ctx context.Context, dbi sqlx.ExecerContext, chl model.Challenge) error + Delete(ctx context.Context, dbi sqlx.ExecerContext, paymentID uuid.UUID) error +} + +type allowListRepo interface { + GetAllowListEntry(ctx context.Context, dbi sqlx.QueryerContext, paymentID uuid.UUID) (model.AllowListEntry, error) +} + type metricSvc interface { LinkSuccessZP(cc string) LinkFailureZP(cc string) @@ -121,6 +134,8 @@ type geminiSvc interface { type Service struct { Datastore Datastore RoDatastore ReadOnlyDatastore + chlRepo challengeRepo + allowListRepo allowListRepo repClient reputation.Client geminiClient gemini.Client geoValidator GeoValidator @@ -130,20 +145,39 @@ type Service struct { custodianRegions custodian.Regions metric metricSvc gemini geminiSvc + dappConf DAppConfig } -// InitService creates a service using the passed datastore and clients configured from the environment -func InitService(datastore Datastore, roDatastore ReadOnlyDatastore, repClient reputation.Client, geminiClient gemini.Client, geoCountryValidator GeoValidator, retry backoff.RetryFunc, metric metricSvc, gemini geminiSvc) (*Service, error) { +type DAppConfig struct { + AllowedOrigins []string +} + +// InitService creates a new instances of the wallet service. +func InitService( + datastore Datastore, + roDatastore ReadOnlyDatastore, + chlRepo challengeRepo, + allowList allowListRepo, + repClient reputation.Client, + geminiClient gemini.Client, + geoCountryValidator GeoValidator, + retry backoff.RetryFunc, + metric metricSvc, + gemini geminiSvc, + dappConf DAppConfig) (*Service, error) { service := &Service{ - crMu: new(sync.RWMutex), - Datastore: datastore, - RoDatastore: roDatastore, - repClient: repClient, - geminiClient: geminiClient, - geoValidator: geoCountryValidator, - retry: retry, - metric: metric, - gemini: gemini, + Datastore: datastore, + RoDatastore: roDatastore, + chlRepo: chlRepo, + allowListRepo: allowList, + repClient: repClient, + geminiClient: geminiClient, + geoValidator: geoCountryValidator, + retry: retry, + metric: metric, + gemini: gemini, + dappConf: dappConf, + crMu: new(sync.RWMutex), } return service, nil } @@ -163,16 +197,19 @@ func (service *Service) ReadableDatastore() ReadOnlyDatastore { // SetupService - create a new wallet service func SetupService(ctx context.Context) (context.Context, *Service) { - logger := logging.Logger(ctx, "wallet.SetupService") + l := logging.Logger(ctx, "wallet.SetupService") + + chlRepo := storage.NewChallenge() + alRepo := storage.NewAllowList() db, err := NewWritablePostgres(viper.GetString("datastore"), false, "wallet_db") if err != nil { - logger.Panic().Err(err).Msg("unable connect to wallet db") + l.Panic().Err(err).Msg("unable connect to wallet db") } roDB, err := NewReadOnlyPostgres(viper.GetString("ro-datastore"), false, "wallet_ro_db") if err != nil { - logger.Panic().Err(err).Msg("unable connect to wallet db") + l.Panic().Err(err).Msg("unable connect to wallet db") } ctx = context.WithValue(ctx, appctx.RODatastoreCTXKey, roDB) @@ -184,7 +221,7 @@ func SetupService(ctx context.Context) (context.Context, *Service) { // jwt key is hex encoded string decodedBitFlyerJWTKey, err := hex.DecodeString(viper.GetString("bitflyer-jwt-key")) if err != nil { - logger.Error().Err(err).Msg("invalid bitflyer jwt key") + l.Error().Err(err).Msg("invalid bitflyer jwt key") } ctx = context.WithValue(ctx, appctx.BitFlyerJWTKeyCTXKey, decodedBitFlyerJWTKey) @@ -192,7 +229,7 @@ func SetupService(ctx context.Context) (context.Context, *Service) { repClient, err := reputation.New() // it's okay to not fatally fail if this environment is local and we cant make a rep client if err != nil && os.Getenv("ENV") != "local" { - logger.Panic().Err(err).Msg("failed to initialize wallet service") + l.Panic().Err(err).Msg("failed to initialize wallet service") } ctx = context.WithValue(ctx, appctx.ReputationClientCTXKey, repClient) @@ -201,19 +238,19 @@ func SetupService(ctx context.Context) (context.Context, *Service) { if os.Getenv("GEMINI_ENABLED") == "true" { geminiClient, err = gemini.New() if err != nil { - logger.Panic().Err(err).Msg("failed to create gemini client") + l.Panic().Err(err).Msg("failed to create gemini client") } ctx = context.WithValue(ctx, appctx.GeminiClientCTXKey, geminiClient) } - cfg, err := appaws.BaseAWSConfig(ctx, logger) + cfg, err := appaws.BaseAWSConfig(ctx, l) if err != nil { - logger.Panic().Err(err).Msg("failed to initialize wallet service") + l.Panic().Err(err).Msg("failed to initialize wallet service") } awsClient, err := appaws.NewClient(cfg) if err != nil { - logger.Panic().Err(err).Msg("failed to initialize wallet service") + l.Panic().Err(err).Msg("failed to initialize wallet service") } // put the configured aws client on ctx @@ -222,41 +259,55 @@ func SetupService(ctx context.Context) (context.Context, *Service) { // get the s3 bucket and object bucket, bucketOK := ctx.Value(appctx.ParametersMergeBucketCTXKey).(string) if !bucketOK { - logger.Panic().Err(errors.New("bucket not in context")). + l.Panic().Err(errors.New("bucket not in context")). Msg("failed to initialize wallet service") } object, ok := ctx.Value(appctx.DisabledWalletGeoCountriesCTXKey).(string) if !ok { - logger.Panic().Err(errors.New("wallet geo countries disabled ctx key value not found")). + l.Panic().Err(errors.New("wallet geo countries disabled ctx key value not found")). Msg("failed to initialize wallet service") } - config := Config{ + geoCountryValidator := NewGeoCountryValidator(awsClient, Config{ bucket: bucket, object: object, - } - - geoCountryValidator := NewGeoCountryValidator(awsClient, config) + }) mtc := metric.New() gemx := newGeminix("passport", "drivers_license", "national_identity_card", "passport_card") - s, err := InitService(db, roDB, repClient, geminiClient, geoCountryValidator, backoff.Retry, mtc, gemx) + dappAO := strings.Split(os.Getenv("DAPP_ALLOWED_CORS_ORIGINS"), ",") + if len(dappAO) == 0 { + l.Panic().Err(errors.New("dapp allowed origins missing")).Msg("failed to initialize wallet service") + } + + dappConf := DAppConfig{ + AllowedOrigins: dappAO, + } + + s, err := InitService(db, roDB, chlRepo, alRepo, repClient, geminiClient, geoCountryValidator, backoff.Retry, mtc, gemx, dappConf) if err != nil { - logger.Panic().Err(err).Msg("failed to initialize wallet service") + l.Panic().Err(err).Msg("failed to initialize wallet service") } _, err = s.RefreshCustodianRegionsWorker(ctx) if err != nil { - logger.Error().Err(err).Msg("failed to initialize custodian regions") + l.Error().Err(err).Msg("failed to initialize custodian regions") } + decJob := deleteExpiredChallengeTask{exec: db.RawDB(), deleter: chlRepo, deleteAfterMin: 5} + s.jobs = []srv.Job{ { Func: s.RefreshCustodianRegionsWorker, Cadence: 15 * time.Minute, Workers: 1, }, + { + Func: decJob.deleteExpiredChallenges, + Cadence: 10 * time.Minute, + Workers: 1, + }, } if VerifiedWalletEnable { @@ -269,7 +320,7 @@ func SetupService(ctx context.Context) (context.Context, *Service) { err = cmd.SetupJobWorkers(ctx, s.Jobs()) if err != nil { - logger.Error().Err(err).Msg("error initializing job workers") + l.Error().Err(err).Msg("error initializing job workers") } return ctx, s @@ -288,7 +339,7 @@ func (service *Service) getCustodianRegions() custodian.Regions { } // RegisterRoutes - register the wallet api routes given a chi.Mux -func RegisterRoutes(ctx context.Context, s *Service, r *chi.Mux) *chi.Mux { +func RegisterRoutes(ctx context.Context, s *Service, r *chi.Mux, metricsMw middleware.InstrumentHandlerDef, dAppCorsMw func(next http.Handler) http.Handler) *chi.Mux { // setup our wallet routes r.Route("/v3/wallet", func(r chi.Router) { // rate limited to 2 per minute... @@ -319,39 +370,62 @@ func RegisterRoutes(ctx context.Context, s *Service, r *chi.Mux) *chi.Mux { "LinkGeminiDepositAccount", LinkGeminiDepositAccountV3(s))).ServeHTTP) r.Post("/zebpay/{paymentID}/connect", middleware.HTTPSignedOnly(s)(middleware.InstrumentHandlerFunc( "LinkZebPayDepositAccount", LinkZebPayDepositAccountV3(s))).ServeHTTP) + + r.Method(http.MethodPost, "/solana/{paymentID}/connect", metricsMw("LinkSolanaAddress", dAppCorsMw(LinkSolanaAddress(s)))) + r.Method(http.MethodOptions, "/solana/{paymentID}/connect", metricsMw("LinkSolanaAddressOptions", dAppCorsMw(noOpHandler()))) } - r.Get("/linking-info", middleware.SimpleTokenAuthorizedOnly( - middleware.InstrumentHandlerFunc("GetLinkingInfo", GetLinkingInfoV3(s))).ServeHTTP) + r.Get("/linking-info", middleware.SimpleTokenAuthorizedOnly(middleware.InstrumentHandlerFunc("GetLinkingInfo", GetLinkingInfoV3(s))).ServeHTTP) // get wallet routes - r.Get("/{paymentID}", middleware.InstrumentHandlerFunc( - "GetWallet", GetWalletV3)) - r.Get("/recover/{publicKey}", middleware.InstrumentHandlerFunc( - "RecoverWallet", RecoverWalletV3)) + r.Get("/{paymentID}", middleware.InstrumentHandlerFunc("GetWallet", GetWalletV3)) + r.Get("/recover/{publicKey}", middleware.InstrumentHandlerFunc("RecoverWallet", RecoverWalletV3)) // get wallet balance routes - r.Get("/uphold/{paymentID}", middleware.InstrumentHandlerFunc( - "GetUpholdWalletBalance", GetUpholdWalletBalanceV3)) + r.Get("/uphold/{paymentID}", middleware.InstrumentHandlerFunc("GetUpholdWalletBalance", GetUpholdWalletBalanceV3)) + + r.Post("/challenges", middleware.RateLimiter(ctx, 2)(metricsMw("CreateChallenge", dAppCorsMw(CreateChallenge(s)))).ServeHTTP) + r.Options("/challenges", middleware.RateLimiter(ctx, 2)(metricsMw("CreateChallengeOptions", dAppCorsMw(noOpHandler()))).ServeHTTP) }) r.Route("/v4/wallets", func(r chi.Router) { - r.Use(middleware.RateLimiter(ctx, 2)) - r.Post("/", middleware.InstrumentHandlerFunc("CreateWalletV4", CreateWalletV4(s))) - r.Patch("/{paymentID}", middleware.HTTPSignedOnly(s)(middleware.InstrumentHandlerFunc( - "UpdateWalletV4", UpdateWalletV4(s))).ServeHTTP) - r.Get("/{paymentID}", - middleware.HTTPSignedOnly(s)(middleware.InstrumentHandlerFunc( - "GetWalletV4", GetWalletV4)).ServeHTTP) - // get wallet balance routes - r.Get("/uphold/{paymentID}", - middleware.HTTPSignedOnly(s)(middleware.InstrumentHandlerFunc( - "GetUpholdWalletBalanceV4", GetUpholdWalletBalanceV4)).ServeHTTP) + r.Post("/", middleware.RateLimiter(ctx, 2)( + middleware.InstrumentHandlerFunc("CreateWalletV4", CreateWalletV4(s))).ServeHTTP) + + r.Patch("/{paymentID}", middleware.RateLimiter(ctx, 2)(middleware.HTTPSignedOnly(s)( + middleware.InstrumentHandlerFunc("UpdateWalletV4", UpdateWalletV4(s)))).ServeHTTP) + + r.Get("/{paymentID}", middleware.RateLimiter(ctx, 7)(middleware.HTTPSignedOnly(s)( + middleware.InstrumentHandlerFunc("GetWalletV4", GetWalletV4(s)))).ServeHTTP) + + r.Get("/uphold/{paymentID}", middleware.RateLimiter(ctx, 2)(middleware.HTTPSignedOnly(s)( + middleware.InstrumentHandlerFunc("GetUpholdWalletBalanceV4", GetUpholdWalletBalanceV4))).ServeHTTP) }) return r } +// TODO(clD11): WR. Move once we address the rest_run.go and grant.go start functions. + +func NewDAppCorsMw(origins []string) func(next http.Handler) http.Handler { + opts := cors.Options{ + Debug: false, + AllowedOrigins: origins, + AllowedHeaders: []string{"Accept", "Content-Type"}, + ExposedHeaders: []string{""}, + AllowedMethods: []string{http.MethodPost}, + AllowCredentials: false, + MaxAge: 300, + } + return cors.Handler(opts) +} + +func noOpHandler() http.Handler { + return http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + return + }) +} + // SubmitAnonCardTransaction validates and submits a transaction on behalf of an anonymous card func (service *Service) SubmitAnonCardTransaction( ctx context.Context, @@ -421,20 +495,11 @@ func (service *Service) LinkBitFlyerWallet(ctx context.Context, walletID uuid.UU country = "JP" ) - err := validateCustodianLinking(ctx, service.Datastore, walletID, depositProvider) - if err != nil { - if errors.Is(err, errCustodianLinkMismatch) { - return "", errCustodianLinkMismatch - } - return "", handlers.WrapError(err, "failed to check linking mismatch", http.StatusInternalServerError) - } - // In the controller validation, we verified that the account hash and deposit id were signed by bitflyer // we also validated that this "info" signed the request to perform the linking with http signature // we assume that since we got linkingInfo signed from BF that they are KYC providerLinkingID := uuid.NewV5(ClaimNamespace, accountHash) - err = service.Datastore.LinkWallet(ctx, walletID.String(), depositID, providerLinkingID, depositProvider, country) - if err != nil { + if err := service.linkCustodialAccount(ctx, walletID.String(), depositID, providerLinkingID, depositProvider, country); err != nil { if errors.Is(err, ErrUnusualActivity) { return "", handlers.WrapError(err, "unable to link - unusual activity", http.StatusBadRequest) } @@ -468,18 +533,8 @@ func (service *Service) LinkZebPayWallet(ctx context.Context, walletID uuid.UUID return "", err } - err = validateCustodianLinking(ctx, service.Datastore, walletID, depositProvider) - if err != nil { - service.metric.LinkFailureZP(claims.CountryCode) - - if errors.Is(err, errCustodianLinkMismatch) { - return "", errCustodianLinkMismatch - } - return "", handlers.WrapError(err, "failed to check linking mismatch", http.StatusInternalServerError) - } - providerLinkingID := uuid.NewV5(ClaimNamespace, claims.AccountID) - if err := service.Datastore.LinkWallet(ctx, walletID.String(), claims.DepositID, providerLinkingID, depositProvider, claims.CountryCode); err != nil { + if err := service.linkCustodialAccount(ctx, walletID.String(), claims.DepositID, providerLinkingID, depositProvider, claims.CountryCode); err != nil { service.metric.LinkFailureZP(claims.CountryCode) if errors.Is(err, ErrUnusualActivity) { @@ -560,7 +615,7 @@ func (service *Service) LinkGeminiWallet(ctx context.Context, walletID uuid.UUID } service.metric.LinkSuccessGemini(issuingCountry) - if err := service.Datastore.LinkWallet(ctx, walletID.String(), depositID, linkingID, depositProvider, issuingCountry); err != nil { + if err := service.linkCustodialAccount(ctx, walletID.String(), depositID, linkingID, depositProvider, issuingCountry); err != nil { if errors.Is(err, ErrUnusualActivity) { return "", handlers.WrapError(err, "unable to link - unusual activity", http.StatusBadRequest) } @@ -610,14 +665,6 @@ func (service *Service) LinkUpholdWallet(ctx context.Context, wallet uphold.Wall return "", fmt.Errorf("failed to parse uphold id: %w", err) } - err = validateCustodianLinking(ctx, service.Datastore, walletID, depositProvider) - if err != nil { - if errors.Is(err, errCustodianLinkMismatch) { - return "", errCustodianLinkMismatch - } - return "", handlers.WrapError(err, "failed to check linking mismatch", http.StatusInternalServerError) - } - // verify that the user is kyc from uphold. (for all wallet provider cases) if uID, ok, c, err := wallet.IsUserKYC(ctx, transactionInfo.Destination); err != nil { // check if this gemini accountID has already been linked to this wallet, @@ -659,9 +706,7 @@ func (service *Service) LinkUpholdWallet(ctx context.Context, wallet uphold.Wall probi = transactionInfo.Probi providerLinkingID := uuid.NewV5(ClaimNamespace, userID) - // tx.Destination will be stored as UserDepositDestination in the wallet info upon linking - err = service.Datastore.LinkWallet(ctx, info.ID, transactionInfo.Destination, providerLinkingID, depositProvider, country) - if err != nil { + if err := service.linkCustodialAccount(ctx, walletID.String(), transactionInfo.Destination, providerLinkingID, depositProvider, country); err != nil { if errors.Is(err, ErrUnusualActivity) { return "", handlers.WrapError(err, "unable to link - unusual activity", http.StatusBadRequest) } @@ -689,6 +734,87 @@ func (service *Service) LinkUpholdWallet(ctx context.Context, wallet uphold.Wall return country, nil } +func (service *Service) LinkSolanaAddress(ctx context.Context, paymentID uuid.UUID, req linkSolanaAddrRequest) error { + if err := isWalletWhitelisted(ctx, service.Datastore.RawDB(), service.allowListRepo, paymentID); err != nil { + return err + } + + ctx, txn, rollback, commit, err := getTx(ctx, service.Datastore) + if err != nil { + return err + } + defer rollback() + + chl, err := service.chlRepo.Get(ctx, txn, paymentID) + if err != nil { + return err + } + + if err := chl.IsValid(time.Now()); err != nil { + return err + } + + w, err := service.Datastore.GetWallet(ctx, paymentID) + if err != nil { + return err + } + + if w == nil { + return model.ErrWalletNotFound + } + + if err := w.LinkSolanaAddress(ctx, walletutils.SolanaLinkReq{ + Pub: req.SolanaPublicKey, + Sig: req.SolanaSignature, + Msg: req.Message, + Nonce: chl.Nonce, + }); err != nil { + return err + } + + cl := NewSolanaCustodialLink(paymentID, w.UserDepositDestination) + + if err := service.Datastore.LinkWallet(ctx, w.ID, w.UserDepositDestination, *cl.LinkingID, cl.Custodian); err != nil { + return err + } + + if err := service.chlRepo.Delete(ctx, txn, chl.PaymentID); err != nil { + return err + } + + if err := commit(); err != nil { + return err + } + + return nil +} + +func (service *Service) linkCustodialAccount(ctx context.Context, wID string, userDepositDestination string, providerLinkingID uuid.UUID, depositProvider, country string) error { + walletID, err := uuid.FromString(wID) + if err != nil { + return fmt.Errorf("invalid wallet id, not uuid: %w", err) + } + + repClient, ok := ctx.Value(appctx.ReputationClientCTXKey).(reputation.Client) + if !ok { + return ErrNoReputationClient + } + + if _, _, err := repClient.IsLinkingReputable(ctx, walletID, country); err != nil { + return fmt.Errorf("failed to check wallet rep: %w", err) + } + + return service.Datastore.LinkWallet(ctx, walletID.String(), userDepositDestination, providerLinkingID, depositProvider) +} + +func (service *Service) CreateChallenge(ctx context.Context, paymentID uuid.UUID) (model.Challenge, error) { + chl := model.NewChallenge(paymentID) + if err := service.chlRepo.Upsert(ctx, service.Datastore.RawDB(), chl); err != nil { + return model.Challenge{}, fmt.Errorf("error creating challenge: %w", err) + } + return chl, nil +} + // DisconnectCustodianLink - removes the link to the custodian wallet that is active func (service *Service) DisconnectCustodianLink(ctx context.Context, custodian string, walletID uuid.UUID) error { if err := service.Datastore.DisconnectCustodialWallet(ctx, walletID); err != nil { @@ -880,22 +1006,13 @@ func (c *claimsZP) validateTime(now time.Time) error { return nil } -func validateCustodianLinking(ctx context.Context, storage Datastore, walletID uuid.UUID, depositProvider string) error { - c, err := storage.GetCustodianLinkByWalletID(ctx, walletID) - if err != nil && !errors.Is(err, model.ErrNoWalletCustodian) { - return err - } - - // if there are no instances of wallet custodian then it is - // considered a new linking and therefore valid. - if c == nil { - return nil - } - - if !strings.EqualFold(c.Custodian, depositProvider) { - return errCustodianLinkMismatch +func isWalletWhitelisted(ctx context.Context, dbi sqlx.QueryerContext, alRepo allowListRepo, paymentID uuid.UUID) error { + if _, err := alRepo.GetAllowListEntry(ctx, dbi, paymentID); err != nil { + if errors.Is(err, model.ErrNotFound) { + return model.ErrWalletNotWhitelisted + } + return fmt.Errorf("error checking allow list entry: %w", err) } - return nil } @@ -940,3 +1057,20 @@ func parseZebPayClaims(ctx context.Context, verificationToken string) (claimsZP, return claims, nil } + +type deleter interface { + DeleteAfter(ctx context.Context, dbi sqlx.ExecerContext, interval time.Duration) error +} + +type deleteExpiredChallengeTask struct { + exec sqlx.ExecerContext + deleter deleter + deleteAfterMin time.Duration +} + +func (d *deleteExpiredChallengeTask) deleteExpiredChallenges(ctx context.Context) (bool, error) { + if err := d.deleter.DeleteAfter(ctx, d.exec, d.deleteAfterMin); err != nil && !errors.Is(err, model.ErrNoRowsDeleted) { + return false, fmt.Errorf("error deleting expired challenges: %w", err) + } + return true, nil +} diff --git a/services/wallet/storage/storage.go b/services/wallet/storage/storage.go new file mode 100644 index 000000000..dc6353a83 --- /dev/null +++ b/services/wallet/storage/storage.go @@ -0,0 +1,114 @@ +package storage + +import ( + "context" + "database/sql" + "errors" + "time" + + "github.com/brave-intl/bat-go/services/wallet/model" + "github.com/jmoiron/sqlx" + uuid "github.com/satori/go.uuid" +) + +type Challenge struct{} + +func NewChallenge() *Challenge { return &Challenge{} } + +// Get retrieves a model.Challenge from the database by the given paymentID. +func (c *Challenge) Get(ctx context.Context, dbi sqlx.QueryerContext, paymentID uuid.UUID) (model.Challenge, error) { + const q = `select * from challenge where payment_id = $1` + + var result model.Challenge + if err := sqlx.GetContext(ctx, dbi, &result, q, paymentID); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return result, model.ErrChallengeNotFound + } + return result, err + } + + return result, nil +} + +// Upsert persists a model.Challenge to the database. +func (c *Challenge) Upsert(ctx context.Context, dbi sqlx.ExecerContext, chl model.Challenge) error { + const q = `insert into challenge (payment_id, created_at, nonce) values($1, $2, $3) on conflict (payment_id) do update set created_at = $2, nonce = $3` + + result, err := dbi.ExecContext(ctx, q, chl.PaymentID, chl.CreatedAt, chl.Nonce) + if err != nil { + return err + } + + row, err := result.RowsAffected() + if err != nil { + return err + } + + if row != 1 { + return model.ErrNotInserted + } + + return nil +} + +// Delete removes a model.Challenge from the database identified by the paymentID. +func (c *Challenge) Delete(ctx context.Context, dbi sqlx.ExecerContext, paymentID uuid.UUID) error { + const q = `delete from challenge where payment_id = $1` + + result, err := dbi.ExecContext(ctx, q, paymentID) + if err != nil { + return err + } + + row, err := result.RowsAffected() + if err != nil { + return err + } + + if row == 0 { + return model.ErrNoRowsDeleted + } + + return nil +} + +// DeleteAfter removes model.Challenge's from the database where their created at plus the specified interval it +// less than the current time. The interval should be specified in minutes. +func (c *Challenge) DeleteAfter(ctx context.Context, dbi sqlx.ExecerContext, interval time.Duration) error { + const q = `delete from challenge where created_at + interval '1 min' * $1 < now()` + + result, err := dbi.ExecContext(ctx, q, interval) + if err != nil { + return err + } + + row, err := result.RowsAffected() + if err != nil { + return err + } + + if row == 0 { + return model.ErrNoRowsDeleted + } + + return nil +} + +type AllowList struct{} + +func NewAllowList() *AllowList { return &AllowList{} } + +// GetAllowListEntry retrieves a model.AllowListEntry from the database for the given paymentID. +func (a *AllowList) GetAllowListEntry(ctx context.Context, dbi sqlx.QueryerContext, paymentID uuid.UUID) (model.AllowListEntry, error) { + const q = `select * from allow_list where payment_id = $1` + + var result model.AllowListEntry + if err := sqlx.GetContext(ctx, dbi, &result, q, paymentID); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return result, model.ErrNotFound + } + return result, err + } + + return result, nil +} diff --git a/services/wallet/storage/storage_test.go b/services/wallet/storage/storage_test.go new file mode 100644 index 000000000..270579f4a --- /dev/null +++ b/services/wallet/storage/storage_test.go @@ -0,0 +1,437 @@ +//go:build integration + +package storage + +import ( + "context" + "testing" + "time" + + "github.com/brave-intl/bat-go/libs/datastore" + "github.com/brave-intl/bat-go/services/wallet/model" + "github.com/jmoiron/sqlx" + uuid "github.com/satori/go.uuid" + should "github.com/stretchr/testify/assert" + must "github.com/stretchr/testify/require" +) + +func TestChallenge_Get(t *testing.T) { + dbi, err := setupDBI() + must.NoError(t, err) + + defer func() { + _, _ = dbi.Exec("TRUNCATE_TABLE challenge;") + }() + + type tcGiven struct { + paymentID uuid.UUID + chal model.Challenge + } + + type exp struct { + chal model.Challenge + err error + } + + type testCase struct { + name string + given tcGiven + exp exp + } + + tests := []testCase{ + { + name: "get", + given: tcGiven{ + paymentID: uuid.FromStringOrNil("f6d43f0c-db24-4d65-9f02-b000ff8ec782"), + chal: model.Challenge{ + PaymentID: uuid.FromStringOrNil("f6d43f0c-db24-4d65-9f02-b000ff8ec782"), + CreatedAt: time.Date(2024, 1, 1, 1, 1, 1, 0, time.UTC), + Nonce: "nonce-1", + }}, + exp: exp{ + chal: model.Challenge{ + PaymentID: uuid.FromStringOrNil("f6d43f0c-db24-4d65-9f02-b000ff8ec782"), + CreatedAt: time.Date(2024, 1, 1, 1, 1, 1, 0, time.UTC), + Nonce: "nonce-1", + }, + err: nil, + }, + }, + { + name: "challenge_not_found", + given: tcGiven{ + paymentID: uuid.FromStringOrNil("1b8c218f-2585-49c1-90cd-b82006eb9865"), + chal: model.Challenge{ + PaymentID: uuid.FromStringOrNil("54e4a78e-2c69-4fb2-8d72-47921bb0b374"), + CreatedAt: time.Date(2024, 1, 1, 1, 1, 1, 0, time.UTC), + Nonce: "nonce-2", + }}, + exp: exp{ + err: model.ErrChallengeNotFound, + }, + }, + } + + const q = `insert into challenge (payment_id, created_at, nonce) values($1, $2, $3)` + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + _, err1 := dbi.ExecContext(ctx, q, tc.given.chal.PaymentID, tc.given.chal.CreatedAt, tc.given.chal.Nonce) + must.NoError(t, err1) + + c := Challenge{} + actual, err2 := c.Get(ctx, dbi, tc.given.paymentID) + + should.Equal(t, tc.exp.err, err2) + should.Equal(t, tc.exp.chal.PaymentID, actual.PaymentID) + should.Equal(t, tc.exp.chal.CreatedAt, actual.CreatedAt) + should.Equal(t, tc.exp.chal.Nonce, actual.Nonce) + }) + } +} + +func TestChallenge_Upsert(t *testing.T) { + dbi, err := setupDBI() + must.NoError(t, err) + + defer func() { + _, _ = dbi.Exec("TRUNCATE_TABLE challenge;") + }() + + const q = `select * from challenge where payment_id = $1` + + chlRepo := &Challenge{} + + t.Run("insert", func(t *testing.T) { + exp := model.Challenge{ + PaymentID: uuid.FromStringOrNil("66e4751f-cd72-4bb0-aebd-66c50a2e8c45"), + CreatedAt: time.Date(2024, 1, 1, 1, 1, 1, 0, time.UTC), + Nonce: "a", + } + + err1 := chlRepo.Upsert(context.TODO(), dbi, exp) + must.NoError(t, err1) + + var actual model.Challenge + err2 := sqlx.GetContext(context.TODO(), dbi, &actual, q, exp.PaymentID) + must.NoError(t, err2) + + should.Equal(t, exp.PaymentID, actual.PaymentID) + should.Equal(t, exp.CreatedAt, actual.CreatedAt) + should.Equal(t, exp.Nonce, actual.Nonce) + }) + + t.Run("upsert", func(t *testing.T) { + ctx := context.Background() + + exp := model.Challenge{ + PaymentID: uuid.FromStringOrNil("66e4751f-cd72-4bb0-aebd-66c50a2e8c45"), + CreatedAt: time.Date(2024, 12, 1, 1, 1, 1, 0, time.UTC), + Nonce: "b", + } + + err1 := chlRepo.Upsert(ctx, dbi, exp) + must.NoError(t, err1) + + var actual model.Challenge + err2 := sqlx.GetContext(ctx, dbi, &actual, q, exp.PaymentID) + must.NoError(t, err2) + + should.Equal(t, exp.PaymentID, actual.PaymentID) + should.Equal(t, exp.CreatedAt, actual.CreatedAt) + should.Equal(t, exp.Nonce, actual.Nonce) + }) +} + +func TestChallenge_Delete(t *testing.T) { + dbi, err := setupDBI() + must.NoError(t, err) + + defer func() { + _, _ = dbi.Exec("TRUNCATE_TABLE challenge;") + }() + + type tcGiven struct { + paymentID uuid.UUID + chal model.Challenge + } + + type exp struct { + err error + } + + type testCase struct { + name string + given tcGiven + exp exp + } + + tests := []testCase{ + { + name: "delete", + given: tcGiven{ + paymentID: uuid.FromStringOrNil("66e4751f-cd72-4bb0-aebd-66c50a2e8c45"), + chal: model.Challenge{ + PaymentID: uuid.FromStringOrNil("66e4751f-cd72-4bb0-aebd-66c50a2e8c45"), + Nonce: "nonce-1", + }}, + exp: exp{ + err: nil, + }, + }, + { + name: "no_rows_deleted", + given: tcGiven{ + paymentID: uuid.FromStringOrNil("34fe675b-aebf-4209-90b6-a7ba4452087a"), + chal: model.Challenge{ + PaymentID: uuid.FromStringOrNil("f6d43f0c-db24-4d65-9f02-b000ff8ec782"), + CreatedAt: time.Date(2024, 1, 1, 1, 1, 1, 0, time.UTC), + Nonce: "nonce-2", + }}, + exp: exp{ + err: model.ErrNoRowsDeleted, + }, + }, + } + + const q = `insert into challenge (payment_id, created_at, nonce) values($1, $2, $3)` + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + _, err1 := dbi.ExecContext(ctx, q, tc.given.chal.PaymentID, tc.given.chal.CreatedAt, tc.given.chal.Nonce) + must.NoError(t, err1) + + c := Challenge{} + err2 := c.Delete(ctx, dbi, tc.given.paymentID) + should.Equal(t, tc.exp.err, err2) + }) + } +} + +func TestChallenge_DeleteAfter(t *testing.T) { + dbi, err := setupDBI() + must.NoError(t, err) + + defer func() { + _, _ = dbi.Exec("TRUNCATE_TABLE challenge;") + }() + + type tcGiven struct { + interval time.Duration + challenges []model.Challenge + } + + type exp struct { + errDel error + errGet error + } + + type testCase struct { + name string + given tcGiven + exp exp + } + + tests := []testCase{ + { + name: "delete_single", + given: tcGiven{ + interval: 1, + challenges: []model.Challenge{ + { + PaymentID: uuid.NewV4(), + CreatedAt: time.Now().Add(-6 * time.Minute), + Nonce: "nonce-1", + }, + }, + }, + exp: exp{ + errDel: nil, + errGet: model.ErrChallengeNotFound, + }, + }, + { + name: "delete_multiple", + given: tcGiven{ + interval: 1, + challenges: []model.Challenge{ + { + PaymentID: uuid.NewV4(), + CreatedAt: time.Now().Add(-6 * time.Minute), + Nonce: "nonce-1", + }, + { + PaymentID: uuid.NewV4(), + CreatedAt: time.Now().Add(-10 * time.Minute), + Nonce: "nonce-2", + }, + }, + }, + exp: exp{ + errDel: nil, + errGet: model.ErrChallengeNotFound, + }, + }, + { + name: "delete_none", + given: tcGiven{ + interval: 1, + challenges: []model.Challenge{ + { + PaymentID: uuid.NewV4(), + CreatedAt: time.Now(), + Nonce: "nonce-1", + }, + { + PaymentID: uuid.NewV4(), + CreatedAt: time.Now(), + Nonce: "nonce-2", + }, + }, + }, + exp: exp{ + errDel: model.ErrNoRowsDeleted, + errGet: nil, + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + c := Challenge{} + + for j := range tc.given.challenges { + err := c.Upsert(ctx, dbi, tc.given.challenges[j]) + must.NoError(t, err) + } + + err := c.DeleteAfter(ctx, dbi, tc.given.interval) + must.Equal(t, tc.exp.errDel, err) + + for j := range tc.given.challenges { + _, err := c.Get(ctx, dbi, tc.given.challenges[j].PaymentID) + must.Equal(t, tc.exp.errGet, err) + } + }) + } +} + +func TestAllowList_GetAllowListEntry(t *testing.T) { + dbi, err := setupDBI() + must.NoError(t, err) + + defer func() { + _, _ = dbi.Exec("TRUNCATE_TABLE allow_list;") + }() + + type tcGiven struct { + paymentID uuid.UUID + allow model.AllowListEntry + } + + type exp struct { + allow model.AllowListEntry + err error + } + + type testCase struct { + name string + given tcGiven + exp exp + } + + tests := []testCase{ + { + name: "success", + given: tcGiven{ + paymentID: uuid.FromStringOrNil("6d85a314-0fa8-4594-9cb9-c9141b61a887"), + allow: model.AllowListEntry{ + PaymentID: uuid.FromStringOrNil("6d85a314-0fa8-4594-9cb9-c9141b61a887"), + CreatedAt: time.Date(2024, 1, 1, 1, 1, 1, 0, time.UTC), + }, + }, + exp: exp{ + allow: model.AllowListEntry{ + PaymentID: uuid.FromStringOrNil("6d85a314-0fa8-4594-9cb9-c9141b61a887"), + CreatedAt: time.Date(2024, 1, 1, 1, 1, 1, 0, time.UTC), + }, + err: nil, + }, + }, + { + name: "not_found", + given: tcGiven{ + paymentID: uuid.NewV4(), + }, + exp: exp{ + err: model.ErrNotFound, + }, + }, + } + + const q = `insert into allow_list (payment_id, created_at) values($1, $2)` + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + _, err1 := dbi.ExecContext(ctx, q, tc.given.allow.PaymentID, tc.given.allow.CreatedAt) + must.NoError(t, err1) + + a := &AllowList{} + actual, err2 := a.GetAllowListEntry(ctx, dbi, tc.given.paymentID) + should.Equal(t, tc.exp.err, err2) + should.Equal(t, tc.exp.allow, actual) + }) + } +} + +func setupDBI() (*sqlx.DB, error) { + pg, err := datastore.NewPostgres("", false, "") + if err != nil { + return nil, err + } + + mg, err := pg.NewMigrate() + if err != nil { + return nil, err + } + + ver, dirty, err := mg.Version() + if err != nil { + return nil, err + } + + if dirty { + if err := mg.Force(int(ver)); err != nil { + return nil, err + } + } + + if ver > 0 { + if err := mg.Down(); err != nil { + return nil, err + } + } + + if err := pg.Migrate(); err != nil { + return nil, err + } + + return pg.RawDB(), nil +} diff --git a/tools/go.mod b/tools/go.mod index 968705fe8..1b0f5c17b 100644 --- a/tools/go.mod +++ b/tools/go.mod @@ -97,6 +97,7 @@ require ( github.com/fatih/color v1.15.0 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-chi/chi v4.1.2+incompatible // indirect + github.com/go-chi/cors v1.2.1 // indirect github.com/go-jose/go-jose/v3 v3.0.1 // indirect github.com/go-openapi/analysis v0.21.4 // indirect github.com/go-openapi/errors v0.20.3 // indirect diff --git a/tools/go.sum b/tools/go.sum index d64e2a4f7..86753a598 100644 --- a/tools/go.sum +++ b/tools/go.sum @@ -565,6 +565,8 @@ github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2H github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-chi/chi v4.1.2+incompatible h1:fGFk2Gmi/YKXk0OmGfBh0WgmN3XB8lVnEyNz34tQRec= github.com/go-chi/chi v4.1.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= +github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= +github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g= github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3T0ecnM9pNujks=