1919import static com .google .common .truth .Truth .assertThat ;
2020import static org .mockito .ArgumentMatchers .anyInt ;
2121import static org .mockito .Mockito .mock ;
22+ import static org .mockito .Mockito .verify ;
2223import static org .mockito .Mockito .when ;
2324
2425import com .github .xds .data .orca .v3 .OrcaLoadReport ;
2526import com .google .common .collect .ImmutableMap ;
2627import com .google .common .collect .Iterables ;
2728import io .grpc .Attributes ;
29+ import io .grpc .CallOptions ;
2830import io .grpc .ClientStreamTracer ;
2931import io .grpc .ConnectivityState ;
3032import io .grpc .EquivalentAddressGroup ;
3133import io .grpc .InsecureChannelCredentials ;
3234import io .grpc .LoadBalancer ;
3335import io .grpc .LoadBalancer .CreateSubchannelArgs ;
3436import io .grpc .LoadBalancer .Helper ;
37+ import io .grpc .LoadBalancer .PickDetailsConsumer ;
3538import io .grpc .LoadBalancer .PickResult ;
3639import io .grpc .LoadBalancer .PickSubchannelArgs ;
3740import io .grpc .LoadBalancer .ResolvedAddresses ;
4548import io .grpc .SynchronizationContext ;
4649import io .grpc .internal .FakeClock ;
4750import io .grpc .internal .ObjectPool ;
51+ import io .grpc .internal .PickSubchannelArgsImpl ;
4852import io .grpc .internal .ServiceConfigUtil .PolicySelection ;
4953import io .grpc .protobuf .ProtoUtils ;
54+ import io .grpc .testing .TestMethodDescriptors ;
5055import io .grpc .xds .ClusterImplLoadBalancerProvider .ClusterImplConfig ;
5156import io .grpc .xds .Endpoints .DropOverload ;
5257import io .grpc .xds .EnvoyServerProtoData .DownstreamTlsContext ;
@@ -141,6 +146,9 @@ public AtomicLong getOrCreate(String cluster, @Nullable String edsServiceName) {
141146 }
142147 };
143148 private final Helper helper = new FakeLbHelper ();
149+ private PickSubchannelArgs pickSubchannelArgs = new PickSubchannelArgsImpl (
150+ TestMethodDescriptors .voidMethod (), new Metadata (), CallOptions .DEFAULT ,
151+ new PickDetailsConsumer () {});
144152 @ Mock
145153 private ThreadSafeRandom mockRandom ;
146154 private int xdsClientRefs ;
@@ -218,7 +226,7 @@ public void handleResolvedAddresses_childPolicyChanges() {
218226 public void nameResolutionError_beforeChildPolicyInstantiated_returnErrorPickerToUpstream () {
219227 loadBalancer .handleNameResolutionError (Status .UNIMPLEMENTED .withDescription ("not found" ));
220228 assertThat (currentState ).isEqualTo (ConnectivityState .TRANSIENT_FAILURE );
221- PickResult result = currentPicker .pickSubchannel (mock ( PickSubchannelArgs . class ) );
229+ PickResult result = currentPicker .pickSubchannel (pickSubchannelArgs );
222230 assertThat (result .getStatus ().isOk ()).isFalse ();
223231 assertThat (result .getStatus ().getCode ()).isEqualTo (Code .UNIMPLEMENTED );
224232 assertThat (result .getStatus ().getDescription ()).isEqualTo ("not found" );
@@ -243,6 +251,32 @@ public void nameResolutionError_afterChildPolicyInstantiated_propagateToDownstre
243251 .isEqualTo ("cannot reach server" );
244252 }
245253
254+ @ Test
255+ public void pick_addsLocalityLabel () {
256+ LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider ();
257+ WeightedTargetConfig weightedTargetConfig =
258+ buildWeightedTargetConfig (ImmutableMap .of (locality , 10 ));
259+ ClusterImplConfig config = new ClusterImplConfig (CLUSTER , EDS_SERVICE_NAME , LRS_SERVER_INFO ,
260+ null , Collections .<DropOverload >emptyList (),
261+ new PolicySelection (weightedTargetProvider , weightedTargetConfig ), null );
262+ EquivalentAddressGroup endpoint = makeAddress ("endpoint-addr" , locality );
263+ deliverAddressesAndConfig (Collections .singletonList (endpoint ), config );
264+ FakeLoadBalancer leafBalancer = Iterables .getOnlyElement (downstreamBalancers );
265+ Subchannel subchannel = leafBalancer .helper .createSubchannel (
266+ CreateSubchannelArgs .newBuilder ().setAddresses (leafBalancer .addresses ).build ());
267+ leafBalancer .deliverSubchannelState (subchannel , ConnectivityState .READY );
268+ assertThat (currentState ).isEqualTo (ConnectivityState .READY );
269+
270+ PickDetailsConsumer detailsConsumer = mock (PickDetailsConsumer .class );
271+ pickSubchannelArgs = new PickSubchannelArgsImpl (
272+ TestMethodDescriptors .voidMethod (), new Metadata (), CallOptions .DEFAULT , detailsConsumer );
273+ PickResult result = currentPicker .pickSubchannel (pickSubchannelArgs );
274+ assertThat (result .getStatus ().isOk ()).isTrue ();
275+ // The value will be determined by the parent policy, so can be different than the value used in
276+ // makeAddress() for the test.
277+ verify (detailsConsumer ).addOptionalLabel ("grpc.lb.locality" , locality .toString ());
278+ }
279+
246280 @ Test
247281 public void recordLoadStats () {
248282 LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider ();
@@ -258,7 +292,7 @@ public void recordLoadStats() {
258292 CreateSubchannelArgs .newBuilder ().setAddresses (leafBalancer .addresses ).build ());
259293 leafBalancer .deliverSubchannelState (subchannel , ConnectivityState .READY );
260294 assertThat (currentState ).isEqualTo (ConnectivityState .READY );
261- PickResult result = currentPicker .pickSubchannel (mock ( PickSubchannelArgs . class ) );
295+ PickResult result = currentPicker .pickSubchannel (pickSubchannelArgs );
262296 assertThat (result .getStatus ().isOk ()).isTrue ();
263297 ClientStreamTracer streamTracer1 = result .getStreamTracerFactory ().newClientStreamTracer (
264298 ClientStreamTracer .StreamInfo .newBuilder ().build (), new Metadata ()); // first RPC call
@@ -347,7 +381,7 @@ public void dropRpcsWithRespectToLbConfigDropCategories() {
347381 CreateSubchannelArgs .newBuilder ().setAddresses (leafBalancer .addresses ).build ());
348382 leafBalancer .deliverSubchannelState (subchannel , ConnectivityState .READY );
349383 assertThat (currentState ).isEqualTo (ConnectivityState .READY );
350- PickResult result = currentPicker .pickSubchannel (mock ( PickSubchannelArgs . class ) );
384+ PickResult result = currentPicker .pickSubchannel (pickSubchannelArgs );
351385 assertThat (result .getStatus ().isOk ()).isFalse ();
352386 assertThat (result .getStatus ().getCode ()).isEqualTo (Code .UNAVAILABLE );
353387 assertThat (result .getStatus ().getDescription ()).isEqualTo ("Dropped: throttle" );
@@ -373,7 +407,7 @@ public void dropRpcsWithRespectToLbConfigDropCategories() {
373407 .build ())
374408 .setLoadBalancingPolicyConfig (config )
375409 .build ());
376- result = currentPicker .pickSubchannel (mock ( PickSubchannelArgs . class ) );
410+ result = currentPicker .pickSubchannel (pickSubchannelArgs );
377411 assertThat (result .getStatus ().isOk ()).isFalse ();
378412 assertThat (result .getStatus ().getCode ()).isEqualTo (Code .UNAVAILABLE );
379413 assertThat (result .getStatus ().getDescription ()).isEqualTo ("Dropped: lb" );
@@ -386,7 +420,7 @@ public void dropRpcsWithRespectToLbConfigDropCategories() {
386420 .isEqualTo (1L );
387421 assertThat (clusterStats .totalDroppedRequests ()).isEqualTo (1L );
388422
389- result = currentPicker .pickSubchannel (mock ( PickSubchannelArgs . class ) );
423+ result = currentPicker .pickSubchannel (pickSubchannelArgs );
390424 assertThat (result .getStatus ().isOk ()).isTrue ();
391425 }
392426
@@ -423,7 +457,7 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu
423457 leafBalancer .deliverSubchannelState (subchannel , ConnectivityState .READY );
424458 assertThat (currentState ).isEqualTo (ConnectivityState .READY );
425459 for (int i = 0 ; i < maxConcurrentRequests ; i ++) {
426- PickResult result = currentPicker .pickSubchannel (mock ( PickSubchannelArgs . class ) );
460+ PickResult result = currentPicker .pickSubchannel (pickSubchannelArgs );
427461 assertThat (result .getStatus ().isOk ()).isTrue ();
428462 ClientStreamTracer .Factory streamTracerFactory = result .getStreamTracerFactory ();
429463 streamTracerFactory .newClientStreamTracer (
@@ -434,7 +468,7 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu
434468 assertThat (clusterStats .clusterServiceName ()).isEqualTo (EDS_SERVICE_NAME );
435469 assertThat (clusterStats .totalDroppedRequests ()).isEqualTo (0L );
436470
437- PickResult result = currentPicker .pickSubchannel (mock ( PickSubchannelArgs . class ) );
471+ PickResult result = currentPicker .pickSubchannel (pickSubchannelArgs );
438472 clusterStats = Iterables .getOnlyElement (loadStatsManager .getClusterStatsReports (CLUSTER ));
439473 assertThat (clusterStats .clusterServiceName ()).isEqualTo (EDS_SERVICE_NAME );
440474 if (enableCircuitBreaking ) {
@@ -455,15 +489,15 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu
455489 new PolicySelection (weightedTargetProvider , weightedTargetConfig ), null );
456490 deliverAddressesAndConfig (Collections .singletonList (endpoint ), config );
457491
458- result = currentPicker .pickSubchannel (mock ( PickSubchannelArgs . class ) );
492+ result = currentPicker .pickSubchannel (pickSubchannelArgs );
459493 assertThat (result .getStatus ().isOk ()).isTrue ();
460494 result .getStreamTracerFactory ().newClientStreamTracer (
461495 ClientStreamTracer .StreamInfo .newBuilder ().build (), new Metadata ()); // 101th request
462496 clusterStats = Iterables .getOnlyElement (loadStatsManager .getClusterStatsReports (CLUSTER ));
463497 assertThat (clusterStats .clusterServiceName ()).isEqualTo (EDS_SERVICE_NAME );
464498 assertThat (clusterStats .totalDroppedRequests ()).isEqualTo (0L );
465499
466- result = currentPicker .pickSubchannel (mock ( PickSubchannelArgs . class ) ); // 102th request
500+ result = currentPicker .pickSubchannel (pickSubchannelArgs ); // 102th request
467501 clusterStats = Iterables .getOnlyElement (loadStatsManager .getClusterStatsReports (CLUSTER ));
468502 assertThat (clusterStats .clusterServiceName ()).isEqualTo (EDS_SERVICE_NAME );
469503 if (enableCircuitBreaking ) {
@@ -511,7 +545,7 @@ private void subtest_maxConcurrentRequests_appliedWithDefaultValue(
511545 leafBalancer .deliverSubchannelState (subchannel , ConnectivityState .READY );
512546 assertThat (currentState ).isEqualTo (ConnectivityState .READY );
513547 for (int i = 0 ; i < ClusterImplLoadBalancer .DEFAULT_PER_CLUSTER_MAX_CONCURRENT_REQUESTS ; i ++) {
514- PickResult result = currentPicker .pickSubchannel (mock ( PickSubchannelArgs . class ) );
548+ PickResult result = currentPicker .pickSubchannel (pickSubchannelArgs );
515549 assertThat (result .getStatus ().isOk ()).isTrue ();
516550 ClientStreamTracer .Factory streamTracerFactory = result .getStreamTracerFactory ();
517551 streamTracerFactory .newClientStreamTracer (
@@ -522,7 +556,7 @@ private void subtest_maxConcurrentRequests_appliedWithDefaultValue(
522556 assertThat (clusterStats .clusterServiceName ()).isEqualTo (EDS_SERVICE_NAME );
523557 assertThat (clusterStats .totalDroppedRequests ()).isEqualTo (0L );
524558
525- PickResult result = currentPicker .pickSubchannel (mock ( PickSubchannelArgs . class ) );
559+ PickResult result = currentPicker .pickSubchannel (pickSubchannelArgs );
526560 clusterStats = Iterables .getOnlyElement (loadStatsManager .getClusterStatsReports (CLUSTER ));
527561 assertThat (clusterStats .clusterServiceName ()).isEqualTo (EDS_SERVICE_NAME );
528562 if (enableCircuitBreaking ) {
@@ -697,7 +731,11 @@ public String toString() {
697731 }
698732
699733 EquivalentAddressGroup eag = new EquivalentAddressGroup (new FakeSocketAddress (name ),
700- Attributes .newBuilder ().set (InternalXdsAttributes .ATTR_LOCALITY , locality ).build ());
734+ Attributes .newBuilder ()
735+ .set (InternalXdsAttributes .ATTR_LOCALITY , locality )
736+ // Unique but arbitrary string
737+ .set (InternalXdsAttributes .ATTR_LOCALITY_NAME , locality .toString ())
738+ .build ());
701739 return AddressFilter .setPathFilter (eag , Collections .singletonList (locality .toString ()));
702740 }
703741
0 commit comments