Skip to content

Commit

Permalink
Provide a way to create a custom temporal data stream client in the t…
Browse files Browse the repository at this point in the history
…ests
  • Loading branch information
Stefan-Ethernal committed Aug 8, 2024
1 parent 4508caf commit c470836
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 16 deletions.
74 changes: 59 additions & 15 deletions zk/stages/stage_batches.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,42 @@ type DatastreamClient interface {
GetStreamingAtomic() *atomic.Bool
GetProgressAtomic() *atomic.Uint64
EnsureConnected() (bool, error)
Start() error
Stop()
}

type dsClientCreatorHandler func(context.Context, *ethconfig.Zk, uint16) (DatastreamClient, error)

type BatchesCfg struct {
db kv.RwDB
blockRoutineStarted bool
dsClient DatastreamClient
zkCfg *ethconfig.Zk
db kv.RwDB
blockRoutineStarted bool
dsClient DatastreamClient
temporalDSClientCreator dsClientCreatorHandler
zkCfg *ethconfig.Zk
}

func StageBatchesCfg(db kv.RwDB, dsClient DatastreamClient, zkCfg *ethconfig.Zk) BatchesCfg {
return BatchesCfg{
func StageBatchesCfg(db kv.RwDB, dsClient DatastreamClient, zkCfg *ethconfig.Zk, options ...Option) BatchesCfg {
cfg := BatchesCfg{
db: db,
blockRoutineStarted: false,
dsClient: dsClient,
zkCfg: zkCfg,
}

for _, opt := range options {
opt(&cfg)
}

return cfg
}

type Option func(*BatchesCfg)

// WithDSClientCreator is a functional option to set the datastream client creator callback.
func WithDSClientCreator(handler dsClientCreatorHandler) Option {
return func(c *BatchesCfg) {
c.temporalDSClientCreator = handler
}
}

var emptyHash = common.Hash{0}
Expand Down Expand Up @@ -171,7 +191,16 @@ func SpawnStageBatches(
return err
}

newDSClient, cleanup, err := newStreamClient(ctx, cfg.zkCfg, latestForkId)
var (
tmpDSClient DatastreamClient
cleanup func()
)

if cfg.temporalDSClientCreator != nil {
tmpDSClient, err = cfg.temporalDSClientCreator(ctx, cfg.zkCfg, uint16(latestForkId))
} else {
tmpDSClient, cleanup, err = newStreamClient(ctx, cfg.zkCfg, latestForkId)
}
defer cleanup()
if err != nil {
log.Warn(fmt.Sprintf("[%s] Error when starting datastream client. Error: %s", logPrefix, err))
Expand All @@ -182,10 +211,11 @@ func SpawnStageBatches(
var highestL2Block *types.FullL2Block

for highestL2Block == nil {
highestL2Block, err = newDSClient.GetL2BlockByNumber(highestBlockInDS)
highestL2Block, err = tmpDSClient.GetL2BlockByNumber(highestBlockInDS)
if err != nil {
return err
}

if highestL2Block == nil {
highestBlockInDS--
} else {
Expand Down Expand Up @@ -368,19 +398,29 @@ LOOP:
return err
}

newDSClient, cleanup, err := newStreamClient(ctx, cfg.zkCfg, latestForkId)
defer cleanup()
var (
tmpDSClient DatastreamClient
cleanup func()
)

if cfg.temporalDSClientCreator != nil {
tmpDSClient, err = cfg.temporalDSClientCreator(ctx, cfg.zkCfg, uint16(latestForkId))
} else {
tmpDSClient, cleanup, err = newStreamClient(ctx, cfg.zkCfg, latestForkId)
defer cleanup()
}

if err != nil {
log.Warn(fmt.Sprintf("[%s] Error when starting datastream client... Error: %s", logPrefix, err))
return err
}

ancestorBlockNum, _, err := findCommonAncestor(eriDb, hermezDb, newDSClient, l2Block.L2BlockNumber)
ancestorBlockNum, ancestorBlockHash, err := findCommonAncestor(eriDb, hermezDb, tmpDSClient, l2Block.L2BlockNumber)
if err != nil {
return err
}

unwindBlockNum, unwindBlockHash, batchNum, err := resolveUnwindBlock(eriDb, hermezDb, ancestorBlockNum)
unwindBlockNum, unwindBlockHash, batchNum, err := resolveUnwindBlock(eriDb, hermezDb, ancestorBlockNum, ancestorBlockHash)
if err != nil {
return err
}
Expand Down Expand Up @@ -451,12 +491,12 @@ LOOP:
// unwind/rollback blocks until the latest common ancestor block
log.Warn(fmt.Sprintf("[%s] Parent block hashes mismatch on block %d. Triggering unwind...", logPrefix, l2Block.L2BlockNumber),
"db parent block hash", dbParentBlockHash, "ds parent block hash", lastHash)
ancestorBlockNum, _, err := findCommonAncestor(eriDb, hermezDb, cfg.dsClient, l2Block.L2BlockNumber)
ancestorBlockNum, ancestorBlockHash, err := findCommonAncestor(eriDb, hermezDb, cfg.dsClient, l2Block.L2BlockNumber)
if err != nil {
return err
}

unwindBlockNum, unwindBlockHash, batchNum, err := resolveUnwindBlock(eriDb, hermezDb, ancestorBlockNum)
unwindBlockNum, unwindBlockHash, batchNum, err := resolveUnwindBlock(eriDb, hermezDb, ancestorBlockNum, ancestorBlockHash)
if err != nil {
return err
}
Expand Down Expand Up @@ -1072,12 +1112,16 @@ func findCommonAncestor(
}

// resolveUnwindBlock resolves the unwind block as the latest block in the previous batch, relative to the found ancestor block.
func resolveUnwindBlock(eriDb erigon_db.ReadOnlyErigonDb, hermezDb state.ReadOnlyHermezDb, ancestorBlockNum uint64) (uint64, common.Hash, uint64, error) {
func resolveUnwindBlock(eriDb erigon_db.ReadOnlyErigonDb, hermezDb state.ReadOnlyHermezDb, ancestorBlockNum uint64, ancestorBlockHash common.Hash) (uint64, common.Hash, uint64, error) {
batchNum, err := hermezDb.GetBatchNoByL2Block(ancestorBlockNum)
if err != nil {
return 0, emptyHash, 0, err
}

if batchNum == 0 {
return ancestorBlockNum, ancestorBlockHash, batchNum, nil
}

unwindBlockNum, err := hermezDb.GetHighestBlockInBatch(batchNum - 1)
if err != nil {
return 0, emptyHash, 0, err
Expand Down
10 changes: 9 additions & 1 deletion zk/stages/stage_batches_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,21 @@ func TestUnwindBatches(t *testing.T) {
require.NoError(t, err)

dsClient := NewTestDatastreamClient(fullL2Blocks, gerUpdates)
cfg := StageBatchesCfg(db1, dsClient, &ethconfig.Zk{})

tmpDSClientCreator := func(_ context.Context, _ *ethconfig.Zk, _ uint16) (DatastreamClient, error) {
return NewTestDatastreamClient(fullL2Blocks, gerUpdates), nil
}
cfg := StageBatchesCfg(db1, dsClient, &ethconfig.Zk{}, WithDSClientCreator(tmpDSClientCreator))

s := &stagedsync.StageState{ID: stages.Batches, BlockNumber: 0}
u := &stagedsync.Sync{}
us := &stagedsync.UnwindState{ID: stages.Batches, UnwindPoint: 0, CurrentBlockNumber: uint64(currentBlockNumber)}
err = stages.SaveStageProgress(tx, stages.L1VerificationsBatchNo, 20)
require.NoError(t, err)
err = stages.SaveStageProgress(tx, stages.Batches, 5)
require.NoError(t, err)
err = stages.SaveStageProgress(tx, stages.Execution, uint64(currentBlockNumber))
require.NoError(t, err)

// get bucket sizes pre inserts
bucketSized := make(map[string]uint64)
Expand Down
10 changes: 10 additions & 0 deletions zk/stages/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type TestDatastreamClient struct {
errChan chan error
batchStartChan chan types.BatchStart
batchEndChan chan types.BatchEnd
isStarted bool
}

func NewTestDatastreamClient(fullL2Blocks []types.FullL2Block, gerUpdates []types.GerUpdate) *TestDatastreamClient {
Expand Down Expand Up @@ -95,3 +96,12 @@ func (c *TestDatastreamClient) GetStreamingAtomic() *atomic.Bool {
func (c *TestDatastreamClient) GetProgressAtomic() *atomic.Uint64 {
return &c.progress
}

func (c *TestDatastreamClient) Start() error {
c.isStarted = true
return nil
}

func (c *TestDatastreamClient) Stop() {
c.isStarted = false
}

0 comments on commit c470836

Please sign in to comment.