From 152b8e2db43dfe1bc33cc627218423afcf94106c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Fri, 9 Dec 2022 11:48:44 +0000 Subject: [PATCH] Make Lambda related types feature gated --- .../src/context.rs | 93 ++-------- .../src/context/lambda.rs | 167 ++++++++++++++++++ .../src/context/layer.rs | 96 +++++++--- .../src/context/testing.rs | 21 +-- .../aws-smithy-http-server-python/src/lib.rs | 1 + .../aws-smithy-http-server-python/src/util.rs | 1 + 6 files changed, 264 insertions(+), 115 deletions(-) create mode 100644 rust-runtime/aws-smithy-http-server-python/src/context/lambda.rs diff --git a/rust-runtime/aws-smithy-http-server-python/src/context.rs b/rust-runtime/aws-smithy-http-server-python/src/context.rs index b067ab20ea0..a2b2028158d 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/context.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/context.rs @@ -5,14 +5,11 @@ //! Python context definition. -use std::collections::HashSet; - -use pyo3::{types::PyDict, PyObject, PyResult, Python, ToPyObject}; - -use crate::{rich_py_err, util::is_optional_of}; - -use super::lambda::PyLambdaContext; +use http::Extensions; +use pyo3::{PyObject, PyResult, Python, ToPyObject}; +#[cfg(feature = "aws-lambda")] +mod lambda; pub mod layer; #[cfg(test)] mod testing; @@ -38,38 +35,23 @@ pub struct PyContext { // We could introduce a registry to keep track of every injectable type but I'm not sure that is the best way to do it, // so until we found a good way to achive that, I didn't want to introduce any abstraction here and // keep it simple because we only have one field that is injectable. - lambda_ctx_fields: HashSet, + #[cfg(feature = "aws-lambda")] + lambda_ctx: lambda::PyContextLambda, } impl PyContext { pub fn new(inner: PyObject) -> PyResult { - let lambda_ctx_fields = Python::with_gil(|py| get_lambda_ctx_fields(py, &inner))?; Ok(Self { + #[cfg(feature = "aws-lambda")] + lambda_ctx: lambda::PyContextLambda::new(inner.clone())?, inner, - lambda_ctx_fields, }) } - /// Returns true if custom context class provided by the user injects [PyLambdaContext]. - pub fn has_lambda_context_fields(&self) -> bool { - !self.lambda_ctx_fields.is_empty() - } - - /// Sets given `lambda_ctx` to user provided context class. - pub fn set_lambda_context(&self, lambda_ctx: Option) { - if !self.has_lambda_context_fields() { - // Return early without acquiring GIL - return; - } - - let inner = &self.inner; - Python::with_gil(|py| { - for field in self.lambda_ctx_fields.iter() { - if let Err(err) = inner.setattr(py, field.as_str(), lambda_ctx.clone()) { - tracing::warn!(field = ?field, error = ?rich_py_err(err), "could not inject `LambdaContext` to context") - } - } - }); + pub fn populate_from_extensions(&self, _ext: &Extensions) { + #[cfg(feature = "aws-lambda")] + self.lambda_ctx + .populate_from_extensions(self.inner.clone(), _ext); } } @@ -79,39 +61,15 @@ impl ToPyObject for PyContext { } } -// Inspects the given `PyObject` to detect fields that type-hinted `PyLambdaContext`. -fn get_lambda_ctx_fields(py: Python, ctx: &PyObject) -> PyResult> { - let typing = py.import("typing")?; - let hints = match typing - .call_method1("get_type_hints", (ctx,)) - .and_then(|res| res.extract::<&PyDict>()) - { - Ok(hints) => hints, - Err(_) => { - // `get_type_hints` could fail if `ctx` is `None`, which is the default value - // for the context if user does not provide a custom class. - // In that case, this is not really an error and we should just return an empty set. - return Ok(HashSet::new()); - } - }; - - let mut fields = HashSet::new(); - for (key, value) in hints { - if is_optional_of::(py, value)? { - fields.insert(key.to_string()); - } - } - Ok(fields) -} - #[cfg(test)] mod tests { + use http::Extensions; use pyo3::{prelude::*, py_run}; - use crate::context::testing::{get_context, py_lambda_ctx}; + use super::testing::get_context; #[test] - fn py_context_with_lambda_context() -> PyResult<()> { + fn py_context() -> PyResult<()> { pyo3::prepare_freethreaded_python(); let ctx = get_context( @@ -119,7 +77,6 @@ mod tests { class Context: foo: int = 0 bar: str = 'qux' - lambda_ctx: typing.Optional[LambdaContext] ctx = Context() ctx.foo = 42 @@ -132,19 +89,6 @@ ctx.foo = 42 r#" assert ctx.foo == 42 assert ctx.bar == 'qux' -assert not hasattr(ctx, 'lambda_ctx') -"# - ); - }); - - ctx.set_lambda_context(Some(py_lambda_ctx("my-req-id", "123"))); - Python::with_gil(|py| { - py_run!( - py, - ctx, - r#" -assert ctx.lambda_ctx.request_id == "my-req-id" -assert ctx.lambda_ctx.deadline == 123 # Make some modifications ctx.foo += 1 ctx.bar = 'baz' @@ -152,15 +96,13 @@ ctx.bar = 'baz' ); }); - // Assume we are getting a new request but that one doesn't have a `LambdaContext`, - // in that case we should make fields `None` and shouldn't leak the previous `LambdaContext`. - ctx.set_lambda_context(None); + ctx.populate_from_extensions(&Extensions::new()); + Python::with_gil(|py| { py_run!( py, ctx, r#" -assert ctx.lambda_ctx is None # Make sure we are preserving any modifications assert ctx.foo == 43 assert ctx.bar == 'baz' @@ -179,7 +121,6 @@ assert ctx.bar == 'baz' pyo3::prepare_freethreaded_python(); let ctx = get_context("ctx = None"); - ctx.set_lambda_context(Some(py_lambda_ctx("my-req-id", "123"))); Python::with_gil(|py| { py_run!(py, ctx, "assert ctx is None"); }); diff --git a/rust-runtime/aws-smithy-http-server-python/src/context/lambda.rs b/rust-runtime/aws-smithy-http-server-python/src/context/lambda.rs new file mode 100644 index 00000000000..65abac71b6c --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/context/lambda.rs @@ -0,0 +1,167 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Support for injecting [PyLambdaContext] to [super::PyContext]. + +use std::collections::HashSet; + +use http::Extensions; +use lambda_http::Context as LambdaContext; +use pyo3::{types::PyDict, PyObject, PyResult, Python}; + +use crate::{lambda::PyLambdaContext, rich_py_err, util::is_optional_of}; + +#[derive(Clone)] +pub struct PyContextLambda { + fields: HashSet, +} + +impl PyContextLambda { + pub fn new(ctx: PyObject) -> PyResult { + let fields = Python::with_gil(|py| get_lambda_ctx_fields(py, &ctx))?; + Ok(Self { fields }) + } + + pub fn populate_from_extensions(&self, ctx: PyObject, ext: &Extensions) { + if self.fields.is_empty() { + // Return early without acquiring GIL + return; + } + + let lambda_ctx = ext + .get::() + .cloned() + .map(PyLambdaContext::new); + + Python::with_gil(|py| { + for field in self.fields.iter() { + if let Err(err) = ctx.setattr(py, field.as_str(), lambda_ctx.clone()) { + tracing::warn!(field = ?field, error = ?rich_py_err(err), "could not inject `LambdaContext` to context") + } + } + }); + } +} + +// Inspects the given `PyObject` to detect fields that type-hinted `PyLambdaContext`. +fn get_lambda_ctx_fields(py: Python, ctx: &PyObject) -> PyResult> { + let typing = py.import("typing")?; + let hints = match typing + .call_method1("get_type_hints", (ctx,)) + .and_then(|res| res.extract::<&PyDict>()) + { + Ok(hints) => hints, + Err(_) => { + // `get_type_hints` could fail if `ctx` is `None`, which is the default value + // for the context if user does not provide a custom class. + // In that case, this is not really an error and we should just return an empty set. + return Ok(HashSet::new()); + } + }; + + let mut fields = HashSet::new(); + for (key, value) in hints { + if is_optional_of::(py, value)? { + fields.insert(key.to_string()); + } + } + Ok(fields) +} + +#[cfg(test)] +mod tests { + use http::Extensions; + use lambda_http::Context as LambdaContext; + use pyo3::{prelude::*, py_run}; + + use crate::context::testing::{get_context, lambda_ctx}; + + #[test] + fn py_context_with_lambda_context() -> PyResult<()> { + pyo3::prepare_freethreaded_python(); + + let ctx = get_context( + r#" +class Context: + foo: int = 0 + bar: str = 'qux' + lambda_ctx: typing.Optional[LambdaContext] + +ctx = Context() +ctx.foo = 42 +"#, + ); + Python::with_gil(|py| { + py_run!( + py, + ctx, + r#" +assert ctx.foo == 42 +assert ctx.bar == 'qux' +assert not hasattr(ctx, 'lambda_ctx') +"# + ); + }); + + ctx.populate_from_extensions(&extensions_with_lambda_ctx(lambda_ctx("my-req-id", "123"))); + Python::with_gil(|py| { + py_run!( + py, + ctx, + r#" +assert ctx.lambda_ctx.request_id == "my-req-id" +assert ctx.lambda_ctx.deadline == 123 +# Make some modifications +ctx.foo += 1 +ctx.bar = 'baz' +"# + ); + }); + + // Assume we are getting a new request but that one doesn't have a `LambdaContext`, + // in that case we should make fields `None` and shouldn't leak the previous `LambdaContext`. + ctx.populate_from_extensions(&empty_extensions()); + Python::with_gil(|py| { + py_run!( + py, + ctx, + r#" +assert ctx.lambda_ctx is None +# Make sure we are preserving any modifications +assert ctx.foo == 43 +assert ctx.bar == 'baz' +"# + ); + }); + + Ok(()) + } + + #[test] + fn works_with_none() -> PyResult<()> { + // Users can set context to `None` by explicity or implicitly by not providing a custom context class, + // it shouldn't be fail in that case. + + pyo3::prepare_freethreaded_python(); + + let ctx = get_context("ctx = None"); + ctx.populate_from_extensions(&extensions_with_lambda_ctx(lambda_ctx("my-req-id", "123"))); + Python::with_gil(|py| { + py_run!(py, ctx, "assert ctx is None"); + }); + + Ok(()) + } + + fn extensions_with_lambda_ctx(ctx: LambdaContext) -> Extensions { + let mut exts = empty_extensions(); + exts.insert(ctx); + exts + } + + fn empty_extensions() -> Extensions { + Extensions::new() + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/context/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/context/layer.rs index 00232e4a2a6..f3d223b2b2d 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/context/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/context/layer.rs @@ -8,11 +8,8 @@ use std::task::{Context, Poll}; use http::{Request, Response}; -use lambda_http::Context as LambdaContext; use tower::{Layer, Service}; -use crate::lambda::PyLambdaContext; - use super::PyContext; /// AddPyContextLayer is a [tower::Layer] that populates given [PyContext] from the [Request] @@ -58,15 +55,7 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - if self.ctx.has_lambda_context_fields() { - let py_lambda_ctx = req - .extensions() - .get::() - .cloned() - .map(PyLambdaContext::new); - self.ctx.set_lambda_context(py_lambda_ctx); - } - + self.ctx.populate_from_extensions(req.extensions()); req.extensions_mut().insert(self.ctx.clone()); self.inner.call(req) } @@ -74,60 +63,109 @@ where #[cfg(test)] mod tests { + use std::convert::Infallible; + use http::{Request, Response}; use hyper::Body; use pyo3::prelude::*; use pyo3::types::IntoPyDict; - use std::convert::Infallible; use tower::{service_fn, ServiceBuilder, ServiceExt}; + use crate::context::testing::get_context; + use super::*; - use crate::context::testing::{get_context, lambda_ctx}; #[tokio::test] - async fn populates_lambda_context() { + async fn injects_context_to_req_extensions() { pyo3::prepare_freethreaded_python(); let ctx = get_context( r#" class Context: counter: int = 42 - lambda_ctx: typing.Optional[LambdaContext] = None ctx = Context() -"#, + "#, ); let svc = ServiceBuilder::new() .layer(AddPyContextLayer::new(ctx)) .service(service_fn(|req: Request| async move { let ctx = req.extensions().get::().unwrap(); - let (req_id, counter) = Python::with_gil(|py| { + let counter = Python::with_gil(|py| { let locals = [("ctx", ctx)].into_py_dict(py); py.run( r#" -req_id = ctx.lambda_ctx.request_id ctx.counter += 1 counter = ctx.counter -"#, + "#, None, Some(locals), ) .unwrap(); - ( - locals.get_item("req_id").unwrap().to_string(), - locals.get_item("counter").unwrap().to_string(), - ) + locals.get_item("counter").unwrap().to_string() }); - Ok::<_, Infallible>(Response::new((req_id, counter))) + Ok::<_, Infallible>(Response::new(counter)) })); - let mut req = Request::new(Body::empty()); - req.extensions_mut().insert(lambda_ctx("my-req-id", "178")); - + let req = Request::new(Body::empty()); let res = svc.oneshot(req).await.unwrap().into_body(); + assert_eq!("43".to_string(), res); + } + + #[cfg(feature = "aws-lambda")] + mod lambda { - assert_eq!(("my-req-id".to_string(), "43".to_string()), res); + use crate::context::testing::lambda_ctx; + + use super::*; + + #[tokio::test] + async fn populates_lambda_context() { + pyo3::prepare_freethreaded_python(); + + let ctx = get_context( + r#" +class Context: + counter: int = 42 + lambda_ctx: typing.Optional[LambdaContext] = None + +ctx = Context() + "#, + ); + + let svc = ServiceBuilder::new() + .layer(AddPyContextLayer::new(ctx)) + .service(service_fn(|req: Request| async move { + let ctx = req.extensions().get::().unwrap(); + let (req_id, counter) = Python::with_gil(|py| { + let locals = [("ctx", ctx)].into_py_dict(py); + py.run( + r#" +req_id = ctx.lambda_ctx.request_id +ctx.counter += 1 +counter = ctx.counter + "#, + None, + Some(locals), + ) + .unwrap(); + + ( + locals.get_item("req_id").unwrap().to_string(), + locals.get_item("counter").unwrap().to_string(), + ) + }); + Ok::<_, Infallible>(Response::new((req_id, counter))) + })); + + let mut req = Request::new(Body::empty()); + req.extensions_mut().insert(lambda_ctx("my-req-id", "178")); + + let res = svc.oneshot(req).await.unwrap().into_body(); + + assert_eq!(("my-req-id".to_string(), "43".to_string()), res); + } } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/context/testing.rs b/rust-runtime/aws-smithy-http-server-python/src/context/testing.rs index cd3e1313cd1..65baefa3809 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/context/testing.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/context/testing.rs @@ -5,20 +5,22 @@ //! Testing utilities for [PyContext]. -use http::{header::HeaderName, HeaderMap, HeaderValue}; -use lambda_http::Context; use pyo3::{ types::{PyDict, PyModule}, IntoPy, PyErr, Python, }; -use super::{PyContext, PyLambdaContext}; +use super::PyContext; pub fn get_context(code: &str) -> PyContext { let inner = Python::with_gil(|py| { let globals = PyModule::import(py, "__main__")?.dict(); globals.set_item("typing", py.import("typing")?)?; - globals.set_item("LambdaContext", py.get_type::())?; + #[cfg(feature = "aws-lambda")] + globals.set_item( + "LambdaContext", + py.get_type::(), + )?; let locals = PyDict::new(py); py.run(code, Some(globals), Some(locals))?; let context = locals @@ -31,7 +33,10 @@ pub fn get_context(code: &str) -> PyContext { PyContext::new(inner).unwrap() } -pub fn lambda_ctx(req_id: &'static str, deadline_ms: &'static str) -> Context { +#[cfg(feature = "aws-lambda")] +pub fn lambda_ctx(req_id: &'static str, deadline_ms: &'static str) -> lambda_http::Context { + use http::{header::HeaderName, HeaderMap, HeaderValue}; + let headers = HeaderMap::from_iter([ ( HeaderName::from_static("lambda-runtime-aws-request-id"), @@ -42,9 +47,5 @@ pub fn lambda_ctx(req_id: &'static str, deadline_ms: &'static str) -> Context { HeaderValue::from_static(deadline_ms), ), ]); - Context::try_from(headers).unwrap() -} - -pub fn py_lambda_ctx(req_id: &'static str, deadline_ms: &'static str) -> PyLambdaContext { - PyLambdaContext::new(lambda_ctx(req_id, deadline_ms)) + lambda_http::Context::try_from(headers).unwrap() } diff --git a/rust-runtime/aws-smithy-http-server-python/src/lib.rs b/rust-runtime/aws-smithy-http-server-python/src/lib.rs index 1a33c58acbd..0e13dca6f9d 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lib.rs @@ -15,6 +15,7 @@ pub mod context; mod error; +#[cfg(feature = "aws-lambda")] pub mod lambda; pub mod logging; pub mod middleware; diff --git a/rust-runtime/aws-smithy-http-server-python/src/util.rs b/rust-runtime/aws-smithy-http-server-python/src/util.rs index df0b8eafc56..b420c26fd45 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/util.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/util.rs @@ -43,6 +43,7 @@ fn is_coroutine(py: Python, func: &PyObject) -> PyResult { } // Checks whether given Python type is `Optional[T]`. +#[allow(unused)] pub fn is_optional_of(py: Python, ty: &PyAny) -> PyResult { // for reference: https://stackoverflow.com/a/56833826