diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 6967c64b36..e7e1a0b346 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -138,3 +138,21 @@ message = "Remove `once_cell` from public API" references = ["smithy-rs#2973"] meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "all" } author = "ysaito1001" + +[[aws-sdk-rust]] +message = """ +The `futures_core::stream::Stream` trait has been removed from public API. It should not affect usual SDK use cases, but it does require code upgrade for a small number of cases. The notable example is Transcribe streaming when streaming data is created via a `stream!` macro from the `async-stream` crate. The use of that macro needs to be replaced with `aws_smithy_async::future::fn_stream::FnStream`. See https://github.com/awslabs/smithy-rs/discussions/2952 for more details. +""" +references = ["smithy-rs#2910"] +meta = { "breaking" = true, "tada" = false, "bug" = false } +author = "ysaito1001" + +[[smithy-rs]] +message = """ +The `futures_core::stream::Stream` trait has been removed from public API. The methods that were made available through the `Stream` trait have been removed from `FnStream`, `TryFlatMap`, ByteStream`, `EventStreamSender`, and `MessageStreamAdapter`. However, we have preserved `.next()` and `.collect()` to continue supporting existing call sites in `smithy-rs` and `aws-sdk-rust`, including tests and rustdocs. If we need to support missing stream operations, we are planning to do so in an additive, backward compatible manner. + +If your code uses a `stream!` macro from the `async_stream` crate to generate stream data, it needs to be replaced by `aws_smithy_async::future::fn_steram::FnStream`. See https://github.com/awslabs/smithy-rs/discussions/2952 for more details. +""" +references = ["smithy-rs#2910"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "all" } +author = "ysaito1001" diff --git a/aws/rust-runtime/aws-inlineable/src/glacier_checksums.rs b/aws/rust-runtime/aws-inlineable/src/glacier_checksums.rs index 18f1d9219e..bf95910e00 100644 --- a/aws/rust-runtime/aws-inlineable/src/glacier_checksums.rs +++ b/aws/rust-runtime/aws-inlineable/src/glacier_checksums.rs @@ -14,7 +14,6 @@ use bytes::Buf; use bytes_utils::SegmentedBuf; use http::header::HeaderName; use ring::digest::{Context, Digest, SHA256}; -use tokio_stream::StreamExt; const TREE_HASH_HEADER: &str = "x-amz-sha256-tree-hash"; const X_AMZ_CONTENT_SHA256: &str = "x-amz-content-sha256"; diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt index 95553d8006..a93d321a91 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt @@ -124,7 +124,7 @@ class TranscribeTestDependencies : LibRsCustomization() { override fun section(section: LibRsSection): Writable = writable { addDependency(AsyncStream) - addDependency(FuturesCore) + addDependency(FuturesCore.toDevDependency()) addDependency(Hound) } } diff --git a/aws/sdk/integration-tests/dynamodb/tests/paginators.rs b/aws/sdk/integration-tests/dynamodb/tests/paginators.rs index 807a11890d..a3d0c62473 100644 --- a/aws/sdk/integration-tests/dynamodb/tests/paginators.rs +++ b/aws/sdk/integration-tests/dynamodb/tests/paginators.rs @@ -6,8 +6,6 @@ use std::collections::HashMap; use std::iter::FromIterator; -use tokio_stream::StreamExt; - use aws_credential_types::Credentials; use aws_sdk_dynamodb::types::AttributeValue; use aws_sdk_dynamodb::{Client, Config}; diff --git a/aws/sdk/integration-tests/ec2/tests/paginators.rs b/aws/sdk/integration-tests/ec2/tests/paginators.rs index 83528f2075..d070971a4f 100644 --- a/aws/sdk/integration-tests/ec2/tests/paginators.rs +++ b/aws/sdk/integration-tests/ec2/tests/paginators.rs @@ -3,8 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -use tokio_stream::StreamExt; - use aws_sdk_ec2::{config::Credentials, config::Region, types::InstanceType, Client, Config}; use aws_smithy_client::http_connector::HttpConnector; use aws_smithy_client::test_connection::TestConnection; diff --git a/aws/sdk/integration-tests/transcribestreaming/Cargo.toml b/aws/sdk/integration-tests/transcribestreaming/Cargo.toml index 181ba493cb..33dfc393d4 100644 --- a/aws/sdk/integration-tests/transcribestreaming/Cargo.toml +++ b/aws/sdk/integration-tests/transcribestreaming/Cargo.toml @@ -9,7 +9,6 @@ repository = "https://github.com/awslabs/smithy-rs" publish = false [dev-dependencies] -async-stream = "0.3.0" aws-credential-types = { path = "../../build/aws-sdk/sdk/aws-credential-types", features = ["test-util"] } aws-http = { path = "../../build/aws-sdk/sdk/aws-http" } aws-sdk-transcribestreaming = { path = "../../build/aws-sdk/sdk/transcribestreaming" } diff --git a/aws/sdk/integration-tests/transcribestreaming/tests/test.rs b/aws/sdk/integration-tests/transcribestreaming/tests/test.rs index 62654ebd82..333ed88b95 100644 --- a/aws/sdk/integration-tests/transcribestreaming/tests/test.rs +++ b/aws/sdk/integration-tests/transcribestreaming/tests/test.rs @@ -3,11 +3,10 @@ * SPDX-License-Identifier: Apache-2.0 */ -use async_stream::stream; use aws_sdk_transcribestreaming::config::{Credentials, Region}; use aws_sdk_transcribestreaming::error::SdkError; use aws_sdk_transcribestreaming::operation::start_stream_transcription::StartStreamTranscriptionOutput; -use aws_sdk_transcribestreaming::primitives::Blob; +use aws_sdk_transcribestreaming::primitives::{Blob, FnStream}; use aws_sdk_transcribestreaming::types::error::{AudioStreamError, TranscriptResultStreamError}; use aws_sdk_transcribestreaming::types::{ AudioEvent, AudioStream, LanguageCode, MediaEncoding, TranscriptResultStream, @@ -16,7 +15,6 @@ use aws_sdk_transcribestreaming::{Client, Config}; use aws_smithy_client::dvr::{Event, ReplayingConnection}; use aws_smithy_eventstream::frame::{DecodedFrame, HeaderValue, Message, MessageFrameDecoder}; use bytes::BufMut; -use futures_core::Stream; use std::collections::{BTreeMap, BTreeSet}; use std::error::Error as StdError; @@ -24,12 +22,18 @@ const CHUNK_SIZE: usize = 8192; #[tokio::test] async fn test_success() { - let input_stream = stream! { - let pcm = pcm_data(); - for chunk in pcm.chunks(CHUNK_SIZE) { - yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); - } - }; + let input_stream = FnStream::new(|tx| { + Box::pin(async move { + let pcm = pcm_data(); + for chunk in pcm.chunks(CHUNK_SIZE) { + tx.send(Ok(AudioStream::AudioEvent( + AudioEvent::builder().audio_chunk(Blob::new(chunk)).build(), + ))) + .await + .expect("send should succeed"); + } + }) + }); let (replayer, mut output) = start_request("us-west-2", include_str!("success.json"), input_stream).await; @@ -65,12 +69,18 @@ async fn test_success() { #[tokio::test] async fn test_error() { - let input_stream = stream! { - let pcm = pcm_data(); - for chunk in pcm.chunks(CHUNK_SIZE).take(1) { - yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); - } - }; + let input_stream = FnStream::new(|tx| { + Box::pin(async move { + let pcm = pcm_data(); + for chunk in pcm.chunks(CHUNK_SIZE).take(1) { + tx.send(Ok(AudioStream::AudioEvent( + AudioEvent::builder().audio_chunk(Blob::new(chunk)).build(), + ))) + .await + .expect("send should succeed"); + } + }) + }); let (replayer, mut output) = start_request("us-east-1", include_str!("error.json"), input_stream).await; @@ -97,7 +107,7 @@ async fn test_error() { async fn start_request( region: &'static str, events_json: &str, - input_stream: impl Stream> + Send + Sync + 'static, + input_stream: FnStream>, ) -> (ReplayingConnection, StartStreamTranscriptionOutput) { let events: Vec = serde_json::from_str(events_json).unwrap(); let replayer = ReplayingConnection::new(events); diff --git a/aws/sdk/sdk-external-types.toml b/aws/sdk/sdk-external-types.toml index b484544c27..b623edd4a1 100644 --- a/aws/sdk/sdk-external-types.toml +++ b/aws/sdk/sdk-external-types.toml @@ -19,14 +19,6 @@ allowed_external_types = [ "http::uri::Uri", "http::method::Method", - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Switch to AsyncIterator once standardized - "futures_core::stream::Stream", - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Once tooling permits it, only allow the following types in the `event-stream` feature "aws_smithy_eventstream::*", - - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Decide if we want to continue exposing tower_layer - "tower_layer::Layer", - "tower_layer::identity::Identity", - "tower_layer::stack::Stack", ] diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt index dedc50f31f..5adf9ac97c 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt @@ -141,8 +141,9 @@ class PaginatorGenerator private constructor( /// Create the pagination stream /// - /// _Note:_ No requests will be dispatched until the stream is used (eg. with [`.next().await`](tokio_stream::StreamExt::next)). - pub fn send(self) -> impl #{Stream} + #{Unpin} { + /// _Note:_ No requests will be dispatched until the stream is used + /// (e.g. with [`.next().await`](aws_smithy_async::future::fn_stream::FnStream::next)). + pub fn send(self) -> #{fn_stream}::FnStream<#{item_type}> { // Move individual fields out of self for the borrow checker let builder = self.builder; let handle = self.handle; @@ -257,10 +258,11 @@ class PaginatorGenerator private constructor( impl ${paginatorName}Items { /// Create the pagination stream /// - /// _Note: No requests will be dispatched until the stream is used (eg. with [`.next().await`](tokio_stream::StreamExt::next))._ + /// _Note_: No requests will be dispatched until the stream is used + /// (e.g. with [`.next().await`](aws_smithy_async::future::fn_stream::FnStream::next)). /// - /// To read the entirety of the paginator, use [`.collect::, _>()`](tokio_stream::StreamExt::collect). - pub fn send(self) -> impl #{Stream} + #{Unpin} { + /// To read the entirety of the paginator, use [`.collect::, _>()`](aws_smithy_async::future::fn_stream::FnStream::collect). + pub fn send(self) -> #{fn_stream}::FnStream<#{item_type}> { #{fn_stream}::TryFlatMap::new(self.0.send()).flat_map(|page| #{extract_items}(page).unwrap_or_default().into_iter()) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt index b2927dca8e..e0ee3314e1 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt @@ -433,7 +433,7 @@ class FluentClientGenerator( """ /// Create a paginator for this request /// - /// Paginators are used by calling [`send().await`](#{Paginator}::send) which returns a `Stream`. + /// Paginators are used by calling [`send().await`](#{Paginator}::send) which returns an [`FnStream`](aws_smithy_async::future::fn_stream::FnStream). pub fn into_paginator(self) -> #{Paginator} { #{Paginator}::new(self.handle, self.inner) } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index 1f60c251ac..e2283054ec 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -7,12 +7,32 @@ package software.amazon.smithy.rust.codegen.client.smithy.protocols import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.smithy.protocols.StreamPayloadSerializerParams +import software.amazon.smithy.rust.codegen.core.smithy.protocols.StreamPayloadSerializerRenderer + +private class ClientStreamPayloadSerializerRenderer : StreamPayloadSerializerRenderer { + override fun renderOutputType(writer: RustWriter, params: StreamPayloadSerializerParams) { + writer.rust( + "#T", + RuntimeType.futuresStreamCompatByteStream(params.runtimeConfig).toSymbol(), + ) + } + + override fun renderPayload(writer: RustWriter, params: StreamPayloadSerializerParams) { + writer.rust( + "#T::new(${params.payloadName!!})", + RuntimeType.futuresStreamCompatByteStream(params.runtimeConfig), + ) + } +} class ClientHttpBoundProtocolPayloadGenerator( codegenContext: ClientCodegenContext, @@ -29,16 +49,19 @@ class ClientHttpBoundProtocolPayloadGenerator( _cfg.interceptor_state().store_put(signer_sender); let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> = ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer); - let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into(); + let body: #{SdkBody} = #{hyper}::Body::wrap_stream(#{FuturesStreamCompatEventStream}::new(adapter)).into(); body } """, "hyper" to CargoDependency.HyperWithStream.toType(), "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig), "aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), - "DeferredSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::DeferredSigner"), + "DeferredSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig) + .resolve("frame::DeferredSigner"), + "FuturesStreamCompatEventStream" to RuntimeType.futuresStreamCompatEventStream(codegenContext.runtimeConfig), "marshallerConstructorFn" to params.marshallerConstructorFn, "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, ) }, + ClientStreamPayloadSerializerRenderer(), ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index 73a5e9ebbe..d3f3819b94 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -227,6 +227,7 @@ data class CargoDependency( val Bytes: CargoDependency = CargoDependency("bytes", CratesIo("1.0.0")) val BytesUtils: CargoDependency = CargoDependency("bytes-utils", CratesIo("0.1.0")) val FastRand: CargoDependency = CargoDependency("fastrand", CratesIo("2.0.0")) + val FuturesCore: CargoDependency = CargoDependency("futures-core", CratesIo("0.3.25")) val Hex: CargoDependency = CargoDependency("hex", CratesIo("0.4.3")) val Http: CargoDependency = CargoDependency("http", CratesIo("0.2.9")) val HttpBody: CargoDependency = CargoDependency("http-body", CratesIo("0.4.4")) @@ -246,7 +247,6 @@ data class CargoDependency( val AsyncStd: CargoDependency = CargoDependency("async-std", CratesIo("1.12.0"), DependencyScope.Dev) val AsyncStream: CargoDependency = CargoDependency("async-stream", CratesIo("0.3.0"), DependencyScope.Dev) val Criterion: CargoDependency = CargoDependency("criterion", CratesIo("0.4.0"), DependencyScope.Dev) - val FuturesCore: CargoDependency = CargoDependency("futures-core", CratesIo("0.3.25"), DependencyScope.Dev) val FuturesUtil: CargoDependency = CargoDependency("futures-util", CratesIo("0.3.25"), DependencyScope.Dev, defaultFeatures = false) val HdrHistogram: CargoDependency = CargoDependency("hdrhistogram", CratesIo("7.5.2"), DependencyScope.Dev) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index 2341c36e5c..7aba2a5aee 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -406,9 +406,12 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) fun retryErrorKind(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("retry::ErrorKind") fun eventStreamReceiver(runtimeConfig: RuntimeConfig): RuntimeType = smithyHttp(runtimeConfig).resolve("event_stream::Receiver") - fun eventStreamSender(runtimeConfig: RuntimeConfig): RuntimeType = smithyHttp(runtimeConfig).resolve("event_stream::EventStreamSender") + fun futuresStreamCompatByteStream(runtimeConfig: RuntimeConfig): RuntimeType = + smithyHttp(runtimeConfig).resolve("futures_stream_adapter::FuturesStreamCompatByteStream") + fun futuresStreamCompatEventStream(runtimeConfig: RuntimeConfig): RuntimeType = + smithyHttp(runtimeConfig).resolve("futures_stream_adapter::FuturesStreamCompatEventStream") fun errorMetadata(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("error::ErrorMetadata") fun errorMetadataBuilder(runtimeConfig: RuntimeConfig) = diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt index 14d081b2cb..22b9b0eb1e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt @@ -14,6 +14,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.util.hasEventStreamMember +import software.amazon.smithy.rust.codegen.core.util.hasEventStreamOperations import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember /** Returns true if the model has normal streaming operations (excluding event streams) */ @@ -70,4 +71,12 @@ fun pubUseSmithyPrimitives(codegenContext: CodegenContext, model: Model): Writab "SdkBody" to RuntimeType.smithyHttp(rc).resolve("body::SdkBody"), ) } + if (codegenContext.serviceShape.hasEventStreamOperations(model)) { + rustTemplate( + """ + pub use #{FnStream}; + """, + "FnStream" to RuntimeType.smithyAsync(rc).resolve("future::fn_stream::FnStream"), + ) + } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index 0ed594fc28..303e74c48d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -6,6 +6,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.DocumentShape import software.amazon.smithy.model.shapes.MemberShape @@ -17,12 +18,15 @@ import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.AdditionalPayloadContext @@ -50,11 +54,38 @@ data class EventStreamBodyParams( val additionalPayloadContext: AdditionalPayloadContext, ) +data class StreamPayloadSerializerParams( + val symbolProvider: SymbolProvider, + val runtimeConfig: RuntimeConfig, + val member: MemberShape, + val payloadName: String?, +) + +/** + * An interface to help customize how to render a stream payload serializer. + * + * When the output of the serializer is passed to `hyper::body::Body::wrap_stream`, + * it requires what's passed to implement `futures_core::stream::Stream` trait. + * However, a certain type, such as `aws_smithy_http::byte_stream::ByteStream` does not + * implement the trait, so we need to wrap it with a new-type that does implement the trait. + * + * Each implementing type of the interface can choose whether the payload should be wrapped + * with such a new-type or should simply be used as-is. + */ +interface StreamPayloadSerializerRenderer { + /** Renders the return type of stream payload serializer **/ + fun renderOutputType(writer: RustWriter, params: StreamPayloadSerializerParams) + + /** Renders the stream payload **/ + fun renderPayload(writer: RustWriter, params: StreamPayloadSerializerParams) +} + class HttpBoundProtocolPayloadGenerator( codegenContext: CodegenContext, private val protocol: Protocol, private val httpMessageType: HttpMessageType = HttpMessageType.REQUEST, private val renderEventStreamBody: (RustWriter, EventStreamBodyParams) -> Unit, + private val streamPayloadSerializerRenderer: StreamPayloadSerializerRenderer, ) : ProtocolPayloadGenerator { private val symbolProvider = codegenContext.symbolProvider private val model = codegenContext.model @@ -63,11 +94,13 @@ class HttpBoundProtocolPayloadGenerator( private val httpBindingResolver = protocol.httpBindingResolver private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig) private val codegenScope = arrayOf( - "hyper" to CargoDependency.HyperWithStream.toType(), - "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + *preludeScope, "BuildError" to runtimeConfig.operationBuildError(), - "SmithyHttp" to RuntimeType.smithyHttp(runtimeConfig), + "FuturesStreamCompatEventStream" to RuntimeType.futuresStreamCompatEventStream(runtimeConfig), "NoOpSigner" to smithyEventStream.resolve("frame::NoOpSigner"), + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + "SmithyHttp" to RuntimeType.smithyHttp(runtimeConfig), + "hyper" to CargoDependency.HyperWithStream.toType(), ) private val protocolFunctions = ProtocolFunctions(codegenContext) @@ -78,6 +111,7 @@ class HttpBoundProtocolPayloadGenerator( val (shape, payloadMemberName) = when (httpMessageType) { HttpMessageType.RESPONSE -> operationShape.outputShape(model) to httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName + HttpMessageType.REQUEST -> operationShape.inputShape(model) to httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName } @@ -97,6 +131,7 @@ class HttpBoundProtocolPayloadGenerator( is DocumentShape, is StructureShape, is UnionShape -> ProtocolPayloadGenerator.PayloadMetadata( takesOwnership = false, ) + is StringShape, is BlobShape -> ProtocolPayloadGenerator.PayloadMetadata(takesOwnership = true) else -> UNREACHABLE("Unexpected payload target type: $type") } @@ -110,8 +145,19 @@ class HttpBoundProtocolPayloadGenerator( additionalPayloadContext: AdditionalPayloadContext, ) { when (httpMessageType) { - HttpMessageType.RESPONSE -> generateResponsePayload(writer, shapeName, operationShape, additionalPayloadContext) - HttpMessageType.REQUEST -> generateRequestPayload(writer, shapeName, operationShape, additionalPayloadContext) + HttpMessageType.RESPONSE -> generateResponsePayload( + writer, + shapeName, + operationShape, + additionalPayloadContext, + ) + + HttpMessageType.REQUEST -> generateRequestPayload( + writer, + shapeName, + operationShape, + additionalPayloadContext, + ) } } @@ -119,13 +165,20 @@ class HttpBoundProtocolPayloadGenerator( writer: RustWriter, shapeName: String, operationShape: OperationShape, additionalPayloadContext: AdditionalPayloadContext, ) { - val payloadMemberName = httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName + val payloadMemberName = + httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName if (payloadMemberName == null) { val serializerGenerator = protocol.structuredDataSerializer() generateStructureSerializer(writer, shapeName, serializerGenerator.operationInputSerializer(operationShape)) } else { - generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName, additionalPayloadContext) + generatePayloadMemberSerializer( + writer, + shapeName, + operationShape, + payloadMemberName, + additionalPayloadContext, + ) } } @@ -133,13 +186,24 @@ class HttpBoundProtocolPayloadGenerator( writer: RustWriter, shapeName: String, operationShape: OperationShape, additionalPayloadContext: AdditionalPayloadContext, ) { - val payloadMemberName = httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName + val payloadMemberName = + httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName if (payloadMemberName == null) { val serializerGenerator = protocol.structuredDataSerializer() - generateStructureSerializer(writer, shapeName, serializerGenerator.operationOutputSerializer(operationShape)) + generateStructureSerializer( + writer, + shapeName, + serializerGenerator.operationOutputSerializer(operationShape), + ) } else { - generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName, additionalPayloadContext) + generatePayloadMemberSerializer( + writer, + shapeName, + operationShape, + payloadMemberName, + additionalPayloadContext, + ) } } @@ -155,10 +219,22 @@ class HttpBoundProtocolPayloadGenerator( if (operationShape.isEventStream(model)) { if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) { val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName) - writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, shapeName, additionalPayloadContext) + writer.serializeViaEventStream( + operationShape, + payloadMember, + serializerGenerator, + shapeName, + additionalPayloadContext, + ) } else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) { val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName) - writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "output", additionalPayloadContext) + writer.serializeViaEventStream( + operationShape, + payloadMember, + serializerGenerator, + "output", + additionalPayloadContext, + ) } else { throw CodegenException("Payload serializer for event streams with an invalid configuration") } @@ -216,18 +292,33 @@ class HttpBoundProtocolPayloadGenerator( contentType, ).render() - // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the - // parameters that are not `@eventHeader` or `@eventPayload`. - renderEventStreamBody( - this, - EventStreamBodyParams( - outerName, - memberName, - marshallerConstructorFn, - errorMarshallerConstructorFn, - additionalPayloadContext, - ), - ) + val renderEventStreamBody = writable { + // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the + // parameters that are not `@eventHeader` or `@eventPayload`. + renderEventStreamBody( + this, + EventStreamBodyParams( + outerName, + memberName, + marshallerConstructorFn, + errorMarshallerConstructorFn, + additionalPayloadContext, + ), + ) + } + + if (target == CodegenTarget.CLIENT) { + // No need to wrap it with `FuturesStreamCompatEventStream` for the client since wrapping takes place + // within `renderEventStreamBody` provided by `ClientHttpBoundProtocolPayloadGenerator`. + renderEventStreamBody() + } else { + withBlockTemplate( + "#{FuturesStreamCompatEventStream}::new(", ")", + *codegenScope, + ) { + renderEventStreamBody() + } + } } private fun RustWriter.serializeViaPayload( @@ -238,17 +329,22 @@ class HttpBoundProtocolPayloadGenerator( ) { val ref = if (payloadMetadata.takesOwnership) "" else "&" val serializer = protocolFunctions.serializeFn(member, fnNameSuffix = "http_payload") { fnName -> - val outputT = if (member.isStreaming(model)) { - symbolProvider.toSymbol(member) - } else { - RuntimeType.ByteSlab.toSymbol() - } - rustBlockTemplate( - "pub fn $fnName(payload: $ref#{Member}) -> Result<#{outputT}, #{BuildError}>", + rustTemplate( + "pub(crate) fn $fnName(payload: $ref#{Member}) -> #{Result}<", "Member" to symbolProvider.toSymbol(member), - "outputT" to outputT, *codegenScope, - ) { + ) + if (member.isStreaming(model)) { + streamPayloadSerializerRenderer.renderOutputType( + this, + StreamPayloadSerializerParams(symbolProvider, runtimeConfig, member, null), + ) + } else { + rust("#T", RuntimeType.ByteSlab.toSymbol()) + } + rustTemplate(", #{BuildError}>", *codegenScope) + + withBlockTemplate("{", "}", *codegenScope) { val asRef = if (payloadMetadata.takesOwnership) "" else ".as_ref()" if (symbolProvider.toSymbol(member).isOptional()) { @@ -268,6 +364,7 @@ class HttpBoundProtocolPayloadGenerator( Vec::new() """, ) + is StructureShape -> rust("#T()", serializerGenerator.unsetStructure(targetShape)) is UnionShape -> rust("#T()", serializerGenerator.unsetUnion(targetShape)) else -> throw CodegenException("`httpPayload` on member shapes targeting shapes of type ${targetShape.type} is unsupported") @@ -303,13 +400,16 @@ class HttpBoundProtocolPayloadGenerator( is BlobShape -> { // Write the raw blob to the payload. if (member.isStreaming(model)) { - // Return the `ByteStream`. - rust(payloadName) + streamPayloadSerializerRenderer.renderPayload( + this, + StreamPayloadSerializerParams(symbolProvider, runtimeConfig, member, payloadName), + ) } else { // Convert the `Blob` into a `Vec` and return it. rust("$payloadName.into_inner()") } } + is StructureShape, is UnionShape -> { check( !((targetShape as? UnionShape)?.isEventStream() ?: false), @@ -320,12 +420,14 @@ class HttpBoundProtocolPayloadGenerator( serializer.payloadSerializer(member), ) } + is DocumentShape -> { rust( "#T($payloadName)", serializer.documentSerializer(), ) } + else -> PANIC("Unexpected payload target type: $targetShape") } } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEventStreamWrapperGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEventStreamWrapperGenerator.kt index 8594d703c1..0832c38d79 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEventStreamWrapperGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEventStreamWrapperGenerator.kt @@ -12,6 +12,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.util.isOutputEventStream import software.amazon.smithy.rust.codegen.core.util.toPascalCase @@ -66,6 +67,7 @@ class PythonServerEventStreamWrapperGenerator( private val pyO3 = PythonServerCargoDependency.PyO3.toType() private val codegenScope = arrayOf( + *preludeScope, "Inner" to innerT, "Error" to errorT, "SmithyPython" to PythonServerCargoDependency.smithyHttpServerPython(runtimeConfig).toType(), @@ -81,6 +83,7 @@ class PythonServerEventStreamWrapperGenerator( "Option" to RuntimeType.Option, "Arc" to RuntimeType.Arc, "Body" to RuntimeType.sdkBody(runtimeConfig), + "FnStream" to RuntimeType.smithyAsync(runtimeConfig).resolve("future::fn_stream::FnStream"), "UnmarshallMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::UnmarshallMessage"), "MarshallMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::MarshallMessage"), "SignMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessage"), @@ -137,7 +140,7 @@ class PythonServerEventStreamWrapperGenerator( fn extract(obj: &'source #{PyO3}::PyAny) -> #{PyO3}::PyResult { use #{TokioStream}::StreamExt; let stream = #{PyO3Asyncio}::tokio::into_stream_v1(obj)?; - let stream = stream.filter_map(|res| { + let mut stream = stream.filter_map(|res| { #{PyO3}::Python::with_gil(|py| { // TODO(EventStreamImprovements): Add `InternalServerError` variant to all event streaming // errors and return that variant in case of errors here? @@ -166,6 +169,14 @@ class PythonServerEventStreamWrapperGenerator( }) }); + let stream = #{FnStream}::new(|tx| { + Box::pin(async move { + while let #{Some}(item) = stream.next().await { + tx.send(item).await.expect("send should succeed"); + } + }) + }); + Ok($name { inner: #{Arc}::new(#{Mutex}::new(Some(stream.into()))) }) } } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt index f84d35e7cb..5569c92526 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt @@ -61,6 +61,29 @@ class PythonServerAfterDeserializedMemberServerHttpBoundCustomization() : is ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember -> writable { rust(".into()") } + + else -> emptySection + } +} + +/** + * Customization class used to determine the type of serialized stream payload and how it should be wrapped in a + * new-type wrapper to enable `futures_core::stream::Stream` trait. + */ +class PythonServerStreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomization() { + override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { + is ServerHttpBoundProtocolSection.TypeOfSerializedStreamPayload -> writable { + // `aws_smithy_http_server_python::types::ByteStream` already implements + // `futures::stream::Stream`, so no need to wrap it in a futures' stream-compatible + // wrapper. + rust("#T", section.params.symbolProvider.toSymbol(section.params.member)) + } + + is ServerHttpBoundProtocolSection.WrapStreamPayload -> writable { + // payloadName is always non-null within WrapStreamAfterPayloadGenerated + rust(section.params.payloadName!!) + } + else -> emptySection } } @@ -91,6 +114,7 @@ class PythonServerProtocolLoader( ), additionalServerHttpBoundProtocolCustomizations = listOf( PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), + PythonServerStreamPayloadSerializerCustomization(), ), additionalHttpBindingCustomizations = listOf( PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), @@ -103,6 +127,7 @@ class PythonServerProtocolLoader( ), additionalServerHttpBoundProtocolCustomizations = listOf( PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), + PythonServerStreamPayloadSerializerCustomization(), ), additionalHttpBindingCustomizations = listOf( PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), @@ -115,6 +140,7 @@ class PythonServerProtocolLoader( ), additionalServerHttpBoundProtocolCustomizations = listOf( PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), + PythonServerStreamPayloadSerializerCustomization(), ), additionalHttpBindingCustomizations = listOf( PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index c579d64415..bc6fa6ae01 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -57,6 +57,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtoc import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.smithy.protocols.StreamPayloadSerializerParams +import software.amazon.smithy.rust.codegen.core.smithy.protocols.StreamPayloadSerializerRenderer import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors @@ -85,7 +87,25 @@ import java.util.logging.Logger * Class describing a ServerHttpBoundProtocol section that can be used in a customization. */ sealed class ServerHttpBoundProtocolSection(name: String) : Section(name) { - data class AfterTimestampDeserializedMember(val shape: MemberShape) : ServerHttpBoundProtocolSection("AfterTimestampDeserializedMember") + data class AfterTimestampDeserializedMember(val shape: MemberShape) : + ServerHttpBoundProtocolSection("AfterTimestampDeserializedMember") + + /** + * Represent a section for rendering the return type of serialized stream payload. + * + * When overriding the `section` method, this should render [Symbol] for that return type. + */ + data class TypeOfSerializedStreamPayload(val params: StreamPayloadSerializerParams) : + ServerHttpBoundProtocolSection("TypeOfSerializedStreamPayload") + + /** + * Represent a section for rendering the serialized stream payload. + * + * When overriding the `section` method, this should render either the payload as-is or the payload wrapped + * with a new-type that implements the `futures_core::stream::Stream` trait. + */ + data class WrapStreamPayload(val params: StreamPayloadSerializerParams) : + ServerHttpBoundProtocolSection("WrapStreamPayload") } /** @@ -105,7 +125,12 @@ class ServerHttpBoundProtocolGenerator( additionalHttpBindingCustomizations: List = listOf(), ) : ServerProtocolGenerator( protocol, - ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations), + ServerHttpBoundProtocolTraitImplGenerator( + codegenContext, + protocol, + customizations, + additionalHttpBindingCustomizations, + ), ) { // Define suffixes for operation input / output / error wrappers companion object { @@ -114,9 +139,40 @@ class ServerHttpBoundProtocolGenerator( } } +/** + * Server implementation of the [StreamPayloadSerializerRenderer] interface. + * + * The implementation of each method is delegated to [customizations]. Regular server codegen and python server + * have different requirements for how to render stream payload serializers, and they express their requirements + * through customizations, specifically with [TypeOfSerializedStreamPayload] and [WrapStreamPayload]. + */ +private class ServerStreamPayloadSerializerRenderer(private val customizations: List) : + StreamPayloadSerializerRenderer { + override fun renderOutputType(writer: RustWriter, params: StreamPayloadSerializerParams) { + for (customization in customizations) { + customization.section( + ServerHttpBoundProtocolSection.TypeOfSerializedStreamPayload( + params, + ), + )(writer) + } + } + + override fun renderPayload(writer: RustWriter, params: StreamPayloadSerializerParams) { + for (customization in customizations) { + customization.section( + ServerHttpBoundProtocolSection.WrapStreamPayload( + params, + ), + )(writer) + } + } +} + class ServerHttpBoundProtocolPayloadGenerator( codegenContext: CodegenContext, protocol: Protocol, + customizations: List = listOf(), ) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator( codegenContext, protocol, HttpMessageType.RESPONSE, renderEventStreamBody = { writer, params -> @@ -137,6 +193,7 @@ class ServerHttpBoundProtocolPayloadGenerator( "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, ) }, + ServerStreamPayloadSerializerRenderer(customizations), ) /* @@ -538,8 +595,12 @@ class ServerHttpBoundProtocolTraitImplGenerator( ?: serverRenderHttpResponseCode(httpTraitStatusCode)(this) operationShape.outputShape(model).findStreamingMember(model)?.let { - val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol) - withBlockTemplate("let body = #{SmithyHttpServer}::body::boxed(#{SmithyHttpServer}::body::Body::wrap_stream(", "));", *codegenScope) { + val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol, customizations) + withBlockTemplate( + "let body = #{SmithyHttpServer}::body::boxed(#{SmithyHttpServer}::body::Body::wrap_stream(", + "));", + *codegenScope, + ) { payloadGenerator.generatePayload(this, "output", operationShape) } } ?: run { @@ -576,7 +637,10 @@ class ServerHttpBoundProtocolTraitImplGenerator( * 2. The protocol-specific `Content-Type` header for the operation. * 3. Additional protocol-specific headers for errors, if [errorShape] is non-null. */ - private fun RustWriter.serverRenderResponseHeaders(operationShape: OperationShape, errorShape: StructureShape? = null) { + private fun RustWriter.serverRenderResponseHeaders( + operationShape: OperationShape, + errorShape: StructureShape? = null, + ) { val bindingGenerator = ServerResponseBindingGenerator(protocol, codegenContext, operationShape) val addHeadersFn = bindingGenerator.generateAddHeadersFn(errorShape ?: operationShape) if (addHeadersFn != null) { @@ -686,7 +750,8 @@ class ServerHttpBoundProtocolTraitImplGenerator( inputShape: StructureShape, bindings: List, ) { - val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) + val httpBindingGenerator = + ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) val structuredDataParser = protocol.structuredDataParser() Attribute.AllowUnusedMut.render(this) rust( @@ -696,7 +761,8 @@ class ServerHttpBoundProtocolTraitImplGenerator( Attribute.AllowUnusedVariables.render(this) rust("let (parts, body) = request.into_parts();") val parser = structuredDataParser.serverInputParser(operationShape) - val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null + val noInputs = + model.expectShape(operationShape.inputShape).expectTrait().originalId == null if (parser != null) { // `null` is only returned by Smithy when there are no members, but we know there's at least one, since @@ -707,7 +773,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( rustTemplate( """ #{SmithyHttpServer}::protocol::content_type_header_classifier( - &parts.headers, + &parts.headers, Some("$expectedRequestContentType"), )?; input = #{parser}(bytes.as_ref(), input)?; @@ -719,7 +785,8 @@ class ServerHttpBoundProtocolTraitImplGenerator( } for (binding in bindings) { val member = binding.member - val parsedValue = serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) + val parsedValue = + serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) if (parsedValue != null) { rust("if let Some(value) = ") parsedValue(this) @@ -809,10 +876,12 @@ class ServerHttpBoundProtocolTraitImplGenerator( } } } + HttpLocation.DOCUMENT, HttpLocation.LABEL, HttpLocation.QUERY, HttpLocation.QUERY_PARAMS -> { // All of these are handled separately. null } + else -> { logger.warning("[rust-server-codegen] ${operationShape.id}: request parsing does not currently support ${binding.location} bindings") null @@ -838,15 +907,24 @@ class ServerHttpBoundProtocolTraitImplGenerator( } val restAfterGreedyLabel = if (greedyLabelIndex >= 0) { - httpTrait.uri.segments.slice((greedyLabelIndex + 1) until httpTrait.uri.segments.size).joinToString(prefix = "/", separator = "/") + httpTrait.uri.segments.slice((greedyLabelIndex + 1) until httpTrait.uri.segments.size) + .joinToString(prefix = "/", separator = "/") } else { "" } val labeledNames = segments .mapIndexed { index, segment -> - if (segment.isLabel) { "m$index" } else { "_" } + if (segment.isLabel) { + "m$index" + } else { + "_" + } } - .joinToString(prefix = (if (segments.size > 1) "(" else ""), separator = ",", postfix = (if (segments.size > 1) ")" else "")) + .joinToString( + prefix = (if (segments.size > 1) "(" else ""), + separator = ",", + postfix = (if (segments.size > 1) ")" else ""), + ) val nomParser = segments .map { segment -> if (segment.isGreedyLabel) { @@ -1011,6 +1089,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( rust("let v = v.into_owned();") } } + memberShape.isTimestampShape -> { val index = HttpBindingIndex.of(model) val timestampFormat = @@ -1019,7 +1098,11 @@ class ServerHttpBoundProtocolTraitImplGenerator( it.location, protocol.defaultTimestampFormat, ) - val timestampFormatType = RuntimeType.parseTimestampFormat(CodegenTarget.SERVER, runtimeConfig, timestampFormat) + val timestampFormatType = RuntimeType.parseTimestampFormat( + CodegenTarget.SERVER, + runtimeConfig, + timestampFormat, + ) rustTemplate( """ let v = #{DateTime}::from_str(&v, #{format})? @@ -1028,10 +1111,15 @@ class ServerHttpBoundProtocolTraitImplGenerator( "format" to timestampFormatType, ) for (customization in customizations) { - customization.section(ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember(it.member))(this) + customization.section( + ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember( + it.member, + ), + )(this) } rust(";") } + else -> { // Number or boolean. rust( """ @@ -1054,9 +1142,11 @@ class ServerHttpBoundProtocolTraitImplGenerator( QueryParamsTargetMapValueType.STRING -> { rust("query_params.${if (hasConstrainedTarget) "0." else ""}entry(String::from(k)).or_insert_with(|| String::from(v));") } + QueryParamsTargetMapValueType.LIST, QueryParamsTargetMapValueType.SET -> { if (hasConstrainedTarget) { - val collectionShape = model.expectShape(target.value.target, CollectionShape::class.java) + val collectionShape = + model.expectShape(target.value.target, CollectionShape::class.java) val collectionSymbol = unconstrainedShapeSymbolProvider.toSymbol(collectionShape) rust( // `or_insert_with` instead of `or_insert` to avoid the allocation when the entry is @@ -1091,7 +1181,8 @@ class ServerHttpBoundProtocolTraitImplGenerator( // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Constraint traits on member shapes are not // implemented yet. val hasConstrainedTarget = - model.expectShape(binding.member.target, CollectionShape::class.java).canReachConstrainedShape(model, symbolProvider) + model.expectShape(binding.member.target, CollectionShape::class.java) + .canReachConstrainedShape(model, symbolProvider) val memberName = unconstrainedShapeSymbolProvider.toMemberName(binding.member) val isOptional = unconstrainedShapeSymbolProvider.toSymbol(binding.member).isOptional() rustBlock("if !$memberName.is_empty()") { @@ -1119,8 +1210,13 @@ class ServerHttpBoundProtocolTraitImplGenerator( } } - private fun serverRenderHeaderParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) { - val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) + private fun serverRenderHeaderParser( + writer: RustWriter, + binding: HttpBindingDescriptor, + operationShape: OperationShape, + ) { + val httpBindingGenerator = + ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding) writer.rustTemplate( """ @@ -1131,7 +1227,11 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } - private fun serverRenderPrefixHeadersParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) { + private fun serverRenderPrefixHeadersParser( + writer: RustWriter, + binding: HttpBindingDescriptor, + operationShape: OperationShape, + ) { check(binding.location == HttpLocation.PREFIX_HEADERS) val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) @@ -1168,6 +1268,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( rust("let value = value.to_owned();") } } + target.isTimestampShape -> { val index = HttpBindingIndex.of(model) val timestampFormat = @@ -1176,7 +1277,8 @@ class ServerHttpBoundProtocolTraitImplGenerator( binding.location, protocol.defaultTimestampFormat, ) - val timestampFormatType = RuntimeType.parseTimestampFormat(CodegenTarget.SERVER, runtimeConfig, timestampFormat) + val timestampFormatType = + RuntimeType.parseTimestampFormat(CodegenTarget.SERVER, runtimeConfig, timestampFormat) if (percentDecoding) { rustTemplate( @@ -1197,10 +1299,15 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } for (customization in customizations) { - customization.section(ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember(binding.member))(this) + customization.section( + ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember( + binding.member, + ), + )(this) } rust(";") } + else -> { check(target is NumberShape || target is BooleanShape) rustTemplate( @@ -1230,9 +1337,11 @@ class ServerHttpBoundProtocolTraitImplGenerator( RestJson1Trait.ID, AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> { RuntimeType.smithyJson(runtimeConfig).resolve("deserialize::error::DeserializeError").toSymbol() } + RestXmlTrait.ID -> { RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError").toSymbol() } + else -> { TODO("Protocol ${codegenContext.protocol} not supported yet") } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt index e52d9e3a3b..e13220687d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt @@ -9,21 +9,60 @@ import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator +class StreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomization() { + override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { + is ServerHttpBoundProtocolSection.TypeOfSerializedStreamPayload -> writable { + rust( + "#T", + RuntimeType.futuresStreamCompatByteStream(section.params.runtimeConfig).toSymbol(), + ) + } + + is ServerHttpBoundProtocolSection.WrapStreamPayload -> writable { + rustTemplate( + "#{FuturesStreamCompatByteStream}::new(${section.params.payloadName!!})", + "FuturesStreamCompatByteStream" to RuntimeType.futuresStreamCompatByteStream(section.params.runtimeConfig), + ) + } + + else -> emptySection + } +} + class ServerProtocolLoader(supportedProtocols: ProtocolMap) : ProtocolLoader(supportedProtocols) { companion object { val DefaultProtocols = mapOf( - RestJson1Trait.ID to ServerRestJsonFactory(), - RestXmlTrait.ID to ServerRestXmlFactory(), - AwsJson1_0Trait.ID to ServerAwsJsonFactory(AwsJsonVersion.Json10), - AwsJson1_1Trait.ID to ServerAwsJsonFactory(AwsJsonVersion.Json11), + RestJson1Trait.ID to ServerRestJsonFactory( + additionalServerHttpBoundProtocolCustomizations = listOf( + StreamPayloadSerializerCustomization(), + ), + ), + RestXmlTrait.ID to ServerRestXmlFactory( + additionalServerHttpBoundProtocolCustomizations = listOf( + StreamPayloadSerializerCustomization(), + ), + ), + AwsJson1_0Trait.ID to ServerAwsJsonFactory( + AwsJsonVersion.Json10, + additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), + ), + AwsJson1_1Trait.ID to ServerAwsJsonFactory( + AwsJsonVersion.Json11, + additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), + ), ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXmlFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXmlFactory.kt index f5b3be454f..9207c56046 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXmlFactory.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXmlFactory.kt @@ -15,11 +15,17 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser * RestXml server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator] * with RestXml specific configurations. */ -class ServerRestXmlFactory : ProtocolGeneratorFactory { +class ServerRestXmlFactory( + private val additionalServerHttpBoundProtocolCustomizations: List = listOf(), +) : ProtocolGeneratorFactory { override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerRestXmlProtocol(codegenContext) override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = - ServerHttpBoundProtocolGenerator(codegenContext, ServerRestXmlProtocol(codegenContext)) + ServerHttpBoundProtocolGenerator( + codegenContext, + ServerRestXmlProtocol(codegenContext), + additionalServerHttpBoundProtocolCustomizations, + ) override fun support(): ProtocolSupport { return ProtocolSupport( diff --git a/examples/pokemon-service-common/Cargo.toml b/examples/pokemon-service-common/Cargo.toml index f2c86eee0e..c0a1aa676c 100644 --- a/examples/pokemon-service-common/Cargo.toml +++ b/examples/pokemon-service-common/Cargo.toml @@ -7,7 +7,6 @@ authors = ["Smithy-rs Server Team "] description = "A smithy Rust service to retrieve information about Pokémon." [dependencies] -async-stream = "0.3" http = "0.2.9" rand = "0.8" tracing = "0.1" @@ -16,6 +15,7 @@ tokio = { version = "1", default-features = false, features = ["time"] } tower = "0.4" # Local paths +aws-smithy-async = { path = "../../rust-runtime/aws-smithy-async" } aws-smithy-client = { path = "../../rust-runtime/aws-smithy-client" } aws-smithy-http = { path = "../../rust-runtime/aws-smithy-http" } aws-smithy-http-server = { path = "../../rust-runtime/aws-smithy-http-server" } diff --git a/examples/pokemon-service-common/src/lib.rs b/examples/pokemon-service-common/src/lib.rs index cb5c4bd604..08d96189c8 100644 --- a/examples/pokemon-service-common/src/lib.rs +++ b/examples/pokemon-service-common/src/lib.rs @@ -14,7 +14,7 @@ use std::{ sync::{atomic::AtomicUsize, Arc}, }; -use async_stream::stream; +use aws_smithy_async::future::fn_stream::FnStream; use aws_smithy_client::{conns, hyper_ext::Adapter}; use aws_smithy_http::{body::SdkBody, byte_stream::ByteStream}; use aws_smithy_http_server::Extension; @@ -242,59 +242,63 @@ pub async fn capture_pokemon( }, )); } - let output_stream = stream! { - loop { - use std::time::Duration; - match input.events.recv().await { - Ok(maybe_event) => match maybe_event { - Some(event) => { - let capturing_event = event.as_event(); - if let Ok(attempt) = capturing_event { - let payload = attempt.payload.clone().unwrap_or_else(|| CapturingPayload::builder().build()); - let pokeball = payload.pokeball().unwrap_or(""); - if ! matches!(pokeball, "Master Ball" | "Great Ball" | "Fast Ball") { - yield Err( + let output_stream = FnStream::new(|tx| { + Box::pin(async move { + loop { + use std::time::Duration; + match input.events.recv().await { + Ok(maybe_event) => match maybe_event { + Some(event) => { + let capturing_event = event.as_event(); + if let Ok(attempt) = capturing_event { + let payload = attempt + .payload + .clone() + .unwrap_or_else(|| CapturingPayload::builder().build()); + let pokeball = payload.pokeball().unwrap_or(""); + if !matches!(pokeball, "Master Ball" | "Great Ball" | "Fast Ball") { + tx.send(Err( crate::error::CapturePokemonEventsError::InvalidPokeballError( crate::error::InvalidPokeballError { pokeball: pokeball.to_owned() } ) - ); - } else { - let captured = match pokeball { - "Master Ball" => true, - "Great Ball" => rand::thread_rng().gen_range(0..100) > 33, - "Fast Ball" => rand::thread_rng().gen_range(0..100) > 66, - _ => unreachable!("invalid pokeball"), - }; - // Only support Kanto - tokio::time::sleep(Duration::from_millis(1000)).await; - // Will it capture the Pokémon? - if captured { - let shiny = rand::thread_rng().gen_range(0..4096) == 0; - let pokemon = payload - .name() - .unwrap_or("") - .to_string(); - let pokedex: Vec = (0..255).collect(); - yield Ok(crate::model::CapturePokemonEvents::Event( - crate::model::CaptureEvent { - name: Some(pokemon), - shiny: Some(shiny), - pokedex_update: Some(Blob::new(pokedex)), - captured: Some(true), - } - )); + )).await.expect("send should succeed"); + } else { + let captured = match pokeball { + "Master Ball" => true, + "Great Ball" => rand::thread_rng().gen_range(0..100) > 33, + "Fast Ball" => rand::thread_rng().gen_range(0..100) > 66, + _ => unreachable!("invalid pokeball"), + }; + // Only support Kanto + tokio::time::sleep(Duration::from_millis(1000)).await; + // Will it capture the Pokémon? + if captured { + let shiny = rand::thread_rng().gen_range(0..4096) == 0; + let pokemon = payload.name().unwrap_or("").to_string(); + let pokedex: Vec = (0..255).collect(); + tx.send(Ok(crate::model::CapturePokemonEvents::Event( + crate::model::CaptureEvent { + name: Some(pokemon), + shiny: Some(shiny), + pokedex_update: Some(Blob::new(pokedex)), + captured: Some(true), + }, + ))) + .await + .expect("send should succeed"); + } } } } - } - None => break, - }, - Err(e) => println!("{:?}", e), + None => break, + }, + Err(e) => println!("{:?}", e), + } } - } - }; + }) + }); Ok(output::CapturePokemonOutput::builder() .events(output_stream.into()) .build() diff --git a/examples/pokemon-service/Cargo.toml b/examples/pokemon-service/Cargo.toml index d3bc81ea0b..23c3704392 100644 --- a/examples/pokemon-service/Cargo.toml +++ b/examples/pokemon-service/Cargo.toml @@ -20,7 +20,6 @@ pokemon-service-common = { path = "../pokemon-service-common/" } [dev-dependencies] assert_cmd = "2.0" -async-stream = "0.3" rand = "0.8.5" serial_test = "1.0.0" @@ -31,6 +30,7 @@ hyper = { version = "0.14.26", features = ["server", "client"] } hyper-rustls = { version = "0.24", features = ["http2"] } # Local paths +aws-smithy-async = { path = "../../rust-runtime/aws-smithy-async/" } aws-smithy-client = { path = "../../rust-runtime/aws-smithy-client/", features = ["rustls"] } aws-smithy-http = { path = "../../rust-runtime/aws-smithy-http/" } aws-smithy-types = { path = "../../rust-runtime/aws-smithy-types/" } diff --git a/examples/pokemon-service/tests/event_streaming.rs b/examples/pokemon-service/tests/event_streaming.rs index 664827620b..9b86596c4e 100644 --- a/examples/pokemon-service/tests/event_streaming.rs +++ b/examples/pokemon-service/tests/event_streaming.rs @@ -5,7 +5,7 @@ pub mod common; -use async_stream::stream; +use aws_smithy_async::future::fn_stream::FnStream; use rand::Rng; use serial_test::serial; @@ -40,35 +40,56 @@ async fn event_stream_test() { let client = common::client(); let mut team = vec![]; - let input_stream = stream! { - // Always Pikachu - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name("Pikachu") - .pokeball("Master Ball") - .build()) - .build() - )); - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name("Regieleki") - .pokeball("Fast Ball") - .build()) - .build() - )); - yield Err(AttemptCapturingPokemonEventError::MasterBallUnsuccessful(MasterBallUnsuccessful::builder().build())); - // The next event should not happen - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name("Charizard") - .pokeball("Great Ball") - .build()) - .build() - )); - }; + let input_stream = FnStream::new(|tx| { + Box::pin(async move { + // Always Pikachu + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name("Pikachu") + .pokeball("Master Ball") + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name("Regieleki") + .pokeball("Fast Ball") + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + + tx.send(Err( + AttemptCapturingPokemonEventError::MasterBallUnsuccessful( + MasterBallUnsuccessful::builder().build(), + ), + )) + .await + .expect("send should succeed"); + // The next event should not happen + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name("Charizard") + .pokeball("Great Ball") + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + }) + }); // Throw many! let mut output = common::client() @@ -112,16 +133,22 @@ async fn event_stream_test() { while team.len() < 6 { let pokeball = get_pokeball(); let pokemon = get_pokemon_to_capture(); - let input_stream = stream! { - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name(pokemon) - .pokeball(pokeball) - .build()) - .build() - )) - }; + let input_stream = FnStream::new(|tx| { + Box::pin(async move { + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name(pokemon) + .pokeball(pokeball) + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + }) + }); let mut output = client .capture_pokemon() .region("Kanto") diff --git a/examples/python/pokemon-service-test/Cargo.toml b/examples/python/pokemon-service-test/Cargo.toml index b4084185c2..2821d8c0b8 100644 --- a/examples/python/pokemon-service-test/Cargo.toml +++ b/examples/python/pokemon-service-test/Cargo.toml @@ -8,7 +8,6 @@ description = "Run tests against the Python server implementation" [dev-dependencies] rand = "0.8" -async-stream = "0.3" command-group = "2.1.0" tokio = { version = "1.20.1", features = ["full"] } serial_test = "2.0.0" @@ -17,6 +16,7 @@ tokio-rustls = "0.24.0" hyper-rustls = { version = "0.24", features = ["http2"] } # Local paths +aws-smithy-async = { path = "../../../rust-runtime/aws-smithy-async/" } aws-smithy-client = { path = "../../../rust-runtime/aws-smithy-client/", features = ["rustls"] } aws-smithy-http = { path = "../../../rust-runtime/aws-smithy-http/" } aws-smithy-types = { path = "../../../rust-runtime/aws-smithy-types/" } diff --git a/examples/python/pokemon-service-test/tests/simple_integration_test.rs b/examples/python/pokemon-service-test/tests/simple_integration_test.rs index 39f979e9a3..a5ed604679 100644 --- a/examples/python/pokemon-service-test/tests/simple_integration_test.rs +++ b/examples/python/pokemon-service-test/tests/simple_integration_test.rs @@ -7,7 +7,7 @@ // These tests only have access to your crate's public API. // See: https://doc.rust-lang.org/book/ch11-03-test-organization.html#integration-tests -use async_stream::stream; +use aws_smithy_async::future::fn_stream::FnStream; use aws_smithy_types::error::display::DisplayErrorContext; use rand::Rng; use serial_test::serial; @@ -75,35 +75,55 @@ async fn event_stream_test() { let _program = PokemonService::run().await; let mut team = vec![]; - let input_stream = stream! { - // Always Pikachu - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name("Pikachu") - .pokeball("Master Ball") - .build()) - .build() - )); - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name("Regieleki") - .pokeball("Fast Ball") - .build()) - .build() - )); - yield Err(AttemptCapturingPokemonEventError::MasterBallUnsuccessful(MasterBallUnsuccessful::builder().build())); - // The next event should not happen - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name("Charizard") - .pokeball("Great Ball") - .build()) - .build() - )); - }; + let input_stream = FnStream::new(|tx| { + Box::pin(async move { + // Always Pikachu + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name("Pikachu") + .pokeball("Master Ball") + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name("Regieleki") + .pokeball("Fast Ball") + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + tx.send(Err( + AttemptCapturingPokemonEventError::MasterBallUnsuccessful( + MasterBallUnsuccessful::builder().build(), + ), + )) + .await + .expect("send should succeed"); + // The next event should not happen + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name("Charizard") + .pokeball("Great Ball") + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + }) + }); // Throw many! let mut output = client() @@ -147,16 +167,22 @@ async fn event_stream_test() { while team.len() < 6 { let pokeball = get_pokeball(); let pokemon = get_pokemon_to_capture(); - let input_stream = stream! { - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name(pokemon) - .pokeball(pokeball) - .build()) - .build() - )) - }; + let input_stream = FnStream::new(|tx| { + Box::pin(async move { + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name(pokemon) + .pokeball(pokeball) + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + }) + }); let mut output = client() .capture_pokemon() .region("Kanto") diff --git a/rust-runtime/aws-smithy-async/Cargo.toml b/rust-runtime/aws-smithy-async/Cargo.toml index c95862d9ff..fd51b4fb1e 100644 --- a/rust-runtime/aws-smithy-async/Cargo.toml +++ b/rust-runtime/aws-smithy-async/Cargo.toml @@ -14,13 +14,18 @@ test-util = [] [dependencies] pin-project-lite = "0.2" tokio = { version = "1.23.1", features = ["sync"] } -tokio-stream = { version = "0.1.5", default-features = false } futures-util = { version = "0.3.16", default-features = false } [dev-dependencies] +pin-utils = "0.1" tokio = { version = "1.23.1", features = ["rt", "macros", "test-util"] } tokio-test = "0.4.2" +# futures-util is used by `now_or_later`, for instance, but the tooling +# reports a false positive, saying it is unused. +[package.metadata.cargo-udeps.ignore] +normal = ["futures-util"] + [package.metadata.docs.rs] all-features = true targets = ["x86_64-unknown-linux-gnu"] diff --git a/rust-runtime/aws-smithy-async/external-types.toml b/rust-runtime/aws-smithy-async/external-types.toml index 424f7dc1db..464456a2dc 100644 --- a/rust-runtime/aws-smithy-async/external-types.toml +++ b/rust-runtime/aws-smithy-async/external-types.toml @@ -2,7 +2,4 @@ allowed_external_types = [ "aws_smithy_types::config_bag::storable::Storable", "aws_smithy_types::config_bag::storable::StoreReplace", "aws_smithy_types::config_bag::storable::Storer", - - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Switch to AsyncIterator once standardized - "futures_core::stream::Stream", ] diff --git a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs index 804b08f6bb..f236c33775 100644 --- a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs +++ b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs @@ -6,12 +6,14 @@ //! Utility to drive a stream with an async function and a channel. use crate::future::rendezvous; -use futures_util::StreamExt; use pin_project_lite::pin_project; +use std::fmt; +use std::future::poll_fn; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio_stream::{Iter, Once, Stream}; + +pub mod collect; pin_project! { /// Utility to drive a stream with an async function and a channel. @@ -24,12 +26,14 @@ pin_project! { /// /// If `tx.send` returns an error, the function MUST return immediately. /// + /// Note `FnStream` is only `Send` but not `Sync` because `generator` is a boxed future that + /// is `Send` and returns `()` as output when it is done. + /// /// # Examples /// ```no_run - /// use tokio_stream::StreamExt; /// # async fn docs() { /// use aws_smithy_async::future::fn_stream::FnStream; - /// let stream = FnStream::new(|tx| Box::pin(async move { + /// let mut stream = FnStream::new(|tx| Box::pin(async move { /// if let Err(_) = tx.send("Hello!").await { /// return; /// } @@ -39,52 +43,88 @@ pin_project! { /// })); /// assert_eq!(stream.collect::>().await, vec!["Hello!", "Goodbye!"]); /// # } - pub struct FnStream { + pub struct FnStream { #[pin] rx: rendezvous::Receiver, - #[pin] - generator: Option, + generator: Option + Send + 'static>>>, } } -impl FnStream { +impl fmt::Debug for FnStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let item_typename = std::any::type_name::(); + write!(f, "FnStream<{item_typename}>") + } +} + +impl FnStream { /// Creates a new function based stream driven by `generator`. /// /// For examples, see the documentation for [`FnStream`] pub fn new(generator: T) -> Self where - T: FnOnce(rendezvous::Sender) -> F, + T: FnOnce(rendezvous::Sender) -> Pin + Send + 'static>>, { let (tx, rx) = rendezvous::channel::(); Self { rx, - generator: Some(generator(tx)), + generator: Some(Box::pin(generator(tx))), } } -} -impl Stream for FnStream -where - F: Future, -{ - type Item = Item; + /// Creates unreadable `FnStream` but useful to pass to `std::mem::swap` when extracting an + /// owned `FnStream` from a mutable reference. + pub fn taken() -> Self { + Self::new(|_tx| Box::pin(async move {})) + } + + /// Consumes and returns the next `Item` from this stream. + pub async fn next(&mut self) -> Option + where + Self: Unpin, + { + let mut me = Pin::new(self); + poll_fn(|cx| me.as_mut().poll_next(cx)).await + } - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + /// Attempts to pull out the next value of this stream, returning `None` if the stream is + /// exhausted. + pub fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut me = self.project(); match me.rx.poll_recv(cx) { Poll::Ready(item) => Poll::Ready(item), Poll::Pending => { - if let Some(generator) = me.generator.as_mut().as_pin_mut() { - if generator.poll(cx).is_ready() { - // if the generator returned ready we MUST NOT poll it again—doing so - // will cause a panic. - me.generator.set(None); + if let Some(generator) = me.generator { + if generator.as_mut().poll(cx).is_ready() { + // `generator` keeps writing items to `tx` and will not be `Poll::Ready` + // until it is done writing to `tx`. Once it is done, it returns `()` + // as output and is `Poll::Ready`, at which point we MUST NOT poll it again + // since doing so will cause a panic. + *me.generator = None; } } Poll::Pending } } } + + /// Consumes this stream and gathers elements into a collection. + pub async fn collect>(mut self) -> T { + let mut collection = T::initialize(); + while let Some(item) = self.next().await { + if !T::extend(&mut collection, item) { + break; + } + } + T::finalize(collection) + } +} + +impl FnStream> { + /// Yields the next item in the stream or returns an error if an error is encountered. + pub async fn try_next(&mut self) -> Result, E> { + self.next().await.transpose() + } } /// Utility wrapper to flatten paginated results @@ -93,62 +133,50 @@ where /// is present in each item. This provides `items()` which can wrap an stream of `Result` /// and produce a stream of `Result`. #[derive(Debug)] -pub struct TryFlatMap(I); +pub struct TryFlatMap(FnStream>); -impl TryFlatMap { - /// Create a `TryFlatMap` that wraps the input - pub fn new(i: I) -> Self { - Self(i) +impl TryFlatMap { + /// Creates a `TryFlatMap` that wraps the input. + pub fn new(stream: FnStream>) -> Self { + Self(stream) } - /// Produce a new [`Stream`] by mapping this stream with `map` then flattening the result - pub fn flat_map(self, map: M) -> impl Stream> + /// Produces a new [`FnStream`] by mapping this stream with `map` then flattening the result. + pub fn flat_map(mut self, map: M) -> FnStream> where - I: Stream>, - M: Fn(Page) -> Iter, - Iter: IntoIterator, + Page: Send + 'static, + Err: Send + 'static, + M: Fn(Page) -> Iter + Send + 'static, + Item: Send + 'static, + Iter: IntoIterator + Send, + ::IntoIter: Send, { - self.0.flat_map(move |page| match page { - Ok(page) => OnceOrMany::Many { - many: tokio_stream::iter(map(page).into_iter().map(Ok)), - }, - Err(e) => OnceOrMany::Once { - once: tokio_stream::once(Err(e)), - }, + FnStream::new(|tx| { + Box::pin(async move { + while let Some(page) = self.0.next().await { + match page { + Ok(page) => { + let mapped = map(page); + for item in mapped.into_iter() { + let _ = tx.send(Ok(item)).await; + } + } + Err(e) => { + let _ = tx.send(Err(e)).await; + break; + } + } + } + }) as Pin + Send>> }) } } -pin_project! { - /// Helper enum to to support returning `Once` and `Iter` from `Items::items` - #[project = OnceOrManyProj] - enum OnceOrMany { - Many { #[pin] many: Iter }, - Once { #[pin] once: Once }, - } -} - -impl Stream for OnceOrMany -where - Iter: Iterator, -{ - type Item = Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let me = self.project(); - match me { - OnceOrManyProj::Many { many } => many.poll_next(cx), - OnceOrManyProj::Once { once } => once.poll_next(cx), - } - } -} - #[cfg(test)] mod test { use crate::future::fn_stream::{FnStream, TryFlatMap}; use std::sync::{Arc, Mutex}; use std::time::Duration; - use tokio_stream::StreamExt; /// basic test of FnStream functionality #[tokio::test] @@ -168,7 +196,24 @@ mod test { while let Some(value) = stream.next().await { out.push(value); } - assert_eq!(out, vec!["1", "2", "3"]); + assert_eq!(vec!["1", "2", "3"], out); + } + + #[tokio::test] + async fn fn_stream_try_next() { + tokio::time::pause(); + let mut stream = FnStream::new(|tx| { + Box::pin(async move { + tx.send(Ok(1)).await.unwrap(); + tx.send(Ok(2)).await.unwrap(); + tx.send(Err("err")).await.unwrap(); + }) + }); + let mut out = vec![]; + while let Ok(value) = stream.try_next().await { + out.push(value); + } + assert_eq!(vec![Some(1), Some(2)], out); } // smithy-rs#1902: there was a bug where we could continue to poll the generator after it @@ -183,10 +228,16 @@ mod test { Box::leak(Box::new(tx)); }) }); - assert_eq!(stream.next().await, Some("blah")); + assert_eq!(Some("blah"), stream.next().await); let mut test_stream = tokio_test::task::spawn(stream); - assert!(test_stream.poll_next().is_pending()); - assert!(test_stream.poll_next().is_pending()); + let _ = test_stream.enter(|ctx, pin| { + let polled = pin.poll_next(ctx); + assert!(polled.is_pending()); + }); + let _ = test_stream.enter(|ctx, pin| { + let polled = pin.poll_next(ctx); + assert!(polled.is_pending()); + }); } /// Tests that the generator will not advance until demand exists @@ -209,13 +260,13 @@ mod test { stream.next().await.expect("ready"); assert_eq!(*progress.lock().unwrap(), 1); - assert_eq!(stream.next().await.expect("ready"), "2"); - assert_eq!(*progress.lock().unwrap(), 2); + assert_eq!("2", stream.next().await.expect("ready")); + assert_eq!(2, *progress.lock().unwrap()); let _ = stream.next().await.expect("ready"); - assert_eq!(*progress.lock().unwrap(), 3); - assert_eq!(stream.next().await, None); - assert_eq!(*progress.lock().unwrap(), 4); + assert_eq!(3, *progress.lock().unwrap()); + assert_eq!(None, stream.next().await); + assert_eq!(4, *progress.lock().unwrap()); } #[tokio::test] @@ -238,7 +289,7 @@ mod test { while let Some(Ok(value)) = stream.next().await { out.push(value); } - assert_eq!(out, vec![0, 1]); + assert_eq!(vec![0, 1], out); } #[tokio::test] @@ -262,12 +313,12 @@ mod test { }) }); assert_eq!( - TryFlatMap(stream) + Ok(vec![1, 2, 3, 4, 5, 6]), + TryFlatMap::new(stream) .flat_map(|output| output.items.into_iter()) .collect::, &str>>() .await, - Ok(vec![1, 2, 3, 4, 5, 6]) - ) + ); } #[tokio::test] @@ -287,11 +338,11 @@ mod test { }) }); assert_eq!( - TryFlatMap(stream) + Err("bummer"), + TryFlatMap::new(stream) .flat_map(|output| output.items.into_iter()) .collect::, &str>>() - .await, - Err("bummer") + .await ) } } diff --git a/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs b/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs new file mode 100644 index 0000000000..a07909b999 --- /dev/null +++ b/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs @@ -0,0 +1,75 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Module to extend the functionality of `FnStream` to allow for collecting elements of the stream +//! into collection. +//! +//! Majority of the code is borrowed from +//! + +pub(crate) mod sealed { + /// A trait that signifies that elements can be collected into `T`. + /// + /// Currently the trait may not be implemented by clients so we can make changes in the future + /// without breaking code depending on it. + #[doc(hidden)] + pub trait Collectable { + type Collection; + + fn initialize() -> Self::Collection; + + fn extend(collection: &mut Self::Collection, item: T) -> bool; + + fn finalize(collection: Self::Collection) -> Self; + } +} + +impl sealed::Collectable for Vec { + type Collection = Self; + + fn initialize() -> Self::Collection { + Vec::default() + } + + fn extend(collection: &mut Self::Collection, item: T) -> bool { + collection.push(item); + true + } + + fn finalize(collection: Self::Collection) -> Self { + collection + } +} + +impl sealed::Collectable> for Result +where + U: sealed::Collectable, +{ + type Collection = Result; + + fn initialize() -> Self::Collection { + Ok(U::initialize()) + } + + fn extend(collection: &mut Self::Collection, item: Result) -> bool { + match item { + Ok(item) => { + let collection = collection.as_mut().ok().expect("invalid state"); + U::extend(collection, item) + } + Err(e) => { + *collection = Err(e); + false + } + } + } + + fn finalize(collection: Self::Collection) -> Self { + match collection { + Ok(collection) => Ok(U::finalize(collection)), + err @ Err(_) => Err(err.map(drop).unwrap_err()), + } + } +} diff --git a/rust-runtime/aws-smithy-async/src/future/rendezvous.rs b/rust-runtime/aws-smithy-async/src/future/rendezvous.rs index 16456f123e..f2342543f9 100644 --- a/rust-runtime/aws-smithy-async/src/future/rendezvous.rs +++ b/rust-runtime/aws-smithy-async/src/future/rendezvous.rs @@ -12,6 +12,7 @@ //! Rendezvous channels should be used with care—it's inherently easy to deadlock unless they're being //! used from separate tasks or an a coroutine setup (e.g. [`crate::future::fn_stream::FnStream`]) +use std::future::poll_fn; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::sync::Semaphore; @@ -104,7 +105,11 @@ pub struct Receiver { impl Receiver { /// Polls to receive an item from the channel - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + pub async fn recv(&mut self) -> Option { + poll_fn(|cx| self.poll_recv(cx)).await + } + + pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { // This uses `needs_permit` to track whether this is the first poll since we last returned an item. // If it is, we will grant a permit to the semaphore. Otherwise, we'll just forward the response through. let resp = self.chan.poll_recv(cx); @@ -124,13 +129,8 @@ impl Receiver { #[cfg(test)] mod test { - use crate::future::rendezvous::{channel, Receiver}; + use crate::future::rendezvous::channel; use std::sync::{Arc, Mutex}; - use tokio::macros::support::poll_fn; - - async fn recv(rx: &mut Receiver) -> Option { - poll_fn(|cx| rx.poll_recv(cx)).await - } #[tokio::test] async fn send_blocks_caller() { @@ -145,11 +145,11 @@ mod test { *idone.lock().unwrap() = 3; }); assert_eq!(*done.lock().unwrap(), 0); - assert_eq!(recv(&mut rx).await, Some(0)); + assert_eq!(rx.recv().await, Some(0)); assert_eq!(*done.lock().unwrap(), 1); - assert_eq!(recv(&mut rx).await, Some(1)); + assert_eq!(rx.recv().await, Some(1)); assert_eq!(*done.lock().unwrap(), 2); - assert_eq!(recv(&mut rx).await, None); + assert_eq!(rx.recv().await, None); assert_eq!(*done.lock().unwrap(), 3); let _ = send.await; } diff --git a/rust-runtime/aws-smithy-http-server-python/src/types.rs b/rust-runtime/aws-smithy-http-server-python/src/types.rs index a2fa308512..a274efe086 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/types.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/types.rs @@ -31,7 +31,6 @@ use pyo3::{ prelude::*, }; use tokio::{runtime::Handle, sync::Mutex}; -use tokio_stream::StreamExt; use crate::PyError; diff --git a/rust-runtime/aws-smithy-http/Cargo.toml b/rust-runtime/aws-smithy-http/Cargo.toml index 0238f5dfc6..72d03493eb 100644 --- a/rust-runtime/aws-smithy-http/Cargo.toml +++ b/rust-runtime/aws-smithy-http/Cargo.toml @@ -15,6 +15,7 @@ rt-tokio = ["dep:tokio-util", "dep:tokio", "tokio?/rt", "tokio?/fs", "tokio?/io- event-stream = ["aws-smithy-eventstream"] [dependencies] +aws-smithy-async = { path = "../aws-smithy-async" } aws-smithy-eventstream = { path = "../aws-smithy-eventstream", optional = true } aws-smithy-types = { path = "../aws-smithy-types" } bytes = "1" @@ -36,7 +37,6 @@ tokio = { version = "1.23.1", optional = true } tokio-util = { version = "0.7", optional = true } [dev-dependencies] -async-stream = "0.3" futures-util = { version = "0.3.16", default-features = false } hyper = { version = "0.14.26", features = ["stream"] } pretty_assertions = "1.3" diff --git a/rust-runtime/aws-smithy-http/external-types.toml b/rust-runtime/aws-smithy-http/external-types.toml index cf745485bb..951deca06e 100644 --- a/rust-runtime/aws-smithy-http/external-types.toml +++ b/rust-runtime/aws-smithy-http/external-types.toml @@ -20,12 +20,10 @@ allowed_external_types = [ # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Feature gate Tokio `AsyncRead` "tokio::io::async_read::AsyncRead", - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Switch to AsyncIterator once standardized - "futures_core::stream::Stream", - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Feature gate references to Tokio `File` "tokio::fs::file::File", # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Once tooling permits it, only allow the following types in the `event-stream` feature "aws_smithy_eventstream::*", + "aws_smithy_async::*", ] diff --git a/rust-runtime/aws-smithy-http/src/byte_stream.rs b/rust-runtime/aws-smithy-http/src/byte_stream.rs index e067018a9d..ed35149621 100644 --- a/rust-runtime/aws-smithy-http/src/byte_stream.rs +++ b/rust-runtime/aws-smithy-http/src/byte_stream.rs @@ -48,7 +48,8 @@ //! //! ### Stream a ByteStream into a file //! The previous example is recommended in cases where loading the entire file into memory first is desirable. For extremely large -//! files, you may wish to stream the data directly to the file system, chunk by chunk. This is posible using the `futures::Stream` implementation. +//! files, you may wish to stream the data directly to the file system, chunk by chunk. +//! This is possible using the [`.next()`](crate::byte_stream::ByteStream::next). //! //! ```no_run //! use bytes::{Buf, Bytes}; @@ -128,6 +129,7 @@ use bytes::Bytes; use bytes_utils::SegmentedBuf; use http_body::Body; use pin_project_lite::pin_project; +use std::future::poll_fn; use std::io::IoSlice; use std::pin::Pin; use std::task::{Context, Poll}; @@ -166,9 +168,7 @@ pin_project! { /// println!("first chunk: {:?}", data.chunk()); /// } /// ``` - /// 2. Via [`impl Stream`](futures_core::Stream): - /// - /// _Note: An import of `StreamExt` is required to use `.try_next()`._ + /// 2. Via [`.next()`](crate::byte_stream::ByteStream::next) or [`.try_next()`](crate::byte_stream::ByteStream::try_next): /// /// For use-cases where holding the entire ByteStream in memory is unnecessary, use the /// `Stream` implementation: @@ -183,7 +183,6 @@ pin_project! { /// # } /// use aws_smithy_http::byte_stream::{ByteStream, AggregatedBytes, error::Error}; /// use aws_smithy_http::body::SdkBody; - /// use tokio_stream::StreamExt; /// /// async fn example() -> Result<(), Error> { /// let mut stream = ByteStream::from(vec![1, 2, 3, 4, 5, 99]); @@ -276,7 +275,7 @@ impl ByteStream { } } - /// Consumes the ByteStream, returning the wrapped SdkBody + /// Consume the `ByteStream`, returning the wrapped SdkBody. // Backwards compatibility note: Because SdkBody has a dyn variant, // we will always be able to implement this method, even if we stop using // SdkBody as the internal representation @@ -284,6 +283,31 @@ impl ByteStream { self.inner.body } + /// Return the next item in the `ByteStream`. + pub async fn next(&mut self) -> Option> { + Some(self.inner.next().await?.map_err(Error::streaming)) + } + + /// Attempts to pull out the next value of this stream, returning `None` if the stream is + /// exhausted. + pub fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + self.project().inner.poll_next(cx).map_err(Error::streaming) + } + + /// Consume and return the next item in the `ByteStream` or return an error if an error is + /// encountered. + pub async fn try_next(&mut self) -> Result, Error> { + self.next().await.transpose() + } + + /// Return the bounds on the remaining length of the `ByteStream`. + pub fn size_hint(&self) -> (u64, Option) { + self.inner.size_hint() + } + /// Read all the data from this `ByteStream` into memory /// /// If an error in the underlying stream is encountered, `ByteStreamError` is returned. @@ -393,7 +417,9 @@ impl ByteStream { /// # } /// ``` pub fn into_async_read(self) -> impl tokio::io::AsyncRead { - tokio_util::io::StreamReader::new(self) + tokio_util::io::StreamReader::new( + crate::futures_stream_adapter::FuturesStreamCompatByteStream::new(self), + ) } /// Given a function to modify an [`SdkBody`], run it on the `SdkBody` inside this `Bytestream`. @@ -442,18 +468,6 @@ impl From for ByteStream { } } -impl futures_core::stream::Stream for ByteStream { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_next(cx).map_err(Error::streaming) - } - - fn size_hint(&self) -> (usize, Option) { - self.inner.size_hint() - } -} - /// Non-contiguous Binary Data Storage /// /// When data is read from the network, it is read in a sequence of chunks that are not in @@ -524,6 +538,25 @@ impl Inner { Self { body } } + async fn next(&mut self) -> Option> + where + Self: Unpin, + B: http_body::Body, + { + let mut me = Pin::new(self); + poll_fn(|cx| me.as_mut().poll_next(cx)).await + } + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> + where + B: http_body::Body, + { + self.project().body.poll_data(cx) + } + async fn collect(self) -> Result where B: http_body::Body, @@ -536,34 +569,13 @@ impl Inner { } Ok(AggregatedBytes(output)) } -} - -const SIZE_HINT_32_BIT_PANIC_MESSAGE: &str = r#" -You're running a 32-bit system and this stream's length is too large to be represented with a usize. -Please limit stream length to less than 4.294Gb or run this program on a 64-bit computer architecture. -"#; - -impl futures_core::stream::Stream for Inner -where - B: http_body::Body, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().body.poll_data(cx) - } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (u64, Option) + where + B: http_body::Body, + { let size_hint = http_body::Body::size_hint(&self.body); - let lower = size_hint.lower().try_into(); - let upper = size_hint.upper().map(|u| u.try_into()).transpose(); - - match (lower, upper) { - (Ok(lower), Ok(upper)) => (lower, upper), - (Err(_), _) | (_, Err(_)) => { - panic!("{}", SIZE_HINT_32_BIT_PANIC_MESSAGE) - } - } + (size_hint.lower(), size_hint.upper()) } } diff --git a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs index d19690e727..4274e210af 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs @@ -4,20 +4,25 @@ */ use crate::result::SdkError; +use aws_smithy_async::future::fn_stream::FnStream; use aws_smithy_eventstream::frame::{MarshallMessage, SignMessage}; use bytes::Bytes; -use futures_core::Stream; use std::error::Error as StdError; use std::fmt; use std::fmt::Debug; +use std::future::poll_fn; use std::marker::PhantomData; use std::pin::Pin; +use std::sync::Mutex; use std::task::{Context, Poll}; use tracing::trace; /// Input type for Event Streams. pub struct EventStreamSender { - input_stream: Pin> + Send + Sync>>, + // `FnStream` does not have a `Sync` bound but this struct needs to be `Sync` + // as demonstrated by a unit test `event_stream_sender_send_sync`. + // Wrapping `input_stream` with a `Mutex` will make `EventStreamSender` `Sync`. + input_stream: Mutex>>, } impl Debug for EventStreamSender { @@ -36,17 +41,19 @@ impl EventStreamSender { error_marshaller: impl MarshallMessage + Send + Sync + 'static, signer: impl SignMessage + Send + Sync + 'static, ) -> MessageStreamAdapter { - MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream) + MessageStreamAdapter::new( + marshaller, + error_marshaller, + signer, + std::mem::replace(&mut *self.input_stream.lock().unwrap(), FnStream::taken()), + ) } } -impl From for EventStreamSender -where - S: Stream> + Send + Sync + 'static, -{ - fn from(stream: S) -> Self { +impl From>> for EventStreamSender { + fn from(stream: FnStream>) -> Self { EventStreamSender { - input_stream: Box::pin(stream), + input_stream: Mutex::new(stream), } } } @@ -109,24 +116,24 @@ impl fmt::Display for MessageStreamError { /// This will yield an `Err(SdkError::ConstructionFailure)` if a message can't be /// marshalled into an Event Stream frame, (e.g., if the message payload was too large). #[allow(missing_debug_implementations)] -pub struct MessageStreamAdapter { +pub struct MessageStreamAdapter { marshaller: Box + Send + Sync>, error_marshaller: Box + Send + Sync>, signer: Box, - stream: Pin> + Send>>, + stream: FnStream>, end_signal_sent: bool, _phantom: PhantomData, } impl Unpin for MessageStreamAdapter {} -impl MessageStreamAdapter { +impl MessageStreamAdapter { /// Create a new `MessageStreamAdapter`. pub fn new( marshaller: impl MarshallMessage + Send + Sync + 'static, error_marshaller: impl MarshallMessage + Send + Sync + 'static, signer: impl SignMessage + Send + Sync + 'static, - stream: Pin> + Send>>, + stream: FnStream>, ) -> Self { MessageStreamAdapter { marshaller: Box::new(marshaller), @@ -139,11 +146,20 @@ impl MessageStreamAdapter { } } -impl Stream for MessageStreamAdapter { - type Item = Result>; +impl MessageStreamAdapter { + /// Consumes and returns the next item from this stream. + pub async fn next(&mut self) -> Option>> { + let mut me = Pin::new(self); + poll_fn(|cx| me.as_mut().poll_next(cx)).await + } - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.stream.as_mut().poll_next(cx) { + /// Attempts to pull out the next value of this stream, returning `None` if the stream is + /// exhausted. + pub fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>>> { + match Pin::new(&mut self.stream).as_mut().poll_next(cx) { Poll::Ready(message_option) => { if let Some(message_result) = message_option { let message = match message_result { @@ -196,14 +212,11 @@ mod tests { use super::MarshallMessage; use crate::event_stream::{EventStreamSender, MessageStreamAdapter}; use crate::result::SdkError; - use async_stream::stream; + use aws_smithy_async::future::fn_stream::FnStream; use aws_smithy_eventstream::error::Error as EventStreamError; use aws_smithy_eventstream::frame::{ Header, HeaderValue, Message, NoOpSigner, SignMessage, SignMessageError, }; - use bytes::Bytes; - use futures_core::Stream; - use futures_util::stream::StreamExt; use std::error::Error as StdError; #[derive(Debug)] @@ -268,34 +281,28 @@ mod tests { #[test] fn event_stream_sender_send_sync() { - check_send_sync(EventStreamSender::from(stream! { - yield Result::<_, SignMessageError>::Ok(TestMessage("test".into())); - })); - } - - fn check_compatible_with_hyper_wrap_stream(stream: S) -> S - where - S: Stream> + Send + 'static, - O: Into + 'static, - E: Into> + 'static, - { - stream + check_send_sync(EventStreamSender::from(FnStream::new(|tx| { + Box::pin(async move { + let message = Result::<_, TestServiceError>::Ok(TestMessage("test".into())); + tx.send(message).await.expect("failed to send"); + }) + }))); } #[tokio::test] async fn message_stream_adapter_success() { - let stream = stream! { - yield Ok(TestMessage("test".into())); - }; - let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::< - TestMessage, - TestServiceError, - >::new( + let stream = FnStream::new(|tx| { + Box::pin(async move { + let message = Ok(TestMessage("test".into())); + tx.send(message).await.expect("failed to send"); + }) + }); + let mut adapter = MessageStreamAdapter::::new( Marshaller, ErrorMarshaller, TestSigner, - Box::pin(stream), - )); + stream, + ); let mut sent_bytes = adapter.next().await.unwrap().unwrap(); let sent = Message::read_from(&mut sent_bytes).unwrap(); @@ -313,18 +320,19 @@ mod tests { #[tokio::test] async fn message_stream_adapter_construction_failure() { - let stream = stream! { - yield Err(TestServiceError); - }; - let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::< - TestMessage, - TestServiceError, - >::new( + let stream = FnStream::new(|tx| { + Box::pin(async move { + tx.send(Err(TestServiceError)) + .await + .expect("failed to send"); + }) + }); + let mut adapter = MessageStreamAdapter::::new( Marshaller, ErrorMarshaller, NoOpSigner {}, - Box::pin(stream), - )); + stream, + ); let result = adapter.next().await.unwrap(); assert!(result.is_err()); @@ -340,11 +348,15 @@ mod tests { fn check(input: impl Into>) { let _: EventStreamSender = input.into(); } - check(stream! { - yield Ok(TestMessage("test".into())); - }); - check(stream! { - yield Err(TestServiceError); - }); + check(FnStream::new(|tx| { + Box::pin(async move { + tx.send(Ok(TestMessage("test".into()))).await.unwrap(); + }) + })); + check(FnStream::new(|tx| { + Box::pin(async move { + tx.send(Err(TestServiceError)).await.unwrap(); + }) + })); } } diff --git a/rust-runtime/aws-smithy-http/src/futures_stream_adapter.rs b/rust-runtime/aws-smithy-http/src/futures_stream_adapter.rs new file mode 100644 index 0000000000..96551714f5 --- /dev/null +++ b/rust-runtime/aws-smithy-http/src/futures_stream_adapter.rs @@ -0,0 +1,149 @@ +// Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT. +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use crate::body::SdkBody; +use crate::byte_stream::error::Error as ByteStreamError; +use crate::byte_stream::ByteStream; +use bytes::Bytes; +use futures_core::stream::Stream; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// New-type wrapper to enable the impl of the `futures_core::stream::Stream` trait +/// +/// [`ByteStream`] no longer implements `futures_core::stream::Stream` so we wrap it in the +/// new-type to enable the trait when it is required. +/// +/// This is meant to be used by codegen code, and users should not need to use it directly. +pub struct FuturesStreamCompatByteStream(ByteStream); + +impl FuturesStreamCompatByteStream { + /// Creates a new `FuturesStreamCompatByteStream` by wrapping `stream`. + pub fn new(stream: ByteStream) -> Self { + Self(stream) + } + + /// Returns [`SdkBody`] of the wrapped [`ByteStream`]. + pub fn into_inner(self) -> SdkBody { + self.0.into_inner() + } +} + +impl Stream for FuturesStreamCompatByteStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_next(cx) + } +} + +#[cfg(feature = "event-stream")] +/// New-type wrapper to enable the impl of the `futures_core::stream::Stream` trait +/// +/// [`crate::event_stream::MessageStreamAdapter`] no longer implements `futures_core::stream::Stream` +/// so we wrap it in the new-type to enable the trait when it is required. +/// +/// This is meant to be used by codegen code, and users should not need to use it directly. +pub struct FuturesStreamCompatEventStream(crate::event_stream::MessageStreamAdapter); + +#[cfg(feature = "event-stream")] +impl FuturesStreamCompatEventStream { + /// Creates a new `FuturesStreamCompatEventStream` by wrapping `adapter`. + pub fn new(adapter: crate::event_stream::MessageStreamAdapter) -> Self { + Self(adapter) + } +} + +#[cfg(feature = "event-stream")] +impl Stream + for FuturesStreamCompatEventStream +{ + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_next(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use aws_smithy_async::future::fn_stream::FnStream; + use futures_core::stream::Stream; + + fn check_compatible_with_hyper_wrap_stream(stream: S) -> S + where + S: Stream> + Send + 'static, + O: Into + 'static, + E: Into> + 'static, + { + stream + } + + #[test] + fn test_byte_stream_stream_can_be_made_compatible_with_hyper_wrap_stream() { + let stream = ByteStream::from_static(b"Hello world"); + check_compatible_with_hyper_wrap_stream(FuturesStreamCompatByteStream::new(stream)); + } + + #[cfg(feature = "event-stream")] + mod tests_event_stream { + use aws_smithy_eventstream::error::Error; + use aws_smithy_eventstream::frame::MarshallMessage; + use aws_smithy_eventstream::frame::Message; + + #[derive(Debug, Eq, PartialEq)] + pub(crate) struct TestMessage(pub(crate) String); + + #[derive(Debug)] + pub(crate) struct Marshaller; + impl MarshallMessage for Marshaller { + type Input = TestMessage; + + fn marshall(&self, input: Self::Input) -> Result { + Ok(Message::new(input.0.as_bytes().to_vec())) + } + } + #[derive(Debug)] + pub(crate) struct ErrorMarshaller; + impl MarshallMessage for ErrorMarshaller { + type Input = TestServiceError; + + fn marshall(&self, _input: Self::Input) -> Result { + Err(Message::read_from(&b""[..]).expect_err("this should always fail")) + } + } + + #[derive(Debug)] + pub(crate) struct TestServiceError; + impl std::fmt::Display for TestServiceError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TestServiceError") + } + } + impl std::error::Error for TestServiceError {} + } + use tests_event_stream::*; + + #[test] + #[cfg(feature = "event-stream")] + fn test_message_adapter_stream_can_be_made_compatible_with_hyper_wrap_stream() { + let stream = FnStream::new(|tx| { + Box::pin(async move { + let message = Ok(TestMessage("test".into())); + tx.send(message).await.expect("failed to send"); + }) + }); + check_compatible_with_hyper_wrap_stream(FuturesStreamCompatEventStream( + crate::event_stream::MessageStreamAdapter::::new( + Marshaller, + ErrorMarshaller, + aws_smithy_eventstream::frame::NoOpSigner {}, + stream, + ), + )); + } +} diff --git a/rust-runtime/aws-smithy-http/src/lib.rs b/rust-runtime/aws-smithy-http/src/lib.rs index b24ecd3bf1..5a832e7973 100644 --- a/rust-runtime/aws-smithy-http/src/lib.rs +++ b/rust-runtime/aws-smithy-http/src/lib.rs @@ -27,6 +27,8 @@ pub mod body; pub mod endpoint; +#[doc(hidden)] +pub mod futures_stream_adapter; pub mod header; pub mod http; pub mod label; diff --git a/tools/ci-cdk/canary-lambda/src/latest/paginator_canary.rs b/tools/ci-cdk/canary-lambda/src/latest/paginator_canary.rs index d50c4f2be8..11914660ad 100644 --- a/tools/ci-cdk/canary-lambda/src/latest/paginator_canary.rs +++ b/tools/ci-cdk/canary-lambda/src/latest/paginator_canary.rs @@ -10,7 +10,6 @@ use aws_sdk_ec2 as ec2; use aws_sdk_ec2::types::InstanceType; use crate::CanaryEnv; -use tokio_stream::StreamExt; mk_canary!( "ec2_paginator", diff --git a/tools/ci-cdk/canary-lambda/src/latest/transcribe_canary.rs b/tools/ci-cdk/canary-lambda/src/latest/transcribe_canary.rs index 8f6420fc1b..0857f54c94 100644 --- a/tools/ci-cdk/canary-lambda/src/latest/transcribe_canary.rs +++ b/tools/ci-cdk/canary-lambda/src/latest/transcribe_canary.rs @@ -5,7 +5,6 @@ use crate::canary::CanaryError; use crate::mk_canary; -use async_stream::stream; use aws_config::SdkConfig; use aws_sdk_transcribestreaming as transcribe; use bytes::BufMut; @@ -31,12 +30,18 @@ pub async fn transcribe_canary( client: transcribe::Client, expected_transcribe_result: String, ) -> anyhow::Result<()> { - let input_stream = stream! { - let pcm = pcm_data(); - for chunk in pcm.chunks(CHUNK_SIZE) { - yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); - } - }; + let input_stream = transcribe::primitives::FnStream::new(|tx| { + Box::pin(async move { + let pcm = pcm_data(); + for chunk in pcm.chunks(CHUNK_SIZE) { + tx.send(Ok(AudioStream::AudioEvent( + AudioEvent::builder().audio_chunk(Blob::new(chunk)).build(), + ))) + .await + .expect("send should succeed"); + } + }) + }); let mut output = client .start_stream_transcription() diff --git a/tools/ci-cdk/canary-lambda/src/main.rs b/tools/ci-cdk/canary-lambda/src/main.rs index 688462031d..d42fb84f53 100644 --- a/tools/ci-cdk/canary-lambda/src/main.rs +++ b/tools/ci-cdk/canary-lambda/src/main.rs @@ -26,11 +26,11 @@ mod latest; #[cfg(feature = "latest")] pub(crate) use latest as current_canary; -// NOTE: This module can be deleted 3 releases after release-2023-01-26 -#[cfg(feature = "release-2023-01-26")] -mod release_2023_01_26; -#[cfg(feature = "release-2023-01-26")] -pub(crate) use release_2023_01_26 as current_canary; +// NOTE: This module can be deleted 3 releases after release-2023-08-03 +#[cfg(feature = "release-2023-08-03")] +mod release_2023_08_03; +#[cfg(feature = "release-2023-08-03")] +pub(crate) use release_2023_08_03 as current_canary; #[tokio::main] async fn main() -> Result<(), Error> { diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_01_26.rs b/tools/ci-cdk/canary-lambda/src/release_2023_08_03.rs similarity index 100% rename from tools/ci-cdk/canary-lambda/src/release_2023_01_26.rs rename to tools/ci-cdk/canary-lambda/src/release_2023_08_03.rs diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs b/tools/ci-cdk/canary-lambda/src/release_2023_08_03/paginator_canary.rs similarity index 92% rename from tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs rename to tools/ci-cdk/canary-lambda/src/release_2023_08_03/paginator_canary.rs index 72c9b40ed0..66df5a03e4 100644 --- a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs +++ b/tools/ci-cdk/canary-lambda/src/release_2023_08_03/paginator_canary.rs @@ -7,7 +7,7 @@ use crate::mk_canary; use anyhow::bail; use aws_sdk_ec2 as ec2; -use aws_sdk_ec2::model::InstanceType; +use aws_sdk_ec2::types::InstanceType; use crate::CanaryEnv; use tokio_stream::StreamExt; @@ -30,7 +30,7 @@ pub async fn paginator_canary(client: ec2::Client, page_size: usize) -> anyhow:: let mut num_pages = 0; while let Some(page) = history.try_next().await? { let items_in_page = page.spot_price_history.unwrap_or_default().len(); - if items_in_page > page_size as usize { + if items_in_page > page_size { bail!( "failed to retrieve results of correct page size (expected {}, got {})", page_size, @@ -60,7 +60,7 @@ pub async fn paginator_canary(client: ec2::Client, page_size: usize) -> anyhow:: #[cfg(test)] mod test { - use crate::paginator_canary::paginator_canary; + use crate::current_canary::paginator_canary::paginator_canary; #[tokio::test] async fn test_paginator() { diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/s3_canary.rs b/tools/ci-cdk/canary-lambda/src/release_2023_08_03/s3_canary.rs similarity index 98% rename from tools/ci-cdk/canary-lambda/src/release_2023_01_26/s3_canary.rs rename to tools/ci-cdk/canary-lambda/src/release_2023_08_03/s3_canary.rs index 70e3d18c55..fbcba976d8 100644 --- a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/s3_canary.rs +++ b/tools/ci-cdk/canary-lambda/src/release_2023_08_03/s3_canary.rs @@ -8,8 +8,8 @@ use crate::{mk_canary, CanaryEnv}; use anyhow::Context; use aws_config::SdkConfig; use aws_sdk_s3 as s3; -use aws_sdk_s3::presigning::config::PresigningConfig; -use s3::types::ByteStream; +use s3::presigning::PresigningConfig; +use s3::primitives::ByteStream; use std::time::Duration; use uuid::Uuid; diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/transcribe_canary.rs b/tools/ci-cdk/canary-lambda/src/release_2023_08_03/transcribe_canary.rs similarity index 97% rename from tools/ci-cdk/canary-lambda/src/release_2023_01_26/transcribe_canary.rs rename to tools/ci-cdk/canary-lambda/src/release_2023_08_03/transcribe_canary.rs index 554f4c3ddf..8f6420fc1b 100644 --- a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/transcribe_canary.rs +++ b/tools/ci-cdk/canary-lambda/src/release_2023_08_03/transcribe_canary.rs @@ -9,10 +9,10 @@ use async_stream::stream; use aws_config::SdkConfig; use aws_sdk_transcribestreaming as transcribe; use bytes::BufMut; -use transcribe::model::{ +use transcribe::primitives::Blob; +use transcribe::types::{ AudioEvent, AudioStream, LanguageCode, MediaEncoding, TranscriptResultStream, }; -use transcribe::types::Blob; const CHUNK_SIZE: usize = 8192; use crate::canary::CanaryEnv; diff --git a/tools/ci-cdk/canary-runner/src/build_bundle.rs b/tools/ci-cdk/canary-runner/src/build_bundle.rs index 464ee2e4ad..635cec4e7d 100644 --- a/tools/ci-cdk/canary-runner/src/build_bundle.rs +++ b/tools/ci-cdk/canary-runner/src/build_bundle.rs @@ -63,9 +63,10 @@ const REQUIRED_SDK_CRATES: &[&str] = &[ "aws-sdk-transcribestreaming", ]; +// The elements in this `Vec` should be sorted in an ascending order by the release date. lazy_static! { static ref NOTABLE_SDK_RELEASE_TAGS: Vec = vec![ - ReleaseTag::from_str("release-2023-01-26").unwrap(), // last version before the crate reorg + ReleaseTag::from_str("release-2023-08-03").unwrap(), // last version before `Stream` trait removal ]; } @@ -112,38 +113,58 @@ enum CrateSource { }, } -fn enabled_features(crate_source: &CrateSource) -> Vec { - let mut enabled = Vec::new(); +fn enabled_feature(crate_source: &CrateSource) -> String { if let CrateSource::VersionsManifest { release_tag, .. } = crate_source { - // we want to select the newest module specified after this release + // we want to select the oldest module specified after this release for notable in NOTABLE_SDK_RELEASE_TAGS.iter() { tracing::debug!(release_tag = ?release_tag, notable = ?notable, "considering if release tag came before notable release"); if release_tag <= notable { tracing::debug!("selecting {} as chosen release", notable); - enabled.push(notable.as_str().into()); - break; + return notable.as_str().into(); } } } - if enabled.is_empty() { - enabled.push("latest".into()); - } - enabled + "latest".into() } fn generate_crate_manifest(crate_source: CrateSource) -> Result { let mut output = BASE_MANIFEST.to_string(); - for &sdk_crate in REQUIRED_SDK_CRATES { + write_dependencies(REQUIRED_SDK_CRATES, &mut output, &crate_source)?; + write!(output, "\n[features]\n").unwrap(); + writeln!(output, "latest = []").unwrap(); + for release_tag in NOTABLE_SDK_RELEASE_TAGS.iter() { + writeln!( + output, + "\"{release_tag}\" = []", + release_tag = release_tag.as_str() + ) + .unwrap(); + } + writeln!( + output, + "default = [\"{enabled}\"]", + enabled = enabled_feature(&crate_source) + ) + .unwrap(); + Ok(output) +} + +fn write_dependencies( + required_crates: &[&str], + output: &mut String, + crate_source: &CrateSource, +) -> Result<()> { + for &required_crate in required_crates { match &crate_source { CrateSource::Path(path) => { - let path_name = match sdk_crate.strip_prefix("aws-sdk-") { + let path_name = match required_crate.strip_prefix("aws-sdk-") { Some(path) => path, - None => sdk_crate, + None => required_crate, }; let crate_path = path.join(path_name); writeln!( output, - r#"{sdk_crate} = {{ path = "{path}" }}"#, + r#"{required_crate} = {{ path = "{path}" }}"#, path = crate_path.to_string_lossy() ) .unwrap() @@ -151,40 +172,20 @@ fn generate_crate_manifest(crate_source: CrateSource) -> Result { CrateSource::VersionsManifest { versions, release_tag, - } => match versions.crates.get(sdk_crate) { + } => match versions.crates.get(required_crate) { Some(version) => writeln!( output, - r#"{sdk_crate} = "{version}""#, + r#"{required_crate} = "{version}""#, version = version.version ) .unwrap(), None => { - bail!("Couldn't find `{sdk_crate}` in versions.toml for `{release_tag}`") + bail!("Couldn't find `{required_crate}` in versions.toml for `{release_tag}`") } }, } } - write!(output, "\n[features]\n").unwrap(); - writeln!(output, "latest = []").unwrap(); - for release_tag in NOTABLE_SDK_RELEASE_TAGS.iter() { - writeln!( - output, - "\"{release_tag}\" = []", - release_tag = release_tag.as_str() - ) - .unwrap(); - } - writeln!( - output, - "default = [{enabled}]", - enabled = enabled_features(&crate_source) - .into_iter() - .map(|f| format!("\"{f}\"")) - .collect::>() - .join(", ") - ) - .unwrap(); - Ok(output) + Ok(()) } fn sha1_file(path: &Path) -> Result { @@ -441,7 +442,7 @@ aws-sdk-transcribestreaming = { path = "some/sdk/path/transcribestreaming" } [features] latest = [] -"release-2023-01-26" = [] +"release-2023-08-03" = [] default = ["latest"] "#, generate_crate_manifest(CrateSource::Path("some/sdk/path".into())).expect("success") @@ -505,7 +506,7 @@ aws-sdk-transcribestreaming = "0.16.0" [features] latest = [] -"release-2023-01-26" = [] +"release-2023-08-03" = [] default = ["latest"] "#, generate_crate_manifest(CrateSource::VersionsManifest { @@ -523,7 +524,7 @@ default = ["latest"] .collect(), release: None, }, - release_tag: ReleaseTag::from_str("release-2023-05-26").unwrap(), + release_tag: ReleaseTag::from_str("release-2023-08-26").unwrap(), }) .expect("success") ); @@ -577,26 +578,25 @@ default = ["latest"] release: None, }; assert_eq!( - enabled_features(&CrateSource::VersionsManifest { + "latest".to_string(), + enabled_feature(&CrateSource::VersionsManifest { versions: versions.clone(), - release_tag: "release-2023-02-23".parse().unwrap(), + release_tag: "release-9999-12-31".parse().unwrap(), }), - vec!["latest".to_string()] ); - assert_eq!( - enabled_features(&CrateSource::VersionsManifest { + "release-2023-08-03".to_string(), + enabled_feature(&CrateSource::VersionsManifest { versions: versions.clone(), - release_tag: "release-2023-01-26".parse().unwrap(), + release_tag: "release-2023-08-03".parse().unwrap(), }), - vec!["release-2023-01-26".to_string()] ); assert_eq!( - enabled_features(&CrateSource::VersionsManifest { + "release-2023-08-03".to_string(), + enabled_feature(&CrateSource::VersionsManifest { versions, release_tag: "release-2023-01-13".parse().unwrap(), }), - vec!["release-2023-01-26".to_string()] ); } }