Skip to content

Commit

Permalink
Add tests of timestream + docs
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh committed May 31, 2023
1 parent 233d9d6 commit 47892e6
Show file tree
Hide file tree
Showing 19 changed files with 773 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,6 @@ where
.map(|(creds, _expiry)| creds)
}

/// Attempts to load the cached value if it has been set
///
/// # Panics
/// This function panics if it is called from an asynchronous context
pub fn try_blocking_get(&self) -> Option<T> {
self.value.blocking_read().get().map(|(v, _exp)| v.clone())
}

/// Attempts to refresh the cached value with the given future.
/// If multiple threads attempt to refresh at the same time, one of them will win,
/// and the others will await that thread's result rather than multiple refreshes occurring.
Expand Down
165 changes: 142 additions & 23 deletions aws/rust-runtime/aws-inlineable/src/endpoint_discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//! Maintain a cache of discovered endpoints
use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_async::time::TimeSource;
use aws_smithy_client::erase::boxclone::BoxFuture;
use aws_smithy_http::endpoint::{ResolveEndpoint, ResolveEndpointError};
use aws_smithy_types::endpoint::Endpoint;
Expand All @@ -24,6 +25,7 @@ pub struct ReloadEndpoint {
error: Arc<Mutex<Option<ResolveEndpointError>>>,
rx: Receiver<()>,
sleep: Arc<dyn AsyncSleep>,
time: Arc<dyn TimeSource>,
}

