From 801762c73b078bbeada30733bb1753f0d8be7613 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Tue, 29 Oct 2024 13:10:27 -0400 Subject: [PATCH] Split `writeCurrentStakers` into multiple functions (#3500) --- vms/platformvm/state/stakers.go | 29 +++ vms/platformvm/state/state.go | 375 +++++++++++++++-------------- vms/platformvm/state/state_test.go | 150 ++++++++---- 3 files changed, 324 insertions(+), 230 deletions(-) diff --git a/vms/platformvm/state/stakers.go b/vms/platformvm/state/stakers.go index 658796855958..14e4dcf7b1ef 100644 --- a/vms/platformvm/state/stakers.go +++ b/vms/platformvm/state/stakers.go @@ -5,6 +5,7 @@ package state import ( "errors" + "fmt" "github.com/google/btree" @@ -273,6 +274,34 @@ type diffValidator struct { deletedDelegators map[ids.ID]*Staker } +func (d *diffValidator) WeightDiff() (ValidatorWeightDiff, error) { + weightDiff := ValidatorWeightDiff{ + Decrease: d.validatorStatus == deleted, + } + if d.validatorStatus != unmodified { + weightDiff.Amount = d.validator.Weight + } + + for _, staker := range d.deletedDelegators { + if err := weightDiff.Add(true, staker.Weight); err != nil { + return ValidatorWeightDiff{}, fmt.Errorf("failed to decrease node weight diff: %w", err) + } + } + + addedDelegatorIterator := iterator.FromTree(d.addedDelegators) + defer addedDelegatorIterator.Release() + + for addedDelegatorIterator.Next() { + staker := addedDelegatorIterator.Value() + + if err := weightDiff.Add(false, staker.Weight); err != nil { + return ValidatorWeightDiff{}, fmt.Errorf("failed to increase node weight diff: %w", err) + } + } + + return weightDiff, nil +} + // GetValidator attempts to fetch the validator with the given subnetID and // nodeID. // Invariant: Assumes that the validator will never be removed and then added. diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index 928d069476f3..be6bee28cf0f 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -4,6 +4,7 @@ package state import ( + "bytes" "context" "errors" "fmt" @@ -14,6 +15,7 @@ import ( "github.com/google/btree" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" + "golang.org/x/exp/maps" "github.com/ava-labs/avalanchego/cache" "github.com/ava-labs/avalanchego/cache/metercacher" @@ -1801,7 +1803,9 @@ func (s *state) write(updateValidators bool, height uint64) error { return errors.Join( s.writeBlocks(), s.writeExpiry(), - s.writeCurrentStakers(updateValidators, height, codecVersion), + s.updateValidatorManager(updateValidators), + s.writeValidatorDiffs(height), + s.writeCurrentStakers(codecVersion), s.writePendingStakers(), s.WriteValidatorMetadata(s.currentValidatorList, s.currentSubnetValidatorList, codecVersion), // Must be called after writeCurrentStakers s.writeTXs(), @@ -2049,47 +2053,75 @@ func (s *state) writeExpiry() error { return nil } -func (s *state) writeCurrentStakers(updateValidators bool, height uint64, codecVersion uint16) error { +// getInheritedPublicKey returns the primary network validator's public key. +// +// Note: This function may return a nil public key and no error if the primary +// network validator does not have a public key. +func (s *state) getInheritedPublicKey(nodeID ids.NodeID) (*bls.PublicKey, error) { + if vdr, ok := s.currentStakers.validators[constants.PrimaryNetworkID][nodeID]; ok && vdr.validator != nil { + // The primary network validator is present. + return vdr.validator.PublicKey, nil + } + if vdr, ok := s.currentStakers.validatorDiffs[constants.PrimaryNetworkID][nodeID]; ok && vdr.validator != nil { + // The primary network validator is being modified. + return vdr.validator.PublicKey, nil + } + return nil, fmt.Errorf("%w: %s", errMissingPrimaryNetworkValidator, nodeID) +} + +// updateValidatorManager updates the validator manager with the pending +// validator set changes. +// +// This function must be called prior to writeCurrentStakers. +func (s *state) updateValidatorManager(updateValidators bool) error { + if !updateValidators { + return nil + } + for subnetID, validatorDiffs := range s.currentStakers.validatorDiffs { - // We must write the primary network stakers last because writing subnet - // validator diffs may depend on the primary network validator diffs to - // inherit the public keys. - if subnetID == constants.PrimaryNetworkID { - continue - } + // Record the change in weight and/or public key for each validator. + for nodeID, diff := range validatorDiffs { + weightDiff, err := diff.WeightDiff() + if err != nil { + return err + } - delete(s.currentStakers.validatorDiffs, subnetID) + if weightDiff.Amount == 0 { + continue // No weight change; go to the next validator. + } - err := s.writeCurrentStakersSubnetDiff( - subnetID, - validatorDiffs, - updateValidators, - height, - codecVersion, - ) - if err != nil { - return err - } - } + if weightDiff.Decrease { + if err := s.validators.RemoveWeight(subnetID, nodeID, weightDiff.Amount); err != nil { + return fmt.Errorf("failed to reduce validator weight: %w", err) + } + continue + } - if validatorDiffs, ok := s.currentStakers.validatorDiffs[constants.PrimaryNetworkID]; ok { - delete(s.currentStakers.validatorDiffs, constants.PrimaryNetworkID) + if diff.validatorStatus != added { + if err := s.validators.AddWeight(subnetID, nodeID, weightDiff.Amount); err != nil { + return fmt.Errorf("failed to increase validator weight: %w", err) + } + continue + } - err := s.writeCurrentStakersSubnetDiff( - constants.PrimaryNetworkID, - validatorDiffs, - updateValidators, - height, - codecVersion, - ) - if err != nil { - return err - } - } + pk, err := s.getInheritedPublicKey(nodeID) + if err != nil { + // This should never happen as there should always be a primary + // network validator corresponding to a subnet validator. + return err + } - // TODO: Move validator set management out of the state package - if !updateValidators { - return nil + err = s.validators.AddStaker( + subnetID, + nodeID, + pk, + diff.validator.TxID, + weightDiff.Amount, + ) + if err != nil { + return fmt.Errorf("failed to add validator: %w", err) + } + } } // Update the stake metrics @@ -2103,185 +2135,168 @@ func (s *state) writeCurrentStakers(updateValidators bool, height uint64, codecV return nil } -func (s *state) writeCurrentStakersSubnetDiff( - subnetID ids.ID, - validatorDiffs map[ids.NodeID]*diffValidator, - updateValidators bool, - height uint64, - codecVersion uint16, -) error { - // Select db to write to - validatorDB := s.currentSubnetValidatorList - delegatorDB := s.currentSubnetDelegatorList - if subnetID == constants.PrimaryNetworkID { - validatorDB = s.currentValidatorList - delegatorDB = s.currentDelegatorList - } +type validatorDiff struct { + weightDiff ValidatorWeightDiff + prevPublicKey []byte + newPublicKey []byte +} - // Record the change in weight and/or public key for each validator. - for nodeID, validatorDiff := range validatorDiffs { - var ( - staker *Staker - pk *bls.PublicKey - weightDiff = &ValidatorWeightDiff{ - Decrease: validatorDiff.validatorStatus == deleted, - } - ) - if validatorDiff.validatorStatus != unmodified { - staker = validatorDiff.validator - - pk = staker.PublicKey - // For non-primary network validators, the public key is inherited - // from the primary network. - if subnetID != constants.PrimaryNetworkID { - if vdr, ok := s.currentStakers.validators[constants.PrimaryNetworkID][nodeID]; ok && vdr.validator != nil { - // The primary network validator is still present after - // writing. - pk = vdr.validator.PublicKey - } else if vdr, ok := s.currentStakers.validatorDiffs[constants.PrimaryNetworkID][nodeID]; ok && vdr.validator != nil { - // The primary network validator is being removed during - // writing. - pk = vdr.validator.PublicKey - } else { - // This should never happen as the primary network diffs are - // written last and subnet validator times must be a subset - // of the primary network validator times. - return fmt.Errorf("%w: %s", errMissingPrimaryNetworkValidator, nodeID) - } +// calculateValidatorDiffs calculates the validator set diff contained by the +// pending validator set changes. +// +// This function must be called prior to writeCurrentStakers. +func (s *state) calculateValidatorDiffs() (map[subnetIDNodeID]*validatorDiff, error) { + changes := make(map[subnetIDNodeID]*validatorDiff) + + // Calculate the changes to the pre-ACP-77 validator set + for subnetID, subnetDiffs := range s.currentStakers.validatorDiffs { + for nodeID, diff := range subnetDiffs { + weightDiff, err := diff.WeightDiff() + if err != nil { + return nil, err } - weightDiff.Amount = staker.Weight - } + pk, err := s.getInheritedPublicKey(nodeID) + if err != nil { + // This should never happen as there should always be a primary + // network validator corresponding to a subnet validator. + return nil, err + } - switch validatorDiff.validatorStatus { - case added: + change := &validatorDiff{ + weightDiff: weightDiff, + } if pk != nil { - // Record that the public key for the validator is being added. - // This means the prior value for the public key was nil. - err := s.validatorPublicKeyDiffsDB.Put( - marshalDiffKey(subnetID, height, nodeID), - nil, - ) - if err != nil { - return err + pkBytes := bls.PublicKeyToUncompressedBytes(pk) + if diff.validatorStatus != added { + change.prevPublicKey = pkBytes + } + if diff.validatorStatus != deleted { + change.newPublicKey = pkBytes } } - // The validator is being added. - // - // Invariant: It's impossible for a delegator to have been rewarded - // in the same block that the validator was added. - startTime := uint64(staker.StartTime.Unix()) - metadata := &validatorMetadata{ - txID: staker.TxID, - lastUpdated: staker.StartTime, - - UpDuration: 0, - LastUpdated: startTime, - StakerStartTime: startTime, - PotentialReward: staker.PotentialReward, - PotentialDelegateeReward: 0, + subnetIDNodeID := subnetIDNodeID{ + subnetID: subnetID, + nodeID: nodeID, } + changes[subnetIDNodeID] = change + } + } - metadataBytes, err := MetadataCodec.Marshal(codecVersion, metadata) - if err != nil { - return fmt.Errorf("failed to serialize current validator: %w", err) - } + return changes, nil +} - if err = validatorDB.Put(staker.TxID[:], metadataBytes); err != nil { - return fmt.Errorf("failed to write current validator to list: %w", err) - } +// writeValidatorDiffs writes the validator set diff contained by the pending +// validator set changes to disk. +// +// This function must be called prior to writeCurrentStakers. +func (s *state) writeValidatorDiffs(height uint64) error { + changes, err := s.calculateValidatorDiffs() + if err != nil { + return err + } - s.validatorState.LoadValidatorMetadata(nodeID, subnetID, metadata) - case deleted: - if pk != nil { - // Record that the public key for the validator is being - // removed. This means we must record the prior value of the - // public key. - // - // Note: We store the uncompressed public key here as it is - // significantly more efficient to parse when applying diffs. - err := s.validatorPublicKeyDiffsDB.Put( - marshalDiffKey(subnetID, height, nodeID), - bls.PublicKeyToUncompressedBytes(pk), - ) - if err != nil { - return err - } + // Write the changes to the database + for subnetIDNodeID, diff := range changes { + diffKey := marshalDiffKey(subnetIDNodeID.subnetID, height, subnetIDNodeID.nodeID) + if diff.weightDiff.Amount != 0 { + err := s.validatorWeightDiffsDB.Put( + diffKey, + marshalWeightDiff(&diff.weightDiff), + ) + if err != nil { + return err } - - if err := validatorDB.Delete(staker.TxID[:]); err != nil { - return fmt.Errorf("failed to delete current staker: %w", err) + } + if !bytes.Equal(diff.prevPublicKey, diff.newPublicKey) { + err := s.validatorPublicKeyDiffsDB.Put( + diffKey, + diff.prevPublicKey, + ) + if err != nil { + return err } - - s.validatorState.DeleteValidatorMetadata(nodeID, subnetID) } + } + return nil +} - err := writeCurrentDelegatorDiff( - delegatorDB, - weightDiff, - validatorDiff, - codecVersion, - ) - if err != nil { - return err +func (s *state) writeCurrentStakers(codecVersion uint16) error { + for subnetID, validatorDiffs := range s.currentStakers.validatorDiffs { + // Select db to write to + validatorDB := s.currentSubnetValidatorList + delegatorDB := s.currentSubnetDelegatorList + if subnetID == constants.PrimaryNetworkID { + validatorDB = s.currentValidatorList + delegatorDB = s.currentDelegatorList } - if weightDiff.Amount == 0 { - // No weight change to record; go to next validator. - continue - } + // Record the change in weight and/or public key for each validator. + for nodeID, validatorDiff := range validatorDiffs { + switch validatorDiff.validatorStatus { + case added: + staker := validatorDiff.validator - err = s.validatorWeightDiffsDB.Put( - marshalDiffKey(subnetID, height, nodeID), - marshalWeightDiff(weightDiff), - ) - if err != nil { - return err - } + // The validator is being added. + // + // Invariant: It's impossible for a delegator to have been rewarded + // in the same block that the validator was added. + startTime := uint64(staker.StartTime.Unix()) + metadata := &validatorMetadata{ + txID: staker.TxID, + lastUpdated: staker.StartTime, + + UpDuration: 0, + LastUpdated: startTime, + StakerStartTime: startTime, + PotentialReward: staker.PotentialReward, + PotentialDelegateeReward: 0, + } - // TODO: Move the validator set management out of the state package - if !updateValidators { - continue - } + metadataBytes, err := MetadataCodec.Marshal(codecVersion, metadata) + if err != nil { + return fmt.Errorf("failed to serialize current validator: %w", err) + } - if weightDiff.Decrease { - err = s.validators.RemoveWeight(subnetID, nodeID, weightDiff.Amount) - } else { - if validatorDiff.validatorStatus == added { - err = s.validators.AddStaker( - subnetID, - nodeID, - pk, - staker.TxID, - weightDiff.Amount, - ) - } else { - err = s.validators.AddWeight(subnetID, nodeID, weightDiff.Amount) + if err = validatorDB.Put(staker.TxID[:], metadataBytes); err != nil { + return fmt.Errorf("failed to write current validator to list: %w", err) + } + + s.validatorState.LoadValidatorMetadata(nodeID, subnetID, metadata) + case deleted: + if err := validatorDB.Delete(validatorDiff.validator.TxID[:]); err != nil { + return fmt.Errorf("failed to delete current staker: %w", err) + } + + s.validatorState.DeleteValidatorMetadata(nodeID, subnetID) + } + + err := writeCurrentDelegatorDiff( + delegatorDB, + validatorDiff, + codecVersion, + ) + if err != nil { + return err } - } - if err != nil { - return fmt.Errorf("failed to update validator weight: %w", err) } } + maps.Clear(s.currentStakers.validatorDiffs) return nil } func writeCurrentDelegatorDiff( currentDelegatorList linkeddb.LinkedDB, - weightDiff *ValidatorWeightDiff, validatorDiff *diffValidator, codecVersion uint16, ) error { addedDelegatorIterator := iterator.FromTree(validatorDiff.addedDelegators) defer addedDelegatorIterator.Release() + for addedDelegatorIterator.Next() { staker := addedDelegatorIterator.Value() - if err := weightDiff.Add(false, staker.Weight); err != nil { - return fmt.Errorf("failed to increase node weight diff: %w", err) - } - metadata := &delegatorMetadata{ txID: staker.TxID, PotentialReward: staker.PotentialReward, @@ -2293,10 +2308,6 @@ func writeCurrentDelegatorDiff( } for _, staker := range validatorDiff.deletedDelegators { - if err := weightDiff.Add(true, staker.Weight); err != nil { - return fmt.Errorf("failed to decrease node weight diff: %w", err) - } - if err := currentDelegatorList.Delete(staker.TxID[:]); err != nil { return fmt.Errorf("failed to delete current staker: %w", err) } diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index 85df176ce42a..143f673b4e0f 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -4,6 +4,7 @@ package state import ( + "bytes" "context" "math" "math/rand" @@ -27,7 +28,6 @@ import ( "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/iterator" "github.com/ava-labs/avalanchego/utils/logging" - "github.com/ava-labs/avalanchego/utils/maybe" "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/utils/units" "github.com/ava-labs/avalanchego/utils/wrappers" @@ -223,8 +223,7 @@ func TestState_writeStakers(t *testing.T) { expectedValidatorSetOutput *validators.GetValidatorOutput // Check whether weight/bls keys diffs are duly stored - expectedWeightDiff *ValidatorWeightDiff - expectedPublicKeyDiff maybe.Maybe[*bls.PublicKey] + expectedValidatorDiffs map[subnetIDNodeID]*validatorDiff }{ "add current primary network validator": { staker: primaryNetworkCurrentValidatorStaker, @@ -235,11 +234,19 @@ func TestState_writeStakers(t *testing.T) { PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, Weight: primaryNetworkCurrentValidatorStaker.Weight, }, - expectedWeightDiff: &ValidatorWeightDiff{ - Decrease: false, - Amount: primaryNetworkCurrentValidatorStaker.Weight, + expectedValidatorDiffs: map[subnetIDNodeID]*validatorDiff{ + { + subnetID: constants.PrimaryNetworkID, + nodeID: primaryNetworkCurrentValidatorStaker.NodeID, + }: { + weightDiff: ValidatorWeightDiff{ + Decrease: false, + Amount: primaryNetworkCurrentValidatorStaker.Weight, + }, + prevPublicKey: nil, + newPublicKey: bls.PublicKeyToUncompressedBytes(primaryNetworkCurrentValidatorStaker.PublicKey), + }, }, - expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](nil), }, "add current primary network delegator": { initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker}, @@ -253,15 +260,25 @@ func TestState_writeStakers(t *testing.T) { PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, Weight: primaryNetworkCurrentValidatorStaker.Weight + primaryNetworkCurrentDelegatorStaker.Weight, }, - expectedWeightDiff: &ValidatorWeightDiff{ - Decrease: false, - Amount: primaryNetworkCurrentDelegatorStaker.Weight, + expectedValidatorDiffs: map[subnetIDNodeID]*validatorDiff{ + { + subnetID: constants.PrimaryNetworkID, + nodeID: primaryNetworkCurrentValidatorStaker.NodeID, + }: { + weightDiff: ValidatorWeightDiff{ + Decrease: false, + Amount: primaryNetworkCurrentDelegatorStaker.Weight, + }, + prevPublicKey: bls.PublicKeyToUncompressedBytes(primaryNetworkCurrentValidatorStaker.PublicKey), + newPublicKey: bls.PublicKeyToUncompressedBytes(primaryNetworkCurrentValidatorStaker.PublicKey), + }, }, }, "add pending primary network validator": { staker: primaryNetworkPendingValidatorStaker, addStakerTx: addPrimaryNetworkValidator, expectedPendingValidator: primaryNetworkPendingValidatorStaker, + expectedValidatorDiffs: map[subnetIDNodeID]*validatorDiff{}, }, "add pending primary network delegator": { initialStakers: []*Staker{primaryNetworkPendingValidatorStaker}, @@ -270,6 +287,7 @@ func TestState_writeStakers(t *testing.T) { addStakerTx: addPrimaryNetworkDelegator, expectedPendingValidator: primaryNetworkPendingValidatorStaker, expectedPendingDelegators: []*Staker{primaryNetworkPendingDelegatorStaker}, + expectedValidatorDiffs: map[subnetIDNodeID]*validatorDiff{}, }, "add current subnet validator": { initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker}, @@ -282,21 +300,37 @@ func TestState_writeStakers(t *testing.T) { PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, Weight: subnetCurrentValidatorStaker.Weight, }, - expectedWeightDiff: &ValidatorWeightDiff{ - Decrease: false, - Amount: subnetCurrentValidatorStaker.Weight, + expectedValidatorDiffs: map[subnetIDNodeID]*validatorDiff{ + { + subnetID: subnetID, + nodeID: subnetCurrentValidatorStaker.NodeID, + }: { + weightDiff: ValidatorWeightDiff{ + Decrease: false, + Amount: subnetCurrentValidatorStaker.Weight, + }, + prevPublicKey: nil, + newPublicKey: bls.PublicKeyToUncompressedBytes(primaryNetworkCurrentValidatorStaker.PublicKey), + }, }, - expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](nil), }, "delete current primary network validator": { initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker}, initialTxs: []*txs.Tx{addPrimaryNetworkValidator}, staker: primaryNetworkCurrentValidatorStaker, - expectedWeightDiff: &ValidatorWeightDiff{ - Decrease: true, - Amount: primaryNetworkCurrentValidatorStaker.Weight, + expectedValidatorDiffs: map[subnetIDNodeID]*validatorDiff{ + { + subnetID: constants.PrimaryNetworkID, + nodeID: primaryNetworkCurrentValidatorStaker.NodeID, + }: { + weightDiff: ValidatorWeightDiff{ + Decrease: true, + Amount: primaryNetworkCurrentValidatorStaker.Weight, + }, + prevPublicKey: bls.PublicKeyToUncompressedBytes(primaryNetworkCurrentValidatorStaker.PublicKey), + newPublicKey: nil, + }, }, - expectedPublicKeyDiff: maybe.Some(primaryNetworkCurrentValidatorStaker.PublicKey), }, "delete current primary network delegator": { initialStakers: []*Staker{ @@ -314,15 +348,25 @@ func TestState_writeStakers(t *testing.T) { PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, Weight: primaryNetworkCurrentValidatorStaker.Weight, }, - expectedWeightDiff: &ValidatorWeightDiff{ - Decrease: true, - Amount: primaryNetworkCurrentDelegatorStaker.Weight, + expectedValidatorDiffs: map[subnetIDNodeID]*validatorDiff{ + { + subnetID: constants.PrimaryNetworkID, + nodeID: primaryNetworkCurrentValidatorStaker.NodeID, + }: { + weightDiff: ValidatorWeightDiff{ + Decrease: true, + Amount: primaryNetworkCurrentDelegatorStaker.Weight, + }, + prevPublicKey: bls.PublicKeyToUncompressedBytes(primaryNetworkCurrentValidatorStaker.PublicKey), + newPublicKey: bls.PublicKeyToUncompressedBytes(primaryNetworkCurrentValidatorStaker.PublicKey), + }, }, }, "delete pending primary network validator": { - initialStakers: []*Staker{primaryNetworkPendingValidatorStaker}, - initialTxs: []*txs.Tx{addPrimaryNetworkValidator}, - staker: primaryNetworkPendingValidatorStaker, + initialStakers: []*Staker{primaryNetworkPendingValidatorStaker}, + initialTxs: []*txs.Tx{addPrimaryNetworkValidator}, + staker: primaryNetworkPendingValidatorStaker, + expectedValidatorDiffs: map[subnetIDNodeID]*validatorDiff{}, }, "delete pending primary network delegator": { initialStakers: []*Staker{ @@ -335,16 +379,25 @@ func TestState_writeStakers(t *testing.T) { }, staker: primaryNetworkPendingDelegatorStaker, expectedPendingValidator: primaryNetworkPendingValidatorStaker, + expectedValidatorDiffs: map[subnetIDNodeID]*validatorDiff{}, }, "delete current subnet validator": { initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker, subnetCurrentValidatorStaker}, initialTxs: []*txs.Tx{addPrimaryNetworkValidator, addSubnetValidator}, staker: subnetCurrentValidatorStaker, - expectedWeightDiff: &ValidatorWeightDiff{ - Decrease: true, - Amount: subnetCurrentValidatorStaker.Weight, + expectedValidatorDiffs: map[subnetIDNodeID]*validatorDiff{ + { + subnetID: subnetID, + nodeID: subnetCurrentValidatorStaker.NodeID, + }: { + weightDiff: ValidatorWeightDiff{ + Decrease: true, + Amount: subnetCurrentValidatorStaker.Weight, + }, + prevPublicKey: bls.PublicKeyToUncompressedBytes(primaryNetworkCurrentValidatorStaker.PublicKey), + newPublicKey: nil, + }, }, - expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](primaryNetworkCurrentValidatorStaker.PublicKey), }, } @@ -398,6 +451,10 @@ func TestState_writeStakers(t *testing.T) { state.AddTx(test.addStakerTx, status.Committed) } + validatorDiffs, err := state.calculateValidatorDiffs() + require.NoError(err) + require.Equal(test.expectedValidatorDiffs, validatorDiffs) + state.SetHeight(1) require.NoError(state.Commit()) @@ -453,29 +510,26 @@ func TestState_writeStakers(t *testing.T) { state.validators.GetMap(test.staker.SubnetID)[test.staker.NodeID], ) - diffKey := marshalDiffKey(test.staker.SubnetID, 1, test.staker.NodeID) - weightDiffBytes, err := state.validatorWeightDiffsDB.Get(diffKey) - if test.expectedWeightDiff == nil { - require.ErrorIs(err, database.ErrNotFound) - } else { - require.NoError(err) - - weightDiff, err := unmarshalWeightDiff(weightDiffBytes) - require.NoError(err) - require.Equal(test.expectedWeightDiff, weightDiff) - } + for subnetIDNodeID, expectedDiff := range test.expectedValidatorDiffs { + diffKey := marshalDiffKey(subnetIDNodeID.subnetID, 1, subnetIDNodeID.nodeID) + weightDiffBytes, err := state.validatorWeightDiffsDB.Get(diffKey) + if expectedDiff.weightDiff.Amount == 0 { + require.ErrorIs(err, database.ErrNotFound) + } else { + require.NoError(err) - publicKeyDiffBytes, err := state.validatorPublicKeyDiffsDB.Get(diffKey) - if test.expectedPublicKeyDiff.IsNothing() { - require.ErrorIs(err, database.ErrNotFound) - } else { - require.NoError(err) + weightDiff, err := unmarshalWeightDiff(weightDiffBytes) + require.NoError(err) + require.Equal(&expectedDiff.weightDiff, weightDiff) + } - expectedPublicKeyDiff := test.expectedPublicKeyDiff.Value() - if expectedPublicKeyDiff != nil { - require.Equal(expectedPublicKeyDiff, bls.PublicKeyFromValidUncompressedBytes(publicKeyDiffBytes)) + publicKeyDiffBytes, err := state.validatorPublicKeyDiffsDB.Get(diffKey) + if bytes.Equal(expectedDiff.prevPublicKey, expectedDiff.newPublicKey) { + require.ErrorIs(err, database.ErrNotFound) } else { - require.Empty(publicKeyDiffBytes) + require.NoError(err) + + require.Equal(expectedDiff.prevPublicKey, publicKeyDiffBytes) } }