Skip to content

Commit

Permalink
Refactor endpoints to remove default
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh committed Oct 16, 2023
1 parent 39af70f commit e228acc
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
Expand Down Expand Up @@ -153,84 +154,17 @@ internal class EndpointConfigCustomization(
rustTemplate(
"""
pub fn set_endpoint_resolver(&mut self, endpoint_resolver: #{Option}<$sharedEndpointResolver>) -> &mut Self {
self.config.store_or_unset(endpoint_resolver);
self.runtime_components.set_endpoint_resolver(endpoint_resolver.map(|r|#{wrap_resolver}));
self
}
""",
*codegenScope,
)
}

ServiceConfig.BuilderBuild -> {
rustTemplate(
"#{set_endpoint_resolver}(&mut resolver);",
"set_endpoint_resolver" to setEndpointResolverFn(),
)
}

is ServiceConfig.OperationConfigOverride -> {
rustTemplate(
"#{set_endpoint_resolver}(&mut resolver);",
"set_endpoint_resolver" to setEndpointResolverFn(),
"wrap_resolver" to codegenContext.wrapResolver { rust("r") },
)
}

else -> emptySection
}
}
}

private fun defaultResolver(): RuntimeType {
// For now, fallback to a default endpoint resolver that always fails. In the future,
// the endpoint resolver will be required (so that it can be unwrapped).
return typesGenerator.defaultResolver() ?: RuntimeType.forInlineFun(
"MissingResolver",
ClientRustModule.Config.endpoint,
) {
rustTemplate(
"""
##[derive(Debug)]
pub(crate) struct MissingResolver;
impl MissingResolver {
pub(crate) fn new() -> Self { Self }
}
impl<T> #{ResolveEndpoint}<T> for MissingResolver {
fn resolve_endpoint(&self, _params: &T) -> #{Result} {
Err(#{ResolveEndpointError}::message("an endpoint resolver must be provided."))
}
}
""",
"ResolveEndpoint" to types.resolveEndpoint,
"ResolveEndpointError" to types.resolveEndpointError,
"Result" to types.smithyHttpEndpointModule.resolve("Result"),
)
}
}

