diff --git a/zk/stages/stage_batches.go b/zk/stages/stage_batches.go index 9a841475ed4..1af00e5f275 100644 --- a/zk/stages/stage_batches.go +++ b/zk/stages/stage_batches.go @@ -92,22 +92,42 @@ type DatastreamClient interface { GetStreamingAtomic() *atomic.Bool GetProgressAtomic() *atomic.Uint64 EnsureConnected() (bool, error) + Start() error + Stop() } +type dsClientCreatorHandler func(context.Context, *ethconfig.Zk, uint16) (DatastreamClient, error) + type BatchesCfg struct { - db kv.RwDB - blockRoutineStarted bool - dsClient DatastreamClient - zkCfg *ethconfig.Zk + db kv.RwDB + blockRoutineStarted bool + dsClient DatastreamClient + temporalDSClientCreator dsClientCreatorHandler + zkCfg *ethconfig.Zk } -func StageBatchesCfg(db kv.RwDB, dsClient DatastreamClient, zkCfg *ethconfig.Zk) BatchesCfg { - return BatchesCfg{ +func StageBatchesCfg(db kv.RwDB, dsClient DatastreamClient, zkCfg *ethconfig.Zk, options ...Option) BatchesCfg { + cfg := BatchesCfg{ db: db, blockRoutineStarted: false, dsClient: dsClient, zkCfg: zkCfg, } + + for _, opt := range options { + opt(&cfg) + } + + return cfg +} + +type Option func(*BatchesCfg) + +// WithDSClientCreator is a functional option to set the datastream client creator callback. +func WithDSClientCreator(handler dsClientCreatorHandler) Option { + return func(c *BatchesCfg) { + c.temporalDSClientCreator = handler + } } var emptyHash = common.Hash{0} @@ -171,7 +191,16 @@ func SpawnStageBatches( return err } - newDSClient, cleanup, err := newStreamClient(ctx, cfg.zkCfg, latestForkId) + var ( + tmpDSClient DatastreamClient + cleanup func() + ) + + if cfg.temporalDSClientCreator != nil { + tmpDSClient, err = cfg.temporalDSClientCreator(ctx, cfg.zkCfg, uint16(latestForkId)) + } else { + tmpDSClient, cleanup, err = newStreamClient(ctx, cfg.zkCfg, latestForkId) + } defer cleanup() if err != nil { log.Warn(fmt.Sprintf("[%s] Error when starting datastream client. Error: %s", logPrefix, err)) @@ -182,10 +211,11 @@ func SpawnStageBatches( var highestL2Block *types.FullL2Block for highestL2Block == nil { - highestL2Block, err = newDSClient.GetL2BlockByNumber(highestBlockInDS) + highestL2Block, err = tmpDSClient.GetL2BlockByNumber(highestBlockInDS) if err != nil { return err } + if highestL2Block == nil { highestBlockInDS-- } else { @@ -368,19 +398,29 @@ LOOP: return err } - newDSClient, cleanup, err := newStreamClient(ctx, cfg.zkCfg, latestForkId) - defer cleanup() + var ( + tmpDSClient DatastreamClient + cleanup func() + ) + + if cfg.temporalDSClientCreator != nil { + tmpDSClient, err = cfg.temporalDSClientCreator(ctx, cfg.zkCfg, uint16(latestForkId)) + } else { + tmpDSClient, cleanup, err = newStreamClient(ctx, cfg.zkCfg, latestForkId) + defer cleanup() + } + if err != nil { log.Warn(fmt.Sprintf("[%s] Error when starting datastream client... Error: %s", logPrefix, err)) return err } - ancestorBlockNum, _, err := findCommonAncestor(eriDb, hermezDb, newDSClient, l2Block.L2BlockNumber) + ancestorBlockNum, ancestorBlockHash, err := findCommonAncestor(eriDb, hermezDb, tmpDSClient, l2Block.L2BlockNumber) if err != nil { return err } - unwindBlockNum, unwindBlockHash, batchNum, err := resolveUnwindBlock(eriDb, hermezDb, ancestorBlockNum) + unwindBlockNum, unwindBlockHash, batchNum, err := resolveUnwindBlock(eriDb, hermezDb, ancestorBlockNum, ancestorBlockHash) if err != nil { return err } @@ -451,12 +491,12 @@ LOOP: // unwind/rollback blocks until the latest common ancestor block log.Warn(fmt.Sprintf("[%s] Parent block hashes mismatch on block %d. Triggering unwind...", logPrefix, l2Block.L2BlockNumber), "db parent block hash", dbParentBlockHash, "ds parent block hash", lastHash) - ancestorBlockNum, _, err := findCommonAncestor(eriDb, hermezDb, cfg.dsClient, l2Block.L2BlockNumber) + ancestorBlockNum, ancestorBlockHash, err := findCommonAncestor(eriDb, hermezDb, cfg.dsClient, l2Block.L2BlockNumber) if err != nil { return err } - unwindBlockNum, unwindBlockHash, batchNum, err := resolveUnwindBlock(eriDb, hermezDb, ancestorBlockNum) + unwindBlockNum, unwindBlockHash, batchNum, err := resolveUnwindBlock(eriDb, hermezDb, ancestorBlockNum, ancestorBlockHash) if err != nil { return err } @@ -1072,12 +1112,16 @@ func findCommonAncestor( } // resolveUnwindBlock resolves the unwind block as the latest block in the previous batch, relative to the found ancestor block. -func resolveUnwindBlock(eriDb erigon_db.ReadOnlyErigonDb, hermezDb state.ReadOnlyHermezDb, ancestorBlockNum uint64) (uint64, common.Hash, uint64, error) { +func resolveUnwindBlock(eriDb erigon_db.ReadOnlyErigonDb, hermezDb state.ReadOnlyHermezDb, ancestorBlockNum uint64, ancestorBlockHash common.Hash) (uint64, common.Hash, uint64, error) { batchNum, err := hermezDb.GetBatchNoByL2Block(ancestorBlockNum) if err != nil { return 0, emptyHash, 0, err } + if batchNum == 0 { + return ancestorBlockNum, ancestorBlockHash, batchNum, nil + } + unwindBlockNum, err := hermezDb.GetHighestBlockInBatch(batchNum - 1) if err != nil { return 0, emptyHash, 0, err diff --git a/zk/stages/stage_batches_test.go b/zk/stages/stage_batches_test.go index e37930849f5..b97d3352eb3 100644 --- a/zk/stages/stage_batches_test.go +++ b/zk/stages/stage_batches_test.go @@ -48,13 +48,21 @@ func TestUnwindBatches(t *testing.T) { require.NoError(t, err) dsClient := NewTestDatastreamClient(fullL2Blocks, gerUpdates) - cfg := StageBatchesCfg(db1, dsClient, ðconfig.Zk{}) + + tmpDSClientCreator := func(_ context.Context, _ *ethconfig.Zk, _ uint16) (DatastreamClient, error) { + return NewTestDatastreamClient(fullL2Blocks, gerUpdates), nil + } + cfg := StageBatchesCfg(db1, dsClient, ðconfig.Zk{}, WithDSClientCreator(tmpDSClientCreator)) s := &stagedsync.StageState{ID: stages.Batches, BlockNumber: 0} u := &stagedsync.Sync{} us := &stagedsync.UnwindState{ID: stages.Batches, UnwindPoint: 0, CurrentBlockNumber: uint64(currentBlockNumber)} err = stages.SaveStageProgress(tx, stages.L1VerificationsBatchNo, 20) require.NoError(t, err) + err = stages.SaveStageProgress(tx, stages.Batches, 5) + require.NoError(t, err) + err = stages.SaveStageProgress(tx, stages.Execution, uint64(currentBlockNumber)) + require.NoError(t, err) // get bucket sizes pre inserts bucketSized := make(map[string]uint64) diff --git a/zk/stages/test_utils.go b/zk/stages/test_utils.go index 10d6a673b14..0f242844eb5 100644 --- a/zk/stages/test_utils.go +++ b/zk/stages/test_utils.go @@ -18,6 +18,7 @@ type TestDatastreamClient struct { errChan chan error batchStartChan chan types.BatchStart batchEndChan chan types.BatchEnd + isStarted bool } func NewTestDatastreamClient(fullL2Blocks []types.FullL2Block, gerUpdates []types.GerUpdate) *TestDatastreamClient { @@ -95,3 +96,12 @@ func (c *TestDatastreamClient) GetStreamingAtomic() *atomic.Bool { func (c *TestDatastreamClient) GetProgressAtomic() *atomic.Uint64 { return &c.progress } + +func (c *TestDatastreamClient) Start() error { + c.isStarted = true + return nil +} + +func (c *TestDatastreamClient) Stop() { + c.isStarted = false +}