Skip to content

Commit

Permalink
Add support for TimeStreamWrite and TimeStreamQuery
Browse files Browse the repository at this point in the history
This adds support for TSW and TSQ by adding endpoint discovery as a customization. This is made much simpler by the fact that endpoint discovery for these services **has no parameters** which
means that there is no complexity from caching the returned endpoint.

Customers call `.enable_endpoint_discovery()` on the client to create a version of the client with endpoint discovery enabled. This returns a new client and a Reloader from which
customers must spawn the reload task if they want endpoint discovery to rerun.
  • Loading branch information
rcoh committed May 17, 2023
1 parent 33e1a67 commit 0eef47b
Show file tree
Hide file tree
Showing 10 changed files with 6,972 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ where
.map(|(creds, _expiry)| creds)
}

/// Attempts to load the cached value if it has been set
///
/// # Panics
/// This function panics if it is called from an asynchronous context
pub fn try_blocking_get(&self) -> Option<T> {
self.value.blocking_read().get().map(|(v, _exp)| v.clone())
}

/// Attempts to refresh the cached value with the given future.
/// If multiple threads attempt to refresh at the same time, one of them will win,
/// and the others will await that thread's result rather than multiple refreshes occurring.
Expand Down
1 change: 1 addition & 0 deletions aws/rust-runtime/aws-inlineable/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ aws-smithy-client = { path = "../../../rust-runtime/aws-smithy-client" }
aws-smithy-http = { path = "../../../rust-runtime/aws-smithy-http" }
aws-smithy-http-tower= { path = "../../../rust-runtime/aws-smithy-http-tower" }
aws-smithy-types = { path = "../../../rust-runtime/aws-smithy-types" }
aws-smithy-async = { path = "../../../rust-runtime/aws-smithy-async" }
aws-types = { path = "../aws-types" }
bytes = "1"
bytes-utils = "0.1.1"
Expand Down
170 changes: 170 additions & 0 deletions aws/rust-runtime/aws-inlineable/src/endpoint_discovery.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

//! Maintain a cache of discovered endpoints
use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_client::erase::boxclone::BoxFuture;
use aws_smithy_http::endpoint::{ResolveEndpoint, ResolveEndpointError};
use aws_smithy_types::endpoint::Endpoint;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
use tokio::sync::oneshot::error::TryRecvError;
use tokio::sync::oneshot::{Receiver, Sender};

/// Endpoint reloader
#[must_use]
pub struct ReloadEndpoint {
loader: Box<dyn Fn() -> BoxFuture<(Endpoint, SystemTime), ResolveEndpointError> + Send + Sync>,
endpoint: Arc<Mutex<Option<ExpiringEndpoint>>>,
error: Arc<Mutex<Option<ResolveEndpointError>>>,
rx: Receiver<()>,
sleep: Arc<dyn AsyncSleep>,
}

impl Debug for ReloadEndpoint {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReloadEndpoint").finish()
}
}

impl ReloadEndpoint {
/// Reload the endpoint once
pub async fn reload_once(&self) {
match (self.loader)().await {
Ok((endpoint, expiry)) => {
*self.endpoint.lock().unwrap() = Some(ExpiringEndpoint { endpoint, expiry })
}
Err(err) => *self.error.lock().unwrap() = Some(err),
}
}

/// An infinite loop task that will reload the endpoint
///
/// This task will terminate when the corresponding [`EndpointCache`] is dropped.
pub async fn reload_task(mut self) {
loop {
match self.rx.try_recv() {
Ok(_) | Err(TryRecvError::Closed) => break,
_ => {}
}
let should_reload = self
.endpoint
.lock()
.unwrap()
.as_ref()
.map(|e| e.is_expired())
.unwrap_or(true);
if should_reload {
self.reload_once().await;
}
self.sleep.sleep(Duration::from_secs(60)).await
}
}
}

#[derive(Debug, Clone)]
pub(crate) struct EndpointCache {
error: Arc<Mutex<Option<ResolveEndpointError>>>,
endpoint: Arc<Mutex<Option<ExpiringEndpoint>>>,
// When the sender is dropped, this allows the reload loop to stop
_drop_guard: Arc<Sender<()>>,
}

impl<T> ResolveEndpoint<T> for EndpointCache {
fn resolve_endpoint(&self, _params: &T) -> aws_smithy_http::endpoint::Result {
self.resolve_endpoint()
}
}

#[derive(Debug)]
struct ExpiringEndpoint {
endpoint: Endpoint,
expiry: SystemTime,
}

impl ExpiringEndpoint {
fn is_expired(&self) -> bool {
match SystemTime::now().duration_since(self.expiry) {
Err(e) => true,
Ok(t) => t < Duration::from_secs(120),
}
}
}

pub(crate) async fn create_cache<F>(
loader_fn: impl Fn() -> F + Send + Sync + 'static,
sleep: Arc<dyn AsyncSleep>,
) -> Result<(EndpointCache, ReloadEndpoint), ResolveEndpointError>
where
F: Future<Output = Result<(Endpoint, SystemTime), ResolveEndpointError>> + Send + 'static,
{
let error_holder = Arc::new(Mutex::new(None));
let endpoint_holder = Arc::new(Mutex::new(None));
let (tx, rx) = tokio::sync::oneshot::channel();
let cache = EndpointCache {
error: error_holder.clone(),
endpoint: endpoint_holder.clone(),
_drop_guard: Arc::new(tx),
};
let reloader = ReloadEndpoint {
loader: Box::new(move || Box::pin((loader_fn)()) as _),
endpoint: endpoint_holder,
error: error_holder,
rx,
sleep,
};
reloader.reload_once().await;
if let Err(e) = cache.resolve_endpoint() {
return Err(e);
}
Ok((cache, reloader))
}