impl Debug for ReloadEndpoint {
Expand All @@ -45,24 +47,29 @@ impl ReloadEndpoint {

/// An infinite loop task that will reload the endpoint
///
/// This task will terminate when the corresponding [`EndpointCache`] is dropped.
/// This task will terminate when the corresponding [`Client`](crate::Client) is dropped.
pub async fn reload_task(mut self) {
loop {
match self.rx.try_recv() {
Ok(_) | Err(TryRecvError::Closed) => break,
_ => {}
}
let should_reload = self
.endpoint
.lock()
.unwrap()
.as_ref()
.map(|e| e.is_expired())
.unwrap_or(true);
if should_reload {
self.reload_once().await;
}
self.sleep.sleep(Duration::from_secs(60)).await
self.reload_increment(self.time.now()).await;
self.sleep.sleep(Duration::from_secs(60)).await;
}
}

async fn reload_increment(&self, now: SystemTime) {
let should_reload = self
.endpoint
.lock()
.unwrap()
.as_ref()
.map(|e| e.is_expired(now))
.unwrap_or(true);
if should_reload {
tracing::debug!("reloading endpoint, previous endpoint was expired");
self.reload_once().await;
}
}
}
Expand All @@ -88,9 +95,10 @@ struct ExpiringEndpoint {
}

impl ExpiringEndpoint {
fn is_expired(&self) -> bool {
match SystemTime::now().duration_since(self.expiry) {
Err(e) => true,
fn is_expired(&self, now: SystemTime) -> bool {
tracing::debug!(expiry = ?self.expiry, now = ?now, delta = ?self.expiry.duration_since(now), "checking expiry status of endpoint");
match self.expiry.duration_since(now) {
Err(_) => true,
Ok(t) => t < Duration::from_secs(120),
}
}
Expand All @@ -99,6 +107,7 @@ impl ExpiringEndpoint {
pub(crate) async fn create_cache<F>(
loader_fn: impl Fn() -> F + Send + Sync + 'static,
sleep: Arc<dyn AsyncSleep>,
time: Arc<dyn TimeSource>,
) -> Result<(EndpointCache, ReloadEndpoint), ResolveEndpointError>
where
F: Future<Output = Result<(Endpoint, SystemTime), ResolveEndpointError>> + Send + 'static,
Expand All @@ -117,11 +126,12 @@ where
error: error_holder,
rx,
sleep,
time,
};
reloader.reload_once().await;
if let Err(e) = cache.resolve_endpoint() {
return Err(e);
}
// if we didn't successfully get an endpoint, bail out so the client knows
// configuration failed to work
cache.resolve_endpoint()?;
Ok((cache, reloader))
}

Expand All @@ -145,26 +155,135 @@ impl EndpointCache {
#[cfg(test)]
mod test {
use crate::endpoint_discovery::{create_cache, EndpointCache};
use aws_credential_types::time_source::TimeSource;
use aws_smithy_async::rt::sleep::TokioSleep;
use aws_smithy_async::test_util::controlled_time_and_sleep;
use aws_smithy_async::time::SystemTimeSource;
use aws_smithy_http::endpoint::ResolveEndpointError;
use aws_smithy_types::endpoint::Endpoint;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

fn check_send<T: Send>() {}
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::time::timeout;

fn check_send_v<T: Send>(t: T) -> T {
t
}

#[tokio::test]
#[allow(unused_must_use)]
async fn check_traits() {
// check_send::<EndpointCache>();

let (cache, reloader) = create_cache(
|| async { Err(ResolveEndpointError::message("stub")) },
|| async {
Ok((
Endpoint::builder().url("http://foo.com").build(),
SystemTime::now(),
))
},
Arc::new(TokioSleep::new()),
Arc::new(SystemTimeSource::new()),
)
.await
.unwrap();
check_send_v(reloader.reload_task());
check_send_v(cache);
}

#[tokio::test]
async fn erroring_endpoint_always_reloaded() {
let expiry = UNIX_EPOCH + Duration::from_secs(123456789);
let ct = Arc::new(AtomicUsize::new(0));
let (cache, reloader) = create_cache(
move || {
let shared_ct = ct.clone();
shared_ct.fetch_add(1, Ordering::AcqRel);
async move {
Ok((
Endpoint::builder()
.url(format!("http://foo.com/{shared_ct:?}"))
.build(),
expiry,
))
}
},
Arc::new(TokioSleep::new()),
Arc::new(SystemTimeSource::new()),
)
.await
.expect("returns an endpoint");
assert_eq!(
cache.resolve_endpoint().expect("ok").url(),
"http://foo.com/1"
);
// 120 second buffer
reloader
.reload_increment(expiry - Duration::from_secs(240))
.await;
assert_eq!(
cache.resolve_endpoint().expect("ok").url(),
"http://foo.com/1"
);

reloader.reload_increment(expiry).await;
assert_eq!(
cache.resolve_endpoint().expect("ok").url(),
"http://foo.com/2"
);
}

#[tokio::test]
async fn test_advance_of_task() {
let expiry = UNIX_EPOCH + Duration::from_secs(123456789);
// expires in 8 minutes
let (time, sleep, mut gate) = controlled_time_and_sleep(expiry - Duration::from_secs(239));
let ct = Arc::new(AtomicUsize::new(0));
let (cache, reloader) = create_cache(
move || {
let shared_ct = ct.clone();
shared_ct.fetch_add(1, Ordering::AcqRel);
async move {
Ok((
Endpoint::builder()
.url(format!("http://foo.com/{shared_ct:?}"))
.build(),
expiry,
))
}
},
Arc::new(sleep.clone()),
Arc::new(time.clone()),
)
.await
.expect("first load success");
let reload_task = tokio::spawn(reloader.reload_task());
assert!(!reload_task.is_finished());
// expiry occurs after 2 sleeps
// t = 0
assert_eq!(
gate.expect_sleep().await.duration(),
Duration::from_secs(60)
);
assert_eq!(cache.resolve_endpoint().unwrap().url(), "http://foo.com/1");
// t = 60

let sleep = gate.expect_sleep().await;
// we're still holding the drop guard, so we haven't expired yet.
assert_eq!(cache.resolve_endpoint().unwrap().url(), "http://foo.com/1");
assert_eq!(sleep.duration(), Duration::from_secs(60));
sleep.allow_progress();
// t = 120

let sleep = gate.expect_sleep().await;
assert_eq!(cache.resolve_endpoint().unwrap().url(), "http://foo.com/2");
sleep.allow_progress();

let sleep = gate.expect_sleep().await;
drop(cache);
sleep.allow_progress();

timeout(Duration::from_secs(1), reload_task)
.await
.expect("task finishes successfully")
.expect("finishes");
}
}
1 change: 1 addition & 0 deletions aws/rust-runtime/aws-inlineable/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ pub mod route53_resource_id_preprocessor;
/// Convert a streaming `SdkBody` into an aws-chunked streaming body with checksum trailers
pub mod http_body_checksum;

#[allow(dead_code)]
pub mod endpoint_discovery;
1 change: 1 addition & 0 deletions aws/rust-runtime/aws-sig-auth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ aws-smithy-eventstream = { path = "../../../rust-runtime/aws-smithy-eventstream"
aws-smithy-http = { path = "../../../rust-runtime/aws-smithy-http" }
aws-smithy-async = { path = "../../../rust-runtime/aws-smithy-async" }
aws-types = { path = "../aws-types" }
aws-smithy-async = { path = "../../../rust-runtime/aws-smithy-async" }
http = "0.2.2"
tracing = "0.1"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rawTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocSection
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection
import software.amazon.smithy.rust.codegen.core.smithy.generators.ManifestCustomizations
Expand Down Expand Up @@ -85,6 +86,10 @@ class AwsCrateDocsDecorator : ClientCodegenDecorator {
SdkSettings.from(codegenContext.settings).generateReadme
}

sealed class DocSection(name: String) : AdHocSection(name) {
data class CreateClient(val crateName: String, val clientName: String = "client", val indent: String) : DocSection("CustomExample")
}

internal class AwsCrateDocGenerator(private val codegenContext: ClientCodegenContext) {
private val logger: Logger = Logger.getLogger(javaClass.name)
private val awsConfigVersion by lazy {
Expand Down Expand Up @@ -154,8 +159,7 @@ internal class AwsCrateDocGenerator(private val codegenContext: ClientCodegenCon
##[#{tokio}::main]
async fn main() -> Result<(), $shortModuleName::Error> {
let config = #{aws_config}::load_from_env().await;
let client = $shortModuleName::Client::new(&config);
#{constructClient}
// ... make some calls with the client
Expand All @@ -171,6 +175,7 @@ internal class AwsCrateDocGenerator(private val codegenContext: ClientCodegenCon
true -> AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).toDevDependency().toType()
else -> writable { rust("aws_config") }
},
"constructClient" to AwsDocs.constructClient(codegenContext, indent = " "),
)

template(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.docsTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizationsOrElse
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase

object AwsDocs {
Expand All @@ -23,6 +26,24 @@ object AwsDocs {
ShapeId.from("com.amazonaws.sso#SWBPortalService"),
).contains(codegenContext.serviceShape.id)

fun constructClient(codegenContext: ClientCodegenContext, indent: String): Writable {
val crateName = codegenContext.moduleName.toSnakeCase()
return writable {
writeCustomizationsOrElse(
codegenContext.rootDecorator.extraSections(codegenContext),
DocSection.CreateClient(crateName = crateName, indent = indent),
) {
addDependency(AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).toDevDependency())
rustTemplate(
"""
let config = aws_config::load_from_env().await;
let client = $crateName::Client::new(&config);
""".trimIndent().prependIndent(indent),
)
}
}
}

fun clientConstructionDocs(codegenContext: ClientCodegenContext): Writable = {
if (canRelyOnAwsConfig(codegenContext)) {
val crateName = codegenContext.moduleName.toSnakeCase()
Expand All @@ -40,8 +61,7 @@ object AwsDocs {
In the simplest case, creating a client looks as follows:
```rust,no_run
## async fn wrapper() {
let config = #{aws_config}::load_from_env().await;
let client = $crateName::Client::new(&config);
#{constructClient}
## }
```
Expand Down Expand Up @@ -76,6 +96,7 @@ object AwsDocs {
[builder pattern]: https://rust-lang.github.io/api-guidelines/type-safety.html##builders-enable-construction-of-complex-values-c-builder
""".trimIndent(),
"aws_config" to AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).toDevDependency().toType(),
"constructClient" to constructClient(codegenContext, indent = ""),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ class IntegrationTestDependencies(
if (hasTests) {
val smithyClient = CargoDependency.smithyClient(codegenContext.runtimeConfig)
.copy(features = setOf("test-util"), scope = DependencyScope.Dev)
val smithyAsync = CargoDependency.smithyAsync(codegenContext.runtimeConfig)
.copy(features = setOf("test-util"), scope = DependencyScope.Dev)
addDependency(smithyClient)
addDependency(smithyAsync)
addDependency(CargoDependency.smithyProtocolTestHelpers(codegenContext.runtimeConfig))
addDependency(SerdeJson)
addDependency(Tokio)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.config.Confi
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
Expand Down Expand Up @@ -148,6 +149,11 @@ class ServiceSpecificDecorator(
delegateTo.protocolTestGenerator(codegenContext, baseGenerator)
}

override fun extraSections(codegenContext: ClientCodegenContext): List<AdHocCustomization> =
listOf<AdHocCustomization>().maybeApply(codegenContext.serviceShape) {
delegateTo.extraSections(codegenContext)
}

override fun operationRuntimePluginCustomizations(
codegenContext: ClientCodegenContext,
operation: OperationShape,
Expand Down
Loading

0 comments on commit 47892e6

Please sign in to comment.