diff --git a/aws/sdk/integration-tests/s3/Cargo.toml b/aws/sdk/integration-tests/s3/Cargo.toml index 5093b00f740..33cadc62cf1 100644 --- a/aws/sdk/integration-tests/s3/Cargo.toml +++ b/aws/sdk/integration-tests/s3/Cargo.toml @@ -48,3 +48,6 @@ tracing-subscriber = { version = "0.3.15", features = ["env-filter", "json"] } # If you're writing a test with this, take heed! `no-env-filter` means you'll be capturing # logs from everything that speaks, so be specific with your asserts. tracing-test = { version = "0.2.4", features = ["no-env-filter"] } + +[dependencies] +once_cell = "1.18.0" diff --git a/aws/sdk/integration-tests/s3/tests/throughput-timeout.rs b/aws/sdk/integration-tests/s3/tests/throughput-timeout.rs index 0a459eda717..d07410d1fa4 100644 --- a/aws/sdk/integration-tests/s3/tests/throughput-timeout.rs +++ b/aws/sdk/integration-tests/s3/tests/throughput-timeout.rs @@ -6,31 +6,35 @@ use aws_sdk_sts::error::DisplayErrorContext; use aws_smithy_async::rt::sleep::AsyncSleep; use aws_smithy_async::test_util::instant_time_and_sleep; +use aws_smithy_async::time::SharedTimeSource; use aws_smithy_http::body::SdkBody; use aws_smithy_http::byte_stream::ByteStream; use aws_smithy_runtime::client::http::body::minimum_throughput::MinimumThroughputBody; use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs; +use aws_smithy_runtime_api::shared::IntoShared; +use aws_types::sdk_config::SharedAsyncSleep; +use bytes::Bytes; +use once_cell::sync::Lazy; use std::convert::Infallible; use std::time::{Duration, UNIX_EPOCH}; #[should_panic = "minimum throughput was specified at 2 B/s, but throughput of 1.5 B/s was observed"] #[tokio::test] -async fn test_throughput_timeout_happens_for_slow_stream() { +async fn test_throughput_timeout_less_than() { let _logs = capture_test_logs(); let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH); let shared_sleep = sleep.clone(); - // Will send ~1 byte per second because ASCII digits have a size - // of 1 byte and we sleep for 1 second after every digit we send. + // Will send ~1 byte per second. let stream = futures_util::stream::unfold(1, move |state| { let sleep = shared_sleep.clone(); async move { - if state > 100 { + if state > 255 { None } else { sleep.sleep(Duration::from_secs(1)).await; Some(( - Result::::Ok(state.to_string()), + Result::<_, Infallible>::Ok(Bytes::from(vec![state as u8])), state + 1, )) } @@ -39,7 +43,7 @@ async fn test_throughput_timeout_happens_for_slow_stream() { let body = ByteStream::new(SdkBody::from(hyper::body::Body::wrap_stream(stream))); let body = body.map(move |body| { let ts = time_source.clone(); - // Throw an error if the stream sends less than 2 bytes per second at any point + // Throw an error if the stream sends less than 2 bytes per second let minimum_throughput = (2u64, Duration::from_secs(1)); SdkBody::from_dyn(aws_smithy_http::body::BoxBody::new( MinimumThroughputBody::new(ts, body, minimum_throughput), @@ -53,23 +57,25 @@ async fn test_throughput_timeout_happens_for_slow_stream() { } } +const EXPECTED_BYTES: Lazy> = Lazy::new(|| (1..=255).map(|i| i as u8).collect::>()); + #[tokio::test] -async fn test_throughput_timeout_doesnt_happen_for_fast_stream() { +async fn test_throughput_timeout_equal_to() { let _logs = capture_test_logs(); let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH); + let time_source: SharedTimeSource = time_source.into_shared(); + let sleep: SharedAsyncSleep = sleep.into_shared(); - let shared_sleep = sleep.clone(); - // Will send ~1 byte per second because ASCII digits have a size - // of 1 byte and we sleep for 1 second after every digit we send. + // Will send ~1 byte per second. let stream = futures_util::stream::unfold(1, move |state| { - let sleep = shared_sleep.clone(); + let sleep = sleep.clone(); async move { - if state > 100 { + if state > 255 { None } else { sleep.sleep(Duration::from_secs(1)).await; Some(( - Result::::Ok(state.to_string()), + Result::<_, Infallible>::Ok(Bytes::from(vec![state as u8])), state + 1, )) } @@ -77,16 +83,60 @@ async fn test_throughput_timeout_doesnt_happen_for_fast_stream() { }); let body = ByteStream::new(SdkBody::from(hyper::body::Body::wrap_stream(stream))); let body = body.map(move |body| { - let ts = time_source.clone(); - // Throw an error if the stream sends less than 1 bytes per 2s at any point + let time_source = time_source.clone(); + // Throw an error if the stream sends less than 1 byte per second + let minimum_throughput = (1u64, Duration::from_secs(1)); + SdkBody::from_dyn(aws_smithy_http::body::BoxBody::new( + MinimumThroughputBody::new(time_source, body, minimum_throughput), + )) + }); + // assert_eq!(255.0, time_source.seconds_since_unix_epoch()); + // assert_eq!(Duration::from_secs(255), sleep.total_duration()); + let res = body + .collect() + .await + .expect("no streaming error occurs because data is sent fast enough") + .to_vec(); + assert_eq!(*EXPECTED_BYTES, res); +} + +#[tokio::test] +async fn test_throughput_timeout_greater_than() { + let _logs = capture_test_logs(); + let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH); + let time_source: SharedTimeSource = time_source.into_shared(); + let sleep: SharedAsyncSleep = sleep.into_shared(); + + // Will send ~1 byte per second. + let stream = futures_util::stream::unfold(1, move |state| { + let sleep = sleep.clone(); + async move { + if state > 255 { + None + } else { + sleep.sleep(Duration::from_secs(1)).await; + Some(( + Result::<_, Infallible>::Ok(Bytes::from(vec![state as u8])), + state + 1, + )) + } + } + }); + let body = ByteStream::new(SdkBody::from(hyper::body::Body::wrap_stream(stream))); + let body = body.map(move |body| { + let time_source = time_source.clone(); + // Throw an error if the stream sends less than 1 byte per 2s let minimum_throughput = (1u64, Duration::from_secs(2)); SdkBody::from_dyn(aws_smithy_http::body::BoxBody::new( - MinimumThroughputBody::new(ts, body, minimum_throughput), + MinimumThroughputBody::new(time_source, body, minimum_throughput), )) }); - assert_eq!(Duration::from_secs(100), sleep.total_duration()); - let _res = body + // assert_eq!(255.0, time_source.seconds_since_unix_epoch()); + // assert_eq!(Duration::from_secs(255), sleep.total_duration()); + let res = body .collect() .await - .expect("no streaming error occurs because data is sent fast enough"); + .expect("no streaming error occurs because data is sent fast enough") + .to_vec(); + assert_eq!(*EXPECTED_BYTES, res); } diff --git a/rust-runtime/aws-smithy-async/src/test_util.rs b/rust-runtime/aws-smithy-async/src/test_util.rs index dd7eacfd895..2627d64f215 100644 --- a/rust-runtime/aws-smithy-async/src/test_util.rs +++ b/rust-runtime/aws-smithy-async/src/test_util.rs @@ -74,6 +74,12 @@ impl TimeSource for ManualTimeSource { } } +impl TimeSource for Arc { + fn now(&self) -> SystemTime { + self._now(&self.log.lock().unwrap()) + } +} + /// A sleep implementation where calls to [`AsyncSleep::sleep`] block until [`SleepGate::expect_sleep`] is called /// /// Create a [`ControlledSleep`] with [`controlled_time_and_sleep`]