diff --git a/ids/node_weight.go b/ids/node_weight.go index 21309586ca2..07d25201546 100644 --- a/ids/node_weight.go +++ b/ids/node_weight.go @@ -4,6 +4,6 @@ package ids type NodeWeight struct { - Node NodeID + ID NodeID Weight uint64 } diff --git a/snow/engine/common/tracker/peers.go b/snow/engine/common/tracker/peers.go index 65dda6f7d1f..0dc7100c2c1 100644 --- a/snow/engine/common/tracker/peers.go +++ b/snow/engine/common/tracker/peers.go @@ -274,7 +274,7 @@ func (p *peerData) SampleValidator() (ids.NodeID, bool) { func (p *peerData) GetValidators() set.Set[ids.NodeWeight] { res := set.NewSet[ids.NodeWeight](len(p.validators)) for k, v := range p.validators { - res.Add(ids.NodeWeight{Node: k, Weight: v}) + res.Add(ids.NodeWeight{ID: k, Weight: v}) } return res } @@ -285,7 +285,7 @@ func (p *peerData) ConnectedValidators() set.Set[ids.NodeWeight] { copied := set.NewSet[ids.NodeWeight](len(p.connectedValidators)) for _, vdrID := range p.connectedValidators.List() { weight := p.validators[vdrID] - copied.Add(ids.NodeWeight{Node: vdrID, Weight: weight}) + copied.Add(ids.NodeWeight{ID: vdrID, Weight: weight}) } return copied } diff --git a/snow/engine/common/tracker/peers_test.go b/snow/engine/common/tracker/peers_test.go index ac577a399b4..24d1ca1dc72 100644 --- a/snow/engine/common/tracker/peers_test.go +++ b/snow/engine/common/tracker/peers_test.go @@ -58,10 +58,10 @@ func TestConnectedValidators(t *testing.T) { require.NoError(p.Connected(context.Background(), nodeID2, version.CurrentApp)) require.Equal(uint64(11), p.ConnectedWeight()) - require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}, ids.NodeWeight{Node: nodeID2, Weight: 6}).Equals(p.GetValidators())) - require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}, ids.NodeWeight{Node: nodeID2, Weight: 6}).Equals(p.ConnectedValidators())) + require.True(set.Of(ids.NodeWeight{ID: nodeID1, Weight: 5}, ids.NodeWeight{ID: nodeID2, Weight: 6}).Equals(p.GetValidators())) + require.True(set.Of(ids.NodeWeight{ID: nodeID1, Weight: 5}, ids.NodeWeight{ID: nodeID2, Weight: 6}).Equals(p.ConnectedValidators())) require.NoError(p.Disconnected(context.Background(), nodeID2)) - require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}, ids.NodeWeight{Node: nodeID2, Weight: 6}).Equals(p.GetValidators())) - require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}).Equals(p.ConnectedValidators())) + require.True(set.Of(ids.NodeWeight{ID: nodeID1, Weight: 5}, ids.NodeWeight{ID: nodeID2, Weight: 6}).Equals(p.GetValidators())) + require.True(set.Of(ids.NodeWeight{ID: nodeID1, Weight: 5}).Equals(p.ConnectedValidators())) } diff --git a/snow/engine/snowman/engine_decorator.go b/snow/engine/snowman/engine_decorator.go index abeb67d0761..41102695dc7 100644 --- a/snow/engine/snowman/engine_decorator.go +++ b/snow/engine/snowman/engine_decorator.go @@ -21,9 +21,30 @@ type decoratedEngineWithStragglerDetector struct { func NewDecoratedEngineWithStragglerDetector(e *Engine, time func() time.Time, f func(time.Duration)) common.Engine { minConfRatio := float64(e.Params.AlphaConfidence) / float64(e.Params.K) - sd := newStragglerDetector(time, e.Config.Ctx.Log, minConfRatio, e.Consensus.LastAccepted, - e.Config.ConnectedValidators.ConnectedValidators, e.Config.ConnectedValidators.ConnectedPercent, - e.Consensus.Processing, e.acceptedFrontiers.LastAccepted) + + sa := &snapshotAnalyzer{ + log: e.Config.Ctx.Log, + processing: e.Consensus.Processing, + } + + s := &snapshotter{ + log: e.Config.Ctx.Log, + connectedValidators: e.Config.ConnectedValidators.ConnectedValidators, + minConfirmationThreshold: minConfRatio, + lastAcceptedByNodeID: e.acceptedFrontiers.LastAccepted, + lastAccepted: dropHeight(e.Consensus.LastAccepted), + } + + conf := stragglerDetectorConfig{ + getSnapshot: s.getNetworkSnapshot, + areWeBehindTheRest: sa.areWeBehindTheRest, + minStragglerCheckInterval: minStragglerCheckInterval, + log: e.Config.Ctx.Log, + getTime: time, + } + + sd := newStragglerDetector(conf) + return &decoratedEngineWithStragglerDetector{ Engine: e, f: f, diff --git a/snow/engine/snowman/engine_decorator_test.go b/snow/engine/snowman/engine_decorator_test.go index 367631249ea..c40664c1e05 100644 --- a/snow/engine/snowman/engine_decorator_test.go +++ b/snow/engine/snowman/engine_decorator_test.go @@ -13,12 +13,13 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/consensus/snowman" "github.com/ava-labs/avalanchego/snow/consensus/snowman/snowmantest" + "github.com/ava-labs/avalanchego/utils/timer/mockable" ) func TestEngineStragglerDetector(t *testing.T) { require := require.New(t) - fakeClock := make(chan time.Time, 1) + var fakeClock mockable.Clock conf := DefaultConfig(t) peerID, _, sender, vm, engine := setup(t, conf) @@ -26,24 +27,14 @@ func TestEngineStragglerDetector(t *testing.T) { parent := snowmantest.BuildChild(snowmantest.Genesis) require.NoError(conf.Consensus.Add(parent)) - listenerShouldInvokeWith := []time.Duration{0, 0, time.Second * 2} - - fakeTime := func() time.Time { - select { - case now := <-fakeClock: - return now - default: - require.Fail("should have a time.Time in the channel") - return time.Time{} - } - } + listenerShouldInvokeWith := []time.Duration{0, 0, minStragglerCheckInterval * 2} f := func(duration time.Duration) { require.Equal(listenerShouldInvokeWith[0], duration) listenerShouldInvokeWith = listenerShouldInvokeWith[1:] } - decoratedEngine := NewDecoratedEngineWithStragglerDetector(engine, fakeTime, f) + decoratedEngine := NewDecoratedEngineWithStragglerDetector(engine, fakeClock.Time, f) vm.GetBlockF = func(_ context.Context, blkID ids.ID) (snowman.Block, error) { switch blkID { @@ -62,13 +53,13 @@ func TestEngineStragglerDetector(t *testing.T) { } now := time.Now() - fakeClock <- now + fakeClock.Set(now) require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) - now = now.Add(time.Second * 2) - fakeClock <- now + now = now.Add(minStragglerCheckInterval * 2) + fakeClock.Set(now) require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) - now = now.Add(time.Second * 2) - fakeClock <- now + now = now.Add(minStragglerCheckInterval * 2) + fakeClock.Set(now) require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) require.Empty(listenerShouldInvokeWith) } diff --git a/snow/engine/snowman/straggler_detect.go b/snow/engine/snowman/straggler_detect.go index 699cd5f2102..47282aa62e7 100644 --- a/snow/engine/snowman/straggler_detect.go +++ b/snow/engine/snowman/straggler_detect.go @@ -18,7 +18,7 @@ import ( ) const ( - minStragglerCheckInterval = time.Second + minStragglerCheckInterval = 10 * time.Second stakeThresholdForStragglerSuspicion = 0.75 minimumStakeThresholdRequiredForNetworkInfo = 0.8 knownStakeThresholdRequiredForAnalysis = 0.8 @@ -34,34 +34,13 @@ type stragglerDetectorConfig struct { // log logs events log logging.Logger - // minConfirmationThreshold is the minimum stake percentage that below it, we do not check if we are stragglers. - minConfirmationThreshold float64 - - // connectedPercent returns the stake percentage of connected nodes. - connectedPercent func() float64 - - // connectedValidators returns a set of tuples of NodeID and corresponding weight. - connectedValidators func() set.Set[ids.NodeWeight] - - // lastAcceptedByNodeID returns the last accepted height a node has reported, or false if it is unknown. - lastAcceptedByNodeID func(id ids.NodeID) (ids.ID, bool) - - // processing returns whether this block ID is known and its descendants have not yet been accepted by consensus. - // This means that when the last accepted block is given as input, true is returned, as by definition - // its descendants have not been accepted by consensus, but this block is known. - // For any block ID belonging to an ancestor of the last accepted block, false is returned, - // as the last accepted block has been accepted by consensus. - processing func(id ids.ID) bool - - // lastAccepted returns the last accepted block of this node. - lastAccepted func() ids.ID - // getSnapshot returns a snapshot of the network's nodes and their last accepted blocks, - // or false if it fails from some reason. + // excluding nodes that have the same last accepted block as we do. + // Returns false if it fails from some reason. getSnapshot func() (snapshot, bool) - // haveWeFailedCatchingUp returns whether we have not replicated enough blocks of the given snapshot - haveWeFailedCatchingUp func(snapshot) bool + // areWeBehindTheRest returns whether we have not replicated enough blocks of the given snapshot + areWeBehindTheRest func(snapshot) bool } type stragglerDetector struct { @@ -78,34 +57,10 @@ type stragglerDetector struct { prevSnapshot snapshot } -func newStragglerDetector( - getTime func() time.Time, - log logging.Logger, - minConfirmationThreshold float64, - lastAccepted func() (ids.ID, uint64), - connectedValidators func() set.Set[ids.NodeWeight], - connectedPercent func() float64, - processing func(id ids.ID) bool, - lastAcceptedByNodeID func(ids.NodeID) (ids.ID, bool), -) *stragglerDetector { - sd := &stragglerDetector{ - stragglerDetectorConfig: stragglerDetectorConfig{ - lastAccepted: dropHeight(lastAccepted), - processing: processing, - minStragglerCheckInterval: minStragglerCheckInterval, - log: log, - connectedValidators: connectedValidators, - connectedPercent: connectedPercent, - minConfirmationThreshold: minConfirmationThreshold, - lastAcceptedByNodeID: lastAcceptedByNodeID, - getTime: getTime, - }, +func newStragglerDetector(config stragglerDetectorConfig) *stragglerDetector { + return &stragglerDetector{ + stragglerDetectorConfig: config, } - - sd.getSnapshot = sd.getNetworkSnapshot - sd.haveWeFailedCatchingUp = sd.failedCatchingUp - - return sd } // CheckIfWeAreStragglingBehind returns for how long our ledger is behind the rest @@ -127,65 +82,92 @@ func (sd *stragglerDetector) CheckIfWeAreStragglingBehind() time.Duration { }() if sd.prevSnapshot.isEmpty() { - snapshot, ok := sd.getSnapshot() - if !ok { - sd.log.Trace("No node snapshot obtained") - sd.continuousStragglingPeriod = 0 - } - sd.prevSnapshot = snapshot + sd.obtainSnapshot() } else { - if sd.haveWeFailedCatchingUp(sd.prevSnapshot) { - timeSinceLastCheck := now.Sub(sd.previousStragglerCheckTime) - sd.continuousStragglingPeriod += timeSinceLastCheck - } else { - sd.continuousStragglingPeriod = 0 - } - sd.prevSnapshot = snapshot{} + sd.evaluateSnapshot(now) } return sd.continuousStragglingPeriod } -func (sd *stragglerDetector) failedCatchingUp(s snapshot) bool { +func (sd *stragglerDetector) obtainSnapshot() { + snap, ok := sd.getSnapshot() + sd.prevSnapshot = snap + if !ok || !sd.areWeBehindTheRest(snap) { + sd.log.Trace("No node snapshot obtained") + sd.continuousStragglingPeriod = 0 + sd.prevSnapshot = snapshot{} + } +} + +func (sd *stragglerDetector) evaluateSnapshot(now time.Time) { + if sd.areWeBehindTheRest(sd.prevSnapshot) { + timeSinceLastCheck := now.Sub(sd.previousStragglerCheckTime) + sd.continuousStragglingPeriod += timeSinceLastCheck + } else { + sd.continuousStragglingPeriod = 0 + } + sd.prevSnapshot = snapshot{} +} + +type snapshotAnalyzer struct { + log logging.Logger + + // processing returns whether this block ID is known and its descendants have not yet been accepted by consensus. + // This means that when the last accepted block is given as input, true is returned, as by definition + // its descendants have not been accepted by consensus, but this block is known. + // For any block ID belonging to an ancestor of the last accepted block, false is returned, + // as the last accepted block has been accepted by consensus. + processing func(id ids.ID) bool +} + +func (sa *snapshotAnalyzer) areWeBehindTheRest(s snapshot) bool { + if s.isEmpty() { + return false + } + totalValidatorWeight, nodeWeightsToBlocks := s.totalValidatorWeight, s.nodeWeightsToBlocks - var processingWeight uint64 - for nw, lastAccepted := range nodeWeightsToBlocks { - if sd.processing(lastAccepted) { - newProcessingWeight, err := safemath.Add(processingWeight, nw.Weight) - if err != nil { - sd.log.Error("Cumulative weight overflow", zap.Uint64("cumulative", processingWeight), zap.Uint64("added", nw.Weight)) - return false - } - processingWeight = newProcessingWeight - } + processingWeight, err := nodeWeightsToBlocks.filter(sa.processing).totalWeight() + if err != nil { + sa.log.Error("Failed computing total weight", zap.Error(err)) + return false } - sd.log.Trace("Counted total weight that accepted blocks we're still processing", zap.Uint64("weight", processingWeight)) + sa.log.Trace("Counted total weight that accepted blocks we're still processing", zap.Uint64("weight", processingWeight)) ratio := float64(processingWeight) / float64(totalValidatorWeight) if ratio > stakeThresholdForStragglerSuspicion { - sd.log.Trace("We are straggling behind", zap.Float64("ratio", ratio)) + sa.log.Trace("We are straggling behind", zap.Float64("ratio", ratio)) return true } - sd.log.Trace("Nodes ahead of us:", zap.Float64("ratio", ratio)) + sa.log.Trace("Nodes ahead of us:", zap.Float64("ratio", ratio)) return false } -func (sd *stragglerDetector) validateNetInfo(netInfo netInfo) bool { - if netInfo.connStakePercent < sd.minConfirmationThreshold { - // We don't know for sure whether we're behind or not. - // Even if we're behind, it's pointless to act before we have established - // connectivity with enough validators. - sd.log.Verbo("not enough connected stake to determine network info", zap.Float64("ratio", netInfo.connStakePercent)) - return false - } +type snapshotter struct { + // log logs events + log logging.Logger + + // minConfirmationThreshold is the minimum stake percentage that below it, we do not check if we are stragglers. + minConfirmationThreshold float64 + // lastAccepted returns the last accepted block of this node. + lastAccepted func() ids.ID + + // connectedValidators returns a set of tuples of NodeID and corresponding weight. + connectedValidators func() set.Set[ids.NodeWeight] + + // lastAcceptedByNodeID returns the last accepted height a node has reported, or false if it is unknown. + lastAcceptedByNodeID func(id ids.NodeID) (ids.ID, bool) +} + +func (s *snapshotter) validateNetInfo(netInfo netInfo) bool { if netInfo.totalValidatorWeight == 0 { - sd.log.Trace("Connected to zero weight") + s.log.Trace("Connected to zero weight") return false } @@ -194,13 +176,13 @@ func (sd *stragglerDetector) validateNetInfo(netInfo netInfo) bool { // Ensure we have collected last accepted blocks for at least 80% (or so) stake of the total weight we are connected to. if totalKnownLastBlockStakePercent < minimumStakeThresholdRequiredForNetworkInfo { - sd.log.Trace("Not collected enough information about last accepted blocks for the validators we are connected to", + s.log.Trace("Not collected enough information about last accepted blocks for the validators we are connected to", zap.Float64("ratio", totalKnownLastBlockStakePercent)) return false } if stakeAheadOfUs < knownStakeThresholdRequiredForAnalysis { - sd.log.Trace("Most stake we're connected to has the same height as we do", + s.log.Trace("Most stake we're connected to has the same height as we do", zap.Float64("ratio", stakeAheadOfUs)) return false } @@ -208,15 +190,15 @@ func (sd *stragglerDetector) validateNetInfo(netInfo netInfo) bool { return true } -func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { - ourLastAcceptedBlock := sd.lastAccepted() +func (s *snapshotter) getNetworkSnapshot() (snapshot, bool) { + ourLastAcceptedBlock := s.lastAccepted() - netInfo, err := sd.getNetworkInfo(ourLastAcceptedBlock) + netInfo, err := s.getNetworkInfo(ourLastAcceptedBlock) if err != nil { return snapshot{}, false } - if !sd.validateNetInfo(netInfo) { + if !s.validateNetInfo(netInfo) { return snapshot{}, false } @@ -225,31 +207,17 @@ func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { totalValidatorWeight: netInfo.totalValidatorWeight, } - if sd.haveWeFailedCatchingUp(snap) { - return snap, true - } - - return snapshot{}, false + return snap, true } -type netInfo struct { - connStakePercent float64 - totalPendingStake uint64 - totalValidatorWeight uint64 - totalWeightWeKnowItsLastAcceptedBlock uint64 - nodeWeightToLastAccepted nodeWeightsToBlocks -} - -func (sd *stragglerDetector) getNetworkInfo(ourLastAcceptedBlock ids.ID) (netInfo, error) { +func (s *snapshotter) getNetworkInfo(ourLastAcceptedBlock ids.ID) (netInfo, error) { var res netInfo - res.connStakePercent = sd.connectedPercent() - - validators := nodeWeights(sd.connectedValidators().List()) + validators := nodeWeights(s.connectedValidators().List()) nodeWeightToLastAccepted := make(nodeWeightsToBlocks, len(validators)) for _, vdr := range validators { - lastAccepted, ok := sd.lastAcceptedByNodeID(vdr.Node) + lastAccepted, ok := s.lastAcceptedByNodeID(vdr.ID) if !ok { continue } @@ -258,14 +226,14 @@ func (sd *stragglerDetector) getNetworkInfo(ourLastAcceptedBlock ids.ID) (netInf totalValidatorWeight, err := validators.totalWeight() if err != nil { - sd.log.Error("Failed computing total weight", zap.Error(err)) + s.log.Error("Failed computing total weight", zap.Error(err)) return netInfo{}, err } res.totalValidatorWeight = totalValidatorWeight totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeightToLastAccepted.totalWeight() if err != nil { - sd.log.Error("Failed computing total weight", zap.Error(err)) + s.log.Error("Failed computing total weight", zap.Error(err)) return netInfo{}, err } res.totalWeightWeKnowItsLastAcceptedBlock = totalWeightWeKnowItsLastAcceptedBlock @@ -273,27 +241,32 @@ func (sd *stragglerDetector) getNetworkInfo(ourLastAcceptedBlock ids.ID) (netInf prevLastAcceptedCount := len(nodeWeightToLastAccepted) // Ensure we have collected last accepted blocks that are not our own last accepted block. - for nodeWeight, lastAccepted := range nodeWeightToLastAccepted { - if ourLastAcceptedBlock.Compare(lastAccepted) == 0 { - delete(nodeWeightToLastAccepted, nodeWeight) - } - } + nodeWeightToLastAccepted = nodeWeightToLastAccepted.filter(func(id ids.ID) bool { + return ourLastAcceptedBlock.Compare(id) != 0 + }) res.nodeWeightToLastAccepted = nodeWeightToLastAccepted totalPendingStake, err := nodeWeightToLastAccepted.totalWeight() if err != nil { - sd.log.Error("Failed computing total weight", zap.Error(err)) + s.log.Error("Failed computing total weight", zap.Error(err)) return netInfo{}, err } res.totalPendingStake = totalPendingStake - sd.log.Trace("Excluding nodes with our own height", zap.Int("prev", prevLastAcceptedCount), zap.Uint64("new", totalPendingStake)) + s.log.Trace("Excluding nodes with our own height", zap.Int("prev", prevLastAcceptedCount), zap.Uint64("new", totalPendingStake)) return res, nil } +type netInfo struct { + totalPendingStake uint64 + totalValidatorWeight uint64 + totalWeightWeKnowItsLastAcceptedBlock uint64 + nodeWeightToLastAccepted nodeWeightsToBlocks +} + type snapshot struct { totalValidatorWeight uint64 nodeWeightsToBlocks nodeWeightsToBlocks @@ -309,6 +282,7 @@ func (nwb nodeWeightsToBlocks) totalWeight() (uint64, error) { return nodeWeights(maps.Keys(nwb)).totalWeight() } +// dropHeight removes the second return parameter from the function f() and keeps its first return parameter, ids.ID. func dropHeight(f func() (ids.ID, uint64)) func() ids.ID { return func() ids.ID { id, _ := f() @@ -329,3 +303,13 @@ func (nws nodeWeights) totalWeight() (uint64, error) { } return weight, nil } + +func (nwb nodeWeightsToBlocks) filter(f func(ids.ID) bool) nodeWeightsToBlocks { + filtered := make(nodeWeightsToBlocks, len(nwb)) + for nw, id := range nwb { + if f(id) { + filtered[nw] = id + } + } + return filtered +} diff --git a/snow/engine/snowman/straggler_detect_test.go b/snow/engine/snowman/straggler_detect_test.go index 220975a46bc..5e65fb95da7 100644 --- a/snow/engine/snowman/straggler_detect_test.go +++ b/snow/engine/snowman/straggler_detect_test.go @@ -4,6 +4,7 @@ package snowman import ( + "fmt" "math" "testing" "time" @@ -13,6 +14,7 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/set" + "github.com/ava-labs/avalanchego/utils/timer/mockable" safemath "github.com/ava-labs/avalanchego/utils/math" ) @@ -71,26 +73,18 @@ func TestGetNetworkSnapshot(t *testing.T) { lastAcceptedFromNodes map[ids.NodeID]ids.ID processing map[ids.ID]struct{} connectedValidators func() set.Set[ids.NodeWeight] - connectedPercent float64 expectedSnapshot snapshot expectedOK bool expectedLogged string }{ - { - description: "not enough connected validators", - connectedValidators: connectedValidators([]ids.NodeWeight{}), - expectedLogged: "not enough connected stake to determine network info", - }, { description: "connected to zero weight", - connectedPercent: 1.0, connectedValidators: connectedValidators([]ids.NodeWeight{}), expectedLogged: "Connected to zero weight", }, { description: "not enough info", - connectedPercent: 1.0, - connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 1, Node: n1}, {Weight: 999999, Node: n2}}), + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 1, ID: n1}, {Weight: 999999, ID: n2}}), lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ n1: {0x1}, }, @@ -98,8 +92,7 @@ func TestGetNetworkSnapshot(t *testing.T) { }, { description: "we're in sync", - connectedPercent: 1.0, - connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, Node: n1}}), + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, ID: n1}}), lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ n1: {0x1}, }, @@ -108,48 +101,49 @@ func TestGetNetworkSnapshot(t *testing.T) { }, { description: "we're behind", - connectedPercent: 1.0, - connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, Node: n1}}), + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, ID: n1}}), lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ n1: {0x1}, }, processing: map[ids.ID]struct{}{{0x1}: {}}, lastAccepted: ids.ID{0x0}, expectedSnapshot: snapshot{totalValidatorWeight: 999999, nodeWeightsToBlocks: nodeWeightsToBlocks{ - ids.NodeWeight{Node: n1, Weight: 999999}: {0x1}, + ids.NodeWeight{ID: n1, Weight: 999999}: {0x1}, }}, expectedOK: true, }, { description: "we're not behind", - connectedPercent: 1.0, - connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, Node: n1}}), + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, ID: n1}}), lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ n1: {0x1}, }, processing: map[ids.ID]struct{}{{0x2}: {}}, lastAccepted: ids.ID{0x0}, + expectedSnapshot: snapshot{totalValidatorWeight: 999999, nodeWeightsToBlocks: nodeWeightsToBlocks{ + ids.NodeWeight{ID: n1, Weight: 999999}: {0x1}, + }}, + expectedOK: true, }, } { t.Run(testCase.description, func(t *testing.T) { var buff logBuffer log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) - sd := newStragglerDetector(nil, log, 0.75, - func() (ids.ID, uint64) { - return testCase.lastAccepted, 0 + s := &snapshotter{ + log: log, + connectedValidators: testCase.connectedValidators, + minConfirmationThreshold: 0.75, + lastAccepted: func() ids.ID { + return testCase.lastAccepted }, - testCase.connectedValidators, func() float64 { return testCase.connectedPercent }, - func(id ids.ID) bool { - _, ok := testCase.processing[id] - return ok - }, - func(vdr ids.NodeID) (ids.ID, bool) { + lastAcceptedByNodeID: func(vdr ids.NodeID) (ids.ID, bool) { id, ok := testCase.lastAcceptedFromNodes[vdr] return id, ok - }) + }, + } - snapshot, ok := sd.getNetworkSnapshot() + snapshot, ok := s.getNetworkSnapshot() require.Equal(t, testCase.expectedSnapshot, snapshot) require.Equal(t, testCase.expectedOK, ok) require.Contains(t, buff.String(), testCase.expectedLogged) @@ -168,7 +162,6 @@ func TestFailedCatchingUp(t *testing.T) { lastAcceptedFromNodes map[ids.NodeID]ids.ID processing map[ids.ID]struct{} connectedValidators []ids.NodeWeight - connectedPercent float64 input snapshot expected bool expectedLogged string @@ -176,23 +169,24 @@ func TestFailedCatchingUp(t *testing.T) { { description: "stake overflow", input: snapshot{ + totalValidatorWeight: 100, nodeWeightsToBlocks: nodeWeightsToBlocks{ - ids.NodeWeight{Node: n1, Weight: math.MaxUint64 - 10}: ids.ID{0x1}, - ids.NodeWeight{Node: n2, Weight: 11}: ids.ID{0x2}, + ids.NodeWeight{ID: n1, Weight: math.MaxUint64 - 10}: ids.ID{0x1}, + ids.NodeWeight{ID: n2, Weight: 11}: ids.ID{0x2}, }, }, processing: map[ids.ID]struct{}{ {0x1}: {}, {0x2}: {}, }, - expectedLogged: "Cumulative weight overflow", + expectedLogged: "Failed computing total weight", }, { description: "Straggling behind stake minority", input: snapshot{ totalValidatorWeight: 100, nodeWeightsToBlocks: nodeWeightsToBlocks{ - ids.NodeWeight{Node: n1, Weight: 25}: ids.ID{0x1}, - ids.NodeWeight{Node: n2, Weight: 50}: ids.ID{0x2}, + ids.NodeWeight{ID: n1, Weight: 25}: ids.ID{0x1}, + ids.NodeWeight{ID: n2, Weight: 50}: ids.ID{0x2}, }, }, processing: map[ids.ID]struct{}{ @@ -205,8 +199,8 @@ func TestFailedCatchingUp(t *testing.T) { description: "Straggling behind stake majority", input: snapshot{ totalValidatorWeight: 100, nodeWeightsToBlocks: nodeWeightsToBlocks{ - ids.NodeWeight{Node: n1, Weight: 26}: ids.ID{0x1}, - ids.NodeWeight{Node: n2, Weight: 50}: ids.ID{0x2}, + ids.NodeWeight{ID: n1, Weight: 26}: ids.ID{0x1}, + ids.NodeWeight{ID: n2, Weight: 50}: ids.ID{0x2}, }, }, processing: map[ids.ID]struct{}{ @@ -220,8 +214,8 @@ func TestFailedCatchingUp(t *testing.T) { description: "In sync with the majority", input: snapshot{ totalValidatorWeight: 100, nodeWeightsToBlocks: nodeWeightsToBlocks{ - ids.NodeWeight{Node: n1, Weight: 75}: ids.ID{0x1}, - ids.NodeWeight{Node: n2, Weight: 25}: ids.ID{0x2}, + ids.NodeWeight{ID: n1, Weight: 75}: ids.ID{0x1}, + ids.NodeWeight{ID: n2, Weight: 25}: ids.ID{0x2}, }, }, processing: map[ids.ID]struct{}{ @@ -234,34 +228,22 @@ func TestFailedCatchingUp(t *testing.T) { var buff logBuffer log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) - sd := newStragglerDetector(nil, log, 0.75, - func() (ids.ID, uint64) { - return testCase.lastAccepted, 0 - }, - func() set.Set[ids.NodeWeight] { - var set set.Set[ids.NodeWeight] - for _, nw := range testCase.connectedValidators { - set.Add(nw) - } - return set - }, func() float64 { return testCase.connectedPercent }, - func(id ids.ID) bool { + sa := &snapshotAnalyzer{ + log: log, + processing: func(id ids.ID) bool { _, ok := testCase.processing[id] return ok }, - func(vdr ids.NodeID) (ids.ID, bool) { - id, ok := testCase.lastAcceptedFromNodes[vdr] - return id, ok - }) + } - require.Equal(t, testCase.expected, sd.failedCatchingUp(testCase.input)) + require.Equal(t, testCase.expected, sa.areWeBehindTheRest(testCase.input)) require.Contains(t, buff.String(), testCase.expectedLogged) }) } } func TestCheckIfWeAreStragglingBehind(t *testing.T) { - fakeClock := make(chan time.Time, 1) + var fakeClock mockable.Clock snapshots := make(chan snapshot, 1) assertNoSnapshotsRemain := func() { @@ -286,16 +268,13 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { sd := stragglerDetector{ stragglerDetectorConfig: stragglerDetectorConfig{ minStragglerCheckInterval: time.Second, - getTime: func() time.Time { - now := <-fakeClock - return now - }, - log: log, + getTime: fakeClock.Time, + log: log, getSnapshot: func() (snapshot, bool) { s := <-snapshots return s, !s.isEmpty() }, - haveWeFailedCatchingUp: func(_ snapshot) bool { + areWeBehindTheRest: func(_ snapshot) bool { return haveWeFailedCatchingUpReturns }, }, @@ -329,9 +308,10 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { }, }, { - description: "Advance time some more to the first check where the snapshot isn't empty", - timeAdvanced: time.Second * 2, - snapshotsRead: []snapshot{nonEmptySnap}, + description: "Advance time some more to the first check where the snapshot isn't empty", + timeAdvanced: time.Second * 2, + snapshotsRead: []snapshot{nonEmptySnap}, + haveWeFailedCatchingUpReturns: true, evalExtraAssertions: func(t *testing.T) { require.Empty(t, buff.String()) }, @@ -346,11 +326,12 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { }, }, { - description: "The third snapshot is due to a fresh check", - timeAdvanced: time.Second * 2, - snapshotsRead: []snapshot{nonEmptySnap}, + description: "The third snapshot is due to a fresh check", + timeAdvanced: time.Second * 2, + snapshotsRead: []snapshot{nonEmptySnap}, + haveWeFailedCatchingUpReturns: true, // We carry over the total straggling time from previous testCase to this check, - // as we need the next check to nullify it. + // as we expect the next check to nullify it. expectedStragglingTime: time.Second * 2, evalExtraAssertions: func(_ *testing.T) {}, }, @@ -361,8 +342,9 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { }, } { t.Run(testCase.description, func(t *testing.T) { + fmt.Println(testCase.description) fakeTime = fakeTime.Add(testCase.timeAdvanced) - fakeClock <- fakeTime + fakeClock.Set(fakeTime) // Load the snapshot expected to be retrieved in this testCase, if applicable. if len(testCase.snapshotsRead) > 0 { diff --git a/snow/networking/handler/health.go b/snow/networking/handler/health.go index 4c43e0ead00..25c1463e4f7 100644 --- a/snow/networking/handler/health.go +++ b/snow/networking/handler/health.go @@ -72,7 +72,7 @@ func (h *handler) getDisconnectedValidators() set.Set[ids.NodeID] { func withoutWeights(weights set.Set[ids.NodeWeight]) set.Set[ids.NodeID] { var res set.Set[ids.NodeID] for _, nw := range weights.List() { - res.Add(nw.Node) + res.Add(nw.ID) } return res }