diff --git a/consensus/consortium/v2/snapshot.go b/consensus/consortium/v2/snapshot.go index 322f42de5..610233d2d 100644 --- a/consensus/consortium/v2/snapshot.go +++ b/consensus/consortium/v2/snapshot.go @@ -16,6 +16,7 @@ import ( blsCommon "github.com/ethereum/go-ethereum/crypto/bls/common" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/internal/ethapi" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" "github.com/hashicorp/golang-lru/arc/v2" ) @@ -243,7 +244,7 @@ func (s *Snapshot) apply(headers []*types.Header, chain consensus.ChainHeaderRea // Change the validator set base on the size of the validators set if number > 0 && number%s.config.EpochV2 == uint64(len(snap.validators())/2) { // Get the most recent checkpoint header - checkpointHeader := FindAncientHeader(header, uint64(len(snap.validators())/2), chain, parents) + checkpointHeader := findAncestorHeader(header, number-uint64(len(snap.validators())/2), chain, parents) if checkpointHeader == nil { return nil, consensus.ErrUnknownAncestor } @@ -420,36 +421,58 @@ func (s *Snapshot) IsRecentlySigned(validator common.Address) bool { return false } -// FindAncientHeader finds the most recent checkpoint header -// Travel through the candidateParents to find the ancient header. -// If all headers in candidateParents have the number is larger than the header number, -// the search function will return the index, but it is not valid if we check with the -// header since the number and hash is not equals. The candidateParents is -// only available when it downloads blocks from the network. -// Otherwise, the candidateParents is nil, and it will be found by header hash and number. -func FindAncientHeader(header *types.Header, ite uint64, chain consensus.ChainHeaderReader, candidateParents []*types.Header) *types.Header { - ancient := header - for i := uint64(1); i <= ite; i++ { - parentHash := ancient.ParentHash - parentHeight := ancient.Number.Uint64() - 1 - found := false - if len(candidateParents) > 0 { - index := sort.Search(len(candidateParents), func(i int) bool { - return candidateParents[i].Number.Uint64() >= parentHeight - }) - if index < len(candidateParents) && candidateParents[index].Number.Uint64() == parentHeight && - candidateParents[index].Hash() == parentHash { - ancient = candidateParents[index] - found = true - } - } - if !found { - ancient = chain.GetHeader(parentHash, parentHeight) - found = true +// findAncestorHeader traverses back to look for the requested ancestor header +// in parents list or in chaindata +// +// parents are guaranteed to be ordered and linked by the check when InsertChain +// +// There are 2 possible cases: +// Case 1: ancestor header is in parents list +// <- parents -> +// [ ancestorHeader ] +// +// Case 2: ancestor header's height is lower than parents list +// <- parents -> +// ancestorHeader ... [ ] + +func findAncestorHeader( + currentHeader *types.Header, + ancestorBlockNumber uint64, + chain consensus.ChainHeaderReader, + parents []*types.Header, +) *types.Header { + // Find the first header in parents list that is higher or equal to checkpoint block + index := sort.Search(len(parents), func(i int) bool { + return parents[i].Number.Uint64() >= ancestorBlockNumber + }) + + // This must not happen, checkpoint header's height cannot be higher the parents list + if len(parents) != 0 && index >= len(parents) { + log.Warn( + "Checkpoint header's height is higher than parents list", + "checkpointNumber", ancestorBlockNumber, + "last parent", parents[len(parents)-1].Number, + ) + return nil + } + + if len(parents) != 0 && parents[index].Number.Uint64() == ancestorBlockNumber { + // Case 1: checkpoint header is in parents list + return parents[index] + } else { + // Case 2: checkpoint header's height is lower than parents list + var headerIterator *types.Header + if len(parents) != 0 { + headerIterator = parents[0] + } else { + headerIterator = currentHeader } - if ancient == nil || !found { - return nil + for headerIterator.Number.Uint64() != ancestorBlockNumber { + headerIterator = chain.GetHeader(headerIterator.ParentHash, headerIterator.Number.Uint64()-1) + if headerIterator == nil { + return nil + } } + return headerIterator } - return ancient } diff --git a/consensus/consortium/v2/snapshot_test.go b/consensus/consortium/v2/snapshot_test.go new file mode 100644 index 000000000..0732fa3e7 --- /dev/null +++ b/consensus/consortium/v2/snapshot_test.go @@ -0,0 +1,106 @@ +package v2 + +import ( + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/params" +) + +type mockChainReader struct { + headerMapping map[common.Hash]*types.Header +} + +func (chainReader *mockChainReader) Config() *params.ChainConfig { return nil } +func (chainReader *mockChainReader) CurrentHeader() *types.Header { return nil } +func (chainReader *mockChainReader) GetHeader(hash common.Hash, number uint64) *types.Header { + return chainReader.headerMapping[hash] +} +func (chainReader *mockChainReader) GetHeaderByNumber(number uint64) *types.Header { return nil } +func (chainReader *mockChainReader) GetHeaderByHash(hash common.Hash) *types.Header { return nil } +func (chainReader *mockChainReader) DB() ethdb.Database { return nil } +func (chainReader *mockChainReader) StateCache() state.Database { return nil } +func (chainReader *mockChainReader) OpEvents() []*vm.PublishEvent { return nil } + +func TestFindCheckpointHeader(t *testing.T) { + // Case 1: checkpoint header is at block 5 (in parent list) + // parent list ranges from [0, 10) + parents := make([]*types.Header, 10) + for i := range parents { + parents[i] = &types.Header{Number: big.NewInt(int64(i)), Coinbase: common.BigToAddress(big.NewInt(int64(i)))} + } + + currentHeader := &types.Header{Number: big.NewInt(10)} + checkpointHeader := findAncestorHeader(currentHeader, 5, nil, parents) + if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.Coinbase != common.BigToAddress(big.NewInt(5)) { + t.Fatalf("Expect checkpoint header number: %d, got: %d", 5, checkpointHeader.Number.Int64()) + } + + // Case 2: checkpoint header is at 5 (lower than parent list) + // parent list ranges from [10, 20) + for i := range parents { + parents[i] = &types.Header{Number: big.NewInt(int64(i + 10)), ParentHash: common.BigToHash(big.NewInt(int64(i + 10 - 1)))} + } + mockChain := mockChainReader{ + headerMapping: make(map[common.Hash]*types.Header), + } + // create mock chain 1 + for i := 5; i < 10; i++ { + mockChain.headerMapping[common.BigToHash(big.NewInt(int64(100+i)))] = &types.Header{ + Number: big.NewInt(int64(i)), + ParentHash: common.BigToHash(big.NewInt(int64(100 + i - 1))), + } + } + + // create mock chain 2 + for i := 5; i < 10; i++ { + mockChain.headerMapping[common.BigToHash(big.NewInt(int64(i)))] = &types.Header{ + Number: big.NewInt(int64(i)), + ParentHash: common.BigToHash(big.NewInt(int64(i - 1))), + } + } + + currentHeader = &types.Header{ParentHash: common.BigToHash(big.NewInt(19)), Number: big.NewInt(20)} + // Must traverse and get the correct header in chain 2 + checkpointHeader = findAncestorHeader(currentHeader, 5, &mockChain, parents) + if checkpointHeader == nil { + t.Fatal("Failed to find checkpoint header") + } + if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.ParentHash != common.BigToHash(big.NewInt(int64(4))) { + t.Fatalf("Expect checkpoint header number %d, parent hash: %s, got number: %d, parent hash: %s", + 5, common.BigToHash(big.NewInt(int64(4))), + checkpointHeader.Number.Int64(), checkpointHeader.ParentHash, + ) + } + + // Case 3: find checkpoint header with nil parent list + currentHeader = &types.Header{Number: big.NewInt(10), ParentHash: common.BigToHash(big.NewInt(109))} + checkpointHeader = findAncestorHeader(currentHeader, 5, &mockChain, nil) + // Must traverse and get the correct header in chain 1 + if checkpointHeader == nil { + t.Fatal("Failed to find checkpoint header") + } + if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.ParentHash != common.BigToHash(big.NewInt(int64(104))) { + t.Fatalf("Expect checkpoint header number %d, parent hash: %s, got number: %d, parent hash: %s", + 5, common.BigToHash(big.NewInt(int64(104))), + checkpointHeader.Number.Int64(), checkpointHeader.ParentHash, + ) + } + + // Case 4: checkpoint header is higher than parent list, this must not happen + // but the function must not crash in this case + // parent list ranges from [0, 10) + parents = make([]*types.Header, 10) + for i := range parents { + parents[i] = &types.Header{Number: big.NewInt(int64(i)), Coinbase: common.BigToAddress(big.NewInt(int64(i)))} + } + checkpointHeader = findAncestorHeader(nil, 10, nil, parents) + if checkpointHeader != nil { + t.Fatalf("Expect %v checkpoint header, got %v", nil, checkpointHeader) + } +}