Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for L1 MPT #1104

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions cmd/geth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ var (
utils.ScrollAlphaFlag,
utils.ScrollSepoliaFlag,
utils.ScrollFlag,
utils.ScrollMPTFlag,
utils.VMEnableDebugFlag,
utils.NetworkIdFlag,
utils.EthStatsURLFlag,
Expand Down
1 change: 1 addition & 0 deletions cmd/geth/usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ var AppHelpFlagGroups = []flags.FlagGroup{
utils.ScrollAlphaFlag,
utils.ScrollSepoliaFlag,
utils.ScrollFlag,
utils.ScrollMPTFlag,
utils.SyncModeFlag,
utils.ExitWhenSyncedFlag,
utils.GCModeFlag,
Expand Down
30 changes: 20 additions & 10 deletions cmd/utils/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ var (
Name: "scroll",
Usage: "Scroll mainnet",
}
ScrollMPTFlag = cli.BoolFlag{
Name: "scroll-mpt",
Usage: "Use MPT trie for state storage",
}
DeveloperFlag = cli.BoolFlag{
Name: "dev",
Usage: "Ephemeral proof-of-authority network with a pre-funded developer account, mining enabled",
Expand Down Expand Up @@ -1879,12 +1883,15 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *ethconfig.Config) {
stack.Config().L1Confirmations = rpc.FinalizedBlockNumber
log.Info("Setting flag", "--l1.sync.startblock", "4038000")
stack.Config().L1DeploymentBlock = 4038000
// disable pruning
if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive {
log.Crit("Must use --gcmode=archive")
cfg.Genesis.Config.Scroll.UseZktrie = !ctx.GlobalBool(ScrollMPTFlag.Name)
if cfg.Genesis.Config.Scroll.UseZktrie {
// disable pruning
if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive {
log.Crit("Must use --gcmode=archive")
}
log.Info("Pruning disabled")
cfg.NoPruning = true
}
log.Info("Pruning disabled")
cfg.NoPruning = true
case ctx.GlobalBool(ScrollFlag.Name):
if !ctx.GlobalIsSet(NetworkIdFlag.Name) {
cfg.NetworkId = 534352
Expand All @@ -1895,12 +1902,15 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *ethconfig.Config) {
stack.Config().L1Confirmations = rpc.FinalizedBlockNumber
log.Info("Setting flag", "--l1.sync.startblock", "18306000")
stack.Config().L1DeploymentBlock = 18306000
// disable pruning
if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive {
log.Crit("Must use --gcmode=archive")
cfg.Genesis.Config.Scroll.UseZktrie = !ctx.GlobalBool(ScrollMPTFlag.Name)
if cfg.Genesis.Config.Scroll.UseZktrie {
// disable pruning
if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive {
log.Crit("Must use --gcmode=archive")
}
log.Info("Pruning disabled")
cfg.NoPruning = true
}
log.Info("Pruning disabled")
cfg.NoPruning = true
case ctx.GlobalBool(DeveloperFlag.Name):
if !ctx.GlobalIsSet(NetworkIdFlag.Name) {
cfg.NetworkId = 1337
Expand Down
3 changes: 2 additions & 1 deletion core/block_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ func (v *BlockValidator) ValidateState(block *types.Block, statedb *state.StateD
}
// Validate the state root against the received state root and throw
// an error if they don't match.
if root := statedb.IntermediateRoot(v.config.IsEIP158(header.Number)); header.Root != root {
shouldValidateStateRoot := v.config.Scroll.UseZktrie != v.config.IsEuclid(header.Time)
if root := statedb.IntermediateRoot(v.config.IsEIP158(header.Number)); shouldValidateStateRoot && header.Root != root {
return fmt.Errorf("invalid merkle root (remote: %x local: %x)", header.Root, root)
}
return nil
Expand Down
7 changes: 5 additions & 2 deletions core/blockchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,9 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.
return NonStatTy, err
}
triedb := bc.stateCache.TrieDB()
if block.Root() != root {
rawdb.WriteDiskStateRoot(bc.db, block.Root(), root)
}

// If we're running an archive node, always flush
if bc.cacheConfig.TrieDirtyDisabled {
Expand Down Expand Up @@ -1677,7 +1680,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er
}

// Enable prefetching to pull in trie node paths while processing transactions
statedb.StartPrefetcher("chain")
statedb.StartPrefetcher("chain", nil)
activeState = statedb

// If we have a followup block, run that against the current state to pre-cache
Expand Down Expand Up @@ -1814,7 +1817,7 @@ func (bc *BlockChain) BuildAndWriteBlock(parentBlock *types.Block, header *types
return NonStatTy, err
}

statedb.StartPrefetcher("l1sync")
statedb.StartPrefetcher("l1sync", nil)
defer statedb.StopPrefetcher()

header.ParentHash = parentBlock.Hash()
Expand Down
5 changes: 4 additions & 1 deletion core/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ func (g *Genesis) ToBlock(db ethdb.Database) *types.Block {
}
statedb.Commit(false)
statedb.Database().TrieDB().Commit(root, true, nil)

if g.Config.Scroll.GenesisStateRoot != nil {
head.Root = *g.Config.Scroll.GenesisStateRoot
rawdb.WriteDiskStateRoot(db, head.Root, root)
}
return types.NewBlock(head, nil, nil, nil, trie.NewStackTrie(nil))
}

Expand Down
14 changes: 14 additions & 0 deletions core/rawdb/accessors_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,17 @@ func DeleteTrieNode(db ethdb.KeyValueWriter, hash common.Hash) {
log.Crit("Failed to delete trie node", "err", err)
}
}

func WriteDiskStateRoot(db ethdb.KeyValueWriter, headerRoot, diskRoot common.Hash) {
if err := db.Put(diskStateRootKey(headerRoot), diskRoot.Bytes()); err != nil {
log.Crit("Failed to store disk state root", "err", err)
}
}

func ReadDiskStateRoot(db ethdb.KeyValueReader, headerRoot common.Hash) (common.Hash, error) {
data, err := db.Get(diskStateRootKey(headerRoot))
if err != nil {
return common.Hash{}, err
}
return common.BytesToHash(data), nil
}
6 changes: 6 additions & 0 deletions core/rawdb/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ var (

// Scroll da syncer store
daSyncedL1BlockNumberKey = []byte("LastDASyncedL1BlockNumber")

diskStateRootPrefix = []byte("disk-state-root")
)

// Use the updated "L1" prefix on all new networks
Expand Down Expand Up @@ -312,3 +314,7 @@ func batchMetaKey(batchIndex uint64) []byte {
func committedBatchMetaKey(batchIndex uint64) []byte {
return append(committedBatchMetaPrefix, encodeBigEndian(batchIndex)...)
}

func diskStateRootKey(headerRoot common.Hash) []byte {
return append(diskStateRootPrefix, headerRoot.Bytes()...)
}
6 changes: 6 additions & 0 deletions core/state/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ type Trie interface {
// nodes of the longest existing prefix of the key (at least the root), ending
// with the node that proves the absence of the key.
Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error

// Witness returns a set containing all trie nodes that have been accessed.
Witness() map[string]struct{}
}

// NewDatabase creates a backing store for state. The returned database is safe for
Expand Down Expand Up @@ -136,6 +139,9 @@ type cachingDB struct {

// OpenTrie opens the main account trie at a specific root hash.
func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
if diskRoot, err := rawdb.ReadDiskStateRoot(db.db.DiskDB(), root); err == nil {
root = diskRoot
}
if db.zktrie {
tr, err := trie.NewZkTrie(root, trie.NewZktrieDatabaseFromTriedb(db.db))
if err != nil {
Expand Down
16 changes: 13 additions & 3 deletions core/state/state_object.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,18 @@ func (s *stateObject) Code(db Database) []byte {
// CodeSize returns the size of the contract code associated with this object,
// or zero if none. This method is an almost mirror of Code, but uses a cache
// inside the database to avoid loading codes seen recently.
func (s *stateObject) CodeSize() uint64 {
return s.data.CodeSize
func (s *stateObject) CodeSize(db Database) uint64 {
if s.code != nil {
return uint64(len(s.code))
}
if bytes.Equal(s.KeccakCodeHash(), emptyKeccakCodeHash) {
return 0
}
size, err := db.ContractCodeSize(s.addrHash, common.BytesToHash(s.KeccakCodeHash()))
if err != nil {
s.setError(fmt.Errorf("can't load code size %x: %v", s.KeccakCodeHash(), err))
}
return uint64(size)
}

func (s *stateObject) SetCode(code []byte) {
Expand Down Expand Up @@ -534,7 +544,7 @@ func (s *stateObject) setNonce(nonce uint64) {
}

func (s *stateObject) PoseidonCodeHash() []byte {
return s.data.PoseidonCodeHash
panic("shouldn't be called")
omerfirmak marked this conversation as resolved.
Show resolved Hide resolved
}

func (s *stateObject) KeccakCodeHash() []byte {
Expand Down
11 changes: 6 additions & 5 deletions core/state/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ func TestSnapshotEmpty(t *testing.T) {
}

func TestSnapshot2(t *testing.T) {
state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
stateDb := NewDatabase(rawdb.NewMemoryDatabase())
state, _ := New(common.Hash{}, stateDb, nil)

stateobjaddr0 := common.BytesToAddress([]byte("so0"))
stateobjaddr1 := common.BytesToAddress([]byte("so1"))
Expand Down Expand Up @@ -201,7 +202,7 @@ func TestSnapshot2(t *testing.T) {
so0Restored.GetState(state.db, storageaddr)
so0Restored.Code(state.db)
// non-deleted is equal (restored)
compareStateObjects(so0Restored, so0, t)
compareStateObjects(so0Restored, so0, stateDb, t)

// deleted should be nil, both before and after restore of state copy
so1Restored := state.getStateObject(stateobjaddr1)
Expand All @@ -210,7 +211,7 @@ func TestSnapshot2(t *testing.T) {
}
}

func compareStateObjects(so0, so1 *stateObject, t *testing.T) {
func compareStateObjects(so0, so1 *stateObject, db Database, t *testing.T) {
if so0.Address() != so1.Address() {
t.Fatalf("Address mismatch: have %v, want %v", so0.address, so1.address)
}
Expand All @@ -229,8 +230,8 @@ func compareStateObjects(so0, so1 *stateObject, t *testing.T) {
if !bytes.Equal(so0.PoseidonCodeHash(), so1.PoseidonCodeHash()) {
t.Fatalf("PoseidonCodeHash mismatch: have %v, want %v", so0.PoseidonCodeHash(), so1.PoseidonCodeHash())
}
if so0.CodeSize() != so1.CodeSize() {
t.Fatalf("CodeSize mismatch: have %v, want %v", so0.CodeSize(), so1.CodeSize())
if so0.CodeSize(db) != so1.CodeSize(db) {
t.Fatalf("CodeSize mismatch: have %v, want %v", so0.CodeSize(db), so1.CodeSize(db))
}
if !bytes.Equal(so0.code, so1.code) {
t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code)
Expand Down
55 changes: 52 additions & 3 deletions core/state/statedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/scroll-tech/go-ethereum/common"
"github.com/scroll-tech/go-ethereum/core/rawdb"
"github.com/scroll-tech/go-ethereum/core/state/snapshot"
"github.com/scroll-tech/go-ethereum/core/stateless"
"github.com/scroll-tech/go-ethereum/core/types"
"github.com/scroll-tech/go-ethereum/crypto"
"github.com/scroll-tech/go-ethereum/log"
Expand Down Expand Up @@ -106,6 +107,9 @@ type StateDB struct {
validRevisions []revision
nextRevisionId int

// State witness if cross validation is needed
witness *stateless.Witness

// Measurements gathered during execution for debugging purposes
AccountReads time.Duration
AccountHashes time.Duration
Expand Down Expand Up @@ -159,11 +163,15 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error)
// StartPrefetcher initializes a new trie prefetcher to pull in nodes from the
// state trie concurrently while the state is mutated so that when we reach the
// commit phase, most of the needed data is already hot.
func (s *StateDB) StartPrefetcher(namespace string) {
func (s *StateDB) StartPrefetcher(namespace string, witness *stateless.Witness) {
if s.prefetcher != nil {
s.prefetcher.close()
s.prefetcher = nil
}

// Enable witness collection if requested
s.witness = witness

if s.snap != nil {
s.prefetcher = newTriePrefetcher(s.db, s.originalRoot, namespace)
}
Expand Down Expand Up @@ -289,6 +297,9 @@ func (s *StateDB) TxIndex() int {
func (s *StateDB) GetCode(addr common.Address) []byte {
stateObject := s.getStateObject(addr)
if stateObject != nil {
if s.witness != nil {
s.witness.AddCode(stateObject.Code(s.db))
}
return stateObject.Code(s.db)
}
return nil
Expand All @@ -297,7 +308,10 @@ func (s *StateDB) GetCode(addr common.Address) []byte {
func (s *StateDB) GetCodeSize(addr common.Address) uint64 {
stateObject := s.getStateObject(addr)
if stateObject != nil {
return stateObject.CodeSize()
if s.witness != nil {
s.witness.AddCode(stateObject.Code(s.db))
}
return stateObject.CodeSize(s.db)
}
return 0
}
Expand Down Expand Up @@ -725,6 +739,9 @@ func (s *StateDB) Copy() *StateDB {
journal: newJournal(),
hasher: crypto.NewKeccakState(),
}
if s.witness != nil {
state.witness = s.witness.Copy()
}
// Copy the dirty states, logs, and preimages
for addr := range s.journal.dirties {
// As documented [here](https://github.com/scroll-tech/go-ethereum/pull/16485#issuecomment-380438527),
Expand Down Expand Up @@ -913,7 +930,33 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash {
// to pull useful data from disk.
for addr := range s.stateObjectsPending {
if obj := s.stateObjects[addr]; !obj.deleted {

// If witness building is enabled and the state object has a trie,
// gather the witnesses for its specific storage trie
if s.witness != nil && obj.trie != nil {
s.witness.AddState(obj.trie.Witness())
}

obj.updateRoot(s.db)

// If witness building is enabled and the state object has a trie,
// gather the witnesses for its specific storage trie
if s.witness != nil && obj.trie != nil {
s.witness.AddState(obj.trie.Witness())
}
}
omerfirmak marked this conversation as resolved.
Show resolved Hide resolved
}

if s.witness != nil {
// If witness building is enabled, gather the account trie witness for read-only operations
for _, obj := range s.stateObjects {
if len(obj.originStorage) == 0 {
continue
}

if trie := obj.getTrie(s.db); trie != nil {
s.witness.AddState(trie.Witness())
}
}
}
// Now we're about to start to write changes to the trie. The trie is so far
Expand Down Expand Up @@ -945,7 +988,13 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash {
if metrics.EnabledExpensive {
defer func(start time.Time) { s.AccountHashes += time.Since(start) }(time.Now())
}
return s.trie.Hash()

hash := s.trie.Hash()
// If witness building is enabled, gather the account trie witness
if s.witness != nil {
s.witness.AddState(s.trie.Witness())
}
return hash
}

// SetTxContext sets the current transaction hash and index which are
Expand Down
Loading
Loading