impl EndpointCache {
fn resolve_endpoint(&self) -> aws_smithy_http::endpoint::Result {
self.endpoint
.lock()
.unwrap()
.as_ref()
.map(|e| e.endpoint.clone())
.ok_or_else(|| {
self.error
.lock()
.unwrap()
.take()
.unwrap_or_else(|| ResolveEndpointError::message("no endpoint loaded"))
})
}
}

#[cfg(test)]
mod test {
use crate::endpoint_discovery::{create_cache, EndpointCache};
use aws_smithy_async::rt::sleep::TokioSleep;
use aws_smithy_http::endpoint::ResolveEndpointError;
use std::sync::Arc;

fn check_send<T: Send>() {}

fn check_send_v<T: Send>(t: T) -> T {
t
}

#[tokio::test]
async fn check_traits() {
// check_send::<EndpointCache>();

let (cache, reloader) = create_cache(
|| async { Err(ResolveEndpointError::message("stub")) },
Arc::new(TokioSleep::new()),
)
.await
.unwrap();
check_send_v(reloader.reload_task());
}
}
2 changes: 2 additions & 0 deletions aws/rust-runtime/aws-inlineable/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@ pub mod route53_resource_id_preprocessor;

/// Convert a streaming `SdkBody` into an aws-chunked streaming body with checksum trailers
pub mod http_body_checksum;

pub mod endpoint_discovery;
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import software.amazon.smithy.rustsdk.customize.s3.S3ExtendedRequestIdDecorator
import software.amazon.smithy.rustsdk.customize.s3control.S3ControlDecorator
import software.amazon.smithy.rustsdk.customize.sso.SSODecorator
import software.amazon.smithy.rustsdk.customize.sts.STSDecorator
import software.amazon.smithy.rustsdk.customize.timestream.TimestreamDecorator
import software.amazon.smithy.rustsdk.endpoints.AwsEndpointsStdLib
import software.amazon.smithy.rustsdk.endpoints.OperationInputTestDecorator
import software.amazon.smithy.rustsdk.endpoints.RequireEndpointRules
Expand Down Expand Up @@ -69,6 +70,8 @@ val DECORATORS: List<ClientCodegenDecorator> = listOf(
S3ControlDecorator().onlyApplyTo("com.amazonaws.s3control#AWSS3ControlServiceV20180820"),
STSDecorator().onlyApplyTo("com.amazonaws.sts#AWSSecurityTokenServiceV20110615"),
SSODecorator().onlyApplyTo("com.amazonaws.sso#SWBPortalService"),
TimestreamDecorator().onlyApplyTo("com.amazonaws.timestreamwrite#Timestream_20181101"),
TimestreamDecorator().onlyApplyTo("com.amazonaws.timestreamquery#Timestream_20181101"),

// Only build docs-rs for linux to reduce load on docs.rs
listOf(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rustsdk.customize.timestream

import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.Types
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.toType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rustsdk.InlineAwsDependency

/** Adds Endpoint Discovery Utility to Timestream */
class TimestreamDecorator : ClientCodegenDecorator {
override val name: String = "Timestream"
override val order: Byte = 0

override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) {
val endpointDiscovery = InlineAwsDependency.forRustFile(
"endpoint_discovery",
Visibility.PUBLIC,
CargoDependency.Tokio.copy(scope = DependencyScope.Compile, features = setOf("sync")),
)
rustCrate.lib {
rustTemplate(
"""
async fn resolve_endpoint(client: &crate::Client) -> Result<(#{Endpoint}, #{SystemTime}), #{ResolveEndpointError}> {
let describe_endpoints =
client.describe_endpoints().send().await.map_err(|e| {
#{ResolveEndpointError}::from_source("failed to call describe_endpoints", e)
})?;
let endpoint = describe_endpoints.endpoints().unwrap().get(0).unwrap();
let expiry =
#{SystemTime}::now() + #{Duration}::from_secs(endpoint.cache_period_in_minutes() as u64 * 60);
Ok((
#{Endpoint}::builder()
.url(format!("https://{}", endpoint.address().unwrap()))
.build(),
expiry,
))
}
impl Client {
pub async fn enable_endpoint_discovery(self) -> Result<(Self, #{endpoint_discovery}::ReloadEndpoint), #{ResolveEndpointError}> {
let mut new_conf = self.conf().clone();
let sleep = self.conf().sleep_impl().expect("sleep impl must be provided");
let (resolver, reloader) = #{endpoint_discovery}::create_cache(
move || {
let client = self.clone();
async move { resolve_endpoint(&client).await }
},
sleep,
)
.await?;
new_conf.endpoint_resolver = std::sync::Arc::new(resolver);
Ok((Self::from_conf(new_conf), reloader))
}
}
""",
"endpoint_discovery" to endpointDiscovery.toType(),
"SystemTime" to RuntimeType.std.resolve("time::SystemTime"),
"Duration" to RuntimeType.std.resolve("time::Duration"),
*Types(codegenContext.runtimeConfig).toArray(),
)
}
}
}
Loading

0 comments on commit 0eef47b

Please sign in to comment.