diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java index 0c4792cb924..bf877d57391 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java @@ -25,6 +25,7 @@ import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import com.google.common.base.Joiner; import com.google.common.base.MoreObjects; import com.google.common.collect.HashMultiset; import com.google.common.collect.Multiset; @@ -34,6 +35,7 @@ import io.grpc.EquivalentAddressGroup; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; +import io.grpc.Metadata; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.util.MultiChildLoadBalancer; @@ -47,6 +49,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Random; import java.util.Set; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -70,6 +73,7 @@ final class RingHashLoadBalancer extends MultiChildLoadBalancer { private final XdsLogger logger; private final SynchronizationContext syncContext; private List ring; + private String requestHashHeader = ""; RingHashLoadBalancer(Helper helper) { super(helper); @@ -99,6 +103,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { if (config == null) { throw new IllegalArgumentException("Missing RingHash configuration"); } + requestHashHeader = config.requestHashHeader; Map serverWeights = new HashMap<>(); long totalWeight = 0L; for (EquivalentAddressGroup eag : addrList) { @@ -213,7 +218,8 @@ protected void updateOverallBalancingState() { overallState = TRANSIENT_FAILURE; } - RingHashPicker picker = new RingHashPicker(syncContext, ring, getChildLbStates()); + RingHashPicker picker = + new RingHashPicker(syncContext, ring, getChildLbStates(), requestHashHeader); getHelper().updateBalancingState(overallState, picker); this.currentConnectivityState = overallState; } @@ -334,6 +340,7 @@ public static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { } private static final class RingHashPicker extends SubchannelPicker { + private final Random random = new Random(); private final SynchronizationContext syncContext; private final List ring; // Avoid synchronization between pickSubchannel and subchannel's connectivity state change, @@ -341,16 +348,22 @@ private static final class RingHashPicker extends SubchannelPicker { // TODO(chengyuanzhang): can be more performance-friendly with // IdentityHashMap and RingEntry contains Subchannel. private final Map pickableSubchannels; // read-only + private final String requestHashHeader; + private boolean hasEndpointInConnectingState = false; private RingHashPicker( SynchronizationContext syncContext, List ring, - Collection children) { + Collection children, String requestHashHeader) { this.syncContext = syncContext; this.ring = ring; + this.requestHashHeader = requestHashHeader; pickableSubchannels = new HashMap<>(children.size()); for (ChildLbState childLbState : children) { pickableSubchannels.put((Endpoint)childLbState.getKey(), new SubchannelView(childLbState, childLbState.getCurrentState())); + if (childLbState.getCurrentState() == CONNECTING) { + hasEndpointInConnectingState = true; + } } } @@ -381,38 +394,78 @@ private int getTargetIndex(Long requestHash) { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - Long requestHash = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY); - if (requestHash == null) { - return PickResult.withError(RPC_HASH_NOT_FOUND); + // Determine request hash. + boolean usingRandomHash = false; + Long requestHash; + if (requestHashHeader.isEmpty()) { + // Set by the xDS config selector. + requestHash = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY); + if (requestHash == null) { + return PickResult.withError(RPC_HASH_NOT_FOUND); + } + } else { + Iterable headerValues = + args.getHeaders() + .getAll(Metadata.Key.of(requestHashHeader, Metadata.ASCII_STRING_MARSHALLER)); + if (headerValues != null) { + requestHash = hashFunc.hashAsciiString(Joiner.on(",").join(headerValues)); + } else { + requestHash = random.nextLong(); + usingRandomHash = true; + } } int targetIndex = getTargetIndex(requestHash); - // Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF, we ignore - // all TF subchannels and find the first ring entry in READY, CONNECTING or IDLE. If - // CONNECTING or IDLE we return a pick with no results. Additionally, if that entry is in - // IDLE, we initiate a connection. - for (int i = 0; i < ring.size(); i++) { - int index = (targetIndex + i) % ring.size(); - SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey); - ChildLbState childLbState = subchannelView.childLbState; - - if (subchannelView.connectivityState == READY) { - return childLbState.getCurrentPicker().pickSubchannel(args); + if (!usingRandomHash) { + // Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF, we ignore + // all TF subchannels and find the first ring entry in READY, CONNECTING or IDLE. If + // CONNECTING or IDLE we return a pick with no results. Additionally, if that entry is in + // IDLE, we initiate a connection. + for (int i = 0; i < ring.size(); i++) { + int index = (targetIndex + i) % ring.size(); + SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey); + ChildLbState childLbState = subchannelView.childLbState; + + if (subchannelView.connectivityState == READY) { + return childLbState.getCurrentPicker().pickSubchannel(args); + } + + // RPCs can be buffered if the next subchannel is pending (per A62). Otherwise, RPCs + // are failed unless there is a READY connection. + if (subchannelView.connectivityState == CONNECTING) { + return PickResult.withNoResult(); + } + + if (subchannelView.connectivityState == IDLE) { + syncContext.execute(() -> { + childLbState.getLb().requestConnection(); + }); + + return PickResult.withNoResult(); // Indicates that this should be retried after backoff + } } - - // RPCs can be buffered if the next subchannel is pending (per A62). Otherwise, RPCs - // are failed unless there is a READY connection. - if (subchannelView.connectivityState == CONNECTING) { - return PickResult.withNoResult(); + } else { + // Using a random hash. Find and use the first READY ring entry, triggering at most one + // entry to attempt connection. + boolean requestedConnection = hasEndpointInConnectingState; + for (int i = 0; i < ring.size(); i++) { + int index = (targetIndex + i) % ring.size(); + SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey); + ChildLbState childLbState = subchannelView.childLbState; + if (subchannelView.connectivityState == READY) { + return childLbState.getCurrentPicker().pickSubchannel(args); + } + if (!requestedConnection && subchannelView.connectivityState == IDLE) { + syncContext.execute( + () -> { + childLbState.getLb().requestConnection(); + }); + requestedConnection = true; + } } - - if (subchannelView.connectivityState == IDLE) { - syncContext.execute(() -> { - childLbState.getLb().requestConnection(); - }); - - return PickResult.withNoResult(); // Indicates that this should be retried after backoff + if (requestedConnection) { + return PickResult.withNoResult(); } } @@ -460,13 +513,16 @@ public int compareTo(RingEntry entry) { static final class RingHashConfig { final long minRingSize; final long maxRingSize; + final String requestHashHeader; - RingHashConfig(long minRingSize, long maxRingSize) { + RingHashConfig(long minRingSize, long maxRingSize, String requestHashHeader) { checkArgument(minRingSize > 0, "minRingSize <= 0"); checkArgument(maxRingSize > 0, "maxRingSize <= 0"); checkArgument(minRingSize <= maxRingSize, "minRingSize > maxRingSize"); + checkNotNull(requestHashHeader); this.minRingSize = minRingSize; this.maxRingSize = maxRingSize; + this.requestHashHeader = requestHashHeader; } @Override @@ -474,6 +530,7 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("minRingSize", minRingSize) .add("maxRingSize", maxRingSize) + .add("requestHashHeader", requestHashHeader) .toString(); } } diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java index dad79384569..035ff76c585 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java @@ -24,6 +24,7 @@ import io.grpc.LoadBalancerProvider; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonUtil; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import io.grpc.xds.RingHashOptions; @@ -81,6 +82,10 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal( Map rawLoadBalancingPolicyConfig) { Long minRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "minRingSize"); Long maxRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "maxRingSize"); + String requestHashHeader = ""; + if (GrpcUtil.getFlag("GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY", false)) { + requestHashHeader = JsonUtil.getString(rawLoadBalancingPolicyConfig, "requestHashHeader"); + } long maxRingSizeCap = RingHashOptions.getRingSizeCap(); if (minRingSize == null) { minRingSize = DEFAULT_MIN_RING_SIZE; @@ -88,6 +93,9 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal( if (maxRingSize == null) { maxRingSize = DEFAULT_MAX_RING_SIZE; } + if (requestHashHeader == null) { + requestHashHeader = ""; + } if (minRingSize > maxRingSizeCap) { minRingSize = maxRingSizeCap; } @@ -98,6 +106,7 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal( return ConfigOrError.fromError(Status.UNAVAILABLE.withDescription( "Invalid 'mingRingSize'/'maxRingSize'")); } - return ConfigOrError.fromConfig(new RingHashConfig(minRingSize, maxRingSize)); + return ConfigOrError.fromConfig( + new RingHashConfig(minRingSize, maxRingSize, requestHashHeader)); } } diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index 9243abba6d3..28898c0930f 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -163,7 +163,7 @@ public void uncaughtException(Thread t, Throwable e) { GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( new FakeLoadBalancerProvider("round_robin"), null))); private final Object ringHash = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( - new FakeLoadBalancerProvider("ring_hash_experimental"), new RingHashConfig(10L, 100L)); + new FakeLoadBalancerProvider("ring_hash_experimental"), new RingHashConfig(10L, 100L, "")); private final Object leastRequest = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( new FakeLoadBalancerProvider("wrr_locality_experimental"), new WrrLocalityConfig( GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java index 87615a125c0..3036db5b09f 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java @@ -42,6 +42,8 @@ @RunWith(JUnit4.class) public class RingHashLoadBalancerProviderTest { private static final String AUTHORITY = "foo.googleapis.com"; + private static final String GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY = + "GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY"; private final SynchronizationContext syncContext = new SynchronizationContext( new UncaughtExceptionHandler() { @@ -81,6 +83,7 @@ public void parseLoadBalancingConfig_valid() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(10L); assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEmpty(); } @Test @@ -92,6 +95,7 @@ public void parseLoadBalancingConfig_missingRingSize_useDefaults() throws IOExce RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(RingHashLoadBalancerProvider.DEFAULT_MIN_RING_SIZE); assertThat(config.maxRingSize).isEqualTo(RingHashLoadBalancerProvider.DEFAULT_MAX_RING_SIZE); + assertThat(config.requestHashHeader).isEmpty(); } @Test @@ -127,6 +131,7 @@ public void parseLoadBalancingConfig_ringTooLargeUsesCap() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(10); assertThat(config.maxRingSize).isEqualTo(RingHashOptions.DEFAULT_RING_SIZE_CAP); + assertThat(config.requestHashHeader).isEmpty(); } @Test @@ -142,6 +147,7 @@ public void parseLoadBalancingConfig_ringCapCanBeRaised() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); assertThat(config.maxRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -159,6 +165,7 @@ public void parseLoadBalancingConfig_ringCapIsClampedTo8M() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); assertThat(config.maxRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -176,6 +183,7 @@ public void parseLoadBalancingConfig_ringCapCanBeLowered() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(1); assertThat(config.maxRingSize).isEqualTo(1); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -193,6 +201,7 @@ public void parseLoadBalancingConfig_ringCapLowerLimitIs1() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(1); assertThat(config.maxRingSize).isEqualTo(1); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -219,6 +228,59 @@ public void parseLoadBalancingConfig_minRingSizeGreaterThanMaxRingSize() throws .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); } + @Test + public void parseLoadBalancingConfig_requestHashHeaderIgnoredWhenEnvVarNotSet() + throws IOException { + String lbConfig = + "{\"minRingSize\" : 10, \"maxRingSize\" : 100, \"requestHashHeader\" : \"dummy-hash\"}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(10L); + assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEmpty(); + } + + @Test + public void parseLoadBalancingConfig_requestHashHeaderSetWhenEnvVarSet() throws IOException { + System.setProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY, "true"); + try { + String lbConfig = + "{\"minRingSize\" : 10, \"maxRingSize\" : 100, \"requestHashHeader\" : \"dummy-hash\"}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(10L); + assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEqualTo("dummy-hash"); + assertThat(config.toString()).contains("minRingSize=10"); + assertThat(config.toString()).contains("maxRingSize=100"); + assertThat(config.toString()).contains("requestHashHeader=dummy-hash"); + } finally { + System.clearProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY); + } + } + + @Test + public void parseLoadBalancingConfig_requestHashHeaderUnsetWhenEnvVarSet_useDefaults() + throws IOException { + System.setProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY, "true"); + try { + String lbConfig = "{\"minRingSize\" : 10, \"maxRingSize\" : 100}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(10L); + assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEmpty(); + } finally { + System.clearProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY); + } + } + @SuppressWarnings("unchecked") private static Map parseJsonObject(String json) throws IOException { return (Map) JsonParser.parse(json); diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java index 65fc1527b0c..05669b4a230 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java @@ -98,6 +98,9 @@ @RunWith(JUnit4.class) public class RingHashLoadBalancerTest { private static final String AUTHORITY = "foo.googleapis.com"; + private static final String CUSTOM_REQUEST_HASH_HEADER = "custom-request-hash-header"; + private static final Metadata.Key CUSTOM_METADATA_KEY = + Metadata.Key.of(CUSTOM_REQUEST_HASH_HEADER, Metadata.ASCII_STRING_MARSHALLER); private static final Attributes.Key CUSTOM_KEY = Attributes.Key.create("custom-key"); private static final ConnectivityStateInfo CSI_CONNECTING = ConnectivityStateInfo.forNonError(CONNECTING); @@ -142,7 +145,7 @@ public void tearDown() { @Test public void subchannelLazyConnectUntilPicked() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1); // one server Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() @@ -176,7 +179,7 @@ public void subchannelLazyConnectUntilPicked() { @Test public void subchannelNotAutoReconnectAfterReenteringIdle() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1); // one server Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() @@ -207,7 +210,7 @@ public void subchannelNotAutoReconnectAfterReenteringIdle() { @Test public void aggregateSubchannelStates_connectingReadyIdleFailure() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1); InOrder inOrder = Mockito.inOrder(helper); @@ -266,7 +269,7 @@ private void verifyConnection(int times) { @Test public void aggregateSubchannelStates_allSubchannelsInTransientFailure() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1, 1, 1); List subChannelList = initializeLbSubchannels(config, servers, STAY_IN_CONNECTING); @@ -324,7 +327,7 @@ private void refreshInvokedAndUpdateBS(InOrder inOrder, ConnectivityState state) @Test public void ignoreShutdownSubchannelStateChange() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -340,7 +343,7 @@ public void ignoreShutdownSubchannelStateChange() { @Test public void deterministicPickWithHostsPartiallyRemoved() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1, 1, 1, 1); initializeLbSubchannels(config, servers); InOrder inOrder = Mockito.inOrder(helper); @@ -380,7 +383,7 @@ public void deterministicPickWithHostsPartiallyRemoved() { @Test public void deterministicPickWithNewHostsAdded() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1); // server0 and server1 initializeLbSubchannels(config, servers, DO_NOT_VERIFY, DO_NOT_RESET_HELPER); @@ -412,6 +415,138 @@ public void deterministicPickWithNewHostsAdded() { inOrder.verifyNoMoreInteractions(); } + @Test + public void deterministicPickWithRequestHashHeader_oneHeaderValue() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, CUSTOM_REQUEST_HASH_HEADER); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + InOrder inOrder = Mockito.inOrder(helper); + + // Bring all subchannels to READY. + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, CSI_READY); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + } + + // Pick subchannel with custom request hash header where the rpc hash hits server1. + Metadata headers = new Metadata(); + headers.put(CUSTOM_METADATA_KEY, "FakeSocketAddress-server1_0"); + PickSubchannelArgs args = + new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), + headers, + CallOptions.DEFAULT, + new PickDetailsConsumer() {}); + SubchannelPicker picker = pickerCaptor.getValue(); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(1)); + } + + @Test + public void deterministicPickWithRequestHashHeader_multipleHeaderValues() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, CUSTOM_REQUEST_HASH_HEADER); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + InOrder inOrder = Mockito.inOrder(helper); + + // Bring all subchannels to READY. + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, CSI_READY); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + } + + // Pick subchannel with custom request hash header with multiple values for the same key where + // the rpc hash hits server1. + Metadata headers = new Metadata(); + headers.put(CUSTOM_METADATA_KEY, "FakeSocketAddress-server0_0"); + headers.put(CUSTOM_METADATA_KEY, "FakeSocketAddress-server1_0"); + PickSubchannelArgs args = + new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), + headers, + CallOptions.DEFAULT, + new PickDetailsConsumer() {}); + SubchannelPicker picker = pickerCaptor.getValue(); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(1)); + } + + @Test + public void pickWithRandomHash_allSubchannelsReady() { + // Large ring to better reflect the request distribution. + RingHashConfig config = new RingHashConfig(10000, 10000, "dummy-random-hash"); + List servers = createWeightedServerAddrs(1, 1); + initializeLbSubchannels(config, servers); + InOrder inOrder = Mockito.inOrder(helper); + + // Bring all subchannels to READY. + Map pickCounts = new HashMap<>(); + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, CSI_READY); + pickCounts.put(subchannel.getAddresses(), 0); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + } + + // Pick subchannel 10000 times with random hash. + SubchannelPicker picker = pickerCaptor.getValue(); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); + for (int i = 0; i < 10000; ++i) { + Subchannel pickedSubchannel = picker.pickSubchannel(args).getSubchannel(); + EquivalentAddressGroup addr = pickedSubchannel.getAddresses(); + pickCounts.put(addr, pickCounts.get(addr) + 1); + } + + // Verify the distribution is uniform where server0 and server1 are roughly picked 5000 times. + assertThat(pickCounts.get(servers.get(0))).isWithin(500).of(5000); + assertThat(pickCounts.get(servers.get(1))).isWithin(500).of(5000); + } + + @Test + public void pickWithRandomHash_atLeastOneSubchannelConnecting() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, "dummy-random-hash"); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + + // Bring one subchannel to CONNECTING. + deliverSubchannelState(getSubChannel(servers.get(0)), CSI_CONNECTING); + verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + + // Pick subchannel with random hash does not trigger connection. + SubchannelPicker picker = pickerCaptor.getValue(); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel()).isNull(); // buffer request + verifyConnection(0); + } + + @Test + public void pickWithRandomHash_firstSubchannelInTransientFailure_remainingSubchannelsIdle() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, "dummy-random-hash"); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + + // Bring one subchannel to TRANSIENT_FAILURE. + deliverSubchannelUnreachable(getSubChannel(servers.get(0))); + verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verifyConnection(0); + + // Pick subchannel with random hash does trigger connection by walking the ring + // and choosing the first (at most one) IDLE subchannel along the way. + SubchannelPicker picker = pickerCaptor.getValue(); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel()).isNull(); // buffer request + verifyConnection(1); + } + private Subchannel getSubChannel(EquivalentAddressGroup eag) { return subchannels.get(Collections.singletonList(eag)); } @@ -419,7 +554,7 @@ private Subchannel getSubChannel(EquivalentAddressGroup eag) { @Test public void skipFailingHosts_pickNextNonFailingHost() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( @@ -489,7 +624,7 @@ private PickSubchannelArgs getDefaultPickSubchannelArgsForServer(int serverid) { @Test public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -555,7 +690,7 @@ public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { @Test public void removingAddressShutdownSubchannel() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List svs1 = createWeightedServerAddrs(1, 1, 1); List subchannels1 = initializeLbSubchannels(config, svs1, STAY_IN_CONNECTING); @@ -572,7 +707,7 @@ public void removingAddressShutdownSubchannel() { @Test public void allSubchannelsInTransientFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -599,7 +734,7 @@ public void allSubchannelsInTransientFailure() { @Test public void firstSubchannelIdle() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -620,7 +755,7 @@ public void firstSubchannelIdle() { @Test public void firstSubchannelConnecting() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -644,7 +779,7 @@ private Subchannel getSubchannel(List servers, int serve @Test public void firstSubchannelFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); List subchannelList = @@ -675,7 +810,7 @@ public void firstSubchannelFailure() { @Test public void secondSubchannelConnecting() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -706,7 +841,7 @@ public void secondSubchannelConnecting() { @Test public void secondSubchannelFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -733,7 +868,7 @@ public void secondSubchannelFailure() { @Test public void thirdSubchannelConnecting() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -762,7 +897,7 @@ public void thirdSubchannelConnecting() { @Test public void stickyTransientFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -791,7 +926,7 @@ public void stickyTransientFailure() { @Test public void largeWeights() { - RingHashConfig config = new RingHashConfig(10000, 100000); // large ring + RingHashConfig config = new RingHashConfig(10000, 100000, ""); // large ring List servers = createWeightedServerAddrs(Integer.MAX_VALUE, 10, 100); // MAX:10:100 @@ -829,7 +964,7 @@ public void largeWeights() { @Test public void hostSelectionProportionalToWeights() { - RingHashConfig config = new RingHashConfig(10000, 100000); // large ring + RingHashConfig config = new RingHashConfig(10000, 100000, ""); // large ring List servers = createWeightedServerAddrs(1, 10, 100); // 1:10:100 initializeLbSubchannels(config, servers); @@ -872,7 +1007,7 @@ public void nameResolutionErrorWithNoActiveSubchannels() { @Test public void nameResolutionErrorWithActiveSubchannels() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1); initializeLbSubchannels(config, servers, DO_NOT_VERIFY, DO_NOT_RESET_HELPER); @@ -894,7 +1029,7 @@ public void nameResolutionErrorWithActiveSubchannels() { @Test public void duplicateAddresses() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createRepeatedServerAddrs(1, 2, 3); initializeLbSubchannels(config, servers, DO_NOT_VERIFY); @@ -940,7 +1075,7 @@ protected Helper delegate() { InOrder inOrder = Mockito.inOrder(helper); List servers = createWeightedServerAddrs(1, 1); - initializeLbSubchannels(new RingHashConfig(10, 100), servers); + initializeLbSubchannels(new RingHashConfig(10, 100, ""), servers); Subchannel subchannel0 = subchannels.get(Collections.singletonList(servers.get(0))); Subchannel subchannel1 = subchannels.get(Collections.singletonList(servers.get(1)));