From 3069e1c5fd3c6fb10b845e8f0dc7c3bbd05c7cf1 Mon Sep 17 00:00:00 2001 From: Ashley Zhang Date: Sat, 8 Mar 2025 00:54:36 +0000 Subject: [PATCH 1/6] xds: add support for custom per-target credentials on the transport. --- .../io/grpc/xds/GrpcXdsTransportFactory.java | 54 +++-- .../InternalSharedXdsClientPoolProvider.java | 15 +- .../grpc/xds/SharedXdsClientPoolProvider.java | 50 +++-- .../grpc/xds/GrpcXdsClientImplTestBase.java | 50 ++--- .../grpc/xds/GrpcXdsTransportFactoryTest.java | 7 +- .../io/grpc/xds/LoadReportClientTest.java | 14 +- ...TargetXdsTransportCallCredentialsTest.java | 189 ++++++++++++++++++ .../io/grpc/xds/XdsClientFallbackTest.java | 57 ++++-- 8 files changed, 352 insertions(+), 84 deletions(-) create mode 100644 xds/src/test/java/io/grpc/xds/PerTargetXdsTransportCallCredentialsTest.java diff --git a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java index 74c28ba2d2d..0da51bf47f7 100644 --- a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java +++ b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import io.grpc.CallCredentials; import io.grpc.CallOptions; import io.grpc.ChannelCredentials; import io.grpc.ClientCall; @@ -34,35 +35,50 @@ final class GrpcXdsTransportFactory implements XdsTransportFactory { - static final GrpcXdsTransportFactory DEFAULT_XDS_TRANSPORT_FACTORY = - new GrpcXdsTransportFactory(); + private final CallCredentials callCredentials; + + GrpcXdsTransportFactory(CallCredentials callCredentials) { + this.callCredentials = callCredentials; + } @Override public XdsTransport create(Bootstrapper.ServerInfo serverInfo) { - return new GrpcXdsTransport(serverInfo); + return new GrpcXdsTransport(serverInfo, callCredentials); } @VisibleForTesting public XdsTransport createForTest(ManagedChannel channel) { - return new GrpcXdsTransport(channel); + return new GrpcXdsTransport(channel, callCredentials); } @VisibleForTesting static class GrpcXdsTransport implements XdsTransport { private final ManagedChannel channel; + private final CallCredentials callCredentials; public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo) { + this(serverInfo, null); + } + + @VisibleForTesting + public GrpcXdsTransport(ManagedChannel channel) { + this(channel, null); + } + + public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials callCredentials) { String target = serverInfo.target(); ChannelCredentials channelCredentials = (ChannelCredentials) serverInfo.implSpecificConfig(); this.channel = Grpc.newChannelBuilder(target, channelCredentials) .keepAliveTime(5, TimeUnit.MINUTES) .build(); + this.callCredentials = callCredentials; } @VisibleForTesting - public GrpcXdsTransport(ManagedChannel channel) { + public GrpcXdsTransport(ManagedChannel channel, CallCredentials callCredentials) { this.channel = checkNotNull(channel, "channel"); + this.callCredentials = callCredentials; } @Override @@ -72,7 +88,8 @@ public StreamingCall createStreamingCall( MethodDescriptor.Marshaller respMarshaller) { Context prevContext = Context.ROOT.attach(); try { - return new XdsStreamingCall<>(fullMethodName, reqMarshaller, respMarshaller); + return new XdsStreamingCall<>( + fullMethodName, reqMarshaller, respMarshaller, callCredentials); } finally { Context.ROOT.detach(prevContext); } @@ -89,16 +106,21 @@ private class XdsStreamingCall implements private final ClientCall call; - public XdsStreamingCall(String methodName, MethodDescriptor.Marshaller reqMarshaller, - MethodDescriptor.Marshaller respMarshaller) { - this.call = channel.newCall( - MethodDescriptor.newBuilder() - .setFullMethodName(methodName) - .setType(MethodDescriptor.MethodType.BIDI_STREAMING) - .setRequestMarshaller(reqMarshaller) - .setResponseMarshaller(respMarshaller) - .build(), - CallOptions.DEFAULT); // TODO(zivy): support waitForReady + public XdsStreamingCall( + String methodName, + MethodDescriptor.Marshaller reqMarshaller, + MethodDescriptor.Marshaller respMarshaller, + CallCredentials callCredentials) { + this.call = + channel.newCall( + MethodDescriptor.newBuilder() + .setFullMethodName(methodName) + .setType(MethodDescriptor.MethodType.BIDI_STREAMING) + .setRequestMarshaller(reqMarshaller) + .setResponseMarshaller(respMarshaller) + .build(), + CallOptions.DEFAULT.withCallCredentials( + callCredentials)); // TODO(zivy): support waitForReady } @Override diff --git a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java index 85b59fabfa0..5585992e204 100644 --- a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java @@ -16,6 +16,7 @@ package io.grpc.xds; +import io.grpc.CallCredentials; import io.grpc.Internal; import io.grpc.MetricRecorder; import io.grpc.internal.ObjectPool; @@ -42,6 +43,18 @@ public static ObjectPool getOrCreate(String target) public static ObjectPool getOrCreate(String target, MetricRecorder metricRecorder) throws XdsInitializationException { - return SharedXdsClientPoolProvider.getDefaultProvider().getOrCreate(target, metricRecorder); + return getOrCreate(target, metricRecorder, null); + } + + public static ObjectPool getOrCreate( + String target, CallCredentials transportCallCredentials) throws XdsInitializationException { + return getOrCreate(target, new MetricRecorder() {}, transportCallCredentials); + } + + public static ObjectPool getOrCreate( + String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials) + throws XdsInitializationException { + return SharedXdsClientPoolProvider.getDefaultProvider() + .getOrCreate(target, metricRecorder, transportCallCredentials); } } diff --git a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java index 2bc7be4a014..5302880d48c 100644 --- a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java @@ -17,11 +17,11 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.CallCredentials; import io.grpc.MetricRecorder; import io.grpc.internal.ExponentialBackoffPolicy; import io.grpc.internal.GrpcUtil; @@ -87,6 +87,12 @@ public ObjectPool get(String target) { @Override public ObjectPool getOrCreate(String target, MetricRecorder metricRecorder) throws XdsInitializationException { + return getOrCreate(target, metricRecorder, null); + } + + public ObjectPool getOrCreate( + String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials) + throws XdsInitializationException { ObjectPool ref = targetToXdsClientMap.get(target); if (ref == null) { synchronized (lock) { @@ -102,7 +108,9 @@ public ObjectPool getOrCreate(String target, MetricRecorder metricRec if (bootstrapInfo.servers().isEmpty()) { throw new XdsInitializationException("No xDS server provided"); } - ref = new RefCountedXdsClientObjectPool(bootstrapInfo, target, metricRecorder); + ref = + new RefCountedXdsClientObjectPool( + bootstrapInfo, target, metricRecorder, transportCallCredentials); targetToXdsClientMap.put(target, ref); } } @@ -126,6 +134,7 @@ class RefCountedXdsClientObjectPool implements ObjectPool { private final BootstrapInfo bootstrapInfo; private final String target; // The target associated with the xDS client. private final MetricRecorder metricRecorder; + private final CallCredentials transportCallCredentials; private final Object lock = new Object(); @GuardedBy("lock") private ScheduledExecutorService scheduler; @@ -137,11 +146,21 @@ class RefCountedXdsClientObjectPool implements ObjectPool { private XdsClientMetricReporterImpl metricReporter; @VisibleForTesting - RefCountedXdsClientObjectPool(BootstrapInfo bootstrapInfo, String target, - MetricRecorder metricRecorder) { + RefCountedXdsClientObjectPool( + BootstrapInfo bootstrapInfo, String target, MetricRecorder metricRecorder) { + this(bootstrapInfo, target, metricRecorder, null); + } + + @VisibleForTesting + RefCountedXdsClientObjectPool( + BootstrapInfo bootstrapInfo, + String target, + MetricRecorder metricRecorder, + CallCredentials transportCallCredentials) { this.bootstrapInfo = checkNotNull(bootstrapInfo); this.target = target; this.metricRecorder = metricRecorder; + this.transportCallCredentials = transportCallCredentials; } @Override @@ -153,16 +172,19 @@ public XdsClient getObject() { } scheduler = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE); metricReporter = new XdsClientMetricReporterImpl(metricRecorder, target); - xdsClient = new XdsClientImpl( - DEFAULT_XDS_TRANSPORT_FACTORY, - bootstrapInfo, - scheduler, - BACKOFF_POLICY_PROVIDER, - GrpcUtil.STOPWATCH_SUPPLIER, - TimeProvider.SYSTEM_TIME_PROVIDER, - MessagePrinter.INSTANCE, - new TlsContextManagerImpl(bootstrapInfo), - metricReporter); + GrpcXdsTransportFactory xdsTransportFactory = + new GrpcXdsTransportFactory(transportCallCredentials); + xdsClient = + new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + scheduler, + BACKOFF_POLICY_PROVIDER, + GrpcUtil.STOPWATCH_SUPPLIER, + TimeProvider.SYSTEM_TIME_PROVIDER, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + metricReporter); metricReporter.setXdsClient(xdsClient); } refCount++; diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java index 51c07cb3537..e332ed5b472 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java @@ -18,7 +18,6 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; -import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; @@ -324,29 +323,32 @@ public void setUp() throws IOException { .start()); channel = cleanupRule.register(InProcessChannelBuilder.forName(serverName).directExecutor().build()); - XdsTransportFactory xdsTransportFactory = new XdsTransportFactory() { - @Override - public XdsTransport create(ServerInfo serverInfo) { - if (serverInfo.target().equals(SERVER_URI)) { - return new GrpcXdsTransport(channel); - } - if (serverInfo.target().equals(SERVER_URI_CUSTOME_AUTHORITY)) { - if (channelForCustomAuthority == null) { - channelForCustomAuthority = cleanupRule.register( - InProcessChannelBuilder.forName(serverName).directExecutor().build()); - } - return new GrpcXdsTransport(channelForCustomAuthority); - } - if (serverInfo.target().equals(SERVER_URI_EMPTY_AUTHORITY)) { - if (channelForEmptyAuthority == null) { - channelForEmptyAuthority = cleanupRule.register( - InProcessChannelBuilder.forName(serverName).directExecutor().build()); + XdsTransportFactory xdsTransportFactory = + new XdsTransportFactory() { + @Override + public XdsTransport create(ServerInfo serverInfo) { + if (serverInfo.target().equals(SERVER_URI)) { + return new GrpcXdsTransport(channel); + } + if (serverInfo.target().equals(SERVER_URI_CUSTOME_AUTHORITY)) { + if (channelForCustomAuthority == null) { + channelForCustomAuthority = + cleanupRule.register( + InProcessChannelBuilder.forName(serverName).directExecutor().build()); + } + return new GrpcXdsTransport(channelForCustomAuthority); + } + if (serverInfo.target().equals(SERVER_URI_EMPTY_AUTHORITY)) { + if (channelForEmptyAuthority == null) { + channelForEmptyAuthority = + cleanupRule.register( + InProcessChannelBuilder.forName(serverName).directExecutor().build()); + } + return new GrpcXdsTransport(channelForEmptyAuthority); + } + throw new IllegalArgumentException("Can not create channel for " + serverInfo); } - return new GrpcXdsTransport(channelForEmptyAuthority); - } - throw new IllegalArgumentException("Can not create channel for " + serverInfo); - } - }; + }; xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, ignoreResourceDeletion(), true); @@ -4193,7 +4195,7 @@ public void serverFailureMetricReport_forRetryAndBackoff() { private XdsClientImpl createXdsClient(String serverUri) { BootstrapInfo bootstrapInfo = buildBootStrap(serverUri); return new XdsClientImpl( - DEFAULT_XDS_TRANSPORT_FACTORY, + new GrpcXdsTransportFactory(null), bootstrapInfo, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java index 703e429fa23..9202d010b23 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java @@ -92,9 +92,10 @@ public void onCompleted() { @Test public void callApis() throws Exception { XdsTransportFactory.XdsTransport xdsTransport = - GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.create( - Bootstrapper.ServerInfo.create("localhost:" + server.getPort(), - InsecureChannelCredentials.create())); + (new GrpcXdsTransportFactory(null)) + .create( + Bootstrapper.ServerInfo.create( + "localhost:" + server.getPort(), InsecureChannelCredentials.create())); MethodDescriptor methodDescriptor = AggregatedDiscoveryServiceGrpc.getStreamAggregatedResourcesMethod(); XdsTransportFactory.StreamingCall streamingCall = diff --git a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java index c11a3a6e0d2..7fe0da751dd 100644 --- a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java @@ -178,11 +178,15 @@ public void cancelled(Context context) { when(backoffPolicy2.nextBackoffNanos()) .thenReturn(TimeUnit.SECONDS.toNanos(2L), TimeUnit.SECONDS.toNanos(20L)); addFakeStatsData(); - lrsClient = new LoadReportClient(loadStatsManager, - GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.createForTest(channel), - NODE, - syncContext, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, - fakeClock.getStopwatchSupplier()); + lrsClient = + new LoadReportClient( + loadStatsManager, + (new GrpcXdsTransportFactory(null)).createForTest(channel), + NODE, + syncContext, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier()); syncContext.execute(new Runnable() { @Override public void run() { diff --git a/xds/src/test/java/io/grpc/xds/PerTargetXdsTransportCallCredentialsTest.java b/xds/src/test/java/io/grpc/xds/PerTargetXdsTransportCallCredentialsTest.java new file mode 100644 index 00000000000..a634b743515 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/PerTargetXdsTransportCallCredentialsTest.java @@ -0,0 +1,189 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; +import static org.mockito.Mockito.when; + +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.OAuth2Credentials; +import io.envoyproxy.envoy.service.discovery.v3.AggregatedDiscoveryServiceGrpc; +import io.envoyproxy.envoy.service.discovery.v3.DiscoveryRequest; +import io.envoyproxy.envoy.service.discovery.v3.DiscoveryResponse; +import io.grpc.CallCredentials; +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; +import io.grpc.Metadata; +import io.grpc.MetricRecorder; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.auth.MoreCallCredentials; +import io.grpc.internal.ObjectPool; +import io.grpc.stub.StreamObserver; +import io.grpc.xds.XdsListenerResource.LdsUpdate; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.EnvoyProtoData.Node; +import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClient.ResourceWatcher; +import io.grpc.xds.client.XdsInitializationException; +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class PerTargetXdsTransportCallCredentialsTest { + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + static final Metadata.Key AUTHORIZATION_METADATA_KEY = + Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER); + static final Context.Key TOKEN_CONTEXT_KEY = Context.key("token"); + private final Node node = Node.newBuilder().setId("SharedXdsClientPoolProviderTest").build(); + private final MetricRecorder metricRecorder = new MetricRecorder() {}; + + private FakeAdsService adsService; + private Server xdsServer; + private String xdsServerUri; + // Used to notify the xDS client that the fake server has finished processing the request. + private CountDownLatch handleDiscoveryRequest; + + @Mock private GrpcBootstrapperImpl bootstrapper; + @Mock private ResourceWatcher ldsResourceWatcher; + + @Before + public void setup() throws Exception { + adsService = new FakeAdsService(); + xdsServer = + Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) + .addService(adsService) + .intercept(new CallCredsServerInterceptor()) + .build() + .start(); + xdsServerUri = "localhost:" + xdsServer.getPort(); + handleDiscoveryRequest = new CountDownLatch(1); + } + + @After + public void tearDown() { + xdsServer.shutdown(); + } + + private class CallCredsServerInterceptor implements ServerInterceptor { + @Override + public ServerCall.Listener interceptCall( + ServerCall serverCall, + Metadata metadata, + ServerCallHandler serverCallHandler) { + String callCredsValue = metadata.get(AUTHORIZATION_METADATA_KEY); + if (callCredsValue == null) { + serverCall.close( + Status.UNAUTHENTICATED.withDescription("Missing call credentials"), new Metadata()); + return new ServerCall.Listener() { + // noop + }; + } + // set access tokenValue into current context, to be consumed by the server + Context ctx = Context.current().withValue(TOKEN_CONTEXT_KEY, callCredsValue); + return Contexts.interceptCall(ctx, serverCall, metadata, serverCallHandler); + } + } + + private class FakeAdsService + extends AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase { + private String token; + + @Override + public StreamObserver streamAggregatedResources( + final StreamObserver responseObserver) { + StreamObserver requestObserver = + new StreamObserver() { + @Override + public void onNext(DiscoveryRequest value) { + token = TOKEN_CONTEXT_KEY.get().substring("Bearer".length()).trim(); + responseObserver.onNext(DiscoveryResponse.newBuilder().build()); + handleDiscoveryRequest.countDown(); + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + handleDiscoveryRequest.countDown(); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + handleDiscoveryRequest.countDown(); + } + }; + + return requestObserver; + } + + public boolean receivedToken(String expected) { + return token.equals(expected); + } + } + + @Test + public void usePerTargetXdsTransportCallCredentials() throws XdsInitializationException { + // Set up bootstrap & xDS client pool provider + ServerInfo server = ServerInfo.create(xdsServerUri, InsecureChannelCredentials.create()); + BootstrapInfo bootstrapInfo = + BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); + when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(bootstrapper); + + // Create custom xDS transport CallCredentials + CallCredentials sampleCreds = + MoreCallCredentials.from( + OAuth2Credentials.create(new AccessToken("token", /* expirationTime= */ null))); + + // Create xDS client & transport, and verify that the custom CallCredentials were used + ObjectPool xdsClientPool = + provider.getOrCreate("target", metricRecorder, sampleCreds); + XdsClient xdsClient = xdsClientPool.getObject(); + xdsClient.watchXdsResource( + XdsListenerResource.getInstance(), "someLDSresource", ldsResourceWatcher); + assertThat(waitForXdsServerDone()).isTrue(); + assertThat(adsService.receivedToken("token")).isTrue(); + } + + private boolean waitForXdsServerDone() { + try { + return handleDiscoveryRequest.await(5, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new AssertionError( + "Interrupted while waiting for xDS server to finish handling request", e); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java b/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java index 97c2695f209..0ea3e29c9ca 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java @@ -18,7 +18,6 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; -import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -344,17 +343,18 @@ public void connect_then_mainServerDown_fallbackServerUp() throws Exception { mainXdsServer.restartXdsServer(); fallbackServer.restartXdsServer(); ExecutorService executor = Executors.newFixedThreadPool(1); - XdsTransportFactory xdsTransportFactory = new XdsTransportFactory() { - @Override - public XdsTransport create(Bootstrapper.ServerInfo serverInfo) { - ChannelCredentials channelCredentials = - (ChannelCredentials) serverInfo.implSpecificConfig(); - return new GrpcXdsTransportFactory.GrpcXdsTransport( - Grpc.newChannelBuilder(serverInfo.target(), channelCredentials) - .executor(executor) - .build()); - } - }; + XdsTransportFactory xdsTransportFactory = + new XdsTransportFactory() { + @Override + public XdsTransport create(Bootstrapper.ServerInfo serverInfo) { + ChannelCredentials channelCredentials = + (ChannelCredentials) serverInfo.implSpecificConfig(); + return new GrpcXdsTransportFactory.GrpcXdsTransport( + Grpc.newChannelBuilder(serverInfo.target(), channelCredentials) + .executor(executor) + .build()); + } + }; XdsClientImpl xdsClient = CommonBootstrapperTestUtils.createXdsClient( new GrpcBootstrapperImpl().bootstrap(defaultBootstrapOverride()), xdsTransportFactory, fakeClock, new ExponentialBackoffPolicy.Provider(), @@ -442,9 +442,14 @@ public void fallbackFromBadUrlToGoodOne() { String garbageUri = "some. garbage"; String validUri = "localhost:" + mainXdsServer.getServer().getPort(); - XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient( - Arrays.asList(garbageUri, validUri), DEFAULT_XDS_TRANSPORT_FACTORY, fakeClock, - new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, xdsClientMetricReporter); + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(garbageUri, validUri), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); fakeClock.forwardTime(20, TimeUnit.SECONDS); @@ -462,9 +467,14 @@ public void testGoodUrlFollowedByBadUrl() { String garbageUri = "some. garbage"; String validUri = "localhost:" + mainXdsServer.getServer().getPort(); - XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient( - Arrays.asList(validUri, garbageUri), DEFAULT_XDS_TRANSPORT_FACTORY, fakeClock, - new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, xdsClientMetricReporter); + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(validUri, garbageUri), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); verify(ldsWatcher, timeout(5000)).onChanged( @@ -481,9 +491,14 @@ public void testTwoBadUrl() { String garbageUri1 = "some. garbage"; String garbageUri2 = "other garbage"; - XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient( - Arrays.asList(garbageUri1, garbageUri2), DEFAULT_XDS_TRANSPORT_FACTORY, fakeClock, - new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, xdsClientMetricReporter); + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(garbageUri1, garbageUri2), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); fakeClock.forwardTime(20, TimeUnit.SECONDS); From 081cbe0609dafee9d3bca6ffecb1659ce2f67f39 Mon Sep 17 00:00:00 2001 From: Ashley Zhang Date: Tue, 11 Mar 2025 18:23:30 +0000 Subject: [PATCH 2/6] Fix style --- xds/src/test/java/io/grpc/xds/LoadReportClientTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java index 7fe0da751dd..9bdf86132b6 100644 --- a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java @@ -181,7 +181,7 @@ public void cancelled(Context context) { lrsClient = new LoadReportClient( loadStatsManager, - (new GrpcXdsTransportFactory(null)).createForTest(channel), + new GrpcXdsTransportFactory(null).createForTest(channel), NODE, syncContext, fakeClock.getScheduledExecutorService(), From 3aa5d6b37bd007ac98f1c37d8b44d0d5eec34050 Mon Sep 17 00:00:00 2001 From: Ashley Zhang Date: Tue, 11 Mar 2025 18:38:07 +0000 Subject: [PATCH 3/6] Fix style --- xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java index 9202d010b23..5191d88f9c9 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java @@ -92,7 +92,7 @@ public void onCompleted() { @Test public void callApis() throws Exception { XdsTransportFactory.XdsTransport xdsTransport = - (new GrpcXdsTransportFactory(null)) + new GrpcXdsTransportFactory(null) .create( Bootstrapper.ServerInfo.create( "localhost:" + server.getPort(), InsecureChannelCredentials.create())); @@ -140,4 +140,3 @@ public void onStatusReceived(Status status) { } } } - From 524e4921367fe6bd42cbcbcd7549d873524d2408 Mon Sep 17 00:00:00 2001 From: Ashley Zhang Date: Wed, 19 Mar 2025 19:26:03 +0000 Subject: [PATCH 4/6] Simplify test & formatting --- .../grpc/xds/GrpcXdsClientImplTestBase.java | 47 +++++----- .../grpc/xds/GrpcXdsTransportFactoryTest.java | 1 + .../xds/SharedXdsClientPoolProviderTest.java | 86 +++++++++++++++++++ .../io/grpc/xds/XdsClientFallbackTest.java | 23 +++-- 4 files changed, 120 insertions(+), 37 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java index e332ed5b472..36131464d08 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java @@ -323,32 +323,29 @@ public void setUp() throws IOException { .start()); channel = cleanupRule.register(InProcessChannelBuilder.forName(serverName).directExecutor().build()); - XdsTransportFactory xdsTransportFactory = - new XdsTransportFactory() { - @Override - public XdsTransport create(ServerInfo serverInfo) { - if (serverInfo.target().equals(SERVER_URI)) { - return new GrpcXdsTransport(channel); - } - if (serverInfo.target().equals(SERVER_URI_CUSTOME_AUTHORITY)) { - if (channelForCustomAuthority == null) { - channelForCustomAuthority = - cleanupRule.register( - InProcessChannelBuilder.forName(serverName).directExecutor().build()); - } - return new GrpcXdsTransport(channelForCustomAuthority); - } - if (serverInfo.target().equals(SERVER_URI_EMPTY_AUTHORITY)) { - if (channelForEmptyAuthority == null) { - channelForEmptyAuthority = - cleanupRule.register( - InProcessChannelBuilder.forName(serverName).directExecutor().build()); - } - return new GrpcXdsTransport(channelForEmptyAuthority); - } - throw new IllegalArgumentException("Can not create channel for " + serverInfo); + XdsTransportFactory xdsTransportFactory = new XdsTransportFactory() { + @Override + public XdsTransport create(ServerInfo serverInfo) { + if (serverInfo.target().equals(SERVER_URI)) { + return new GrpcXdsTransport(channel); + } + if (serverInfo.target().equals(SERVER_URI_CUSTOME_AUTHORITY)) { + if (channelForCustomAuthority == null) { + channelForCustomAuthority = cleanupRule.register( + InProcessChannelBuilder.forName(serverName).directExecutor().build()); + } + return new GrpcXdsTransport(channelForCustomAuthority); + } + if (serverInfo.target().equals(SERVER_URI_EMPTY_AUTHORITY)) { + if (channelForEmptyAuthority == null) { + channelForEmptyAuthority = cleanupRule.register( + InProcessChannelBuilder.forName(serverName).directExecutor().build()); } - }; + return new GrpcXdsTransport(channelForEmptyAuthority); + } + throw new IllegalArgumentException("Can not create channel for " + serverInfo); + } + }; xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, ignoreResourceDeletion(), true); diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java index 5191d88f9c9..66e0d4b3198 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java @@ -140,3 +140,4 @@ public void onStatusReceived(Status status) { } } } + diff --git a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java index 4fb77f0be42..980d3a96f8c 100644 --- a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java @@ -18,20 +18,36 @@ import static com.google.common.truth.Truth.assertThat; +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.OAuth2Credentials; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.CallCredentials; +import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; +import io.grpc.Metadata; import io.grpc.MetricRecorder; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.auth.MoreCallCredentials; import io.grpc.internal.ObjectPool; import io.grpc.xds.SharedXdsClientPoolProvider.RefCountedXdsClientObjectPool; +import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.EnvoyProtoData.Node; import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClient.ResourceWatcher; import io.grpc.xds.client.XdsInitializationException; import java.util.Collections; +import java.util.concurrent.TimeUnit; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -54,9 +70,12 @@ public class SharedXdsClientPoolProviderTest { private final Node node = Node.newBuilder().setId("SharedXdsClientPoolProviderTest").build(); private final MetricRecorder metricRecorder = new MetricRecorder() {}; private static final String DUMMY_TARGET = "dummy"; + static final Metadata.Key AUTHORIZATION_METADATA_KEY = + Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER); @Mock private GrpcBootstrapperImpl bootstrapper; + @Mock private ResourceWatcher ldsResourceWatcher; @Test public void noServer() throws XdsInitializationException { @@ -138,4 +157,71 @@ public void refCountedXdsClientObjectPool_getObjectCreatesNewInstanceIfAlreadySh assertThat(xdsClient2).isNotSameInstanceAs(xdsClient1); xdsClientPool.returnObject(xdsClient2); } + + private class CallCredsServerInterceptor implements ServerInterceptor { + private String token; + private SettableFuture requestDone = SettableFuture.create(); + + @Override + public ServerCall.Listener interceptCall( + ServerCall serverCall, + Metadata metadata, + ServerCallHandler next) { + String callCreds = metadata.get(AUTHORIZATION_METADATA_KEY); + if (callCreds != null) { + token = callCreds.substring("Bearer".length()).trim(); + } + requestDone.set(null); + return next.startCall(serverCall, metadata); + } + + public String getToken() { + return token; + } + + public Void waitForRequestDone(long timeout, TimeUnit unit) throws Exception { + return requestDone.get(timeout, unit); + } + } + + @Test + public void xdsClient_usesCallCredentials() throws Exception { + // Set up fake xDS server + XdsTestControlPlaneService fakeXdsService = new XdsTestControlPlaneService(); + CallCredsServerInterceptor callCredentialsInterceptor = new CallCredsServerInterceptor(); + Server xdsServer = + Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) + .addService(fakeXdsService) + .intercept(callCredentialsInterceptor) + .build() + .start(); + String xdsServerUri = "localhost:" + xdsServer.getPort(); + + // Set up bootstrap & xDS client pool provider + ServerInfo server = ServerInfo.create(xdsServerUri, InsecureChannelCredentials.create()); + BootstrapInfo bootstrapInfo = + BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); + when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(bootstrapper); + + // Create custom xDS transport CallCredentials + CallCredentials sampleCreds = + MoreCallCredentials.from( + OAuth2Credentials.create(new AccessToken("token", /* expirationTime= */ null))); + + // Create xDS client that uses the CallCredentials on the transport + ObjectPool xdsClientPool = + provider.getOrCreate("target", metricRecorder, sampleCreds); + XdsClient xdsClient = xdsClientPool.getObject(); + xdsClient.watchXdsResource( + XdsListenerResource.getInstance(), "someLDSresource", ldsResourceWatcher); + + // Wait for xDS server to get the request and verify that it received the CallCredentials + assertThat(callCredentialsInterceptor.waitForRequestDone(5, TimeUnit.SECONDS)).isNull(); + assertThat(callCredentialsInterceptor.getToken()).isEqualTo("token"); + + // Clean up + xdsClientPool.returnObject(xdsClient); + xdsServer.shutdownNow(); + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java b/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java index 0ea3e29c9ca..036b9f6f55d 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java @@ -343,18 +343,17 @@ public void connect_then_mainServerDown_fallbackServerUp() throws Exception { mainXdsServer.restartXdsServer(); fallbackServer.restartXdsServer(); ExecutorService executor = Executors.newFixedThreadPool(1); - XdsTransportFactory xdsTransportFactory = - new XdsTransportFactory() { - @Override - public XdsTransport create(Bootstrapper.ServerInfo serverInfo) { - ChannelCredentials channelCredentials = - (ChannelCredentials) serverInfo.implSpecificConfig(); - return new GrpcXdsTransportFactory.GrpcXdsTransport( - Grpc.newChannelBuilder(serverInfo.target(), channelCredentials) - .executor(executor) - .build()); - } - }; + XdsTransportFactory xdsTransportFactory = new XdsTransportFactory() { + @Override + public XdsTransport create(Bootstrapper.ServerInfo serverInfo) { + ChannelCredentials channelCredentials = + (ChannelCredentials) serverInfo.implSpecificConfig(); + return new GrpcXdsTransportFactory.GrpcXdsTransport( + Grpc.newChannelBuilder(serverInfo.target(), channelCredentials) + .executor(executor) + .build()); + } + }; XdsClientImpl xdsClient = CommonBootstrapperTestUtils.createXdsClient( new GrpcBootstrapperImpl().bootstrap(defaultBootstrapOverride()), xdsTransportFactory, fakeClock, new ExponentialBackoffPolicy.Provider(), From 4e76e21f3c490222f31ba877f48b052516852b15 Mon Sep 17 00:00:00 2001 From: Ashley Zhang Date: Wed, 19 Mar 2025 21:43:17 +0000 Subject: [PATCH 5/6] Clean up old & new test --- ...TargetXdsTransportCallCredentialsTest.java | 189 ------------------ .../xds/SharedXdsClientPoolProviderTest.java | 21 +- 2 files changed, 6 insertions(+), 204 deletions(-) delete mode 100644 xds/src/test/java/io/grpc/xds/PerTargetXdsTransportCallCredentialsTest.java diff --git a/xds/src/test/java/io/grpc/xds/PerTargetXdsTransportCallCredentialsTest.java b/xds/src/test/java/io/grpc/xds/PerTargetXdsTransportCallCredentialsTest.java deleted file mode 100644 index a634b743515..00000000000 --- a/xds/src/test/java/io/grpc/xds/PerTargetXdsTransportCallCredentialsTest.java +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Copyright 2024 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.truth.Truth.assertThat; -import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; -import static org.mockito.Mockito.when; - -import com.google.auth.oauth2.AccessToken; -import com.google.auth.oauth2.OAuth2Credentials; -import io.envoyproxy.envoy.service.discovery.v3.AggregatedDiscoveryServiceGrpc; -import io.envoyproxy.envoy.service.discovery.v3.DiscoveryRequest; -import io.envoyproxy.envoy.service.discovery.v3.DiscoveryResponse; -import io.grpc.CallCredentials; -import io.grpc.Context; -import io.grpc.Contexts; -import io.grpc.Grpc; -import io.grpc.InsecureChannelCredentials; -import io.grpc.InsecureServerCredentials; -import io.grpc.Metadata; -import io.grpc.MetricRecorder; -import io.grpc.Server; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; -import io.grpc.Status; -import io.grpc.auth.MoreCallCredentials; -import io.grpc.internal.ObjectPool; -import io.grpc.stub.StreamObserver; -import io.grpc.xds.XdsListenerResource.LdsUpdate; -import io.grpc.xds.client.Bootstrapper.BootstrapInfo; -import io.grpc.xds.client.Bootstrapper.ServerInfo; -import io.grpc.xds.client.EnvoyProtoData.Node; -import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsClient.ResourceWatcher; -import io.grpc.xds.client.XdsInitializationException; -import java.util.Collections; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnit; -import org.mockito.junit.MockitoRule; - -@RunWith(JUnit4.class) -public class PerTargetXdsTransportCallCredentialsTest { - @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - static final Metadata.Key AUTHORIZATION_METADATA_KEY = - Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER); - static final Context.Key TOKEN_CONTEXT_KEY = Context.key("token"); - private final Node node = Node.newBuilder().setId("SharedXdsClientPoolProviderTest").build(); - private final MetricRecorder metricRecorder = new MetricRecorder() {}; - - private FakeAdsService adsService; - private Server xdsServer; - private String xdsServerUri; - // Used to notify the xDS client that the fake server has finished processing the request. - private CountDownLatch handleDiscoveryRequest; - - @Mock private GrpcBootstrapperImpl bootstrapper; - @Mock private ResourceWatcher ldsResourceWatcher; - - @Before - public void setup() throws Exception { - adsService = new FakeAdsService(); - xdsServer = - Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) - .addService(adsService) - .intercept(new CallCredsServerInterceptor()) - .build() - .start(); - xdsServerUri = "localhost:" + xdsServer.getPort(); - handleDiscoveryRequest = new CountDownLatch(1); - } - - @After - public void tearDown() { - xdsServer.shutdown(); - } - - private class CallCredsServerInterceptor implements ServerInterceptor { - @Override - public ServerCall.Listener interceptCall( - ServerCall serverCall, - Metadata metadata, - ServerCallHandler serverCallHandler) { - String callCredsValue = metadata.get(AUTHORIZATION_METADATA_KEY); - if (callCredsValue == null) { - serverCall.close( - Status.UNAUTHENTICATED.withDescription("Missing call credentials"), new Metadata()); - return new ServerCall.Listener() { - // noop - }; - } - // set access tokenValue into current context, to be consumed by the server - Context ctx = Context.current().withValue(TOKEN_CONTEXT_KEY, callCredsValue); - return Contexts.interceptCall(ctx, serverCall, metadata, serverCallHandler); - } - } - - private class FakeAdsService - extends AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase { - private String token; - - @Override - public StreamObserver streamAggregatedResources( - final StreamObserver responseObserver) { - StreamObserver requestObserver = - new StreamObserver() { - @Override - public void onNext(DiscoveryRequest value) { - token = TOKEN_CONTEXT_KEY.get().substring("Bearer".length()).trim(); - responseObserver.onNext(DiscoveryResponse.newBuilder().build()); - handleDiscoveryRequest.countDown(); - } - - @Override - public void onError(Throwable t) { - responseObserver.onError(t); - handleDiscoveryRequest.countDown(); - } - - @Override - public void onCompleted() { - responseObserver.onCompleted(); - handleDiscoveryRequest.countDown(); - } - }; - - return requestObserver; - } - - public boolean receivedToken(String expected) { - return token.equals(expected); - } - } - - @Test - public void usePerTargetXdsTransportCallCredentials() throws XdsInitializationException { - // Set up bootstrap & xDS client pool provider - ServerInfo server = ServerInfo.create(xdsServerUri, InsecureChannelCredentials.create()); - BootstrapInfo bootstrapInfo = - BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); - when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); - SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(bootstrapper); - - // Create custom xDS transport CallCredentials - CallCredentials sampleCreds = - MoreCallCredentials.from( - OAuth2Credentials.create(new AccessToken("token", /* expirationTime= */ null))); - - // Create xDS client & transport, and verify that the custom CallCredentials were used - ObjectPool xdsClientPool = - provider.getOrCreate("target", metricRecorder, sampleCreds); - XdsClient xdsClient = xdsClientPool.getObject(); - xdsClient.watchXdsResource( - XdsListenerResource.getInstance(), "someLDSresource", ldsResourceWatcher); - assertThat(waitForXdsServerDone()).isTrue(); - assertThat(adsService.receivedToken("token")).isTrue(); - } - - private boolean waitForXdsServerDone() { - try { - return handleDiscoveryRequest.await(5, TimeUnit.SECONDS); - } catch (InterruptedException e) { - throw new AssertionError( - "Interrupted while waiting for xDS server to finish handling request", e); - } - } -} diff --git a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java index 980d3a96f8c..86e4fc83a8c 100644 --- a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java @@ -159,28 +159,19 @@ public void refCountedXdsClientObjectPool_getObjectCreatesNewInstanceIfAlreadySh } private class CallCredsServerInterceptor implements ServerInterceptor { - private String token; - private SettableFuture requestDone = SettableFuture.create(); + private SettableFuture tokenFuture = SettableFuture.create(); @Override public ServerCall.Listener interceptCall( ServerCall serverCall, Metadata metadata, ServerCallHandler next) { - String callCreds = metadata.get(AUTHORIZATION_METADATA_KEY); - if (callCreds != null) { - token = callCreds.substring("Bearer".length()).trim(); - } - requestDone.set(null); + tokenFuture.set(metadata.get(AUTHORIZATION_METADATA_KEY)); return next.startCall(serverCall, metadata); } - public String getToken() { - return token; - } - - public Void waitForRequestDone(long timeout, TimeUnit unit) throws Exception { - return requestDone.get(timeout, unit); + public String getTokenWithTimeout(long timeout, TimeUnit unit) throws Exception { + return tokenFuture.get(timeout, unit); } } @@ -217,8 +208,8 @@ public void xdsClient_usesCallCredentials() throws Exception { XdsListenerResource.getInstance(), "someLDSresource", ldsResourceWatcher); // Wait for xDS server to get the request and verify that it received the CallCredentials - assertThat(callCredentialsInterceptor.waitForRequestDone(5, TimeUnit.SECONDS)).isNull(); - assertThat(callCredentialsInterceptor.getToken()).isEqualTo("token"); + assertThat(callCredentialsInterceptor.getTokenWithTimeout(5, TimeUnit.SECONDS)) + .isEqualTo("Bearer token"); // Clean up xdsClientPool.returnObject(xdsClient); From 78d54dd33986337adcf0b64885768b3d9ef66f84 Mon Sep 17 00:00:00 2001 From: Ashley Zhang Date: Thu, 20 Mar 2025 00:40:41 +0000 Subject: [PATCH 6/6] Remove unnecessary InternalSharedXdsClientPoolProvider.getOrCreate() overload --- .../io/grpc/xds/InternalSharedXdsClientPoolProvider.java | 5 ----- 1 file changed, 5 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java index 5585992e204..9c98bba93cf 100644 --- a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java @@ -46,11 +46,6 @@ public static ObjectPool getOrCreate(String target, MetricRecorder me return getOrCreate(target, metricRecorder, null); } - public static ObjectPool getOrCreate( - String target, CallCredentials transportCallCredentials) throws XdsInitializationException { - return getOrCreate(target, new MetricRecorder() {}, transportCallCredentials); - } - public static ObjectPool getOrCreate( String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials) throws XdsInitializationException {