diff --git a/services/wallet/controllers_v3.go b/services/wallet/controllers_v3.go index 8bbae9a5d..dcbcf3c0d 100644 --- a/services/wallet/controllers_v3.go +++ b/services/wallet/controllers_v3.go @@ -505,6 +505,8 @@ func LinkSolanaAddress(s *Service) handlers.AppHandler { return handlers.WrapError(model.ErrChallengeExpired, "linking challenge expired", http.StatusUnauthorized) case errors.Is(err, model.ErrWalletNotFound): return handlers.WrapError(model.ErrWalletNotFound, "rewards wallet not found", http.StatusNotFound) + case errors.Is(err, ErrTooManyCardsLinked): + return handlers.WrapError(ErrTooManyCardsLinked, "too many wallets linked", http.StatusConflict) case errors.As(err, &solErr): return handlers.WrapError(solErr, "invalid solana linking message", http.StatusUnauthorized) default: diff --git a/services/wallet/datastore.go b/services/wallet/datastore.go index 8461d71cb..121c30b8a 100644 --- a/services/wallet/datastore.go +++ b/services/wallet/datastore.go @@ -424,6 +424,10 @@ func getEnvMaxCards(custodian string) int { if v, err := strconv.Atoi(os.Getenv("ZEBPAY_WALLET_LINKING_LIMIT")); err == nil { return v } + case "solana": + if v, err := strconv.Atoi(os.Getenv("SOLANA_WALLET_LINKING_LIMIT")); err == nil { + return v + } } return 4 } diff --git a/services/wallet/datastore_pvt_test.go b/services/wallet/datastore_pvt_test.go index 69365b176..5bbf540d9 100644 --- a/services/wallet/datastore_pvt_test.go +++ b/services/wallet/datastore_pvt_test.go @@ -1,11 +1,13 @@ package wallet import ( + "os" "testing" "time" "github.com/brave-intl/bat-go/libs/ptr" should "github.com/stretchr/testify/assert" + must "github.com/stretchr/testify/require" ) func TestCustodianLink_isLinked(t *testing.T) { @@ -50,10 +52,64 @@ func TestCustodianLink_isLinked(t *testing.T) { }, } - for _, tc := range tests { + for i := range tests { + tc := tests[i] + t.Run(tc.name, func(t *testing.T) { actual := tc.given.cl.isLinked() should.Equal(t, tc.expected, actual) }) } } + +func TestGetEnvMaxCards(t *testing.T) { + type tcGiven struct { + custodian string + key string + value string + } + + type testCase struct { + name string + given tcGiven + exp int + } + + tests := []testCase{ + { + name: "solana", + given: tcGiven{ + custodian: "solana", + key: "SOLANA_WALLET_LINKING_LIMIT", + value: "10", + }, + exp: 10, + }, + { + name: "non_existent_custodian", + given: tcGiven{ + key: "NON_EXISTENT_CUSTODIAN", + value: "10", + }, + exp: 4, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + + t.Cleanup(func() { + err := os.Unsetenv(tc.given.key) + must.NoError(t, err) + }) + + err := os.Setenv(tc.given.key, tc.given.value) + must.NoError(t, err) + + actual := getEnvMaxCards(tc.given.custodian) + should.Equal(t, tc.exp, actual) + }) + } +}