diff --git a/smt/pkg/smt/witness.go b/smt/pkg/smt/witness.go index e496a9b0749..1d67c8330ca 100644 --- a/smt/pkg/smt/witness.go +++ b/smt/pkg/smt/witness.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "math/big" + "reflect" + "strings" libcommon "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon/smt/pkg/utils" @@ -144,48 +146,55 @@ func BuildSMTfromWitness(w *trie.Witness) (*SMT, error) { switch op := operator.(type) { case *trie.OperatorSMTLeafValue: valScaler := new(big.Int).SetBytes(op.Value) - addr := libcommon.BytesToAddress(op.Address) + hexAddr := libcommon.BytesToAddress(op.Address).String() switch op.NodeType { case utils.KEY_BALANCE: - balanceMap[addr.String()] = valScaler + balanceMap[hexAddr] = valScaler case utils.KEY_NONCE: - nonceMap[addr.String()] = valScaler + nonceMap[hexAddr] = valScaler case utils.SC_STORAGE: - if _, ok := storageMap[addr.String()]; !ok { - storageMap[addr.String()] = make(map[string]string) + if _, ok := storageMap[hexAddr]; !ok { + storageMap[hexAddr] = make(map[string]string) } stKey := hexutils.BytesToHex(op.StorageKey) stKey = fmt.Sprintf("0x%s", stKey) - storageMap[addr.String()][stKey] = valScaler.String() + storageMap[hexAddr][stKey] = valScaler.String() } path = path[:len(path)-1] - nodeChildCountMap[intArrayToString(path)] += 1 + nodePathAsString := intArrayToString(path) + nodeChildCountMap[nodePathAsString] += 1 - for len(path) != 0 && nodeChildCountMap[intArrayToString(path)] == nodesBranchValueMap[intArrayToString(path)] { + for len(path) != 0 && nodeChildCountMap[nodePathAsString] == nodesBranchValueMap[nodePathAsString] { path = path[:len(path)-1] + nodePathAsString = intArrayToString(path) } - if nodeChildCountMap[intArrayToString(path)] < nodesBranchValueMap[intArrayToString(path)] { + + if nodeChildCountMap[nodePathAsString] < nodesBranchValueMap[nodePathAsString] { path = append(path, 1) } case *trie.OperatorCode: - addr := libcommon.BytesToAddress(w.Operators[i+1].(*trie.OperatorSMTLeafValue).Address) + smtLeafValueOp, ok := w.Operators[i+1].(*trie.OperatorSMTLeafValue) + if !ok { + return nil, fmt.Errorf("expected %T, but found %T witness operator", (*trie.OperatorSMTLeafValue)(nil), reflect.TypeOf(smtLeafValueOp)) + } + hexAddr := libcommon.BytesToAddress(smtLeafValueOp.Address).String() code := hexutils.BytesToHex(op.Code) if len(code) > 0 { - if err := s.Db.AddCode(hexutils.HexToBytes(code)); err != nil { + if err := s.Db.AddCode(op.Code); err != nil { return nil, err } code = fmt.Sprintf("0x%s", code) } - contractMap[addr.String()] = code + contractMap[hexAddr] = code case *trie.OperatorBranch: if firstNode { @@ -271,9 +280,9 @@ func BuildSMTfromWitness(w *trie.Witness) (*SMT, error) { } func intArrayToString(a []int) string { - s := "" + var s strings.Builder for _, v := range a { - s += fmt.Sprintf("%d", v) + s.WriteString(fmt.Sprintf("%d", v)) } - return s + return s.String() }