Skip to content

Commit

Permalink
Python: Add tests and fix the issues with Timestamp and ByteStream (
Browse files Browse the repository at this point in the history
#2431)

* Add `timestamp` type test and fix the conversion error

* Add some tests for `ByteStream` and fix async issues

* Use `__anext__` method instead of `anext`

---------

Co-authored-by: Matteo Bigoi <[email protected]>
  • Loading branch information
unexge and crisidev authored Mar 7, 2023
1 parent 7ce8032 commit 26cb37a
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError
Expand Down Expand Up @@ -179,7 +178,6 @@ class JsonSerializerGenerator(
private val serializerUtil = SerializerUtil(model)
private val operationSerModule = RustModule.private("operation_ser")
private val jsonSerModule = RustModule.private("json_ser")
private val typeConversionGenerator = TypeConversionGenerator(model, symbolProvider, runtimeConfig)

/**
* Reusable structure serializer implementation that can be used to generate serializing code for
Expand Down Expand Up @@ -407,11 +405,7 @@ class JsonSerializerGenerator(
val timestampFormat =
httpBindingResolver.timestampFormat(context.shape, HttpLocation.DOCUMENT, EPOCH_SECONDS)
val timestampFormatType = RuntimeType.timestampFormat(runtimeConfig, timestampFormat)
rustTemplate(
"$writer.date_time(${value.asRef()}#{ConvertInto:W}, #{FormatType})?;",
"FormatType" to timestampFormatType,
"ConvertInto" to typeConversionGenerator.convertViaInto(target),
)
rust("$writer.date_time(${value.asRef()}, #T)?;", timestampFormatType)
}

is CollectionShape -> jsonArrayWriter(context) { arrayName ->
Expand Down
3 changes: 3 additions & 0 deletions codegen-server/python/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ dependencies {
implementation(project(":codegen-server"))
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")

// `smithy.framework#ValidationException` is defined here, which is used in `PythonServerTypesTest`.
testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion")
}

tasks.compileKotlin { kotlinOptions.jvmTarget = "1.8" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,95 @@ internal class PythonServerTypesTest {

cargoTest(testDir)
}

@Test
fun `timestamp type`() {
val model = """
namespace test
use aws.protocols#restJson1
use smithy.framework#ValidationException
@restJson1
service Service {
operations: [
Echo,
],
}
@http(method: "POST", uri: "/echo")
operation Echo {
input: EchoInput,
output: EchoOutput,
errors: [ValidationException],
}
structure EchoInput {
@required
value: Timestamp,
opt_value: Timestamp,
}
structure EchoOutput {
@required
value: Timestamp,
opt_value: Timestamp,
}
""".asSmithyModel()

val (pluginCtx, testDir) = generatePythonServerPluginContext(model)
executePythonServerCodegenVisitor(pluginCtx)

val writer = RustWriter.forModule("service")
writer.tokioTest("timestamp_type") {
rust(
"""
use tower::Service as _;
use pyo3::{types::IntoPyDict, IntoPy, Python};
use hyper::{Body, Request, body};
use crate::{input, output, python_types};
pyo3::prepare_freethreaded_python();
let mut service = Service::builder_without_plugins()
.echo(|input: input::EchoInput| async {
Ok(Python::with_gil(|py| {
let globals = [
("EchoOutput", py.get_type::<output::EchoOutput>()),
("DateTime", py.get_type::<python_types::DateTime>()),
].into_py_dict(py);
let locals = [("input", input.into_py(py))].into_py_dict(py);
py.run("assert input.value.secs() == 1676298520", Some(globals), Some(locals)).unwrap();
py.run("output = EchoOutput(value=input.value, opt_value=DateTime.from_secs(1677771678))", Some(globals), Some(locals)).unwrap();
locals
.get_item("output")
.unwrap()
.extract::<output::EchoOutput>()
.unwrap()
}))
})
.build()
.unwrap();
let req = Request::builder()
.method("POST")
.uri("/echo")
.body(Body::from("{\"value\":1676298520}"))
.unwrap();
let res = service.call(req).await.unwrap();
assert!(res.status().is_success());
let body = body::to_bytes(res.into_body()).await.unwrap();
let body = std::str::from_utf8(&body).unwrap();
assert!(body.contains("\"value\":1676298520"));
assert!(body.contains("\"opt_value\":1677771678"));
""".trimIndent(),
)
}

testDir.resolve("src/service.rs").appendText(writer.toString())

cargoTest(testDir)
}
}
10 changes: 9 additions & 1 deletion rust-runtime/aws-smithy-http-server-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,22 @@ pretty_assertions = "1"
futures-util = "0.3"
tower-test = "0.4"
tokio-test = "0.4"
pyo3-asyncio = { version = "0.17.0", features = ["testing", "attributes", "tokio-runtime"] }
pyo3-asyncio = { version = "0.17.0", features = ["testing", "attributes", "tokio-runtime", "unstable-streams"] }
rcgen = "0.10.0"
hyper-rustls = { version = "0.23.1", features = ["http2"] }

# PyO3 Asyncio tests cannot use Cargo's default testing harness because `asyncio`
# wants to control the main thread. So we need to use testing harness provided by `pyo3_asyncio`
# for the async Python tests. For more detail see:
# https://docs.rs/pyo3-asyncio/0.18.0/pyo3_asyncio/testing/index.html#pyo3-asyncio-testing-utilities
[[test]]
name = "middleware_tests"
path = "src/middleware/pytests/harness.rs"
harness = false
[[test]]
name = "python_tests"
path = "src/pytests/harness.rs"
harness = false

[package.metadata.docs.rs]
all-features = true
Expand Down
151 changes: 151 additions & 0 deletions rust-runtime/aws-smithy-http-server-python/src/pytests/bytestream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

use std::io;

use futures::StreamExt;
use futures_util::stream;
use hyper::Body;
use pyo3::{prelude::*, py_run};

use aws_smithy_http::body::SdkBody;
use aws_smithy_http_server_python::types::ByteStream;

#[pyo3_asyncio::tokio::test]
fn consuming_stream_on_python_synchronously() -> PyResult<()> {
let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]);
Python::with_gil(|py| {
let bytestream = bytestream.into_py(py);
py_run!(
py,
bytestream,
r#"
assert next(bytestream) == b"hello"
assert next(bytestream) == b" "
assert next(bytestream) == b"world"
try:
next(bytestream)
assert False, "iteration should stop by now"
except StopIteration:
pass
"#
);
Ok(())
})
}

#[pyo3_asyncio::tokio::test]
fn consuming_stream_on_python_synchronously_with_loop() -> PyResult<()> {
let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]);
Python::with_gil(|py| {
let bytestream = bytestream.into_py(py);
py_run!(
py,
bytestream,
r#"
total = []
for chunk in bytestream:
total.append(chunk)
assert total == [b"hello", b" ", b"world"]
"#
);
Ok(())
})
}

#[pyo3_asyncio::tokio::test]
fn consuming_stream_on_python_asynchronously() -> PyResult<()> {
let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]);
Python::with_gil(|py| {
let bytestream = bytestream.into_py(py);
py_run!(
py,
bytestream,
r#"
import asyncio
async def main(bytestream):
assert await bytestream.__anext__() == b"hello"
assert await bytestream.__anext__() == b" "
assert await bytestream.__anext__() == b"world"
try:
await bytestream.__anext__()
assert False, "iteration should stop by now"
except StopAsyncIteration:
pass
asyncio.run(main(bytestream))
"#
);
Ok(())
})
}

#[pyo3_asyncio::tokio::test]
fn consuming_stream_on_python_asynchronously_with_loop() -> PyResult<()> {
let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]);
Python::with_gil(|py| {
let bytestream = bytestream.into_py(py);
py_run!(
py,
bytestream,
r#"
import asyncio
async def main(bytestream):
total = []
async for chunk in bytestream:
total.append(chunk)
assert total == [b"hello", b" ", b"world"]
asyncio.run(main(bytestream))
"#
);
Ok(())
})
}

#[pyo3_asyncio::tokio::test]
async fn streaming_back_to_rust_from_python() -> PyResult<()> {
let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]);
let py_stream = Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
async def handler(bytestream):
async for chunk in bytestream:
yield "🐍 " + chunk.decode("utf-8")
yield "Hello from Python!"
"#,
"",
"",
)?;
let handler = module.getattr("handler")?;
let output = handler.call1((bytestream,))?;
Ok::<_, PyErr>(pyo3_asyncio::tokio::into_stream_v2(output))
})??;

let mut py_stream = py_stream.map(|v| Python::with_gil(|py| v.extract::<String>(py).unwrap()));

assert_eq!(py_stream.next().await, Some("🐍 hello".to_string()));
assert_eq!(py_stream.next().await, Some("🐍 ".to_string()));
assert_eq!(py_stream.next().await, Some("🐍 world".to_string()));
assert_eq!(
py_stream.next().await,
Some("Hello from Python!".to_string())
);
assert_eq!(py_stream.next().await, None);

Ok(())
}

fn streaming_bytestream_from_vec(chunks: Vec<&'static str>) -> ByteStream {
let stream = stream::iter(chunks.into_iter().map(|v| Ok::<_, io::Error>(v)));
let body = Body::wrap_stream(stream);
ByteStream::new(SdkBody::from(body))
}
11 changes: 11 additions & 0 deletions rust-runtime/aws-smithy-http-server-python/src/pytests/harness.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

#[pyo3_asyncio::tokio::main]
async fn main() -> pyo3::PyResult<()> {
pyo3_asyncio::testing::main().await
}

mod bytestream;
Loading

0 comments on commit 26cb37a

Please sign in to comment.