From 226f22266d308a73e2ae9587f41cdaf660c4f8b3 Mon Sep 17 00:00:00 2001 From: Jaykumar Gosar Date: Tue, 25 Apr 2023 01:00:02 -0700 Subject: [PATCH] Remove identity join in async client for endpoint discovery --- .../codegen/poet/client/AsyncClientClass.java | 27 +++--- .../poet/client/specs/JsonProtocolSpec.java | 10 ++- .../client/test-endpoint-discovery-async.java | 87 +++++++++++-------- 3 files changed, 73 insertions(+), 51 deletions(-) diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java index 32a522c3bab2..cef74db645ca 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java @@ -35,6 +35,7 @@ import com.squareup.javapoet.ParameterizedTypeName; import com.squareup.javapoet.TypeName; import com.squareup.javapoet.TypeSpec; +import com.squareup.javapoet.WildcardTypeName; import java.net.URI; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -80,6 +81,7 @@ import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRefreshCache; import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest; import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -374,24 +376,29 @@ protected MethodSpec.Builder operationBody(MethodSpec.Builder builder, Operation } } - builder.addStatement("$T cachedEndpoint = null", URI.class); + builder.addStatement("$T<$T> endpointFuture = $T.completedFuture(null)", + CompletableFuture.class, URI.class, CompletableFuture.class); builder.beginControlFlow("if (endpointDiscoveryEnabled)"); - builder.addCode("$T key = $N.overrideConfiguration()", String.class, opModel.getInput().getVariableName()) + ParameterizedTypeName identityFutureTypeName = ParameterizedTypeName.get(ClassName.get(CompletableFuture.class), + WildcardTypeName.subtypeOf(AwsCredentialsIdentity.class)); + builder.addCode("$T identityFuture = $N.overrideConfiguration()", identityFutureTypeName, + opModel.getInput().getVariableName()) .addCode(" .flatMap($T::credentialsIdentityProvider)", AwsRequestOverrideConfiguration.class) .addCode(" .orElseGet(() -> clientConfiguration.option($T.CREDENTIALS_IDENTITY_PROVIDER))", AwsClientOption.class) - // TODO: avoid join inside async - .addCode(" .resolveIdentity().join().accessKeyId();"); + .addCode(" .resolveIdentity();"); - builder.addCode("$1T endpointDiscoveryRequest = $1T.builder()", EndpointDiscoveryRequest.class) - .addCode(" .required($L)", opModel.getInputShape().getEndpointDiscovery().isRequired()) - .addCode(" .defaultEndpoint(clientConfiguration.option($T.ENDPOINT))", SdkClientOption.class) - .addCode(" .overrideConfiguration($N.overrideConfiguration().orElse(null))", + builder.addCode("endpointFuture = identityFuture.thenApply(credentials -> {") + .addCode(" $1T endpointDiscoveryRequest = $1T.builder()", EndpointDiscoveryRequest.class) + .addCode(" .required($L)", opModel.getInputShape().getEndpointDiscovery().isRequired()) + .addCode(" .defaultEndpoint(clientConfiguration.option($T.ENDPOINT))", SdkClientOption.class) + .addCode(" .overrideConfiguration($N.overrideConfiguration().orElse(null))", opModel.getInput().getVariableName()) - .addCode(" .build();"); + .addCode(" .build();") + .addCode(" return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);") + .addCode("});"); - builder.addStatement("cachedEndpoint = endpointDiscoveryCache.get(key, endpointDiscoveryRequest)"); builder.endControlFlow(); } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java index 41361004b80f..1a3186c0c2f0 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java @@ -240,8 +240,9 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper : pojoResponseType; TypeName executeFutureValueType = executeFutureValueType(opModel, poetExtensions); - builder.add("\n\n$T<$T> executeFuture = clientHandler.execute(new $T<$T, $T>()\n", - CompletableFuture.class, executeFutureValueType, ClientExecutionParams.class, requestType, responseType) + builder.add("\n\n$T<$T> executeFuture = ", CompletableFuture.class, executeFutureValueType) + .add(opModel.getEndpointDiscovery() != null ? "endpointFuture.thenCompose(cachedEndpoint -> " : "") + .add("clientHandler.execute(new $T<$T, $T>()\n", ClientExecutionParams.class, requestType, responseType) .add(".withOperationName(\"$N\")\n", opModel.getOperationName()) .add(".withMarshaller($L)\n", asyncMarshaller(model, opModel, marshaller, protocolFactory)) .add(asyncRequestBody(opModel)) @@ -257,8 +258,9 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)) .add(NoneAuthTypeRequestTrait.create(opModel)) - .add(".withInput($L)$L);", - opModel.getInput().getVariableName(), asyncResponseTransformerVariable(isStreaming, isRestJson, opModel)); + .add(".withInput($L)$L)", + opModel.getInput().getVariableName(), asyncResponseTransformerVariable(isStreaming, isRestJson, opModel)) + .add(opModel.getEndpointDiscovery() != null ? ");" : ";"); if (opModel.hasStreamingOutput()) { builder.addStatement("$T<$T, ReturnT> finalAsyncResponseTransformer = asyncResponseTransformer", diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java index 243af9e65099..d408310a5d94 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java @@ -23,6 +23,7 @@ import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest; import software.amazon.awssdk.core.http.HttpResponseHandler; import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -175,26 +176,30 @@ public CompletableFuture testDiscovery throw new IllegalStateException( "This operation requires endpoint discovery, but endpoint discovery was disabled on the client."); } - URI cachedEndpoint = null; + CompletableFuture endpointFuture = CompletableFuture.completedFuture(null); if (endpointDiscoveryEnabled) { - String key = testDiscoveryIdentifiersRequiredRequest.overrideConfiguration() - .flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider) - .orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER)) - .resolveIdentity().join().accessKeyId(); - EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true) - .defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT)) - .overrideConfiguration(testDiscoveryIdentifiersRequiredRequest.overrideConfiguration().orElse(null)) - .build(); - cachedEndpoint = endpointDiscoveryCache.get(key, endpointDiscoveryRequest); + CompletableFuture identityFuture = + testDiscoveryIdentifiersRequiredRequest.overrideConfiguration() + .flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider) + .orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER)) + .resolveIdentity(); + endpointFuture = identityFuture.thenApply(credentials -> { + EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true) + .defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT)) + .overrideConfiguration(testDiscoveryIdentifiersRequiredRequest.overrideConfiguration().orElse(null)) + .build(); + return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest); + }); } - CompletableFuture executeFuture = clientHandler - .execute(new ClientExecutionParams() + CompletableFuture executeFuture = + endpointFuture.thenCompose(cachedEndpoint -> + clientHandler.execute(new ClientExecutionParams() .withOperationName("TestDiscoveryIdentifiersRequired") .withMarshaller(new TestDiscoveryIdentifiersRequiredRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withMetricCollector(apiCallMetricCollector).discoveredEndpoint(cachedEndpoint) - .withInput(testDiscoveryIdentifiersRequiredRequest)); + .withInput(testDiscoveryIdentifiersRequiredRequest))); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -243,25 +248,29 @@ public CompletableFuture testDiscoveryOptional( operationMetadata); boolean endpointDiscoveryEnabled = clientConfiguration.option(SdkClientOption.ENDPOINT_DISCOVERY_ENABLED); boolean endpointOverridden = clientConfiguration.option(SdkClientOption.ENDPOINT_OVERRIDDEN) == Boolean.TRUE; - URI cachedEndpoint = null; + CompletableFuture endpointFuture = CompletableFuture.completedFuture(null); if (endpointDiscoveryEnabled) { - String key = testDiscoveryOptionalRequest.overrideConfiguration() - .flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider) - .orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER)) - .resolveIdentity().join().accessKeyId(); - EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(false) - .defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT)) - .overrideConfiguration(testDiscoveryOptionalRequest.overrideConfiguration().orElse(null)).build(); - cachedEndpoint = endpointDiscoveryCache.get(key, endpointDiscoveryRequest); + CompletableFuture identityFuture = + testDiscoveryOptionalRequest.overrideConfiguration() + .flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider) + .orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER)) + .resolveIdentity(); + endpointFuture = identityFuture.thenApply(credentials -> { + EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(false) + .defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT)) + .overrideConfiguration(testDiscoveryOptionalRequest.overrideConfiguration().orElse(null)).build(); + return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest); + }); } - CompletableFuture executeFuture = clientHandler - .execute(new ClientExecutionParams() + CompletableFuture executeFuture = + endpointFuture.thenCompose(cachedEndpoint -> + clientHandler.execute(new ClientExecutionParams() .withOperationName("TestDiscoveryOptional") .withMarshaller(new TestDiscoveryOptionalRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withMetricCollector(apiCallMetricCollector).discoveredEndpoint(cachedEndpoint) - .withInput(testDiscoveryOptionalRequest)); + .withInput(testDiscoveryOptionalRequest))); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -318,25 +327,29 @@ public CompletableFuture testDiscoveryRequired( throw new IllegalStateException( "This operation requires endpoint discovery, but endpoint discovery was disabled on the client."); } - URI cachedEndpoint = null; + CompletableFuture endpointFuture = CompletableFuture.completedFuture(null); if (endpointDiscoveryEnabled) { - String key = testDiscoveryRequiredRequest.overrideConfiguration() - .flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider) - .orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER)) - .resolveIdentity().join().accessKeyId(); - EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true) - .defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT)) - .overrideConfiguration(testDiscoveryRequiredRequest.overrideConfiguration().orElse(null)).build(); - cachedEndpoint = endpointDiscoveryCache.get(key, endpointDiscoveryRequest); + CompletableFuture identityFuture = + testDiscoveryRequiredRequest.overrideConfiguration() + .flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider) + .orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER)) + .resolveIdentity(); + endpointFuture = identityFuture.thenApply(credentials -> { + EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true) + .defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT)) + .overrideConfiguration(testDiscoveryRequiredRequest.overrideConfiguration().orElse(null)).build(); + return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest); + }); } - CompletableFuture executeFuture = clientHandler - .execute(new ClientExecutionParams() + CompletableFuture executeFuture = + endpointFuture.thenCompose(cachedEndpoint -> + clientHandler.execute(new ClientExecutionParams() .withOperationName("TestDiscoveryRequired") .withMarshaller(new TestDiscoveryRequiredRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withMetricCollector(apiCallMetricCollector).discoveredEndpoint(cachedEndpoint) - .withInput(testDiscoveryRequiredRequest)); + .withInput(testDiscoveryRequiredRequest))); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); });