Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ public final class Constant {

public static final String PACKAGE_NAME_RULES_PATTERN = "%s.endpoints";

// TODO: reviewer: maybe just auth instead of authscheme? or auth.scheme?
public static final String PACKAGE_NAME_AUTH_SCHEME_PATTERN = "%s.auth.scheme";

public static final String PACKAGE_NAME_SMOKE_TEST_PATTERN = "%s.smoketests";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ public class CustomizationConfig {
*/
private boolean requiredTraitValidationEnabled = false;

/**
* Whether SRA based auth logic should be used.
*/
private boolean useSraAuth = false;

/**
* Whether to generate auth scheme params based on endpoint params.
*/
Expand Down Expand Up @@ -702,6 +707,16 @@ public void setRequiredTraitValidationEnabled(boolean requiredTraitValidationEna
this.requiredTraitValidationEnabled = requiredTraitValidationEnabled;
}

public void setUseSraAuth(boolean useSraAuth) {
this.useSraAuth = useSraAuth;
}

// TODO(post-sra-identity-auth): Remove this customization and all related switching logic, keeping only the
// useSraAuth==true branch going forward.
public boolean useSraAuth() {
return useSraAuth;
}

public void setEnableEndpointAuthSchemeParams(boolean enableEndpointAuthSchemeParams) {
this.enableEndpointAuthSchemeParams = enableEndpointAuthSchemeParams;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@
public final class AuthSchemeSpecUtils {
private static final Set<String> DEFAULT_AUTH_SCHEME_PARAMS = Collections.unmodifiableSet(setOf("region", "operation"));
private final IntermediateModel intermediateModel;
private final boolean useSraAuth;
private final Set<String> allowedEndpointAuthSchemeParams;
private final boolean allowedEndpointAuthSchemeParamsConfigured;

public AuthSchemeSpecUtils(IntermediateModel intermediateModel) {
this.intermediateModel = intermediateModel;
CustomizationConfig customization = intermediateModel.getCustomizationConfig();
this.useSraAuth = customization.useSraAuth();
if (customization.getAllowedEndpointAuthSchemeParamsConfigured()) {
this.allowedEndpointAuthSchemeParams = Collections.unmodifiableSet(
new HashSet<>(customization.getAllowedEndpointAuthSchemeParams()));
Expand All @@ -57,6 +59,10 @@ public AuthSchemeSpecUtils(IntermediateModel intermediateModel) {
}
}

public boolean useSraAuth() {
return useSraAuth;
}

private String basePackage() {
return intermediateModel.getMetadata().getFullAuthSchemePackageName();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,15 @@ public TypeSpec poetSpec() {
.build());
}

builder.addField(FieldSpec.builder(ParameterizedTypeName.get(ClassName.get(Map.class),
ClassName.get(String.class),
GENERIC_AUTH_SCHEME_TYPE),
"additionalAuthSchemes")
.addModifiers(PRIVATE, FINAL)
.initializer("new $T<>()", HashMap.class)
.build());
if (authSchemeSpecUtils.useSraAuth()) {
builder.addField(FieldSpec.builder(ParameterizedTypeName.get(ClassName.get(Map.class),
ClassName.get(String.class),
GENERIC_AUTH_SCHEME_TYPE),
"additionalAuthSchemes")
.addModifiers(PRIVATE, FINAL)
.initializer("new $T<>()", HashMap.class)
.build());
}

builder.addMethod(serviceEndpointPrefixMethod());
builder.addMethod(serviceNameMethod());
Expand All @@ -130,13 +132,18 @@ public TypeSpec poetSpec() {
mergeInternalDefaultsMethod().ifPresent(builder::addMethod);

builder.addMethod(finalizeServiceConfigurationMethod());
defaultAwsAuthSignerMethod().ifPresent(builder::addMethod);
if (!authSchemeSpecUtils.useSraAuth()) {
defaultAwsAuthSignerMethod().ifPresent(builder::addMethod);
}
builder.addMethod(signingNameMethod());
builder.addMethod(defaultEndpointProviderMethod());

builder.addMethod(authSchemeProviderMethod());
builder.addMethod(defaultAuthSchemeProviderMethod());
builder.addMethod(putAuthSchemeMethod());
if (authSchemeSpecUtils.useSraAuth()) {
builder.addMethod(authSchemeProviderMethod());
builder.addMethod(defaultAuthSchemeProviderMethod());
builder.addMethod(putAuthSchemeMethod());
builder.addMethod(authSchemesMethod());
}

if (hasClientContextParams()) {
model.getClientContextParams().forEach((n, m) -> {
Expand All @@ -157,9 +164,10 @@ public TypeSpec poetSpec() {

if (AuthUtils.usesBearerAuth(model)) {
builder.addMethod(defaultBearerTokenProviderMethod());
builder.addMethod(defaultTokenAuthSignerMethod());
if (!authSchemeSpecUtils.useSraAuth()) {
builder.addMethod(defaultTokenAuthSignerMethod());
}
}
builder.addMethod(authSchemesMethod());

addServiceHttpConfigIfNeeded(builder, model);

Expand Down Expand Up @@ -222,11 +230,14 @@ private MethodSpec mergeServiceDefaultsMethod() {
.addCode("return config.merge(c -> c");

builder.addCode(".option($T.ENDPOINT_PROVIDER, defaultEndpointProvider())", SdkClientOption.class);
builder.addCode(".option($T.AUTH_SCHEME_PROVIDER, defaultAuthSchemeProvider())", SdkClientOption.class);
builder.addCode(".option($T.AUTH_SCHEMES, authSchemes())", SdkClientOption.class);

if (defaultAwsAuthSignerMethod().isPresent()) {
builder.addCode(".option($T.SIGNER, defaultSigner())\n", SdkAdvancedClientOption.class);
if (authSchemeSpecUtils.useSraAuth()) {
builder.addCode(".option($T.AUTH_SCHEME_PROVIDER, defaultAuthSchemeProvider())", SdkClientOption.class);
builder.addCode(".option($T.AUTH_SCHEMES, authSchemes())", SdkClientOption.class);
} else {
if (defaultAwsAuthSignerMethod().isPresent()) {
builder.addCode(".option($T.SIGNER, defaultSigner())\n", SdkAdvancedClientOption.class);
}
}
builder.addCode(".option($T.CRC32_FROM_COMPRESSED_DATA_ENABLED, $L)\n",
SdkClientOption.class, crc32FromCompressedDataEnabled);
Expand All @@ -239,7 +250,9 @@ private MethodSpec mergeServiceDefaultsMethod() {

if (AuthUtils.usesBearerAuth(model)) {
builder.addCode(".option($T.TOKEN_IDENTITY_PROVIDER, defaultTokenProvider())\n", AwsClientOption.class);
builder.addCode(".option($T.TOKEN_SIGNER, defaultTokenSigner())", SdkAdvancedClientOption.class);
if (!authSchemeSpecUtils.useSraAuth()) {
builder.addCode(".option($T.TOKEN_SIGNER, defaultTokenSigner())", SdkAdvancedClientOption.class);
}
}

builder.addCode(");");
Expand Down Expand Up @@ -291,7 +304,9 @@ private MethodSpec finalizeServiceConfigurationMethod() {

List<ClassName> builtInInterceptors = new ArrayList<>();

builtInInterceptors.add(authSchemeSpecUtils.authSchemeInterceptor());
if (authSchemeSpecUtils.useSraAuth()) {
builtInInterceptors.add(authSchemeSpecUtils.authSchemeInterceptor());
}
builtInInterceptors.add(endpointRulesSpecUtils.resolverInterceptorName());
builtInInterceptors.add(endpointRulesSpecUtils.requestModifierInterceptorName());

Expand Down Expand Up @@ -730,19 +745,21 @@ private MethodSpec validateClientOptionsMethod() {
.addParameter(SdkClientConfiguration.class, "c")
.returns(void.class);

if (AuthUtils.usesAwsAuth(model)) {
if (AuthUtils.usesAwsAuth(model) && !authSchemeSpecUtils.useSraAuth()) {
builder.addStatement("$T.notNull(c.option($T.SIGNER), $S)",
Validate.class,
SdkAdvancedClientOption.class,
"The 'overrideConfiguration.advancedOption[SIGNER]' must be configured in the client builder.");
}

if (AuthUtils.usesBearerAuth(model)) {
builder.addStatement("$T.notNull(c.option($T.TOKEN_SIGNER), $S)",
Validate.class,
SdkAdvancedClientOption.class,
"The 'overrideConfiguration.advancedOption[TOKEN_SIGNER]' "
+ "must be configured in the client builder.");
if (!authSchemeSpecUtils.useSraAuth()) {
builder.addStatement("$T.notNull(c.option($T.TOKEN_SIGNER), $S)",
Validate.class,
SdkAdvancedClientOption.class,
"The 'overrideConfiguration.advancedOption[TOKEN_SIGNER]' "
+ "must be configured in the client builder.");
}
builder.addStatement("$T.notNull(c.option($T.TOKEN_IDENTITY_PROVIDER), $S)",
Validate.class,
AwsClientOption.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import software.amazon.awssdk.codegen.poet.PoetExtension;
import software.amazon.awssdk.codegen.poet.PoetUtils;
import software.amazon.awssdk.codegen.poet.StaticImport;
import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils;
import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec;
import software.amazon.awssdk.codegen.poet.eventstream.EventStreamUtils;
import software.amazon.awssdk.codegen.poet.model.EventStreamSpecHelper;
Expand Down Expand Up @@ -95,6 +96,7 @@ public final class AsyncClientClass extends AsyncClientInterface {
private final ClassName className;
private final ProtocolSpec protocolSpec;
private final ClassName serviceClientConfigurationClassName;
private final boolean useSraAuth;

public AsyncClientClass(GeneratorTaskParams dependencies) {
super(dependencies.getModel());
Expand All @@ -103,6 +105,7 @@ public AsyncClientClass(GeneratorTaskParams dependencies) {
this.className = poetExtensions.getClientClass(model.getMetadata().getAsyncClient());
this.protocolSpec = getProtocolSpecs(poetExtensions, model);
this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass();
this.useSraAuth = new AuthSchemeSpecUtils(model).useSraAuth();
}

@Override
Expand Down Expand Up @@ -156,9 +159,11 @@ protected void addAdditionalMethods(TypeSpec.Builder type) {
.addMethod(protocolSpec.initProtocolFactory(model))
.addMethod(resolveMetricPublishersMethod());

if (model.containsRequestSigners() || model.containsRequestEventStreams() || hasStreamingV4AuthOperations()) {
type.addMethod(applySignerOverrideMethod(poetExtensions, model));
type.addMethod(isSignerOverriddenOnClientMethod());
if (!useSraAuth) {
if (model.containsRequestSigners() || model.containsRequestEventStreams() || hasStreamingV4AuthOperations()) {
type.addMethod(applySignerOverrideMethod(poetExtensions, model));
type.addMethod(isSignerOverriddenOnClientMethod());
}
}

protocolSpec.createErrorResponseHandler().ifPresent(type::addMethod);
Expand Down Expand Up @@ -328,10 +333,12 @@ protected MethodSpec.Builder operationBody(MethodSpec.Builder builder, Operation
"pair");
}

if (shouldUseAsyncWithBodySigner(opModel)) {
builder.addCode(applyAsyncWithBodyV4SignerOverride(opModel));
} else {
builder.addCode(ClientClassUtils.callApplySignerOverrideMethod(opModel));
if (!useSraAuth) {
if (shouldUseAsyncWithBodySigner(opModel)) {
builder.addCode(applyAsyncWithBodyV4SignerOverride(opModel));
} else {
builder.addCode(ClientClassUtils.callApplySignerOverrideMethod(opModel));
}
}

builder.addCode(protocolSpec.responseHandler(model, opModel));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import software.amazon.awssdk.codegen.model.intermediate.Protocol;
import software.amazon.awssdk.codegen.poet.PoetExtension;
import software.amazon.awssdk.codegen.poet.PoetUtils;
import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils;
import software.amazon.awssdk.codegen.poet.client.specs.Ec2ProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.JsonProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec;
Expand All @@ -72,6 +73,7 @@ public class SyncClientClass extends SyncClientInterface {
private final ClassName className;
private final ProtocolSpec protocolSpec;
private final ClassName serviceClientConfigurationClassName;
private final boolean useSraAuth;

public SyncClientClass(GeneratorTaskParams taskParams) {
super(taskParams.getModel());
Expand All @@ -80,6 +82,7 @@ public SyncClientClass(GeneratorTaskParams taskParams) {
this.className = poetExtensions.getClientClass(model.getMetadata().getSyncClient());
this.protocolSpec = getProtocolSpecs(poetExtensions, model);
this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass();
this.useSraAuth = new AuthSchemeSpecUtils(model).useSraAuth();
}

@Override
Expand Down Expand Up @@ -115,9 +118,10 @@ protected void addFields(TypeSpec.Builder type) {

@Override
protected void addAdditionalMethods(TypeSpec.Builder type) {

if (model.containsRequestSigners()) {
type.addMethod(applySignerOverrideMethod(poetExtensions, model));
if (!useSraAuth) {
if (model.containsRequestSigners()) {
type.addMethod(applySignerOverrideMethod(poetExtensions, model));
}
}

model.getEndpointOperation().ifPresent(
Expand Down Expand Up @@ -220,9 +224,11 @@ private Stream<MethodSpec> operations(OperationModel opModel) {

private MethodSpec traditionalMethod(OperationModel opModel) {
MethodSpec.Builder method = SyncClientInterface.operationMethodSignature(model, opModel)
.addAnnotation(Override.class)
.addCode(ClientClassUtils.callApplySignerOverrideMethod(opModel))
.addCode(protocolSpec.responseHandler(model, opModel));
.addAnnotation(Override.class);
if (!useSraAuth) {
method.addCode(ClientClassUtils.callApplySignerOverrideMethod(opModel));
}
method.addCode(protocolSpec.responseHandler(model, opModel));

protocolSpec.errorResponseHandler(opModel).ifPresent(method::addCode);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import software.amazon.awssdk.codegen.model.intermediate.Protocol;
import software.amazon.awssdk.codegen.model.intermediate.ShapeModel;
import software.amazon.awssdk.codegen.poet.PoetExtension;
import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils;
import software.amazon.awssdk.codegen.poet.client.traits.HttpChecksumRequiredTrait;
import software.amazon.awssdk.codegen.poet.client.traits.HttpChecksumTrait;
import software.amazon.awssdk.codegen.poet.client.traits.NoneAuthTypeRequestTrait;
Expand All @@ -64,10 +65,12 @@ public class JsonProtocolSpec implements ProtocolSpec {

private final PoetExtension poetExtensions;
private final IntermediateModel model;
private final boolean useSraAuth;

public JsonProtocolSpec(PoetExtension poetExtensions, IntermediateModel model) {
this.poetExtensions = poetExtensions;
this.model = model;
this.useSraAuth = new AuthSchemeSpecUtils(model).useSraAuth();
}

@Override
Expand Down Expand Up @@ -187,9 +190,13 @@ public CodeBlock executionHandler(OperationModel opModel) {
.add(".withInput($L)\n", opModel.getInput().getVariableName())
.add(".withMetricCollector(apiCallMetricCollector)")
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
.add(HttpChecksumTrait.create(opModel))
.add(NoneAuthTypeRequestTrait.create(opModel))
.add(RequestCompressionTrait.create(opModel, model));
.add(HttpChecksumTrait.create(opModel));

if (!useSraAuth) {
codeBlock.add(NoneAuthTypeRequestTrait.create(opModel));
}

codeBlock.add(RequestCompressionTrait.create(opModel, model));

if (opModel.hasStreamingInput()) {
codeBlock.add(".withRequestBody(requestBody)")
Expand Down Expand Up @@ -258,9 +265,13 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
.add(credentialType(opModel, model))
.add(asyncRequestBody)
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
.add(HttpChecksumTrait.create(opModel))
.add(NoneAuthTypeRequestTrait.create(opModel))
.add(RequestCompressionTrait.create(opModel, model))
.add(HttpChecksumTrait.create(opModel));

if (!useSraAuth) {
builder.add(NoneAuthTypeRequestTrait.create(opModel));
}

builder.add(RequestCompressionTrait.create(opModel, model))
.add(".withInput($L)$L)",
opModel.getInput().getVariableName(), asyncResponseTransformerVariable(isStreaming, isRestJson, opModel))
.add(opModel.getEndpointDiscovery() != null ? ");" : ";");
Expand Down
Loading