Skip to content

Commit

Permalink
Add Connection Poisoning to aws-smithy-client
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh committed Mar 9, 2023
1 parent 26cb37a commit e6416ae
Show file tree
Hide file tree
Showing 22 changed files with 1,177 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ private class AwsFluentClientExtensions(types: Types) {
};
let mut builder = builder
.middleware(#{DynMiddleware}::new(#{Middleware}::new()))
.reconnect_mode(retry_config.reconnect_mode())
.retry_config(retry_config.into())
.operation_timeout_config(timeout_config.into());
builder.set_sleep_impl(sleep_impl);
Expand Down
3 changes: 3 additions & 0 deletions aws/sdk/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ fun generateCargoWorkspace(services: AwsServices): String {
|]
|members = [${"\n"}${services.allModules.joinToString(",\n") { "| \"$it\"" }}
|]
|
|[patch.crates-io]
|hyper = { git = 'https://github.com/hyperium/hyper', branch = "0.14.x" }
""".trimMargin()
}

Expand Down
3 changes: 3 additions & 0 deletions aws/sdk/integration-tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ members = [
"transcribestreaming",
"using-native-tls-instead-of-rustls",
]

[patch.crates-io]
hyper = { git = 'https://github.com/hyperium/hyper', branch = "0.14.x" }
62 changes: 62 additions & 0 deletions aws/sdk/integration-tests/s3/tests/reconnects.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

use aws_credential_types::provider::SharedCredentialsProvider;
use aws_credential_types::Credentials;
use aws_smithy_async::rt::sleep::TokioSleep;
use aws_smithy_client::test_connection::wire_mock::{
check_matches, ReplayedEvent, WireLevelTestConnection,
};
use aws_smithy_client::{ev, match_events};
use aws_smithy_types::retry::RetryConfig;
use aws_types::region::Region;
use aws_types::SdkConfig;
use std::sync::Arc;
use tracing_subscriber::EnvFilter;

#[tokio::test]
async fn reconnect_on_503() {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::new("trace"))
.init();
let mock = WireLevelTestConnection::spinup(vec![
ReplayedEvent::status(503),
ReplayedEvent::status(503),
ReplayedEvent::with_body("here-is-your-object"),
])
.await;

let sdk_config = SdkConfig::builder()
.region(Region::from_static("us-east-2"))
.credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests()))
.sleep_impl(Arc::new(TokioSleep::new()))
.endpoint_url(mock.endpoint_url())
.http_connector(mock.http_connector())
.retry_config(RetryConfig::standard())
.build();
let client = aws_sdk_s3::Client::new(&sdk_config);
let resp = client
.get_object()
.bucket("bucket")
.key("key")
.send()
.await
.expect("succeeds after retries");
assert_eq!(
resp.body.collect().await.unwrap().to_vec(),
b"here-is-your-object"
);
match_events!(
ev!(dns),
ev!(connect),
ev!(http(503)),
ev!(dns),
ev!(connect),
ev!(http(503)),
ev!(dns),
ev!(connect),
ev!(http(200))
)(&mock.events());
}
4 changes: 4 additions & 0 deletions rust-runtime/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[workspace]


members = [
"inlineable",
"aws-smithy-async",
Expand All @@ -18,3 +19,6 @@ members = [
"aws-smithy-http-server",
"aws-smithy-http-server-python",
]

[patch.crates-io]
hyper = { git = 'https://github.com/hyperium/hyper', branch = "0.14.x" }
8 changes: 6 additions & 2 deletions rust-runtime/aws-smithy-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ repository = "https://github.com/awslabs/smithy-rs"

[features]
rt-tokio = ["aws-smithy-async/rt-tokio"]
test-util = ["aws-smithy-protocol-test", "serde/derive", "rustls"]
test-util = ["aws-smithy-protocol-test", "serde/derive", "rustls", "hyper/server", "hyper/h2"]
native-tls = ["client-hyper", "hyper-tls", "rt-tokio"]
rustls = ["client-hyper", "hyper-rustls", "rt-tokio", "lazy_static"]
client-hyper = ["hyper"]
hyper-webpki-doctest-only = ["hyper-rustls/webpki-roots"]


[dependencies]
aws-smithy-async = { path = "../aws-smithy-async" }
aws-smithy-http = { path = "../aws-smithy-http" }
Expand All @@ -25,7 +26,7 @@ bytes = "1"
fastrand = "1.4.0"
http = "0.2.3"
http-body = "0.4.4"
hyper = { version = "0.14.12", features = ["client", "http2", "http1", "tcp"], optional = true }
hyper = { version = "0.14.24", features = ["client", "http2", "http1", "tcp"], optional = true }
# cargo does not support optional test dependencies, so to completely disable rustls when
# the native-tls feature is enabled, we need to add the webpki-roots feature here.
# https://github.com/rust-lang/cargo/issues/1596
Expand All @@ -44,6 +45,9 @@ serde = { version = "1", features = ["derive"] }
serde_json = "1"
tokio = { version = "1.8.4", features = ["full", "test-util"] }
tower-test = "0.4.0"
tracing-subscriber = "0.3.16"
tracing-test = "0.2.4"


[package.metadata.docs.rs]
all-features = true
Expand Down
45 changes: 45 additions & 0 deletions rust-runtime/aws-smithy-client/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{bounds, erase, retry, Client};
use aws_smithy_async::rt::sleep::{default_async_sleep, AsyncSleep};
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::result::ConnectorError;
use aws_smithy_types::retry::ReconnectMode;
use aws_smithy_types::timeout::{OperationTimeoutConfig, TimeoutConfig};
use std::sync::Arc;

Expand Down Expand Up @@ -37,6 +38,12 @@ pub struct Builder<C = (), M = (), R = retry::Standard> {
retry_policy: MaybeRequiresSleep<R>,
operation_timeout_config: Option<OperationTimeoutConfig>,
sleep_impl: Option<Arc<dyn AsyncSleep>>,
reconnect_mode: Option<ReconnectMode>,
}

/// transitional default: disable this behavior by default
fn default_reconnect_mode() -> ReconnectMode {
ReconnectMode::NoReconnect
}

impl<C, M> Default for Builder<C, M>
Expand All @@ -55,6 +62,7 @@ where
),
operation_timeout_config: None,
sleep_impl: default_async_sleep(),
reconnect_mode: Some(default_reconnect_mode()),
}
}
}
Expand Down Expand Up @@ -173,6 +181,7 @@ impl<M, R> Builder<(), M, R> {
retry_policy: self.retry_policy,
operation_timeout_config: self.operation_timeout_config,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode,
}
}

Expand Down Expand Up @@ -229,6 +238,7 @@ impl<C, R> Builder<C, (), R> {
operation_timeout_config: self.operation_timeout_config,
middleware,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode,
}
}

Expand Down Expand Up @@ -280,6 +290,7 @@ impl<C, M> Builder<C, M, retry::Standard> {
operation_timeout_config: self.operation_timeout_config,
middleware: self.middleware,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode,
}
}
}
Expand Down Expand Up @@ -347,6 +358,7 @@ impl<C, M, R> Builder<C, M, R> {
retry_policy: self.retry_policy,
operation_timeout_config: self.operation_timeout_config,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode,
}
}

Expand All @@ -361,9 +373,41 @@ impl<C, M, R> Builder<C, M, R> {
retry_policy: self.retry_policy,
operation_timeout_config: self.operation_timeout_config,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode,
}
}

/// Set the [`ReconnectMode`] for the retry strategy
///
/// By default, no reconnection occurs.
///
/// When enabled and a transient error is encountered, the connection in use will be poisoned.
/// This prevents reusing a connection to a potentially bad host.
pub fn reconnect_mode(mut self, reconnect_mode: ReconnectMode) -> Self {
self.set_reconnect_mode(Some(reconnect_mode));
self
}

/// Set the [`ReconnectMode`] for the retry strategy
///
/// By default, no reconnection occurs.
///
/// When enabled and a transient error is encountered, the connection in use will be poisoned.
/// This prevents reusing a connection to a potentially bad host.
pub fn set_reconnect_mode(&mut self, reconnect_mode: Option<ReconnectMode>) -> &mut Self {
self.reconnect_mode = reconnect_mode;
self
}

/// Enable reconnection on transient errors
///
/// By default, when a transient error is encountered, the connection in use will be poisoned.
/// This prevents reusing a connection to a potentially bad host but may increase the load on
/// the server.
pub fn reconnect_on_transient_errors(self) -> Self {
self.reconnect_mode(ReconnectMode::ReconnectOnTransientError)
}

/// Build a Smithy service [`Client`].
pub fn build(self) -> Client<C, M, R> {
let operation_timeout_config = self
Expand Down Expand Up @@ -392,6 +436,7 @@ impl<C, M, R> Builder<C, M, R> {
middleware: self.middleware,
operation_timeout_config,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode.unwrap_or(ReconnectMode::NoReconnect),
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions rust-runtime/aws-smithy-client/src/erase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ where
retry_policy: self.retry_policy,
operation_timeout_config: self.operation_timeout_config,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode,
}
}
}
Expand Down Expand Up @@ -101,6 +102,7 @@ where
retry_policy: self.retry_policy,
operation_timeout_config: self.operation_timeout_config,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode,
}
}

Expand Down
63 changes: 49 additions & 14 deletions rust-runtime/aws-smithy-client/src/hyper_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,22 @@ use crate::never::stream::EmptyStream;
use aws_smithy_async::future::timeout::TimedOutError;
use aws_smithy_async::rt::sleep::{default_async_sleep, AsyncSleep};
use aws_smithy_http::body::SdkBody;

use aws_smithy_http::result::ConnectorError;
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_smithy_types::retry::ErrorKind;
use http::Uri;
use hyper::client::connect::{Connected, Connection};
use http::{Extensions, Uri};
use hyper::client::connect::{
capture_connection, CaptureConnection, Connected, Connection, HttpInfo,
};

use std::error::Error;
use std::fmt::Debug;

use std::sync::Arc;

use crate::erase::boxclone::BoxFuture;
use aws_smithy_http::connection::{CaptureSmithyConnection, ConnectionMetadata};
use tokio::io::{AsyncRead, AsyncWrite};
use tower::{BoxError, Service};

Expand All @@ -108,7 +117,30 @@ use tower::{BoxError, Service};
/// see [the module documentation](crate::hyper_ext).
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct Adapter<C>(HttpReadTimeout<hyper::Client<ConnectTimeout<C>, SdkBody>>);
pub struct Adapter<C> {
client: HttpReadTimeout<hyper::Client<ConnectTimeout<C>, SdkBody>>,
}

/// Extract a smithy connection from a hyper CaptureConnection
fn extract_smithy_connection(capture_conn: &CaptureConnection) -> Option<ConnectionMetadata> {
let capture_conn = capture_conn.clone();
if let Some(conn) = capture_conn.clone().connection_metadata().as_ref() {
let mut extensions = Extensions::new();
conn.get_extras(&mut extensions);
let http_info = extensions.get::<HttpInfo>();
let smithy_connection = ConnectionMetadata::new(
conn.is_proxied(),
http_info.map(|info| info.remote_addr()),
move || match capture_conn.connection_metadata().as_ref() {
Some(conn) => conn.poison(),
None => tracing::trace!("no connection existed to poison"),
},
);
Some(smithy_connection)
} else {
None
}
}

impl<C> Service<http::Request<SdkBody>> for Adapter<C>
where
Expand All @@ -121,20 +153,22 @@ where
type Response = http::Response<SdkBody>;
type Error = ConnectorError;

#[allow(clippy::type_complexity)]
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>,
>;
type Future = BoxFuture<Self::Response, Self::Error>;

fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx).map_err(downcast_error)
self.client.poll_ready(cx).map_err(downcast_error)
}

fn call(&mut self, req: http::Request<SdkBody>) -> Self::Future {
let fut = self.0.call(req);
fn call(&mut self, mut req: http::Request<SdkBody>) -> Self::Future {
let capture_connection = capture_connection(&mut req);
if let Some(capture_smithy_connection) = req.extensions().get::<CaptureSmithyConnection>() {
capture_smithy_connection
.set_connection_retriever(move || extract_smithy_connection(&capture_connection));
}
let fut = self.client.call(req);
Box::pin(async move { Ok(fut.await.map_err(downcast_error)?.map(SdkBody::from)) })
}
}
Expand Down Expand Up @@ -271,7 +305,9 @@ impl Builder {
),
None => HttpReadTimeout::no_timeout(base),
};
Adapter(read_timeout)
Adapter {
client: read_timeout,
}
}

/// Set the async sleep implementation used for timeouts
Expand Down Expand Up @@ -343,7 +379,6 @@ mod timeout_middleware {
use pin_project_lite::pin_project;
use tower::BoxError;

use aws_smithy_async::future;
use aws_smithy_async::future::timeout::{TimedOutError, Timeout};
use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_async::rt::sleep::Sleep;
Expand Down Expand Up @@ -493,7 +528,7 @@ mod timeout_middleware {
Some((sleep, duration)) => {
let sleep = sleep.sleep(*duration);
MaybeTimeoutFuture::Timeout {
timeout: future::timeout::Timeout::new(self.inner.call(req), sleep),
timeout: Timeout::new(self.inner.call(req), sleep),
error_type: "HTTP connect",
duration: *duration,
}
Expand Down Expand Up @@ -522,7 +557,7 @@ mod timeout_middleware {
Some((sleep, duration)) => {
let sleep = sleep.sleep(*duration);
MaybeTimeoutFuture::Timeout {
timeout: future::timeout::Timeout::new(self.inner.call(req), sleep),
timeout: Timeout::new(self.inner.call(req), sleep),
error_type: "HTTP read",
duration: *duration,
}
Expand Down
Loading

0 comments on commit e6416ae

Please sign in to comment.