-
Notifications
You must be signed in to change notification settings - Fork 4.7k
*: Implementation of weighted random shuffling (A113) #8864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6f9d9cf
633fab8
3bd5a7d
3a582da
aaa9d86
d4d3d88
fa17114
3f4487d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,11 +21,14 @@ | |
| package pickfirst | ||
|
|
||
| import ( | ||
| "cmp" | ||
| "encoding/json" | ||
| "errors" | ||
| "fmt" | ||
| "math" | ||
| "net" | ||
| "net/netip" | ||
| "slices" | ||
| "sync" | ||
| "time" | ||
|
|
||
|
|
@@ -34,6 +37,8 @@ import ( | |
| "google.golang.org/grpc/connectivity" | ||
| expstats "google.golang.org/grpc/experimental/stats" | ||
| "google.golang.org/grpc/grpclog" | ||
| "google.golang.org/grpc/internal/balancer/weight" | ||
| "google.golang.org/grpc/internal/envconfig" | ||
| internalgrpclog "google.golang.org/grpc/internal/grpclog" | ||
| "google.golang.org/grpc/internal/pretty" | ||
| "google.golang.org/grpc/resolver" | ||
|
|
@@ -258,8 +263,42 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState | |
| // will change the order of endpoints but not touch the order of the | ||
| // addresses within each endpoint. - A61 | ||
| if cfg.ShuffleAddressList { | ||
| endpoints = append([]resolver.Endpoint{}, endpoints...) | ||
| internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] }) | ||
| if envconfig.PickFirstWeightedShuffling { | ||
| type weightedEndpoint struct { | ||
| endpoint resolver.Endpoint | ||
| weight float64 | ||
| } | ||
|
|
||
| // For each endpoint, compute a key as described in A113 and | ||
| // https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf: | ||
| var weightedEndpoints []weightedEndpoint | ||
| for _, endpoint := range endpoints { | ||
| u := internal.RandFloat64() // Random number in [0.0, 1.0) | ||
| weight := weightAttribute(endpoint) | ||
| weightedEndpoints = append(weightedEndpoints, weightedEndpoint{ | ||
| endpoint: endpoint, | ||
| weight: math.Pow(u, 1.0/float64(weight)), | ||
| }) | ||
| } | ||
| // Sort endpoints by key in descending order and reconstruct the | ||
| // endpoints slice. | ||
| slices.SortFunc(weightedEndpoints, func(a, b weightedEndpoint) int { | ||
| return cmp.Compare(b.weight, a.weight) | ||
| }) | ||
|
|
||
| // Here, and in the "else" block below, we clone the endpoints | ||
| // slice to avoid mutating the resolver state. Doing the latter | ||
| // would lead to data races if the caller is accessing the same | ||
| // slice concurrently. | ||
|
Comment on lines
+289
to
+292
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is important even if it's not a concurrent update. We just shouldn't be mutating the incoming update's endpoints because that likely changes state in the parent. |
||
| sortedEndpoints := make([]resolver.Endpoint, len(endpoints)) | ||
| for i, we := range weightedEndpoints { | ||
| sortedEndpoints[i] = we.endpoint | ||
| } | ||
| endpoints = sortedEndpoints | ||
| } else { | ||
| endpoints = slices.Clone(endpoints) | ||
| internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] }) | ||
| } | ||
| } | ||
|
|
||
| // "Flatten the list by concatenating the ordered list of addresses for | ||
|
|
@@ -906,3 +945,17 @@ func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool { | |
| return a.Addr == b.Addr && a.ServerName == b.ServerName && | ||
| a.Attributes.Equal(b.Attributes) | ||
| } | ||
|
|
||
| // weightAttribute is a convenience function which returns the value of the | ||
| // weight endpoint Attribute. | ||
| // | ||
| // When used in the xDS context, the weight attribute is guaranteed to be | ||
| // non-zero. But, when used in a non-xDS context, the weight attribute could be | ||
| // unset. A Default of 1 is used in the latter case. | ||
| func weightAttribute(e resolver.Endpoint) uint32 { | ||
| w := weight.FromEndpoint(e).Weight | ||
| if w == 0 { | ||
| return 1 | ||
| } | ||
| return w | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1240,29 +1240,30 @@ func (s) TestRingHash_UnsupportedHashPolicyUntilChannelIdHashing(t *testing.T) { | |
| // Tests that ring hash policy that hashes using a random value can spread RPCs | ||
| // across all the backends according to locality weight. | ||
| func (s) TestRingHash_RandomHashingDistributionAccordingToLocalityAndEndpointWeight(t *testing.T) { | ||
| testutils.SetEnvConfig(t, &envconfig.PickFirstWeightedShuffling, true) | ||
| backends := backendAddrs(startTestServiceBackends(t, 2)) | ||
|
|
||
| const clusterName = "cluster" | ||
| const locality1Weight = uint32(1) | ||
| const endpoint1Weight = uint32(1) | ||
| const locality2Weight = uint32(2) | ||
| const endpoint2Weight = uint32(2) | ||
| const locality0Weight = uint32(1) | ||
|
arjan-bal marked this conversation as resolved.
|
||
| const endpoint0Weight = uint32(1) | ||
| const locality1Weight = uint32(2) | ||
| const endpoint1Weight = uint32(2) | ||
| endpoints := e2e.EndpointResourceWithOptions(e2e.EndpointOptions{ | ||
| ClusterName: clusterName, | ||
| Localities: []e2e.LocalityOptions{ | ||
| { | ||
| Backends: []e2e.BackendOptions{{ | ||
| Ports: []uint32{testutils.ParsePort(t, backends[0])}, | ||
| Weight: endpoint1Weight, | ||
| Weight: endpoint0Weight, | ||
| }}, | ||
| Weight: locality1Weight, | ||
| Weight: locality0Weight, | ||
| }, | ||
| { | ||
| Backends: []e2e.BackendOptions{{ | ||
| Ports: []uint32{testutils.ParsePort(t, backends[1])}, | ||
| Weight: endpoint2Weight, | ||
| Weight: endpoint1Weight, | ||
| }}, | ||
| Weight: locality2Weight, | ||
| Weight: locality1Weight, | ||
| }, | ||
| }, | ||
| }) | ||
|
|
@@ -1289,21 +1290,26 @@ func (s) TestRingHash_RandomHashingDistributionAccordingToLocalityAndEndpointWei | |
| defer conn.Close() | ||
| client := testgrpc.NewTestServiceClient(conn) | ||
|
|
||
| const weight1 = endpoint1Weight * locality1Weight | ||
| const weight2 = endpoint2Weight * locality2Weight | ||
| const wantRPCs1 = float64(weight1) / float64(weight1+weight2) | ||
| const wantRPCs2 = float64(weight2) / float64(weight1+weight2) | ||
| numRPCs := computeIdealNumberOfRPCs(t, math.Min(wantRPCs1, wantRPCs2), errorTolerance) | ||
| // The target fraction of RPCs to each backend is computed as the product of | ||
| // the probability of selecting the locality and the probability of | ||
| // selecting the endpoint within the locality. The probability of selecting | ||
| // locality0 is 1/3 and locality1 is 2/3. Since there is only one endpoint | ||
| // in each locality, the probability of selecting the endpoint within the | ||
| // locality is 1. Therefore, the target fractions end up as 1/3 and 2/3 | ||
| // respectively. | ||
| const wantRPCs0 = float64(1) / float64(3) | ||
|
arjan-bal marked this conversation as resolved.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW you only need to float64 either the numerator or the denominator. |
||
| const wantRPCs1 = float64(2) / float64(3) | ||
| numRPCs := computeIdealNumberOfRPCs(t, math.Min(wantRPCs0, wantRPCs1), errorTolerance) | ||
|
|
||
| // Send a large number of RPCs and check that they are distributed randomly. | ||
| gotPerBackend := checkRPCSendOK(ctx, t, client, numRPCs) | ||
| got := float64(gotPerBackend[backends[0]]) / float64(numRPCs) | ||
| if !cmp.Equal(got, wantRPCs1, cmpopts.EquateApprox(0, errorTolerance)) { | ||
| t.Errorf("Fraction of RPCs to backend %s: got %v, want %v (margin: +-%v)", backends[2], got, wantRPCs1, errorTolerance) | ||
| if !cmp.Equal(got, wantRPCs0, cmpopts.EquateApprox(0, errorTolerance)) { | ||
| t.Errorf("Fraction of RPCs to backend %s: got %v, want %v (margin: +-%v)", backends[0], got, wantRPCs0, errorTolerance) | ||
| } | ||
| got = float64(gotPerBackend[backends[1]]) / float64(numRPCs) | ||
| if !cmp.Equal(got, wantRPCs2, cmpopts.EquateApprox(0, errorTolerance)) { | ||
| t.Errorf("Fraction of RPCs to backend %s: got %v, want %v (margin: +-%v)", backends[2], got, wantRPCs2, errorTolerance) | ||
| if !cmp.Equal(got, wantRPCs1, cmpopts.EquateApprox(0, errorTolerance)) { | ||
| t.Errorf("Fraction of RPCs to backend %s: got %v, want %v (margin: +-%v)", backends[1], got, wantRPCs1, errorTolerance) | ||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should continue cloning the endpoints list here. Directly mutating the
ResolverStatecan lead to data races if the caller ofUpdateClientConnStatereads the state concurrently. It would also be helpful to add a comment explaining this to prevent the accidental removal of the copy in the future.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argh. I feel like we need to either make
resolver.State.Endpointsan immutable type (EndpointList) or guarantee it's deeply copied whenresolver.Stateis copied.