Skip to content

Commit

Permalink
Implement Top on the p2p validator set manager
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenButtolph committed Mar 8, 2024
1 parent 50ca08e commit 158c202
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 24 deletions.
10 changes: 8 additions & 2 deletions network/p2p/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,10 @@ func TestNodeSamplerClientOption(t *testing.T) {
},
GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) {
return map[ids.NodeID]*validators.GetValidatorOutput{
nodeID1: nil,
nodeID1: {
NodeID: nodeID1,
Weight: 1,
},
}, nil
},
}
Expand All @@ -575,7 +578,10 @@ func TestNodeSamplerClientOption(t *testing.T) {
},
GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) {
return map[ids.NodeID]*validators.GetValidatorOutput{
nodeID1: nil,
nodeID1: {
NodeID: nodeID1,
Weight: 1,
},
}, nil
},
}
Expand Down
108 changes: 87 additions & 21 deletions network/p2p/validators.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,36 @@
package p2p

import (
"cmp"
"context"
"math"
"sync"
"time"

"go.uber.org/zap"

"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/snow/validators"
"github.com/ava-labs/avalanchego/utils"
"github.com/ava-labs/avalanchego/utils/logging"
"github.com/ava-labs/avalanchego/utils/sampler"
"github.com/ava-labs/avalanchego/utils/set"
)

var (
_ ValidatorSet = (*Validators)(nil)
_ NodeSampler = (*Validators)(nil)
_ ValidatorSet = (*Validators)(nil)
_ ValidatorPortion = (*Validators)(nil)
_ NodeSampler = (*Validators)(nil)
)

type ValidatorSet interface {
Has(ctx context.Context, nodeID ids.NodeID) bool // TODO return error
}

type ValidatorPortion interface {
Top(ctx context.Context, percentage float64) []ids.NodeID // TODO return error
}

