From c5c460daa1cb711c06c5c4f19d44a4bb4486c2de Mon Sep 17 00:00:00 2001 From: Nathan Seva Date: Wed, 25 Oct 2023 11:38:23 -0500 Subject: [PATCH] accounts are now sync.Map type --- internal/handler/wallet/create.go | 2 +- internal/handler/wallet/delete.go | 2 +- internal/handler/wallet/get.go | 6 +-- internal/handler/wallet/get_all_assets.go | 2 +- internal/handler/wallet/import.go | 4 +- .../handler/wallet/nodeFetcherMock_test.go | 2 +- internal/handler/wallet/update.go | 4 +- pkg/network/account.go | 2 +- pkg/network/interface.go | 2 +- pkg/wallet/testUtils.go | 5 ++ pkg/wallet/validation.go | 37 ++++++++++++-- pkg/wallet/wallet.go | 48 ++++++++----------- pkg/wallet/wallet_test.go | 20 ++++---- 13 files changed, 83 insertions(+), 53 deletions(-) diff --git a/internal/handler/wallet/create.go b/internal/handler/wallet/create.go index 60fcd2baa..c718a1fca 100644 --- a/internal/handler/wallet/create.go +++ b/internal/handler/wallet/create.go @@ -65,7 +65,7 @@ func (w *walletCreate) Handle(params operations.CreateAccountParams) middleware. w.prompterApp.EmitEvent(walletapp.PromptResultEvent, walletapp.EventData{Success: true}) - infos, err := w.massaClient.GetAccountsInfos([]account.Account{*acc}) + infos, err := w.massaClient.GetAccountsInfos([]*account.Account{acc}) if err != nil { return operations.NewCreateAccountInternalServerError().WithPayload( &models.Error{ diff --git a/internal/handler/wallet/delete.go b/internal/handler/wallet/delete.go index e9b048cf1..bcf9b4edc 100644 --- a/internal/handler/wallet/delete.go +++ b/internal/handler/wallet/delete.go @@ -35,7 +35,7 @@ func (w *walletDelete) Handle(params operations.DeleteAccountParams) middleware. return errResp } - infos, err := w.massaClient.GetAccountsInfos([]account.Account{*acc}) + infos, err := w.massaClient.GetAccountsInfos([]*account.Account{acc}) if err != nil { return operations.NewDeleteAccountInternalServerError().WithPayload( &models.Error{ diff --git a/internal/handler/wallet/get.go b/internal/handler/wallet/get.go index dd41ea851..899c2269f 100644 --- a/internal/handler/wallet/get.go +++ b/internal/handler/wallet/get.go @@ -38,7 +38,7 @@ func (w *walletGet) Handle(params operations.GetAccountParams) middleware.Respon return errResp } - modelWallet, err := newAccountModel(*acc) + modelWallet, err := newAccountModel(acc) if err != nil { return newErrorResponse(err.Error(), errorGetAccount, http.StatusInternalServerError) } @@ -81,7 +81,7 @@ func (w *walletGet) Handle(params operations.GetAccountParams) middleware.Respon } } - infos, err := w.massaClient.GetAccountsInfos([]account.Account{*acc}) + infos, err := w.massaClient.GetAccountsInfos([]*account.Account{acc}) if err != nil { return operations.NewGetAccountInternalServerError().WithPayload( &models.Error{ @@ -96,7 +96,7 @@ func (w *walletGet) Handle(params operations.GetAccountParams) middleware.Respon return operations.NewGetAccountOK().WithPayload(modelWallet) } -func newAccountModel(acc account.Account) (*models.Account, error) { +func newAccountModel(acc *account.Account) (*models.Account, error) { address, err := acc.Address.MarshalText() if err != nil { return nil, err diff --git a/internal/handler/wallet/get_all_assets.go b/internal/handler/wallet/get_all_assets.go index fbd48bdf0..6b3a0645d 100644 --- a/internal/handler/wallet/get_all_assets.go +++ b/internal/handler/wallet/get_all_assets.go @@ -38,7 +38,7 @@ func (g *getAllAssets) Handle(params operations.GetAllAssetsParams) middleware.R AssetsWithBalance := make([]*models.AssetInfoWithBalance, 0) // Fetch the account information for the wallet using the massaClient - infos, err := g.massaClient.GetAccountsInfos([]account.Account{*acc}) + infos, err := g.massaClient.GetAccountsInfos([]*account.Account{acc}) if err != nil { // Handle the error and return an internal server error response errorMsg := fmt.Sprintf("Failed to fetch balance for asset %s: %s", "MASSA", err.Error()) diff --git a/internal/handler/wallet/import.go b/internal/handler/wallet/import.go index 47feb97f9..1a367afa1 100644 --- a/internal/handler/wallet/import.go +++ b/internal/handler/wallet/import.go @@ -45,7 +45,7 @@ func (w *walletImport) Handle(_ operations.ImportAccountParams) middleware.Respo w.prompterApp.EmitEvent(walletapp.PromptResultEvent, walletapp.EventData{Success: true}) - infos, err := w.massaClient.GetAccountsInfos([]account.Account{*acc}) + infos, err := w.massaClient.GetAccountsInfos([]*account.Account{acc}) if err != nil { return operations.NewImportAccountInternalServerError().WithPayload( &models.Error{ @@ -54,7 +54,7 @@ func (w *walletImport) Handle(_ operations.ImportAccountParams) middleware.Respo }) } - modelWallet, err := newAccountModel(*acc) + modelWallet, err := newAccountModel(acc) if err != nil { return newErrorResponse(err.Error(), errorGetAccount, http.StatusInternalServerError) } diff --git a/internal/handler/wallet/nodeFetcherMock_test.go b/internal/handler/wallet/nodeFetcherMock_test.go index e21db868e..666931023 100644 --- a/internal/handler/wallet/nodeFetcherMock_test.go +++ b/internal/handler/wallet/nodeFetcherMock_test.go @@ -16,7 +16,7 @@ func NewNodeFetcherMock() *NodeFetcherMock { } // returns dummy balances -func (n *NodeFetcherMock) GetAccountsInfos(accounts []account.Account) ([]network.AccountInfos, error) { +func (n *NodeFetcherMock) GetAccountsInfos(accounts []*account.Account) ([]network.AccountInfos, error) { infos := make([]network.AccountInfos, len(accounts)) for i, acc := range accounts { diff --git a/internal/handler/wallet/update.go b/internal/handler/wallet/update.go index e14a62e1c..d0fc64dc8 100644 --- a/internal/handler/wallet/update.go +++ b/internal/handler/wallet/update.go @@ -39,12 +39,12 @@ func (w *walletUpdateAccount) Handle(params operations.UpdateAccountParams) midd return newErrorResponse(err.Error(), "", http.StatusInternalServerError) } - modelWallet, err := newAccountModel(*newAcc) + modelWallet, err := newAccountModel(newAcc) if err != nil { return newErrorResponse(err.Error(), errorGetAccount, http.StatusInternalServerError) } - infos, err := w.massaClient.GetAccountsInfos([]account.Account{*acc}) + infos, err := w.massaClient.GetAccountsInfos([]*account.Account{acc}) if err != nil { return operations.NewGetAccountInternalServerError().WithPayload( &models.Error{ diff --git a/pkg/network/account.go b/pkg/network/account.go index 649d1065d..ce963b5f5 100644 --- a/pkg/network/account.go +++ b/pkg/network/account.go @@ -24,7 +24,7 @@ type AccountInfos struct { Balance uint64 } -func (n *NodeFetcher) GetAccountsInfos(accounts []account.Account) ([]AccountInfos, error) { +func (n *NodeFetcher) GetAccountsInfos(accounts []*account.Account) ([]AccountInfos, error) { client, err := NewMassaClient() if err != nil { return nil, err diff --git a/pkg/network/interface.go b/pkg/network/interface.go index 4fe037098..9a36d5522 100644 --- a/pkg/network/interface.go +++ b/pkg/network/interface.go @@ -12,7 +12,7 @@ func NewNodeFetcher() *NodeFetcher { } type NodeFetcherInterface interface { - GetAccountsInfos(accounts []account.Account) ([]AccountInfos, error) + GetAccountsInfos(accounts []*account.Account) ([]AccountInfos, error) MakeOperation(fee uint64, operation sendOperation.Operation) ([]byte, error) MakeRPCCall(msg []byte, signature []byte, publicKey string) ([]string, error) AssetExistInNetwork(contractAddress string) bool diff --git a/pkg/wallet/testUtils.go b/pkg/wallet/testUtils.go index 25e47c1ca..26a6366f5 100644 --- a/pkg/wallet/testUtils.go +++ b/pkg/wallet/testUtils.go @@ -9,6 +9,11 @@ import ( "github.com/stretchr/testify/assert" ) +// Get the number of accounts in the wallet +func (w *Wallet) GetAccountCount() int { + return len(w.AllAccounts()) +} + func ClearAccounts(t *testing.T, walletPath string) { files, err := os.ReadDir(walletPath) assert.NoError(t, err) diff --git a/pkg/wallet/validation.go b/pkg/wallet/validation.go index bf2adb69a..dd304a2f4 100644 --- a/pkg/wallet/validation.go +++ b/pkg/wallet/validation.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/massalabs/station-massa-wallet/pkg/types" + "github.com/massalabs/station-massa-wallet/pkg/wallet/account" ) var ( @@ -15,20 +16,48 @@ var ( ) func (w *Wallet) NicknameIsUnique(nickname string) error { - for _, account := range w.accounts { + var duplicateNickname string + + w.accounts.Range(func(_, acc interface{}) bool { + account, ok := acc.(*account.Account) + if !ok { + return true + } + if strings.EqualFold(account.Nickname, nickname) { - return fmt.Errorf("%w: %s", ErrNicknameNotUnique, nickname) + duplicateNickname = nickname + return false } + + return true + }) + + if duplicateNickname != "" { + return fmt.Errorf("%w: %s", ErrNicknameNotUnique, duplicateNickname) } return nil } func (w *Wallet) AddressIsUnique(address *types.Address) error { - for _, account := range w.accounts { + duplicateAddress := false + + w.accounts.Range(func(_, acc interface{}) bool { + account, ok := acc.(*account.Account) + if !ok { + return true + } + if bytes.Equal(account.Address.Data, address.Data) { - return ErrAddressNotUnique + duplicateAddress = true + return false } + + return true + }) + + if duplicateAddress { + return ErrAddressNotUnique } return nil diff --git a/pkg/wallet/wallet.go b/pkg/wallet/wallet.go index c8555a9ca..954e067a9 100644 --- a/pkg/wallet/wallet.go +++ b/pkg/wallet/wallet.go @@ -20,15 +20,14 @@ var ( ) type Wallet struct { - accounts map[string]*account.Account // Mapping from nickname to account - InvalidAccountNicknames []string // List of invalid account nicknames + accounts *sync.Map // Mapping from nickname to account + InvalidAccountNicknames []string // List of invalid account nicknames WalletPath string - mutex sync.Mutex } func New(walletPath string) (*Wallet, error) { wallet := &Wallet{ - accounts: make(map[string]*account.Account), + accounts: &sync.Map{}, } if walletPath == "" { @@ -66,7 +65,9 @@ func (w *Wallet) Discover() error { if strings.HasPrefix(fileName, "wallet_") && strings.HasSuffix(fileName, ".yaml") { nickname := w.nicknameFromFilePath(filePath) - if w.accounts[nickname] != nil { + + _, found := w.accounts.Load(nickname) + if found { continue } @@ -116,17 +117,11 @@ func (w *Wallet) AddAccount(acc *account.Account, persist bool) error { } } - w.addAccount(acc) + w.accounts.Store(acc.Nickname, acc) return nil } -func (w *Wallet) addAccount(acc *account.Account) { - w.mutex.Lock() - w.accounts[acc.Nickname] = acc - w.mutex.Unlock() -} - // GenerateAccount generates a new account and adds it to the wallet. // It returns the generated account. // It destroys the password. @@ -146,8 +141,8 @@ func (w *Wallet) GenerateAccount(password *memguard.LockedBuffer, nickname strin // Get an account from the wallet by nickname func (w *Wallet) GetAccount(nickname string) (*account.Account, error) { - if w.accounts[nickname] != nil { - return w.accounts[nickname], nil + if acc, found := w.accounts.Load(nickname); found { + return acc.(*account.Account), nil } accountPath, err := w.AccountPath(nickname) @@ -171,7 +166,7 @@ func (w *Wallet) GetAccount(nickname string) (*account.Account, error) { // Delete an account from the wallet func (w *Wallet) DeleteAccount(nickname string) error { - if w.accounts[nickname] == nil { + if _, found := w.accounts.Load(nickname); !found { return AccountNotFoundError } @@ -185,24 +180,23 @@ func (w *Wallet) DeleteAccount(nickname string) error { return fmt.Errorf("deleting account file: %w", err) } - delete(w.accounts, nickname) + w.accounts.Delete(nickname) return nil } -// Get the number of accounts in the wallet -func (w *Wallet) GetAccountCount() int { - return len(w.accounts) -} - -func (w *Wallet) AllAccounts() []account.Account { - accounts := make([]account.Account, 0, len(w.accounts)) +func (w *Wallet) AllAccounts() []*account.Account { + var accounts []*account.Account - for _, acc := range w.accounts { - accounts = append(accounts, *acc) - } + w.accounts.Range(func(_, value interface{}) bool { + acc, ok := value.(*account.Account) + if ok { + accounts = append(accounts, acc) + } + return true + }) - sort.SliceStable(accounts, func(i, j int) bool { + sort.Slice(accounts, func(i, j int) bool { return accounts[i].Nickname < accounts[j].Nickname }) diff --git a/pkg/wallet/wallet_test.go b/pkg/wallet/wallet_test.go index dacdc0b79..dfdf58a56 100644 --- a/pkg/wallet/wallet_test.go +++ b/pkg/wallet/wallet_test.go @@ -24,7 +24,9 @@ func TestWallet(t *testing.T) { var w *Wallet sampleSalt := [16]byte{145, 114, 211, 33, 247, 163, 215, 171, 90, 186, 97, 47, 43, 252, 68, 170} sampleNonce := [12]byte{113, 122, 168, 123, 48, 187, 178, 12, 209, 91, 243, 63} - sampleNickname := "bonjour2" + sampleNickname := "bonjour" + sampleNickname2 := "unit-test" + sampleNickname3 := "version-0" sampleAccount, err := account.New( uint8(account.LastVersion), sampleNickname, @@ -115,7 +117,7 @@ func TestWallet(t *testing.T) { assert.Equal(t, 1, w.GetAccountCount()) - sampleAccount.Nickname = "bonjour2" + sampleAccount.Nickname = sampleNickname }) t.Run("Get Account", func(t *testing.T) { @@ -146,27 +148,27 @@ func TestWallet(t *testing.T) { t.Run("Get Account: new file added manually", func(t *testing.T) { // User can add an account file in the account folder by its own, // we want to wallet to be able to manage this account. - nickname := "unit-test" - accountPath, err := w.AccountPath(nickname) + accountPath, err := w.AccountPath(sampleNickname2) assert.NoError(t, err) copy(t, "../../tests/wallet_unit-test.yaml", accountPath) - acc := assertAccountIsPresent(t, w, nickname) + assertAccountIsPresent(t, w, sampleNickname) + acc := assertAccountIsPresent(t, w, sampleNickname2) assert.Equal(t, uint8(1), acc.Version) assert.Equal(t, 2, w.GetAccountCount()) }) t.Run("Invalid or unsupported version", func(t *testing.T) { - nickname := "version-0" - accountPath, err := w.AccountPath(nickname) + accountPath, err := w.AccountPath(sampleNickname3) assert.NoError(t, err) copy(t, "../../tests/wallet_version-0.yaml", accountPath) newWallet, err := New(walletPath) assert.NoError(t, err) - assertAccountIsPresent(t, newWallet, "unit-test") - assert.Len(t, newWallet.accounts, 2) + assertAccountIsPresent(t, w, sampleNickname) + assertAccountIsPresent(t, newWallet, sampleNickname2) + assert.Equal(t, 2, newWallet.GetAccountCount()) assert.Len(t, newWallet.InvalidAccountNicknames, 1) })