From a0bccc23d4a9c4d8aa0709327b50fd24495dc262 Mon Sep 17 00:00:00 2001 From: Daniel Liu Date: Fri, 24 Jan 2025 19:18:03 +0000 Subject: [PATCH 1/6] explicitly set request hash header --- .../io/grpc/xds/RingHashLoadBalancer.java | 114 +++++++++++++----- .../xds/RingHashLoadBalancerProvider.java | 6 +- .../xds/ClusterResolverLoadBalancerTest.java | 2 +- .../io/grpc/xds/RingHashLoadBalancerTest.java | 46 +++---- 4 files changed, 114 insertions(+), 54 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java index 0c4792cb924..35f489f8483 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java @@ -34,6 +34,7 @@ import io.grpc.EquivalentAddressGroup; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; +import io.grpc.Metadata; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.util.MultiChildLoadBalancer; @@ -47,6 +48,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Random; import java.util.Set; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -70,6 +72,7 @@ final class RingHashLoadBalancer extends MultiChildLoadBalancer { private final XdsLogger logger; private final SynchronizationContext syncContext; private List ring; + private String requestHashHeader = ""; RingHashLoadBalancer(Helper helper) { super(helper); @@ -99,6 +102,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { if (config == null) { throw new IllegalArgumentException("Missing RingHash configuration"); } + requestHashHeader = config.requestHashHeader; Map serverWeights = new HashMap<>(); long totalWeight = 0L; for (EquivalentAddressGroup eag : addrList) { @@ -213,7 +217,8 @@ protected void updateOverallBalancingState() { overallState = TRANSIENT_FAILURE; } - RingHashPicker picker = new RingHashPicker(syncContext, ring, getChildLbStates()); + RingHashPicker picker = + new RingHashPicker(syncContext, ring, getChildLbStates(), requestHashHeader); getHelper().updateBalancingState(overallState, picker); this.currentConnectivityState = overallState; } @@ -334,6 +339,7 @@ public static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { } private static final class RingHashPicker extends SubchannelPicker { + private final Random random = new Random(); private final SynchronizationContext syncContext; private final List ring; // Avoid synchronization between pickSubchannel and subchannel's connectivity state change, @@ -341,16 +347,22 @@ private static final class RingHashPicker extends SubchannelPicker { // TODO(chengyuanzhang): can be more performance-friendly with // IdentityHashMap and RingEntry contains Subchannel. private final Map pickableSubchannels; // read-only + private final String requestHashHeader; + private boolean hasEndpointInConnectingState = false; private RingHashPicker( SynchronizationContext syncContext, List ring, - Collection children) { + Collection children, String requestHashHeader) { this.syncContext = syncContext; this.ring = ring; + this.requestHashHeader = requestHashHeader; pickableSubchannels = new HashMap<>(children.size()); for (ChildLbState childLbState : children) { pickableSubchannels.put((Endpoint)childLbState.getKey(), new SubchannelView(childLbState, childLbState.getCurrentState())); + if (childLbState.getCurrentState() == CONNECTING) { + hasEndpointInConnectingState = true; + } } } @@ -381,38 +393,78 @@ private int getTargetIndex(Long requestHash) { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - Long requestHash = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY); - if (requestHash == null) { - return PickResult.withError(RPC_HASH_NOT_FOUND); + // Determine request hash. + boolean usingRandomHash = false; + Long requestHash; + if (requestHashHeader.isEmpty()) { + // Set by the xDS config selector. + requestHash = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY); + if (requestHash == null) { + return PickResult.withError(RPC_HASH_NOT_FOUND); + } + } else { + String headerValue = + args.getHeaders() + .get(Metadata.Key.of(requestHashHeader, Metadata.ASCII_STRING_MARSHALLER)); + if (headerValue != null) { + requestHash = hashFunc.hashAsciiString(headerValue); + } else { + requestHash = random.nextLong(); + usingRandomHash = true; + } } int targetIndex = getTargetIndex(requestHash); - // Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF, we ignore - // all TF subchannels and find the first ring entry in READY, CONNECTING or IDLE. If - // CONNECTING or IDLE we return a pick with no results. Additionally, if that entry is in - // IDLE, we initiate a connection. - for (int i = 0; i < ring.size(); i++) { - int index = (targetIndex + i) % ring.size(); - SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey); - ChildLbState childLbState = subchannelView.childLbState; - - if (subchannelView.connectivityState == READY) { - return childLbState.getCurrentPicker().pickSubchannel(args); + if (!usingRandomHash) { + // Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF, we ignore + // all TF subchannels and find the first ring entry in READY, CONNECTING or IDLE. If + // CONNECTING or IDLE we return a pick with no results. Additionally, if that entry is in + // IDLE, we initiate a connection. + for (int i = 0; i < ring.size(); i++) { + int index = (targetIndex + i) % ring.size(); + SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey); + ChildLbState childLbState = subchannelView.childLbState; + + if (subchannelView.connectivityState == READY) { + return childLbState.getCurrentPicker().pickSubchannel(args); + } + + // RPCs can be buffered if the next subchannel is pending (per A62). Otherwise, RPCs + // are failed unless there is a READY connection. + if (subchannelView.connectivityState == CONNECTING) { + return PickResult.withNoResult(); + } + + if (subchannelView.connectivityState == IDLE) { + syncContext.execute(() -> { + childLbState.getLb().requestConnection(); + }); + + return PickResult.withNoResult(); // Indicates that this should be retried after backoff + } } - - // RPCs can be buffered if the next subchannel is pending (per A62). Otherwise, RPCs - // are failed unless there is a READY connection. - if (subchannelView.connectivityState == CONNECTING) { - return PickResult.withNoResult(); + } else { + // Using a random hash. Find and use the first READY ring entry, triggering at most one + // entry to attempt connection. + boolean requestedConnection = hasEndpointInConnectingState; + for (int i = 0; i < ring.size(); i++) { + int index = (targetIndex + i) % ring.size(); + SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey); + ChildLbState childLbState = subchannelView.childLbState; + if (subchannelView.connectivityState == READY) { + return childLbState.getCurrentPicker().pickSubchannel(args); + } + if (!requestedConnection && subchannelView.connectivityState == IDLE) { + syncContext.execute( + () -> { + childLbState.getLb().requestConnection(); + }); + requestedConnection = true; + } } - - if (subchannelView.connectivityState == IDLE) { - syncContext.execute(() -> { - childLbState.getLb().requestConnection(); - }); - - return PickResult.withNoResult(); // Indicates that this should be retried after backoff + if (requestedConnection) { + return PickResult.withNoResult(); } } @@ -460,13 +512,16 @@ public int compareTo(RingEntry entry) { static final class RingHashConfig { final long minRingSize; final long maxRingSize; + final String requestHashHeader; - RingHashConfig(long minRingSize, long maxRingSize) { + RingHashConfig(long minRingSize, long maxRingSize, String requestHashHeader) { checkArgument(minRingSize > 0, "minRingSize <= 0"); checkArgument(maxRingSize > 0, "maxRingSize <= 0"); checkArgument(minRingSize <= maxRingSize, "minRingSize > maxRingSize"); + checkNotNull(requestHashHeader); this.minRingSize = minRingSize; this.maxRingSize = maxRingSize; + this.requestHashHeader = requestHashHeader; } @Override @@ -474,6 +529,7 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("minRingSize", minRingSize) .add("maxRingSize", maxRingSize) + .add("requestHashHeader", requestHashHeader) .toString(); } } diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java index dad79384569..47d9a94e9fd 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java @@ -81,6 +81,7 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal( Map rawLoadBalancingPolicyConfig) { Long minRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "minRingSize"); Long maxRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "maxRingSize"); + String requestHashHeader = JsonUtil.getString(rawLoadBalancingPolicyConfig, "requestHashHeader"); long maxRingSizeCap = RingHashOptions.getRingSizeCap(); if (minRingSize == null) { minRingSize = DEFAULT_MIN_RING_SIZE; @@ -88,6 +89,9 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal( if (maxRingSize == null) { maxRingSize = DEFAULT_MAX_RING_SIZE; } + if (requestHashHeader == null) { + requestHashHeader = ""; + } if (minRingSize > maxRingSizeCap) { minRingSize = maxRingSizeCap; } @@ -98,6 +102,6 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal( return ConfigOrError.fromError(Status.UNAVAILABLE.withDescription( "Invalid 'mingRingSize'/'maxRingSize'")); } - return ConfigOrError.fromConfig(new RingHashConfig(minRingSize, maxRingSize)); + return ConfigOrError.fromConfig(new RingHashConfig(minRingSize, maxRingSize, requestHashHeader)); } } diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index 9243abba6d3..28898c0930f 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -163,7 +163,7 @@ public void uncaughtException(Thread t, Throwable e) { GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( new FakeLoadBalancerProvider("round_robin"), null))); private final Object ringHash = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( - new FakeLoadBalancerProvider("ring_hash_experimental"), new RingHashConfig(10L, 100L)); + new FakeLoadBalancerProvider("ring_hash_experimental"), new RingHashConfig(10L, 100L, "")); private final Object leastRequest = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( new FakeLoadBalancerProvider("wrr_locality_experimental"), new WrrLocalityConfig( GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java index 65fc1527b0c..666aedad75b 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java @@ -142,7 +142,7 @@ public void tearDown() { @Test public void subchannelLazyConnectUntilPicked() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1); // one server Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() @@ -176,7 +176,7 @@ public void subchannelLazyConnectUntilPicked() { @Test public void subchannelNotAutoReconnectAfterReenteringIdle() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1); // one server Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() @@ -207,7 +207,7 @@ public void subchannelNotAutoReconnectAfterReenteringIdle() { @Test public void aggregateSubchannelStates_connectingReadyIdleFailure() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1); InOrder inOrder = Mockito.inOrder(helper); @@ -266,7 +266,7 @@ private void verifyConnection(int times) { @Test public void aggregateSubchannelStates_allSubchannelsInTransientFailure() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1, 1, 1); List subChannelList = initializeLbSubchannels(config, servers, STAY_IN_CONNECTING); @@ -324,7 +324,7 @@ private void refreshInvokedAndUpdateBS(InOrder inOrder, ConnectivityState state) @Test public void ignoreShutdownSubchannelStateChange() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -340,7 +340,7 @@ public void ignoreShutdownSubchannelStateChange() { @Test public void deterministicPickWithHostsPartiallyRemoved() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1, 1, 1, 1); initializeLbSubchannels(config, servers); InOrder inOrder = Mockito.inOrder(helper); @@ -380,7 +380,7 @@ public void deterministicPickWithHostsPartiallyRemoved() { @Test public void deterministicPickWithNewHostsAdded() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1); // server0 and server1 initializeLbSubchannels(config, servers, DO_NOT_VERIFY, DO_NOT_RESET_HELPER); @@ -419,7 +419,7 @@ private Subchannel getSubChannel(EquivalentAddressGroup eag) { @Test public void skipFailingHosts_pickNextNonFailingHost() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( @@ -489,7 +489,7 @@ private PickSubchannelArgs getDefaultPickSubchannelArgsForServer(int serverid) { @Test public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -555,7 +555,7 @@ public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { @Test public void removingAddressShutdownSubchannel() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List svs1 = createWeightedServerAddrs(1, 1, 1); List subchannels1 = initializeLbSubchannels(config, svs1, STAY_IN_CONNECTING); @@ -572,7 +572,7 @@ public void removingAddressShutdownSubchannel() { @Test public void allSubchannelsInTransientFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -599,7 +599,7 @@ public void allSubchannelsInTransientFailure() { @Test public void firstSubchannelIdle() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -620,7 +620,7 @@ public void firstSubchannelIdle() { @Test public void firstSubchannelConnecting() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -644,7 +644,7 @@ private Subchannel getSubchannel(List servers, int serve @Test public void firstSubchannelFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); List subchannelList = @@ -675,7 +675,7 @@ public void firstSubchannelFailure() { @Test public void secondSubchannelConnecting() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -706,7 +706,7 @@ public void secondSubchannelConnecting() { @Test public void secondSubchannelFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -733,7 +733,7 @@ public void secondSubchannelFailure() { @Test public void thirdSubchannelConnecting() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -762,7 +762,7 @@ public void thirdSubchannelConnecting() { @Test public void stickyTransientFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -791,7 +791,7 @@ public void stickyTransientFailure() { @Test public void largeWeights() { - RingHashConfig config = new RingHashConfig(10000, 100000); // large ring + RingHashConfig config = new RingHashConfig(10000, 100000, ""); // large ring List servers = createWeightedServerAddrs(Integer.MAX_VALUE, 10, 100); // MAX:10:100 @@ -829,7 +829,7 @@ public void largeWeights() { @Test public void hostSelectionProportionalToWeights() { - RingHashConfig config = new RingHashConfig(10000, 100000); // large ring + RingHashConfig config = new RingHashConfig(10000, 100000, ""); // large ring List servers = createWeightedServerAddrs(1, 10, 100); // 1:10:100 initializeLbSubchannels(config, servers); @@ -872,7 +872,7 @@ public void nameResolutionErrorWithNoActiveSubchannels() { @Test public void nameResolutionErrorWithActiveSubchannels() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1); initializeLbSubchannels(config, servers, DO_NOT_VERIFY, DO_NOT_RESET_HELPER); @@ -894,7 +894,7 @@ public void nameResolutionErrorWithActiveSubchannels() { @Test public void duplicateAddresses() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createRepeatedServerAddrs(1, 2, 3); initializeLbSubchannels(config, servers, DO_NOT_VERIFY); @@ -940,7 +940,7 @@ protected Helper delegate() { InOrder inOrder = Mockito.inOrder(helper); List servers = createWeightedServerAddrs(1, 1); - initializeLbSubchannels(new RingHashConfig(10, 100), servers); + initializeLbSubchannels(new RingHashConfig(10, 100, ""), servers); Subchannel subchannel0 = subchannels.get(Collections.singletonList(servers.get(0))); Subchannel subchannel1 = subchannels.get(Collections.singletonList(servers.get(1))); From 29fc18c54eedd1e1422ef65c0a71eb97b8229193 Mon Sep 17 00:00:00 2001 From: Daniel Liu Date: Fri, 31 Jan 2025 20:17:18 +0000 Subject: [PATCH 2/6] add temp env var protection --- .../main/java/io/grpc/xds/RingHashLoadBalancerProvider.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java index 47d9a94e9fd..e70d5d61a0e 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java @@ -24,6 +24,7 @@ import io.grpc.LoadBalancerProvider; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonUtil; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import io.grpc.xds.RingHashOptions; @@ -81,7 +82,10 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal( Map rawLoadBalancingPolicyConfig) { Long minRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "minRingSize"); Long maxRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "maxRingSize"); - String requestHashHeader = JsonUtil.getString(rawLoadBalancingPolicyConfig, "requestHashHeader"); + String requestHashHeader = ""; + if (GrpcUtil.getFlag("GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY", false)) { + requestHashHeader = JsonUtil.getString(rawLoadBalancingPolicyConfig, "requestHashHeader"); + } long maxRingSizeCap = RingHashOptions.getRingSizeCap(); if (minRingSize == null) { minRingSize = DEFAULT_MIN_RING_SIZE; From a208780fe76652661f20e9c29f42cbc78cc85434 Mon Sep 17 00:00:00 2001 From: Daniel Liu Date: Mon, 3 Feb 2025 06:47:42 +0000 Subject: [PATCH 3/6] add ring hash LB provider tests --- .../xds/RingHashLoadBalancerProviderTest.java | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java index 87615a125c0..6a10c2b1085 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java @@ -42,6 +42,8 @@ @RunWith(JUnit4.class) public class RingHashLoadBalancerProviderTest { private static final String AUTHORITY = "foo.googleapis.com"; + private static final String GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY = + "GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY"; private final SynchronizationContext syncContext = new SynchronizationContext( new UncaughtExceptionHandler() { @@ -81,6 +83,7 @@ public void parseLoadBalancingConfig_valid() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(10L); assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEmpty(); } @Test @@ -92,6 +95,7 @@ public void parseLoadBalancingConfig_missingRingSize_useDefaults() throws IOExce RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(RingHashLoadBalancerProvider.DEFAULT_MIN_RING_SIZE); assertThat(config.maxRingSize).isEqualTo(RingHashLoadBalancerProvider.DEFAULT_MAX_RING_SIZE); + assertThat(config.requestHashHeader).isEmpty(); } @Test @@ -127,6 +131,7 @@ public void parseLoadBalancingConfig_ringTooLargeUsesCap() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(10); assertThat(config.maxRingSize).isEqualTo(RingHashOptions.DEFAULT_RING_SIZE_CAP); + assertThat(config.requestHashHeader).isEmpty(); } @Test @@ -142,6 +147,7 @@ public void parseLoadBalancingConfig_ringCapCanBeRaised() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); assertThat(config.maxRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -159,6 +165,7 @@ public void parseLoadBalancingConfig_ringCapIsClampedTo8M() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); assertThat(config.maxRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -176,6 +183,7 @@ public void parseLoadBalancingConfig_ringCapCanBeLowered() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(1); assertThat(config.maxRingSize).isEqualTo(1); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -193,6 +201,7 @@ public void parseLoadBalancingConfig_ringCapLowerLimitIs1() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(1); assertThat(config.maxRingSize).isEqualTo(1); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -219,6 +228,56 @@ public void parseLoadBalancingConfig_minRingSizeGreaterThanMaxRingSize() throws .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); } + @Test + public void parseLoadBalancingConfig_requestHashHeaderIgnoredWhenEnvVarNotSet() + throws IOException { + String lbConfig = + "{\"minRingSize\" : 10, \"maxRingSize\" : 100, \"requestHashHeader\" : \"dummy-hash\"}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(10L); + assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEmpty(); + } + + @Test + public void parseLoadBalancingConfig_requestHashHeaderSetWhenEnvVarSet() throws IOException { + System.setProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY, "true"); + try { + String lbConfig = + "{\"minRingSize\" : 10, \"maxRingSize\" : 100, \"requestHashHeader\" : \"dummy-hash\"}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(10L); + assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEqualTo("dummy-hash"); + } finally { + System.clearProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY); + } + } + + @Test + public void parseLoadBalancingConfig_requestHashHeaderUnsetWhenEnvVarSet_useDefaults() + throws IOException { + System.setProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY, "true"); + try { + String lbConfig = "{\"minRingSize\" : 10, \"maxRingSize\" : 100}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(10L); + assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEmpty(); + } finally { + System.clearProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY); + } + } + @SuppressWarnings("unchecked") private static Map parseJsonObject(String json) throws IOException { return (Map) JsonParser.parse(json); From bd04b7294882c56e037e8782eb3b6821e3cd8172 Mon Sep 17 00:00:00 2001 From: Daniel Liu Date: Thu, 6 Feb 2025 23:02:48 +0000 Subject: [PATCH 4/6] add ring hash LB tests --- .../xds/RingHashLoadBalancerProvider.java | 3 +- .../io/grpc/xds/RingHashLoadBalancerTest.java | 103 ++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java index e70d5d61a0e..035ff76c585 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java @@ -106,6 +106,7 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal( return ConfigOrError.fromError(Status.UNAVAILABLE.withDescription( "Invalid 'mingRingSize'/'maxRingSize'")); } - return ConfigOrError.fromConfig(new RingHashConfig(minRingSize, maxRingSize, requestHashHeader)); + return ConfigOrError.fromConfig( + new RingHashConfig(minRingSize, maxRingSize, requestHashHeader)); } } diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java index 666aedad75b..c5758b59237 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java @@ -412,6 +412,109 @@ public void deterministicPickWithNewHostsAdded() { inOrder.verifyNoMoreInteractions(); } + @Test + public void deterministicPickWithRequestHashHeader() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, "custom-request-hash-key"); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + InOrder inOrder = Mockito.inOrder(helper); + + // Bring all subchannels to READY. + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, CSI_READY); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + } + + // Pick subchannel with custom request hash header where the rpc hash hits server1. + Metadata headers = new Metadata(); + headers.put( + Metadata.Key.of("custom-request-hash-key", Metadata.ASCII_STRING_MARSHALLER), + "FakeSocketAddress-server1_0"); + PickSubchannelArgs args = + new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), + headers, + CallOptions.DEFAULT, + new PickDetailsConsumer() {}); + SubchannelPicker picker = pickerCaptor.getValue(); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(1)); + } + + @Test + public void pickWithRandomHash_allSubchannelsReady() { + // Large ring to better reflect the request distribution. + RingHashConfig config = new RingHashConfig(10000, 10000, "dummy-random-hash"); + List servers = createWeightedServerAddrs(1, 1); + initializeLbSubchannels(config, servers); + InOrder inOrder = Mockito.inOrder(helper); + + // Bring all subchannels to READY. + Map pickCounts = new HashMap<>(); + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, CSI_READY); + pickCounts.put(subchannel.getAddresses(), 0); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + } + + // Pick subchannel 10000 times with random hash. + SubchannelPicker picker = pickerCaptor.getValue(); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); + for (int i = 0; i < 10000; ++i) { + Subchannel pickedSubchannel = picker.pickSubchannel(args).getSubchannel(); + EquivalentAddressGroup addr = pickedSubchannel.getAddresses(); + pickCounts.put(addr, pickCounts.get(addr) + 1); + } + + // Verify the distribution is uniform where server0 and server1 are roughly picked 5000 times. + assertThat(pickCounts.get(servers.get(0))).isWithin(500).of(5000); + assertThat(pickCounts.get(servers.get(1))).isWithin(500).of(5000); + } + + @Test + public void pickWithRandomHash_atLeastOneSubchannelConnecting() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, "dummy-random-hash"); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + + // Bring one subchannel to CONNECTING. + deliverSubchannelState(getSubChannel(servers.get(0)), CSI_CONNECTING); + verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + + // Pick subchannel with random hash does not trigger connection. + SubchannelPicker picker = pickerCaptor.getValue(); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel()).isNull(); // buffer request + verifyConnection(0); + } + + @Test + public void pickWithRandomHash_firstSubchannelInTransientFailure_remainingSubchannelsIdle() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, "dummy-random-hash"); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + + // Bring one subchannel to TRANSIENT_FAILURE. + deliverSubchannelUnreachable(getSubChannel(servers.get(0))); + verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verifyConnection(0); + + // Pick subchannel with random hash does trigger connection by walking the ring + // and choosing the first (at most one) IDLE subchannel along the way. + SubchannelPicker picker = pickerCaptor.getValue(); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel()).isNull(); // buffer request + verifyConnection(1); + } + private Subchannel getSubChannel(EquivalentAddressGroup eag) { return subchannels.get(Collections.singletonList(eag)); } From fa52eab29d3a2fa43f7dd3e2e0dab4e757bd5a3c Mon Sep 17 00:00:00 2001 From: Daniel Liu Date: Fri, 7 Feb 2025 00:23:41 +0000 Subject: [PATCH 5/6] add test assertions for RingHashConfig.toString() --- .../java/io/grpc/xds/RingHashLoadBalancerProviderTest.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java index 6a10c2b1085..3036db5b09f 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java @@ -255,6 +255,9 @@ public void parseLoadBalancingConfig_requestHashHeaderSetWhenEnvVarSet() throws assertThat(config.minRingSize).isEqualTo(10L); assertThat(config.maxRingSize).isEqualTo(100L); assertThat(config.requestHashHeader).isEqualTo("dummy-hash"); + assertThat(config.toString()).contains("minRingSize=10"); + assertThat(config.toString()).contains("maxRingSize=100"); + assertThat(config.toString()).contains("requestHashHeader=dummy-hash"); } finally { System.clearProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY); } From c8b5fa0c6628fca1dd78173b4fd37b2c54ca5f69 Mon Sep 17 00:00:00 2001 From: Daniel Liu Date: Mon, 10 Feb 2025 22:20:42 +0000 Subject: [PATCH 6/6] support multiple values for request hash header & add respective tests --- .../io/grpc/xds/RingHashLoadBalancer.java | 9 ++-- .../io/grpc/xds/RingHashLoadBalancerTest.java | 42 ++++++++++++++++--- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java index 35f489f8483..bf877d57391 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java @@ -25,6 +25,7 @@ import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import com.google.common.base.Joiner; import com.google.common.base.MoreObjects; import com.google.common.collect.HashMultiset; import com.google.common.collect.Multiset; @@ -403,11 +404,11 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { return PickResult.withError(RPC_HASH_NOT_FOUND); } } else { - String headerValue = + Iterable headerValues = args.getHeaders() - .get(Metadata.Key.of(requestHashHeader, Metadata.ASCII_STRING_MARSHALLER)); - if (headerValue != null) { - requestHash = hashFunc.hashAsciiString(headerValue); + .getAll(Metadata.Key.of(requestHashHeader, Metadata.ASCII_STRING_MARSHALLER)); + if (headerValues != null) { + requestHash = hashFunc.hashAsciiString(Joiner.on(",").join(headerValues)); } else { requestHash = random.nextLong(); usingRandomHash = true; diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java index c5758b59237..05669b4a230 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java @@ -98,6 +98,9 @@ @RunWith(JUnit4.class) public class RingHashLoadBalancerTest { private static final String AUTHORITY = "foo.googleapis.com"; + private static final String CUSTOM_REQUEST_HASH_HEADER = "custom-request-hash-header"; + private static final Metadata.Key CUSTOM_METADATA_KEY = + Metadata.Key.of(CUSTOM_REQUEST_HASH_HEADER, Metadata.ASCII_STRING_MARSHALLER); private static final Attributes.Key CUSTOM_KEY = Attributes.Key.create("custom-key"); private static final ConnectivityStateInfo CSI_CONNECTING = ConnectivityStateInfo.forNonError(CONNECTING); @@ -413,9 +416,9 @@ public void deterministicPickWithNewHostsAdded() { } @Test - public void deterministicPickWithRequestHashHeader() { + public void deterministicPickWithRequestHashHeader_oneHeaderValue() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3, "custom-request-hash-key"); + RingHashConfig config = new RingHashConfig(3, 3, CUSTOM_REQUEST_HASH_HEADER); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); InOrder inOrder = Mockito.inOrder(helper); @@ -428,9 +431,38 @@ public void deterministicPickWithRequestHashHeader() { // Pick subchannel with custom request hash header where the rpc hash hits server1. Metadata headers = new Metadata(); - headers.put( - Metadata.Key.of("custom-request-hash-key", Metadata.ASCII_STRING_MARSHALLER), - "FakeSocketAddress-server1_0"); + headers.put(CUSTOM_METADATA_KEY, "FakeSocketAddress-server1_0"); + PickSubchannelArgs args = + new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), + headers, + CallOptions.DEFAULT, + new PickDetailsConsumer() {}); + SubchannelPicker picker = pickerCaptor.getValue(); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(1)); + } + + @Test + public void deterministicPickWithRequestHashHeader_multipleHeaderValues() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, CUSTOM_REQUEST_HASH_HEADER); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + InOrder inOrder = Mockito.inOrder(helper); + + // Bring all subchannels to READY. + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, CSI_READY); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + } + + // Pick subchannel with custom request hash header with multiple values for the same key where + // the rpc hash hits server1. + Metadata headers = new Metadata(); + headers.put(CUSTOM_METADATA_KEY, "FakeSocketAddress-server0_0"); + headers.put(CUSTOM_METADATA_KEY, "FakeSocketAddress-server1_0"); PickSubchannelArgs args = new PickSubchannelArgsImpl( TestMethodDescriptors.voidMethod(),