Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions balancer/pickfirst/internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
var (
// RandShuffle pseudo-randomizes the order of addresses.
RandShuffle = rand.Shuffle
// RandFloat64 returns, as a float64, a pseudo-random number in [0.0,1.0).
RandFloat64 = rand.Float64
// TimeAfterFunc allows mocking the timer for testing connection delay
// related functionality.
TimeAfterFunc = func(d time.Duration, f func()) func() {
Expand Down
57 changes: 55 additions & 2 deletions balancer/pickfirst/pickfirst.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
package pickfirst

import (
"cmp"
"encoding/json"
"errors"
"fmt"
"math"
"net"
"net/netip"
"slices"
"sync"
"time"

Expand All @@ -34,6 +37,8 @@ import (
"google.golang.org/grpc/connectivity"
expstats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/balancer/weight"
"google.golang.org/grpc/internal/envconfig"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/resolver"
Expand Down Expand Up @@ -258,8 +263,42 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState
// will change the order of endpoints but not touch the order of the
// addresses within each endpoint. - A61
if cfg.ShuffleAddressList {
endpoints = append([]resolver.Endpoint{}, endpoints...)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should continue cloning the endpoints list here. Directly mutating the ResolverState can lead to data races if the caller of UpdateClientConnState reads the state concurrently. It would also be helpful to add a comment explaining this to prevent the accidental removal of the copy in the future.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh. I feel like we need to either make resolver.State.Endpoints an immutable type (EndpointList) or guarantee it's deeply copied when resolver.State is copied.

internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
if envconfig.PickFirstWeightedShuffling {
type weightedEndpoint struct {
endpoint resolver.Endpoint
weight float64
}

// For each endpoint, compute a key as described in A113 and
// https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf:
var weightedEndpoints []weightedEndpoint
for _, endpoint := range endpoints {
u := internal.RandFloat64() // Random number in [0.0, 1.0)
weight := weightAttribute(endpoint)
weightedEndpoints = append(weightedEndpoints, weightedEndpoint{
endpoint: endpoint,
weight: math.Pow(u, 1.0/float64(weight)),
})
}
// Sort endpoints by key in descending order and reconstruct the
// endpoints slice.
slices.SortFunc(weightedEndpoints, func(a, b weightedEndpoint) int {
return cmp.Compare(b.weight, a.weight)
})

// Here, and in the "else" block below, we clone the endpoints
// slice to avoid mutating the resolver state. Doing the latter
// would lead to data races if the caller is accessing the same
// slice concurrently.
Comment on lines +289 to +292
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is important even if it's not a concurrent update. We just shouldn't be mutating the incoming update's endpoints because that likely changes state in the parent.

sortedEndpoints := make([]resolver.Endpoint, len(endpoints))
for i, we := range weightedEndpoints {
sortedEndpoints[i] = we.endpoint
}
endpoints = sortedEndpoints
} else {
endpoints = slices.Clone(endpoints)
internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
}
}

// "Flatten the list by concatenating the ordered list of addresses for
Expand Down Expand Up @@ -906,3 +945,17 @@ func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool {
return a.Addr == b.Addr && a.ServerName == b.ServerName &&
a.Attributes.Equal(b.Attributes)
}

// weightAttribute is a convenience function which returns the value of the
// weight endpoint Attribute.
//
// When used in the xDS context, the weight attribute is guaranteed to be
// non-zero. But, when used in a non-xDS context, the weight attribute could be
// unset. A Default of 1 is used in the latter case.
func weightAttribute(e resolver.Endpoint) uint32 {
w := weight.FromEndpoint(e).Weight
if w == 0 {
return 1
}
return w
}
71 changes: 71 additions & 0 deletions balancer/pickfirst/pickfirst_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ import (
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/stub"
"google.golang.org/grpc/internal/balancer/weight"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
Expand Down Expand Up @@ -425,6 +427,8 @@ func (s) TestPickFirst_StickyTransientFailure(t *testing.T) {

// Tests the PF LB policy with shuffling enabled.
func (s) TestPickFirst_ShuffleAddressList(t *testing.T) {
testutils.SetEnvConfig(t, &envconfig.PickFirstWeightedShuffling, false)

const serviceConfig = `{"loadBalancingConfig": [{"pick_first":{ "shuffleAddressList": true }}]}`

// Install a shuffler that always reverses two entries.
Expand Down Expand Up @@ -485,6 +489,8 @@ func (s) TestPickFirst_ShuffleAddressList(t *testing.T) {
// Endpoints field in the resolver update to test the shuffling of the
// Addresses.
func (s) TestPickFirst_ShuffleAddressListNoEndpoints(t *testing.T) {
testutils.SetEnvConfig(t, &envconfig.PickFirstWeightedShuffling, false)

// Install a shuffler that always reverses two entries.
origShuf := pfinternal.RandShuffle
defer func() { pfinternal.RandShuffle = origShuf }()
Expand Down Expand Up @@ -560,8 +566,73 @@ func (s) TestPickFirst_ShuffleAddressListNoEndpoints(t *testing.T) {
}
}

// Tests the PF LB policy with weighted shuffling enabled.
func (s) TestPickFirst_ShuffleAddressList_WeightedShuffling(t *testing.T) {
testutils.SetEnvConfig(t, &envconfig.PickFirstWeightedShuffling, true)

const serviceConfig = `{"loadBalancingConfig": [{"pick_first":{ "shuffleAddressList": true }}]}`

// Install a rand func that returns a constant value. The test sets up three
// endpoints with increasing weights. This means that in the weighted
// shuffling algorithm, the endpoints will end up with increasing values for
// their keys. And since the algorithm sorts in descending order, the last
// endpoint should be the one that would get picked.
origRand := pfinternal.RandFloat64
defer func() { pfinternal.RandFloat64 = origRand }()
pfinternal.RandFloat64 = func() float64 {
return 0.5
}

// Set up our backends.
cc, r, backends := setupPickFirst(t, 3)
addrs := stubBackendsToResolverAddrs(backends)

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

// Create endpoints for the above backends with increasing weights.
ep1 := resolver.Endpoint{Addresses: []resolver.Address{addrs[0]}}
ep1 = weight.Set(ep1, weight.EndpointInfo{Weight: 357913941}) // Normalized weight of 1/6
ep2 := resolver.Endpoint{Addresses: []resolver.Address{addrs[1]}}
ep2 = weight.Set(ep2, weight.EndpointInfo{Weight: 715827882}) // Normalized weight of 2/6
ep3 := resolver.Endpoint{Addresses: []resolver.Address{addrs[2]}}
ep3 = weight.Set(ep3, weight.EndpointInfo{Weight: 1073741824}) // Normalized weight of 3/6

// Push an update with all addresses and shuffling disabled. We should
// connect to backend 0.
r.UpdateState(resolver.State{Endpoints: []resolver.Endpoint{ep1, ep2, ep3}})
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}

// Send a config with shuffling enabled. This will reverse the addresses,
// but the channel should still be connected to backend 0.
shufState := resolver.State{
ServiceConfig: parseServiceConfig(t, r, serviceConfig),
Endpoints: []resolver.Endpoint{ep1, ep2, ep3},
}
r.UpdateState(shufState)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}

// Send a resolver update with no addresses. This should push the channel
// into TransientFailure.
r.UpdateState(resolver.State{})
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)

// Send the same config as last time with shuffling enabled. Since we are
// not connected to backend 0, we should connect to backend 2.
r.UpdateState(shufState)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[2]); err != nil {
t.Fatal(err)
}
}

// Test config parsing with the env var turned on and off for various scenarios.
func (s) TestPickFirst_ParseConfig_Success(t *testing.T) {
testutils.SetEnvConfig(t, &envconfig.PickFirstWeightedShuffling, false)

// Install a shuffler that always reverses two entries.
origShuf := pfinternal.RandShuffle
defer func() { pfinternal.RandShuffle = origShuf }()
Expand Down
40 changes: 23 additions & 17 deletions balancer/ringhash/ringhash_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1240,29 +1240,30 @@ func (s) TestRingHash_UnsupportedHashPolicyUntilChannelIdHashing(t *testing.T) {
// Tests that ring hash policy that hashes using a random value can spread RPCs
// across all the backends according to locality weight.
func (s) TestRingHash_RandomHashingDistributionAccordingToLocalityAndEndpointWeight(t *testing.T) {
testutils.SetEnvConfig(t, &envconfig.PickFirstWeightedShuffling, true)
backends := backendAddrs(startTestServiceBackends(t, 2))

const clusterName = "cluster"
const locality1Weight = uint32(1)
const endpoint1Weight = uint32(1)
const locality2Weight = uint32(2)
const endpoint2Weight = uint32(2)
const locality0Weight = uint32(1)
Comment thread
arjan-bal marked this conversation as resolved.
const endpoint0Weight = uint32(1)
const locality1Weight = uint32(2)
const endpoint1Weight = uint32(2)
endpoints := e2e.EndpointResourceWithOptions(e2e.EndpointOptions{
ClusterName: clusterName,
Localities: []e2e.LocalityOptions{
{
Backends: []e2e.BackendOptions{{
Ports: []uint32{testutils.ParsePort(t, backends[0])},
Weight: endpoint1Weight,
Weight: endpoint0Weight,
}},
Weight: locality1Weight,
Weight: locality0Weight,
},
{
Backends: []e2e.BackendOptions{{
Ports: []uint32{testutils.ParsePort(t, backends[1])},
Weight: endpoint2Weight,
Weight: endpoint1Weight,
}},
Weight: locality2Weight,
Weight: locality1Weight,
},
},
})
Expand All @@ -1289,21 +1290,26 @@ func (s) TestRingHash_RandomHashingDistributionAccordingToLocalityAndEndpointWei
defer conn.Close()
client := testgrpc.NewTestServiceClient(conn)

const weight1 = endpoint1Weight * locality1Weight
const weight2 = endpoint2Weight * locality2Weight
const wantRPCs1 = float64(weight1) / float64(weight1+weight2)
const wantRPCs2 = float64(weight2) / float64(weight1+weight2)
numRPCs := computeIdealNumberOfRPCs(t, math.Min(wantRPCs1, wantRPCs2), errorTolerance)
// The target fraction of RPCs to each backend is computed as the product of
// the probability of selecting the locality and the probability of
// selecting the endpoint within the locality. The probability of selecting
// locality0 is 1/3 and locality1 is 2/3. Since there is only one endpoint
// in each locality, the probability of selecting the endpoint within the
// locality is 1. Therefore, the target fractions end up as 1/3 and 2/3
// respectively.
const wantRPCs0 = float64(1) / float64(3)
Comment thread
arjan-bal marked this conversation as resolved.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW you only need to float64 either the numerator or the denominator.

const wantRPCs1 = float64(2) / float64(3)
numRPCs := computeIdealNumberOfRPCs(t, math.Min(wantRPCs0, wantRPCs1), errorTolerance)

// Send a large number of RPCs and check that they are distributed randomly.
gotPerBackend := checkRPCSendOK(ctx, t, client, numRPCs)
got := float64(gotPerBackend[backends[0]]) / float64(numRPCs)
if !cmp.Equal(got, wantRPCs1, cmpopts.EquateApprox(0, errorTolerance)) {
t.Errorf("Fraction of RPCs to backend %s: got %v, want %v (margin: +-%v)", backends[2], got, wantRPCs1, errorTolerance)
if !cmp.Equal(got, wantRPCs0, cmpopts.EquateApprox(0, errorTolerance)) {
t.Errorf("Fraction of RPCs to backend %s: got %v, want %v (margin: +-%v)", backends[0], got, wantRPCs0, errorTolerance)
}
got = float64(gotPerBackend[backends[1]]) / float64(numRPCs)
if !cmp.Equal(got, wantRPCs2, cmpopts.EquateApprox(0, errorTolerance)) {
t.Errorf("Fraction of RPCs to backend %s: got %v, want %v (margin: +-%v)", backends[2], got, wantRPCs2, errorTolerance)
if !cmp.Equal(got, wantRPCs1, cmpopts.EquateApprox(0, errorTolerance)) {
t.Errorf("Fraction of RPCs to backend %s: got %v, want %v (margin: +-%v)", backends[1], got, wantRPCs1, errorTolerance)
}
}

Expand Down
6 changes: 6 additions & 0 deletions internal/envconfig/envconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ var (
// This feature is defined in gRFC A81 and is enabled by setting the
// environment variable GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE to "true".
XDSAuthorityRewrite = boolFromEnv("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false)

// PickFirstWeightedShuffling indicates whether weighted endpoint shuffling
// is enabled in the pick_first LB policy, as defined in gRFC A113. This
// feature can be disabled by setting the environment variable
// GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING to "false".
PickFirstWeightedShuffling = boolFromEnv("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true)
)

func boolFromEnv(envVar string, def bool) bool {
Expand Down
Loading