Skip to content

Commit

Permalink
accounts are now sync.Map type
Browse files Browse the repository at this point in the history
  • Loading branch information
Thykof committed Oct 25, 2023
1 parent 51b14ab commit c5c460d
Show file tree
Hide file tree
Showing 13 changed files with 83 additions and 53 deletions.
2 changes: 1 addition & 1 deletion internal/handler/wallet/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (w *walletCreate) Handle(params operations.CreateAccountParams) middleware.
w.prompterApp.EmitEvent(walletapp.PromptResultEvent,
walletapp.EventData{Success: true})

infos, err := w.massaClient.GetAccountsInfos([]account.Account{*acc})
infos, err := w.massaClient.GetAccountsInfos([]*account.Account{acc})
if err != nil {
return operations.NewCreateAccountInternalServerError().WithPayload(
&models.Error{
Expand Down
2 changes: 1 addition & 1 deletion internal/handler/wallet/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (w *walletDelete) Handle(params operations.DeleteAccountParams) middleware.
return errResp
}

infos, err := w.massaClient.GetAccountsInfos([]account.Account{*acc})
infos, err := w.massaClient.GetAccountsInfos([]*account.Account{acc})
if err != nil {
return operations.NewDeleteAccountInternalServerError().WithPayload(
&models.Error{
Expand Down
6 changes: 3 additions & 3 deletions internal/handler/wallet/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (w *walletGet) Handle(params operations.GetAccountParams) middleware.Respon
return errResp
}

modelWallet, err := newAccountModel(*acc)
modelWallet, err := newAccountModel(acc)
if err != nil {
return newErrorResponse(err.Error(), errorGetAccount, http.StatusInternalServerError)
}
Expand Down Expand Up @@ -81,7 +81,7 @@ func (w *walletGet) Handle(params operations.GetAccountParams) middleware.Respon
}
}

infos, err := w.massaClient.GetAccountsInfos([]account.Account{*acc})
infos, err := w.massaClient.GetAccountsInfos([]*account.Account{acc})
if err != nil {
return operations.NewGetAccountInternalServerError().WithPayload(
&models.Error{
Expand All @@ -96,7 +96,7 @@ func (w *walletGet) Handle(params operations.GetAccountParams) middleware.Respon
return operations.NewGetAccountOK().WithPayload(modelWallet)
}

func newAccountModel(acc account.Account) (*models.Account, error) {
func newAccountModel(acc *account.Account) (*models.Account, error) {
address, err := acc.Address.MarshalText()
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion internal/handler/wallet/get_all_assets.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (g *getAllAssets) Handle(params operations.GetAllAssetsParams) middleware.R
AssetsWithBalance := make([]*models.AssetInfoWithBalance, 0)

// Fetch the account information for the wallet using the massaClient
infos, err := g.massaClient.GetAccountsInfos([]account.Account{*acc})
infos, err := g.massaClient.GetAccountsInfos([]*account.Account{acc})
if err != nil {
// Handle the error and return an internal server error response
errorMsg := fmt.Sprintf("Failed to fetch balance for asset %s: %s", "MASSA", err.Error())
Expand Down
4 changes: 2 additions & 2 deletions internal/handler/wallet/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (w *walletImport) Handle(_ operations.ImportAccountParams) middleware.Respo
w.prompterApp.EmitEvent(walletapp.PromptResultEvent,
walletapp.EventData{Success: true})

infos, err := w.massaClient.GetAccountsInfos([]account.Account{*acc})
infos, err := w.massaClient.GetAccountsInfos([]*account.Account{acc})
if err != nil {
return operations.NewImportAccountInternalServerError().WithPayload(
&models.Error{
Expand All @@ -54,7 +54,7 @@ func (w *walletImport) Handle(_ operations.ImportAccountParams) middleware.Respo
})
}

modelWallet, err := newAccountModel(*acc)
modelWallet, err := newAccountModel(acc)
if err != nil {
return newErrorResponse(err.Error(), errorGetAccount, http.StatusInternalServerError)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handler/wallet/nodeFetcherMock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func NewNodeFetcherMock() *NodeFetcherMock {
}

// returns dummy balances
func (n *NodeFetcherMock) GetAccountsInfos(accounts []account.Account) ([]network.AccountInfos, error) {
func (n *NodeFetcherMock) GetAccountsInfos(accounts []*account.Account) ([]network.AccountInfos, error) {
infos := make([]network.AccountInfos, len(accounts))

for i, acc := range accounts {
Expand Down
4 changes: 2 additions & 2 deletions internal/handler/wallet/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ func (w *walletUpdateAccount) Handle(params operations.UpdateAccountParams) midd
return newErrorResponse(err.Error(), "", http.StatusInternalServerError)
}

modelWallet, err := newAccountModel(*newAcc)
modelWallet, err := newAccountModel(newAcc)
if err != nil {
return newErrorResponse(err.Error(), errorGetAccount, http.StatusInternalServerError)
}

infos, err := w.massaClient.GetAccountsInfos([]account.Account{*acc})
infos, err := w.massaClient.GetAccountsInfos([]*account.Account{acc})
if err != nil {
return operations.NewGetAccountInternalServerError().WithPayload(
&models.Error{
Expand Down
2 changes: 1 addition & 1 deletion pkg/network/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type AccountInfos struct {
Balance uint64
}

func (n *NodeFetcher) GetAccountsInfos(accounts []account.Account) ([]AccountInfos, error) {
func (n *NodeFetcher) GetAccountsInfos(accounts []*account.Account) ([]AccountInfos, error) {
client, err := NewMassaClient()
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion pkg/network/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func NewNodeFetcher() *NodeFetcher {
}

type NodeFetcherInterface interface {
GetAccountsInfos(accounts []account.Account) ([]AccountInfos, error)
GetAccountsInfos(accounts []*account.Account) ([]AccountInfos, error)
MakeOperation(fee uint64, operation sendOperation.Operation) ([]byte, error)
MakeRPCCall(msg []byte, signature []byte, publicKey string) ([]string, error)
AssetExistInNetwork(contractAddress string) bool
Expand Down
5 changes: 5 additions & 0 deletions pkg/wallet/testUtils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import (
"github.com/stretchr/testify/assert"
)

// Get the number of accounts in the wallet
func (w *Wallet) GetAccountCount() int {
return len(w.AllAccounts())
}

func ClearAccounts(t *testing.T, walletPath string) {
files, err := os.ReadDir(walletPath)
assert.NoError(t, err)
Expand Down
37 changes: 33 additions & 4 deletions pkg/wallet/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/massalabs/station-massa-wallet/pkg/types"
"github.com/massalabs/station-massa-wallet/pkg/wallet/account"
)

var (
Expand All @@ -15,20 +16,48 @@ var (
)

func (w *Wallet) NicknameIsUnique(nickname string) error {
for _, account := range w.accounts {
var duplicateNickname string

w.accounts.Range(func(_, acc interface{}) bool {
account, ok := acc.(*account.Account)
if !ok {
return true
}

if strings.EqualFold(account.Nickname, nickname) {
return fmt.Errorf("%w: %s", ErrNicknameNotUnique, nickname)
duplicateNickname = nickname
return false
}

return true
})

if duplicateNickname != "" {
return fmt.Errorf("%w: %s", ErrNicknameNotUnique, duplicateNickname)
}

return nil
}

func (w *Wallet) AddressIsUnique(address *types.Address) error {
for _, account := range w.accounts {
duplicateAddress := false

w.accounts.Range(func(_, acc interface{}) bool {
account, ok := acc.(*account.Account)
if !ok {
return true
}

if bytes.Equal(account.Address.Data, address.Data) {
return ErrAddressNotUnique
duplicateAddress = true
return false
}

return true
})

if duplicateAddress {
return ErrAddressNotUnique
}

return nil
Expand Down
48 changes: 21 additions & 27 deletions pkg/wallet/wallet.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ var (
)

type Wallet struct {
accounts map[string]*account.Account // Mapping from nickname to account
InvalidAccountNicknames []string // List of invalid account nicknames
accounts *sync.Map // Mapping from nickname to account
InvalidAccountNicknames []string // List of invalid account nicknames
WalletPath string
mutex sync.Mutex
}

func New(walletPath string) (*Wallet, error) {
wallet := &Wallet{
accounts: make(map[string]*account.Account),
accounts: &sync.Map{},
}

if walletPath == "" {
Expand Down Expand Up @@ -66,7 +65,9 @@ func (w *Wallet) Discover() error {

if strings.HasPrefix(fileName, "wallet_") && strings.HasSuffix(fileName, ".yaml") {
nickname := w.nicknameFromFilePath(filePath)
if w.accounts[nickname] != nil {

_, found := w.accounts.Load(nickname)
if found {
continue
}

Expand Down Expand Up @@ -116,17 +117,11 @@ func (w *Wallet) AddAccount(acc *account.Account, persist bool) error {
}
}

w.addAccount(acc)
w.accounts.Store(acc.Nickname, acc)

return nil
}

func (w *Wallet) addAccount(acc *account.Account) {
w.mutex.Lock()
w.accounts[acc.Nickname] = acc
w.mutex.Unlock()
}

// GenerateAccount generates a new account and adds it to the wallet.
// It returns the generated account.
// It destroys the password.
Expand All @@ -146,8 +141,8 @@ func (w *Wallet) GenerateAccount(password *memguard.LockedBuffer, nickname strin

// Get an account from the wallet by nickname
func (w *Wallet) GetAccount(nickname string) (*account.Account, error) {
if w.accounts[nickname] != nil {
return w.accounts[nickname], nil
if acc, found := w.accounts.Load(nickname); found {
return acc.(*account.Account), nil
}

accountPath, err := w.AccountPath(nickname)
Expand All @@ -171,7 +166,7 @@ func (w *Wallet) GetAccount(nickname string) (*account.Account, error) {

// Delete an account from the wallet
func (w *Wallet) DeleteAccount(nickname string) error {
if w.accounts[nickname] == nil {
if _, found := w.accounts.Load(nickname); !found {
return AccountNotFoundError
}

Expand All @@ -185,24 +180,23 @@ func (w *Wallet) DeleteAccount(nickname string) error {
return fmt.Errorf("deleting account file: %w", err)
}

delete(w.accounts, nickname)
w.accounts.Delete(nickname)

return nil
}

// Get the number of accounts in the wallet
func (w *Wallet) GetAccountCount() int {
return len(w.accounts)
}

func (w *Wallet) AllAccounts() []account.Account {
accounts := make([]account.Account, 0, len(w.accounts))
func (w *Wallet) AllAccounts() []*account.Account {
var accounts []*account.Account

for _, acc := range w.accounts {
accounts = append(accounts, *acc)
}
w.accounts.Range(func(_, value interface{}) bool {
acc, ok := value.(*account.Account)
if ok {
accounts = append(accounts, acc)
}
return true
})

sort.SliceStable(accounts, func(i, j int) bool {
sort.Slice(accounts, func(i, j int) bool {
return accounts[i].Nickname < accounts[j].Nickname
})

Expand Down
20 changes: 11 additions & 9 deletions pkg/wallet/wallet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ func TestWallet(t *testing.T) {
var w *Wallet
sampleSalt := [16]byte{145, 114, 211, 33, 247, 163, 215, 171, 90, 186, 97, 47, 43, 252, 68, 170}
sampleNonce := [12]byte{113, 122, 168, 123, 48, 187, 178, 12, 209, 91, 243, 63}
sampleNickname := "bonjour2"
sampleNickname := "bonjour"
sampleNickname2 := "unit-test"
sampleNickname3 := "version-0"
sampleAccount, err := account.New(
uint8(account.LastVersion),
sampleNickname,
Expand Down Expand Up @@ -115,7 +117,7 @@ func TestWallet(t *testing.T) {

assert.Equal(t, 1, w.GetAccountCount())

sampleAccount.Nickname = "bonjour2"
sampleAccount.Nickname = sampleNickname
})

t.Run("Get Account", func(t *testing.T) {
Expand Down Expand Up @@ -146,27 +148,27 @@ func TestWallet(t *testing.T) {
t.Run("Get Account: new file added manually", func(t *testing.T) {
// User can add an account file in the account folder by its own,
// we want to wallet to be able to manage this account.
nickname := "unit-test"

accountPath, err := w.AccountPath(nickname)
accountPath, err := w.AccountPath(sampleNickname2)
assert.NoError(t, err)

copy(t, "../../tests/wallet_unit-test.yaml", accountPath)

acc := assertAccountIsPresent(t, w, nickname)
assertAccountIsPresent(t, w, sampleNickname)
acc := assertAccountIsPresent(t, w, sampleNickname2)
assert.Equal(t, uint8(1), acc.Version)
assert.Equal(t, 2, w.GetAccountCount())
})

t.Run("Invalid or unsupported version", func(t *testing.T) {
nickname := "version-0"
accountPath, err := w.AccountPath(nickname)
accountPath, err := w.AccountPath(sampleNickname3)
assert.NoError(t, err)
copy(t, "../../tests/wallet_version-0.yaml", accountPath)
newWallet, err := New(walletPath)
assert.NoError(t, err)
assertAccountIsPresent(t, newWallet, "unit-test")
assert.Len(t, newWallet.accounts, 2)
assertAccountIsPresent(t, w, sampleNickname)
assertAccountIsPresent(t, newWallet, sampleNickname2)
assert.Equal(t, 2, newWallet.GetAccountCount())
assert.Len(t, newWallet.InvalidAccountNicknames, 1)
})

Expand Down

0 comments on commit c5c460d

Please sign in to comment.