diff --git a/chains/manager.go b/chains/manager.go index 2fdd3c0b73f..49b596e3d34 100644 --- a/chains/manager.go +++ b/chains/manager.go @@ -1339,12 +1339,16 @@ func (m *manager) createSnowmanChain( Consensus: consensus, PartialSync: m.PartialSyncPrimaryNetwork && ctx.ChainID == constants.PlatformChainID, } - var engine common.Engine - engine, err = smeng.New(engineConfig) + + sme, err := smeng.New(engineConfig) if err != nil { return nil, fmt.Errorf("error initializing snowman engine: %w", err) } + ed := smeng.EngineStragglerDetector{Listener: func(_ time.Duration) {}} + + engine := ed.AttachToEngine(sme) + if m.TracingEnabled { engine = common.TraceEngine(engine, m.Tracer) } diff --git a/ids/node_weight.go b/ids/node_weight.go new file mode 100644 index 00000000000..21309586ca2 --- /dev/null +++ b/ids/node_weight.go @@ -0,0 +1,9 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package ids + +type NodeWeight struct { + Node NodeID + Weight uint64 +} diff --git a/snow/engine/common/tracker/peers.go b/snow/engine/common/tracker/peers.go index 37bf7b10f02..65dda6f7d1f 100644 --- a/snow/engine/common/tracker/peers.go +++ b/snow/engine/common/tracker/peers.go @@ -9,7 +9,6 @@ import ( "sync" "github.com/prometheus/client_golang/prometheus" - "golang.org/x/exp/maps" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/validators" @@ -37,10 +36,10 @@ type Peers interface { SampleValidator() (ids.NodeID, bool) // GetValidators returns the set of all validators // known to this peer manager - GetValidators() set.Set[ids.NodeID] + GetValidators() set.Set[ids.NodeWeight] // ConnectedValidators returns the set of all validators // that are currently connected - ConnectedValidators() set.Set[ids.NodeID] + ConnectedValidators() set.Set[ids.NodeWeight] } type lockedPeers struct { @@ -112,14 +111,14 @@ func (p *lockedPeers) SampleValidator() (ids.NodeID, bool) { return p.peers.SampleValidator() } -func (p *lockedPeers) GetValidators() set.Set[ids.NodeID] { +func (p *lockedPeers) GetValidators() set.Set[ids.NodeWeight] { p.lock.RLock() defer p.lock.RUnlock() return p.peers.GetValidators() } -func (p *lockedPeers) ConnectedValidators() set.Set[ids.NodeID] { +func (p *lockedPeers) ConnectedValidators() set.Set[ids.NodeWeight] { p.lock.RLock() defer p.lock.RUnlock() @@ -272,14 +271,21 @@ func (p *peerData) SampleValidator() (ids.NodeID, bool) { return p.connectedValidators.Peek() } -func (p *peerData) GetValidators() set.Set[ids.NodeID] { - return set.Of(maps.Keys(p.validators)...) +func (p *peerData) GetValidators() set.Set[ids.NodeWeight] { + res := set.NewSet[ids.NodeWeight](len(p.validators)) + for k, v := range p.validators { + res.Add(ids.NodeWeight{Node: k, Weight: v}) + } + return res } -func (p *peerData) ConnectedValidators() set.Set[ids.NodeID] { +func (p *peerData) ConnectedValidators() set.Set[ids.NodeWeight] { // The set is copied to avoid future changes from being reflected in the // returned set. - copied := set.NewSet[ids.NodeID](len(p.connectedValidators)) - copied.Union(p.connectedValidators) + copied := set.NewSet[ids.NodeWeight](len(p.connectedValidators)) + for _, vdrID := range p.connectedValidators.List() { + weight := p.validators[vdrID] + copied.Add(ids.NodeWeight{Node: vdrID, Weight: weight}) + } return copied } diff --git a/snow/engine/snowman/engine_decorator.go b/snow/engine/snowman/engine_decorator.go new file mode 100644 index 00000000000..9f9f870a62a --- /dev/null +++ b/snow/engine/snowman/engine_decorator.go @@ -0,0 +1,58 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package snowman + +import ( + "context" + "time" + + "go.uber.org/zap" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/engine/common" +) + +type EngineStragglerDetector struct { + Listener func(duration time.Duration) +} + +func (ed *EngineStragglerDetector) AttachToEngine(e *Engine) common.Engine { + minConfRatio := float64(e.Params.AlphaConfidence) / float64(e.Params.K) + sd := newStragglerDetector(e.Config.Ctx.Log, minConfRatio, e.Consensus.LastAccepted, + e.Config.ConnectedValidators.ConnectedValidators, e.Config.ConnectedValidators.ConnectedPercent, + e.Consensus.Processing, e.acceptedFrontiers.LastAccepted) + de := &DecoratedEngine{Engine: e} + de.decorate("Chits", func(e *Engine) { + behindDuration := sd.CheckIfWeAreStragglingBehind() + if behindDuration > 0 { + e.Config.Ctx.Log.Info("We are behind the rest of the network", zap.Float64("seconds", behindDuration.Seconds())) + } + e.metrics.stragglingDuration.Set(float64(behindDuration)) + ed.Listener(behindDuration) + }) + + return de +} + +type DecoratedEngine struct { + decorations map[string]func(*Engine) + + *Engine +} + +func (de *DecoratedEngine) decorate(method string, f func(*Engine)) { + if de.decorations == nil { + de.decorations = map[string]func(*Engine){} + } + de.decorations[method] = f +} + +func (de *DecoratedEngine) Chits(ctx context.Context, nodeID ids.NodeID, requestID uint32, preferredID ids.ID, preferredIDAtHeight ids.ID, acceptedID ids.ID) error { + f, ok := de.decorations["Chits"] + if !ok { + panic("programming error: decorator for Chits not registered") + } + f(de.Engine) + return de.Engine.Chits(ctx, nodeID, requestID, preferredID, preferredIDAtHeight, acceptedID) +} diff --git a/snow/engine/snowman/metrics.go b/snow/engine/snowman/metrics.go index 922b18200d4..68856ba1054 100644 --- a/snow/engine/snowman/metrics.go +++ b/snow/engine/snowman/metrics.go @@ -23,6 +23,7 @@ type metrics struct { numBlocked prometheus.Gauge numBlockers prometheus.Gauge numNonVerifieds prometheus.Gauge + stragglingDuration prometheus.Gauge numBuilt prometheus.Counter numBuildsFailed prometheus.Counter numUselessPutBytes prometheus.Counter @@ -41,6 +42,10 @@ type metrics struct { func newMetrics(reg prometheus.Registerer) (*metrics, error) { errs := wrappers.Errs{} m := &metrics{ + stragglingDuration: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "straggling_duration", + Help: "For how long we have been straggling behind the rest, in nano-seconds.", + }), bootstrapFinished: prometheus.NewGauge(prometheus.GaugeOpts{ Name: "bootstrap_finished", Help: "Whether or not bootstrap process has completed. 1 is success, 0 is fail or ongoing.", @@ -128,6 +133,7 @@ func newMetrics(reg prometheus.Registerer) (*metrics, error) { m.issued.WithLabelValues(unknownSource) errs.Add( + reg.Register(m.stragglingDuration), reg.Register(m.bootstrapFinished), reg.Register(m.numRequests), reg.Register(m.numBlocked), diff --git a/snow/engine/snowman/straggler_detect.go b/snow/engine/snowman/straggler_detect.go new file mode 100644 index 00000000000..9b83fd25827 --- /dev/null +++ b/snow/engine/snowman/straggler_detect.go @@ -0,0 +1,306 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package snowman + +import ( + "fmt" + "time" + + "go.uber.org/zap" + "golang.org/x/exp/maps" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" + + safemath "github.com/ava-labs/avalanchego/utils/math" +) + +const ( + minStragglerCheckInterval = time.Second + stakeThresholdForStragglerSuspicion = 0.75 +) + +type stragglerDetectorConfig struct { + // getTime returns the current time + getTime func() time.Time + + // minStragglerCheckInterval determines how frequently we are allowed to check if we are stragglers. + minStragglerCheckInterval time.Duration + + // log logs events + log logging.Logger + + // minConfirmationThreshold is the minimum percentage that below it, we do not check if we are stragglers. + minConfirmationThreshold float64 + + // connectedPercent returns the percent of connected nodes. + connectedPercent func() float64 + + // connectedValidators returns a set of tuples of NodeID and corresponding weight. + connectedValidators func() set.Set[ids.NodeWeight] + + // lastAcceptedByNodeID returns the last reported height a node has reported, or false if it is unknown. + lastAcceptedByNodeID func(id ids.NodeID) (ids.ID, bool) + + // processing returns whether this block ID is known and its descendants have not yet been accepted by consensus. + // This means that when the last accepted block is given as input, true is returned, as by definition + // its descendants have not been accepted by consensus, but this block is known. + // For any block ID belonging to an ancestor of the last accepted block, false is returned, + // as the last accepted block has been accepted by consensus. + processing func(id ids.ID) bool + + // lastAccepted returns the last accepted block of this node. + lastAccepted func() ids.ID + + // getSnapshot returns a snapshot of the network's nodes and their last accepted blocks, + // or false if it fails from some reason. + getSnapshot func() (snapshot, bool) + + // haveWeFailedCatchingUp returns whether we have not replicated enough blocks of the given snapshot + haveWeFailedCatchingUp func(snapshot) bool +} + +type stragglerDetector struct { + stragglerDetectorConfig + + // continuousStragglingPeriod defines the time we have been straggling continuously. + continuousStragglingPeriod time.Duration + + // previousStragglerCheckTime is the last time we checked whether + // our block height is behind the rest of the network + previousStragglerCheckTime time.Time + + // prevSnapshot is the snapshot from a past iteration. + prevSnapshot snapshot +} + +func newStragglerDetector( + log logging.Logger, + minConfirmationThreshold float64, + lastAccepted func() (ids.ID, uint64), + connectedValidators func() set.Set[ids.NodeWeight], + connectedPercent func() float64, + processing func(id ids.ID) bool, + lastAcceptedByNodeID func(ids.NodeID) (ids.ID, bool), +) *stragglerDetector { + sd := &stragglerDetector{ + stragglerDetectorConfig: stragglerDetectorConfig{ + lastAccepted: dropHeight(lastAccepted), + processing: processing, + minStragglerCheckInterval: minStragglerCheckInterval, + log: log, + connectedValidators: connectedValidators, + connectedPercent: connectedPercent, + minConfirmationThreshold: minConfirmationThreshold, + lastAcceptedByNodeID: lastAcceptedByNodeID, + getTime: time.Now, + }, + } + + sd.getSnapshot = sd.getNetworkSnapshot + sd.haveWeFailedCatchingUp = sd.failedCatchingUp + + return sd +} + +// CheckIfWeAreStragglingBehind returns for how long our ledger is behind the rest +// of the nodes in the network. If we are not behind, zero is returned. +func (sd *stragglerDetector) CheckIfWeAreStragglingBehind() time.Duration { + now := sd.getTime() + if sd.previousStragglerCheckTime.IsZero() { + sd.previousStragglerCheckTime = now + return 0 + } + + // Don't check too often, only once in every minStragglerCheckInterval + if sd.previousStragglerCheckTime.Add(sd.minStragglerCheckInterval).After(now) { + return 0 + } + + defer func() { + sd.previousStragglerCheckTime = now + }() + + if sd.prevSnapshot.isEmpty() { + snapshot, ok := sd.getSnapshot() + if !ok { + sd.log.Trace("No node snapshot obtained") + sd.continuousStragglingPeriod = 0 + } + sd.prevSnapshot = snapshot + } else { + if sd.haveWeFailedCatchingUp(sd.prevSnapshot) { + timeSinceLastCheck := now.Sub(sd.previousStragglerCheckTime) + sd.continuousStragglingPeriod += timeSinceLastCheck + } else { + sd.continuousStragglingPeriod = 0 + } + sd.prevSnapshot = snapshot{} + } + + return sd.continuousStragglingPeriod +} + +func (sd *stragglerDetector) failedCatchingUp(s snapshot) bool { + totalValidatorWeight, nodeWeights2Blocks := s.totalValidatorWeight, s.nodeWeights2Blocks + + var processingWeight uint64 + for nw, lastAccepted := range nodeWeights2Blocks { + if sd.processing(lastAccepted) { + newProcessingWeight, err := safemath.Add(processingWeight, nw.Weight) + if err != nil { + sd.log.Error("Cumulative weight overflow", zap.Uint64("cumulative", processingWeight), zap.Uint64("added", nw.Weight)) + return false + } + processingWeight = newProcessingWeight + } + } + + sd.log.Trace("Counted total weight that accepted blocks we're still processing", zap.Uint64("weight", processingWeight)) + + ratio := float64(processingWeight) / float64(totalValidatorWeight) + + if ratio > stakeThresholdForStragglerSuspicion { + sd.log.Trace("We are straggling behind", zap.Float64("ratio", ratio)) + return true + } + + sd.log.Trace("Nodes ahead of us:", zap.Float64("ratio", ratio)) + + return false +} + +func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { + nodeWeight2lastAccepted, totalValidatorWeight, _ := sd.getNetworkInfo() + if len(nodeWeight2lastAccepted) == 0 { + return snapshot{}, false + } + + ourLastAcceptedBlock := sd.lastAccepted() + + prevLastAcceptedCount := len(nodeWeight2lastAccepted) + for k, v := range nodeWeight2lastAccepted { + if ourLastAcceptedBlock.Compare(v) == 0 { + delete(nodeWeight2lastAccepted, k) + } + } + newLastAcceptedCount := len(nodeWeight2lastAccepted) + + sd.log.Trace("Excluding nodes with our own height", zap.Int("prev", prevLastAcceptedCount), zap.Int("new", newLastAcceptedCount)) + + // Ensure we have collected last accepted blocks that are not our own last accepted block + // for at least 80% stake of the total weight we are connected to. + + totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeight2lastAccepted.totalWeight() + if err != nil { + sd.log.Error("Failed computing total weight", zap.Error(err)) + return snapshot{}, false + } + + ratio := float64(totalWeightWeKnowItsLastAcceptedBlock) / float64(totalValidatorWeight) + + if ratio < 0.8 { + sd.log.Trace("Most stake we're connected to has the same height as we do", + zap.Float64("ratio", ratio)) + return snapshot{}, false + } + + snap := snapshot{ + nodeWeights2Blocks: nodeWeight2lastAccepted, + totalValidatorWeight: totalValidatorWeight, + } + + if sd.haveWeFailedCatchingUp(snap) { + return snap, true + } + + return snapshot{}, false +} + +func (sd *stragglerDetector) getNetworkInfo() (nodeWeights2Blocks, uint64, uint64) { + ratio := sd.connectedPercent() + if ratio < sd.minConfirmationThreshold { + // We don't know for sure whether we're behind or not. + // Even if we're behind, it's pointless to act before we have established + // connectivity with enough validators. + sd.log.Verbo("not enough connected stake to determine network info", zap.Float64("ratio", ratio)) + return nil, 0, 0 + } + + validators := nodeWeights(sd.connectedValidators().List()) + + nodeWeight2lastAccepted := make(nodeWeights2Blocks, len(validators)) + + for _, vdr := range validators { + lastAccepted, ok := sd.lastAcceptedByNodeID(vdr.Node) + if !ok { + continue + } + nodeWeight2lastAccepted[vdr] = lastAccepted + } + + totalValidatorWeight, err := validators.totalWeight() + if err != nil { + sd.log.Error("Failed computing total weight", zap.Error(err)) + return nil, 0, 0 + } + + totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeight2lastAccepted.totalWeight() + if err != nil { + sd.log.Error("Failed computing total weight", zap.Error(err)) + return nil, 0, 0 + } + + if totalValidatorWeight == 0 { + sd.log.Trace("Connected to zero weight") + return nil, 0, 0 + } + + ratio = float64(totalWeightWeKnowItsLastAcceptedBlock) / float64(totalValidatorWeight) + + // Ensure we have collected last accepted blocks for at least 80% stake of the total weight we are connected to. + if ratio < 0.8 { + sd.log.Trace("Not collected enough information about last accepted blocks for the validators we are connected to", + zap.Float64("ratio", ratio)) + return nil, 0, 0 + } + return nodeWeight2lastAccepted, totalValidatorWeight, totalWeightWeKnowItsLastAcceptedBlock +} + +type snapshot struct { + totalValidatorWeight uint64 + nodeWeights2Blocks nodeWeights2Blocks +} + +func (s snapshot) isEmpty() bool { + return s.totalValidatorWeight == 0 || len(s.nodeWeights2Blocks) == 0 +} + +type nodeWeights2Blocks map[ids.NodeWeight]ids.ID + +func (nw2b nodeWeights2Blocks) totalWeight() (uint64, error) { + return nodeWeights(maps.Keys(nw2b)).totalWeight() +} + +func dropHeight(f func() (ids.ID, uint64)) func() ids.ID { + return func() ids.ID { + id, _ := f() + return id + } +} + +type nodeWeights []ids.NodeWeight + +func (nws nodeWeights) totalWeight() (uint64, error) { + var weight uint64 + for _, nw := range nws { + newWeight, err := safemath.Add(weight, nw.Weight) + if err != nil { + return 0, fmt.Errorf("cumulative weight: %d, tried to add %d: %w", weight, nw.Weight, err) + } + weight = newWeight + } + return weight, nil +} diff --git a/snow/engine/snowman/straggler_detect_test.go b/snow/engine/snowman/straggler_detect_test.go new file mode 100644 index 00000000000..9faf2098d22 --- /dev/null +++ b/snow/engine/snowman/straggler_detect_test.go @@ -0,0 +1,373 @@ +package snowman + +import ( + "math" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" + + safemath "github.com/ava-labs/avalanchego/utils/math" +) + +func TestNodeWeights(t *testing.T) { + nws := nodeWeights{ + {Weight: 100}, + {Weight: 50}, + } + + total, err := nws.totalWeight() + require.NoError(t, err) + require.Equal(t, uint64(150), total) +} + +func TestNodeWeightsOverflow(t *testing.T) { + nws := nodeWeights{ + {Weight: math.MaxUint64 - 100}, + {Weight: 110}, + } + + total, err := nws.totalWeight() + require.ErrorIs(t, err, safemath.ErrOverflow) + require.Equal(t, uint64(0), total) +} + +func TestNodeWeights2Blocks(t *testing.T) { + nw2b := nodeWeights2Blocks{ + ids.NodeWeight{Weight: 5}: ids.Empty, + ids.NodeWeight{Weight: 10}: ids.Empty, + } + + total, err := nw2b.totalWeight() + require.NoError(t, err) + require.Equal(t, uint64(15), total) +} + +func TestGetNetworkSnapshot(t *testing.T) { + n1, err := ids.NodeIDFromString("NodeID-N5gc5soT3Gpr98NKpqvQQG2SgGrVPL64w") + require.NoError(t, err) + + n2, err := ids.NodeIDFromString("NodeID-NpagUxt6KQiwPch9Sd4osv8kD1TZnkjdk") + require.NoError(t, err) + + connectedValidators := func(s []ids.NodeWeight) func() set.Set[ids.NodeWeight] { + return func() set.Set[ids.NodeWeight] { + var set set.Set[ids.NodeWeight] + for _, nw := range s { + set.Add(nw) + } + return set + } + } + + for _, testCase := range []struct { + description string + lastAccepted ids.ID + lastAcceptedFromNodes map[ids.NodeID]ids.ID + processing map[ids.ID]struct{} + connectedValidators func() set.Set[ids.NodeWeight] + connectedPercent float64 + expectedSnapshot snapshot + expectedOK bool + expectedLogged string + }{ + { + description: "not enough connected validators", + connectedValidators: connectedValidators([]ids.NodeWeight{}), + expectedLogged: "not enough connected stake to determine network info", + }, + { + description: "connected to zero weight", + connectedPercent: 1.0, + connectedValidators: connectedValidators([]ids.NodeWeight{}), + expectedLogged: "Connected to zero weight", + }, + { + description: "not enough info", + connectedPercent: 1.0, + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 1, Node: n1}, {Weight: 999999, Node: n2}}), + lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ + n1: {0x1}, + }, + expectedLogged: "Not collected enough information about last accepted blocks", + }, + { + description: "we're in sync", + connectedPercent: 1.0, + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, Node: n1}}), + lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ + n1: {0x1}, + }, + lastAccepted: ids.ID{0x1}, + expectedLogged: "Most stake we're connected to has the same height as we do", + }, + { + description: "we're behind", + connectedPercent: 1.0, + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, Node: n1}}), + lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ + n1: {0x1}, + }, + processing: map[ids.ID]struct{}{{0x1}: {}}, + lastAccepted: ids.ID{0x0}, + expectedSnapshot: snapshot{totalValidatorWeight: 999999, nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: 999999}: {0x1}, + }}, + expectedOK: true, + }, + } { + t.Run(testCase.description, func(t *testing.T) { + var buff logBuffer + log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) + + sd := newStragglerDetector(log, 0.75, + func() (ids.ID, uint64) { + return testCase.lastAccepted, 0 + }, + testCase.connectedValidators, func() float64 { return testCase.connectedPercent }, + func(id ids.ID) bool { + _, ok := testCase.processing[id] + return ok + }, + func(vdr ids.NodeID) (ids.ID, bool) { + id, ok := testCase.lastAcceptedFromNodes[vdr] + return id, ok + }) + + snapshot, ok := sd.getNetworkSnapshot() + require.Equal(t, testCase.expectedSnapshot, snapshot) + require.Equal(t, testCase.expectedOK, ok) + require.Contains(t, buff.String(), testCase.expectedLogged) + }) + } +} + +func TestFailedCatchingUp(t *testing.T) { + n1, err := ids.NodeIDFromString("NodeID-N5gc5soT3Gpr98NKpqvQQG2SgGrVPL64w") + require.NoError(t, err) + + n2, err := ids.NodeIDFromString("NodeID-NpagUxt6KQiwPch9Sd4osv8kD1TZnkjdk") + require.NoError(t, err) + + for _, testCase := range []struct { + description string + lastAccepted ids.ID + lastAcceptedFromNodes map[ids.NodeID]ids.ID + processing map[ids.ID]struct{} + connectedValidators []ids.NodeWeight + connectedPercent float64 + input snapshot + expected bool + expectedLogged string + }{ + { + description: "stake overflow", + input: snapshot{ + nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: math.MaxUint64 - 10}: ids.ID{0x1}, + ids.NodeWeight{Node: n2, Weight: 11}: ids.ID{0x2}, + }, + }, + processing: map[ids.ID]struct{}{ + {0x1}: {}, + {0x2}: {}, + }, + expectedLogged: "Cumulative weight overflow", + }, + { + description: "Straggling behind stake minority", + input: snapshot{ + totalValidatorWeight: 100, nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: 25}: ids.ID{0x1}, + ids.NodeWeight{Node: n2, Weight: 50}: ids.ID{0x2}, + }, + }, + processing: map[ids.ID]struct{}{ + {0x1}: {}, + {0x2}: {}, + }, + expectedLogged: "Nodes ahead of us", + }, + { + description: "Straggling behind stake majority", + input: snapshot{ + totalValidatorWeight: 100, nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: 26}: ids.ID{0x1}, + ids.NodeWeight{Node: n2, Weight: 50}: ids.ID{0x2}, + }, + }, + processing: map[ids.ID]struct{}{ + {0x1}: {}, + {0x2}: {}, + }, + expectedLogged: "We are straggling behind", + expected: true, + }, + { + description: "In sync with the majority", + input: snapshot{ + totalValidatorWeight: 100, nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: 75}: ids.ID{0x1}, + ids.NodeWeight{Node: n2, Weight: 25}: ids.ID{0x2}, + }, + }, + processing: map[ids.ID]struct{}{ + {0x2}: {}, + }, + expectedLogged: "Nodes ahead of us", + }, + } { + t.Run(testCase.description, func(t *testing.T) { + var buff logBuffer + log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) + + sd := newStragglerDetector(log, 0.75, + func() (ids.ID, uint64) { + return testCase.lastAccepted, 0 + }, + func() set.Set[ids.NodeWeight] { + var set set.Set[ids.NodeWeight] + for _, nw := range testCase.connectedValidators { + set.Add(nw) + } + return set + }, func() float64 { return testCase.connectedPercent }, + func(id ids.ID) bool { + _, ok := testCase.processing[id] + return ok + }, + func(vdr ids.NodeID) (ids.ID, bool) { + id, ok := testCase.lastAcceptedFromNodes[vdr] + return id, ok + }) + + require.Equal(t, testCase.expected, sd.failedCatchingUp(testCase.input)) + require.Contains(t, buff.String(), testCase.expectedLogged) + }) + } +} + +func TestCheckIfWeAreStragglingBehind(t *testing.T) { + fakeClock := make(chan time.Time, 1) + + snapshots := make(chan snapshot, 1) + assertNoSnapshotsRemain := func() { + select { + case <-snapshots: + require.Fail(t, "Should not have any snapshots in standby") + default: + } + } + nonEmptySnap := snapshot{ + totalValidatorWeight: 100, + nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Weight: 100}: ids.Empty, + }, + } + + var haveWeFailedCatchingUpReturns bool + + var buff logBuffer + log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) + + sd := stragglerDetector{ + stragglerDetectorConfig: stragglerDetectorConfig{ + minStragglerCheckInterval: time.Second, + getTime: func() time.Time { + now := <-fakeClock + return now + }, + log: log, + getSnapshot: func() (snapshot, bool) { + s := <-snapshots + return s, !s.isEmpty() + }, + haveWeFailedCatchingUp: func(_ snapshot) bool { + return haveWeFailedCatchingUpReturns + }, + }, + } + + fakeTime := time.Now() + + for _, testCase := range []struct { + description string + timeAdvanced time.Duration + evalExtraAssertions func() + expectedStragglingTime time.Duration + snapshotsRead []snapshot + haveWeFailedCatchingUpReturns bool + }{ + { + description: "First invocation only sets the time", + evalExtraAssertions: func() {}, + }, + { + description: "Should not check yet, as it is not time yet", + timeAdvanced: time.Millisecond * 500, + evalExtraAssertions: func() {}, + }, + { + description: "Advance time some more, so now we should check", + timeAdvanced: time.Millisecond * 501, + snapshotsRead: []snapshot{{}}, + evalExtraAssertions: func() { + require.Contains(t, buff.String(), "No node snapshot obtained") + }, + }, + { + description: "Advance time some more to the first check where the snapshot isn't empty", + timeAdvanced: time.Second * 2, + snapshotsRead: []snapshot{nonEmptySnap}, + evalExtraAssertions: func() { + require.Empty(t, buff.String()) + }, + }, + { + description: "The next check returns we have failed catching up.", + timeAdvanced: time.Second * 2, + expectedStragglingTime: time.Second * 2, + haveWeFailedCatchingUpReturns: true, + evalExtraAssertions: func() { + require.Empty(t, sd.prevSnapshot) + }, + }, + { + description: "The third snapshot is due to a fresh check", + timeAdvanced: time.Second * 2, + snapshotsRead: []snapshot{nonEmptySnap}, + // We carry over the total straggling time from previous testCase to this check, + // as we need the next check to nullify it. + expectedStragglingTime: time.Second * 2, + evalExtraAssertions: func() {}, + }, + { + description: "The fourth check returns we have succeeded in catching up", + timeAdvanced: time.Second * 2, + evalExtraAssertions: func() {}, + }, + } { + t.Run(testCase.description, func(t *testing.T) { + fakeTime = fakeTime.Add(testCase.timeAdvanced) + fakeClock <- fakeTime + + // Load the snapshot expected to be retrieved in this testCase, if applicable. + if len(testCase.snapshotsRead) > 0 { + snapshots <- testCase.snapshotsRead[0] + } + + haveWeFailedCatchingUpReturns = testCase.haveWeFailedCatchingUpReturns + require.Equal(t, testCase.expectedStragglingTime, sd.CheckIfWeAreStragglingBehind()) + testCase.evalExtraAssertions() + + // Cleanup the log buffer, and make sure no snapshots remain for next testCase. + buff.Reset() + assertNoSnapshotsRemain() + haveWeFailedCatchingUpReturns = false + }) + } +} diff --git a/snow/networking/handler/health.go b/snow/networking/handler/health.go index 0dbcb844fb9..fbbc8113e2d 100644 --- a/snow/networking/handler/health.go +++ b/snow/networking/handler/health.go @@ -66,5 +66,13 @@ func (h *handler) getDisconnectedValidators() set.Set[ids.NodeID] { connectedVdrs := h.peerTracker.ConnectedValidators() // vdrs - connectedVdrs is equal to the disconnectedVdrs vdrs.Difference(connectedVdrs) - return vdrs + return trimWeights(vdrs) +} + +func trimWeights(weights set.Set[ids.NodeWeight]) set.Set[ids.NodeID] { + var res set.Set[ids.NodeID] + for _, nw := range weights.List() { + res.Add(nw.Node) + } + return res }