func NewValidators(
peers *Peers,
log logging.Logger,
Expand All @@ -43,23 +52,39 @@ func NewValidators(

// Validators contains a set of nodes that are staking.
type Validators struct {
peers *Peers
log logging.Logger
subnetID ids.ID
validators validators.State

lock sync.Mutex
validatorIDs set.SampleableSet[ids.NodeID]
lastUpdated time.Time
peers *Peers
log logging.Logger
subnetID ids.ID
validators validators.State
maxValidatorSetStaleness time.Duration

lock sync.Mutex
validatorList []validator
validatorSet set.Set[ids.NodeID]
totalWeight uint64
lastUpdated time.Time
}

type validator struct {
nodeID ids.NodeID
weight uint64
}

func (v validator) Compare(other validator) int {
if weightCmp := cmp.Compare(v.weight, other.weight); weightCmp != 0 {
return -weightCmp // Sort in decreasing order of stake
}
return v.nodeID.Compare(other.nodeID)
}

func (v *Validators) refresh(ctx context.Context) {
if time.Since(v.lastUpdated) < v.maxValidatorSetStaleness {
return
}

v.validatorIDs.Clear()
v.validatorList = v.validatorList[:0]
v.validatorSet.Clear()
v.totalWeight = 0

height, err := v.validators.GetCurrentHeight(ctx)
if err != nil {
Expand All @@ -72,9 +97,15 @@ func (v *Validators) refresh(ctx context.Context) {
return
}

for nodeID := range validatorSet {
v.validatorIDs.Add(nodeID)
for nodeID, vdr := range validatorSet {
v.validatorList = append(v.validatorList, validator{
nodeID: nodeID,
weight: vdr.Weight,
})
v.validatorSet.Add(nodeID)
v.totalWeight += vdr.Weight
}
utils.Sort(v.validatorList)

v.lastUpdated = time.Now()
}
Expand All @@ -86,28 +117,63 @@ func (v *Validators) Sample(ctx context.Context, limit int) []ids.NodeID {

v.refresh(ctx)

// TODO: Account for peer connectivity during the sampling of validators
// rather than filtering sampled validators.
validatorIDs := v.validatorIDs.Sample(limit)
sampled := validatorIDs[:0]
var (
uniform = sampler.NewUniform()
sampled = make([]ids.NodeID, 0, limit)
)

for _, validatorID := range validatorIDs {
if !v.peers.has(validatorID) {
uniform.Initialize(uint64(len(v.validatorList)))
for len(sampled) < limit {
i, err := uniform.Next()
if err != nil {
break
}

nodeID := v.validatorList[i].nodeID
if !v.peers.has(nodeID) {
continue
}

sampled = append(sampled, validatorID)
sampled = append(sampled, nodeID)
}

return sampled
}

// Top returns the top [percentage] of validators, regardless of if they are
// connected or not.
func (v *Validators) Top(ctx context.Context, percentage float64) []ids.NodeID {
percentage = max(0, min(1, percentage)) // bound percentage inside [0, 1]

v.lock.Lock()
defer v.lock.Unlock()

v.refresh(ctx)

var (
maxSize = int(math.Ceil(percentage * float64(len(v.validatorList))))
top = make([]ids.NodeID, 0, maxSize)
currentStake uint64
targetStake = uint64(math.Ceil(percentage * float64(v.totalWeight)))
)

for _, vdr := range v.validatorList {
if currentStake >= targetStake {
break
}
top = append(top, vdr.nodeID)
currentStake += vdr.weight
}

return top
}

// Has returns if nodeID is a connected validator
func (v *Validators) Has(ctx context.Context, nodeID ids.NodeID) bool {
v.lock.Lock()
defer v.lock.Unlock()

v.refresh(ctx)

return v.peers.has(nodeID) && v.validatorIDs.Contains(nodeID)
return v.peers.has(nodeID) && v.validatorSet.Contains(nodeID)
}
136 changes: 135 additions & 1 deletion network/p2p/validators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func TestValidatorsSample(t *testing.T) {
errFoobar := errors.New("foobar")
nodeID1 := ids.GenerateTestNodeID()
nodeID2 := ids.GenerateTestNodeID()
nodeID3 := ids.GenerateTestNodeID()

type call struct {
limit int
Expand All @@ -44,6 +45,20 @@ func TestValidatorsSample(t *testing.T) {
maxStaleness time.Duration
calls []call
}{
{
// if we aren't connected to a validator, we shouldn't return it
name: "drop disconnected validators",
maxStaleness: time.Hour,
calls: []call{
{
time: time.Time{}.Add(time.Second),
limit: 2,
height: 1,
validators: []ids.NodeID{nodeID1, nodeID3},
expected: []ids.NodeID{nodeID1},
},
},
},
{
// if we don't have as many validators as requested by the caller,
// we should return all the validators we have
Expand Down Expand Up @@ -167,7 +182,10 @@ func TestValidatorsSample(t *testing.T) {

validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput, 0)
for _, validator := range call.validators {
validatorSet[validator] = nil
validatorSet[validator] = &validators.GetValidatorOutput{
NodeID: validator,
Weight: 1,
}
}

calls = append(calls,
Expand All @@ -194,3 +212,119 @@ func TestValidatorsSample(t *testing.T) {
})
}
}

func TestValidatorsTop(t *testing.T) {
nodeID1 := ids.GenerateTestNodeID()
nodeID2 := ids.GenerateTestNodeID()
nodeID3 := ids.GenerateTestNodeID()

tests := []struct {
name string
validators []validator
percentage float64
expected []ids.NodeID
}{
{
name: "top 0% is empty",
validators: []validator{
{
nodeID: nodeID1,
weight: 1,
},
{
nodeID: nodeID2,
weight: 1,
},
},
percentage: 0,
expected: []ids.NodeID{},
},
{
name: "top 100% is full",
validators: []validator{
{
nodeID: nodeID1,
weight: 2,
},
{
nodeID: nodeID2,
weight: 1,
},
},
percentage: 1,
expected: []ids.NodeID{
nodeID1,
nodeID2,
},
},
{
name: "top 50% takes larger validator",
validators: []validator{
{
nodeID: nodeID1,
weight: 2,
},
{
nodeID: nodeID2,
weight: 1,
},
},
percentage: .5,
expected: []ids.NodeID{
nodeID1,
},
},
{
name: "top 50% bound",
validators: []validator{
{
nodeID: nodeID1,
weight: 2,
},
{
nodeID: nodeID2,
weight: 1,
},
{
nodeID: nodeID3,
weight: 1,
},
},
percentage: .5,
expected: []ids.NodeID{
nodeID1,
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
require := require.New(t)
ctrl := gomock.NewController(t)

validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput, 0)
for _, validator := range test.validators {
validatorSet[validator.nodeID] = &validators.GetValidatorOutput{
NodeID: validator.nodeID,
Weight: validator.weight,
}
}

subnetID := ids.GenerateTestID()
mockValidators := validators.NewMockState(ctrl)

mockValidators.EXPECT().GetCurrentHeight(gomock.Any()).Return(uint64(1), nil)
mockValidators.EXPECT().GetValidatorSet(gomock.Any(), uint64(1), subnetID).Return(validatorSet, nil)

network, err := NewNetwork(logging.NoLog{}, &common.FakeSender{}, prometheus.NewRegistry(), "")
require.NoError(err)

ctx := context.Background()
require.NoError(network.Connected(ctx, nodeID1, nil))
require.NoError(network.Connected(ctx, nodeID2, nil))

v := NewValidators(network.Peers, network.log, subnetID, mockValidators, time.Second)
nodeIDs := v.Top(ctx, test.percentage)
require.Equal(test.expected, nodeIDs)
})
}
}

0 comments on commit 158c202

Please sign in to comment.