private fun setEndpointResolverFn(): RuntimeType = RuntimeType.forInlineFun("set_endpoint_resolver", ClientRustModule.config) {
// TODO(enableNewSmithyRuntimeCleanup): Simplify the endpoint resolvers
rustTemplate(
"""
fn set_endpoint_resolver(resolver: &mut #{Resolver}<'_>) {
let endpoint_resolver = if resolver.is_initial() {
Some(resolver.resolve_config::<#{OldSharedEndpointResolver}<#{Params}>>().cloned().unwrap_or_else(||
#{OldSharedEndpointResolver}::new(#{DefaultResolver}::new())
))
} else if resolver.is_latest_set::<#{OldSharedEndpointResolver}<#{Params}>>() {
resolver.resolve_config::<#{OldSharedEndpointResolver}<#{Params}>>().cloned()
} else {
None
};
if let Some(endpoint_resolver) = endpoint_resolver {
let shared = #{SharedEndpointResolver}::new(
#{DefaultEndpointResolver}::<#{Params}>::new(endpoint_resolver)
);
resolver.runtime_components_mut().set_endpoint_resolver(#{Some}(shared));
}
}
""",
*codegenScope,
"DefaultResolver" to defaultResolver(),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@ import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegen
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators.CustomRuntimeFunction
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators.endpointTestsModule
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rulesgen.SmithyEndpointsStdLib
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.map
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate

/**
Expand Down Expand Up @@ -104,6 +110,24 @@ class EndpointsDecorator : ClientCodegenDecorator {
EndpointConfigCustomization(codegenContext, EndpointTypesGenerator.fromContext(codegenContext))
}

override fun serviceRuntimePluginCustomizations(
codegenContext: ClientCodegenContext,
baseCustomizations: List<ServiceRuntimePluginCustomization>,
): List<ServiceRuntimePluginCustomization> {
return baseCustomizations + object : ServiceRuntimePluginCustomization() {
override fun section(section: ServiceRuntimePluginSection): Writable {
return when (section) {
is ServiceRuntimePluginSection.RegisterRuntimeComponents -> writable {
codegenContext.defaultEndpointResolver()
?.let { resolver -> section.registerEndpointResolver(this, resolver) }
}

else -> emptySection
}
}
}
}

override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) {
val generator = EndpointTypesGenerator.fromContext(codegenContext)
rustCrate.withModule(ClientRustModule.Config.endpoint) {
Expand All @@ -113,3 +137,53 @@ class EndpointsDecorator : ClientCodegenDecorator {
}
}
}

private fun ClientCodegenContext.defaultEndpointResolver(): Writable {
val generator = EndpointTypesGenerator.fromContext(this)
val defaultResolver = generator.defaultResolver() ?: missingResolver()
val ctx = arrayOf("DefaultResolver" to defaultResolver)
return wrapResolver { rustTemplate("#{DefaultResolver}::new()", *ctx) }
}

fun ClientCodegenContext.wrapResolver(resolver: Writable): Writable {
val generator = EndpointTypesGenerator.fromContext(this)
return resolver.map { base ->
val types = Types(runtimeConfig)
val ctx = arrayOf(
"DefaultEndpointResolver" to RuntimeType.smithyRuntime(runtimeConfig)
.resolve("client::orchestrator::endpoints::DefaultEndpointResolver"),
"Params" to generator.paramsStruct(),
"OldSharedEndpointResolver" to types.sharedEndpointResolver,
)

rustTemplate(
"#{DefaultEndpointResolver}::<#{Params}>::new(#{OldSharedEndpointResolver}::new(#{base}))",
*ctx,
"base" to base,
)
}
}

private fun ClientCodegenContext.missingResolver(): RuntimeType = RuntimeType.forInlineFun(
"MissingResolver",
ClientRustModule.Config.endpoint,
) {
val types = Types(runtimeConfig)
rustTemplate(
"""
##[derive(Debug)]
pub(crate) struct MissingResolver;
impl MissingResolver {
pub(crate) fn new() -> Self { Self }
}
impl<T> #{ResolveEndpoint}<T> for MissingResolver {
fn resolve_endpoint(&self, _params: &T) -> #{Result} {
Err(#{ResolveEndpointError}::message("an endpoint resolver must be provided."))
}
}
""",
"ResolveEndpoint" to types.resolveEndpoint,
"ResolveEndpointError" to types.resolveEndpointError,
"Result" to types.smithyHttpEndpointModule.resolve("Result"),
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ConfigOverrideRuntimePluginGenerator(
) -> Self {
let mut layer = config_override.config;
let mut components = config_override.runtime_components;
let mut resolver = #{Resolver}::overrid(initial_config, initial_components, &mut layer, &mut components);
let resolver = #{Resolver}::overrid(initial_config, initial_components, &mut layer, &mut components);
#{config}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.isNotEmpty
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
Expand All @@ -17,7 +16,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.pre
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.util.dq

sealed class ServiceRuntimePluginSection(name: String) : Section(name) {
/**
Expand All @@ -27,16 +25,6 @@ sealed class ServiceRuntimePluginSection(name: String) : Section(name) {
*/
class DeclareSingletons : ServiceRuntimePluginSection("DeclareSingletons")

/**
* Hook for adding additional things to config inside service runtime plugins.
*/
data class AdditionalConfig(val newLayerName: String, val serviceConfigName: String) : ServiceRuntimePluginSection("AdditionalConfig") {
/** Adds a value to the config bag */
fun putConfigValue(writer: RustWriter, value: Writable) {
writer.rust("$newLayerName.store_put(#T);", value)
}
}

data class RegisterRuntimeComponents(val serviceConfigName: String) : ServiceRuntimePluginSection("RegisterRuntimeComponents") {
/** Generates the code to register an interceptor */
fun registerInterceptor(writer: RustWriter, interceptor: Writable) {
Expand All @@ -47,6 +35,10 @@ sealed class ServiceRuntimePluginSection(name: String) : Section(name) {
writer.rust("runtime_components.push_auth_scheme(#T);", authScheme)
}

fun registerEndpointResolver(writer: RustWriter, resolver: Writable) {
writer.rust("runtime_components.set_endpoint_resolver(Some(#T));", resolver)
}

fun registerIdentityResolver(writer: RustWriter, identityResolver: Writable) {
writer.rust("runtime_components.push_identity_resolver(#T);", identityResolver)
}
Expand Down Expand Up @@ -84,29 +76,24 @@ class ServiceRuntimePluginGenerator(
writer: RustWriter,
customizations: List<ServiceRuntimePluginCustomization>,
) {
val additionalConfig = writable {
writeCustomizations(customizations, ServiceRuntimePluginSection.AdditionalConfig("cfg", "_service_config"))
}
writer.rustTemplate(
"""
##[derive(::std::fmt::Debug)]
pub(crate) struct ServiceRuntimePlugin {
config: #{Option}<#{FrozenLayer}>,
runtime_components: #{RuntimeComponentsBuilder},
}
impl ServiceRuntimePlugin {
pub fn new(_service_config: crate::config::Config) -> Self {
let config = { #{config} };
let mut runtime_components = #{RuntimeComponentsBuilder}::new("ServiceRuntimePlugin");
#{runtime_components}
Self { config, runtime_components }
Self { runtime_components }
}
}
impl #{RuntimePlugin} for ServiceRuntimePlugin {
fn config(&self) -> #{Option}<#{FrozenLayer}> {
self.config.clone()
None
}
fn runtime_components(&self, _: &#{RuntimeComponentsBuilder}) -> #{Cow}<'_, #{RuntimeComponentsBuilder}> {
Expand All @@ -118,21 +105,6 @@ class ServiceRuntimePluginGenerator(
#{declare_singletons}
""",
*codegenScope,
"config" to writable {
if (additionalConfig.isNotEmpty()) {
rustTemplate(
"""
let mut cfg = #{Layer}::new(${codegenContext.serviceShape.id.name.dq()});
#{additional_config}
#{Some}(cfg.freeze())
""",
*codegenScope,
"additional_config" to additionalConfig,
)
} else {
rust("None")
}
},
"runtime_components" to writable {
writeCustomizations(customizations, ServiceRuntimePluginSection.RegisterRuntimeComponents("_service_config"))
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,6 @@ class ServiceConfigGenerator(
rustTemplate(
"""
let mut layer = self.config;
let mut resolver = #{Resolver}::initial(&mut layer, &mut self.runtime_components);
""",
*codegenScope,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,33 +45,21 @@ internal class ConfigOverrideRuntimePluginGeneratorTest {
.resolve("client::endpoint::EndpointResolverParams"),
"RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig),
"RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(runtimeConfig),
"capture_request" to RuntimeType.captureRequest(runtimeConfig),
)
rustCrate.testModule {
addDependency(CargoDependency.Tokio.toDevDependency().withFeature("test-util"))
tokioTest("test_operation_overrides_endpoint_resolver") {
rustTemplate(
"""
use #{RuntimePlugin};
use ::aws_smithy_runtime_api::client::endpoint::EndpointResolver;
let expected_url = "http://localhost:1234/";
let client_config = crate::config::Config::builder().build();
let (http_client, req) = #{capture_request}(None);
let client_config = crate::config::Config::builder().http_client(http_client).build();
let config_override =
crate::config::Config::builder().endpoint_resolver(expected_url);
let sut = crate::config::ConfigOverrideRuntimePlugin::new(
config_override,
client_config.config,
&client_config.runtime_components,
);
let prev = #{RuntimeComponentsBuilder}::new("prev");
let sut_components = sut.runtime_components(&prev);
let endpoint_resolver = sut_components.endpoint_resolver().unwrap();
let endpoint = endpoint_resolver
.resolve_endpoint(&#{EndpointResolverParams}::new(crate::config::endpoint::Params {}))
.await
.unwrap();
assert_eq!(expected_url, endpoint.url());
let client = crate::Client::from_conf(client_config);
let _ = dbg!(client.say_hello().customize().config_override(config_override).send().await);
assert_eq!("http://localhost:1234/", req.expect_request().uri());
""",
*codegenScope,
)
Expand Down

0 comments on commit e228acc

Please sign in to comment.