From 841763342cf696c18ad0f5c6cb7a4146254df43e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Fri, 21 Oct 2022 15:03:25 +0100 Subject: [PATCH] Introduce `FuncMetadata` to represent some information about a Python function --- .../generators/PythonApplicationGenerator.kt | 33 +------- .../aws-smithy-http-server-python/src/lib.rs | 1 + .../src/middleware/handler.rs | 12 ++- .../src/middleware/layer.rs | 7 +- .../src/middleware/pytests/layer.rs | 62 +++++--------- .../src/server.rs | 25 ++---- .../aws-smithy-http-server-python/src/util.rs | 80 +++++++++++++++++++ 7 files changed, 121 insertions(+), 99 deletions(-) create mode 100644 rust-runtime/aws-smithy-http-server-python/src/util.rs diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 9fe5c3d17b5..55ced060b2e 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -96,7 +96,6 @@ class PythonApplicationGenerator( renderAppStruct(writer) renderAppDefault(writer) renderAppClone(writer) - renderAppImpl(writer) renderPyAppTrait(writer) renderPyMethods(writer) } @@ -154,30 +153,6 @@ class PythonApplicationGenerator( ) } - private fun renderAppImpl(writer: RustWriter) { - writer.rustBlockTemplate( - """ - impl App - """, - *codegenScope, - ) { - rustTemplate( - """ - // Check if a Python function is a coroutine. Since the function has not run yet, - // we cannot use `asyncio.iscoroutine()`, we need to use `inspect.iscoroutinefunction()`. - fn is_coroutine(&self, py: #{pyo3}::Python, func: &#{pyo3}::PyObject) -> #{pyo3}::PyResult { - let inspect = py.import("inspect")?; - // NOTE: that `asyncio.iscoroutine()` doesn't work here. - inspect - .call_method1("iscoroutinefunction", (func,))? - .extract::() - } - """, - *codegenScope, - ) - } - } - private fun renderPyAppTrait(writer: RustWriter) { writer.rustBlockTemplate( """ @@ -282,13 +257,7 @@ class PythonApplicationGenerator( /// Register a Python function to be executed inside a Tower middleware layer. ##[pyo3(text_signature = "(${'$'}self, func)")] pub fn middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> { - let name = func.getattr(py, "__name__")?.extract::(py)?; - let is_coroutine = self.is_coroutine(py, &func)?; - let handler = #{SmithyPython}::PyMiddlewareHandler { - name, - func, - is_coroutine, - }; + let handler = #{SmithyPython}::PyMiddlewareHandler::new(py, func)?; tracing::info!( "registering middleware function `{}`, coroutine: {}", handler.name, diff --git a/rust-runtime/aws-smithy-http-server-python/src/lib.rs b/rust-runtime/aws-smithy-http-server-python/src/lib.rs index 15334a10416..ca3ad82abfb 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lib.rs @@ -17,6 +17,7 @@ pub mod middleware; mod server; mod socket; pub mod types; +pub(crate) mod util; #[doc(inline)] pub use error::{PyError, PyMiddlewareException}; diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs index 3ce9121876b..22bcfedbcd6 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs @@ -11,6 +11,8 @@ use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyFunction}; use pyo3_asyncio::TaskLocals; use tower::{util::BoxService, BoxError, Service}; +use crate::util::func_metadata; + use super::{PyMiddlewareError, PyRequest, PyResponse}; type PyNextInner = BoxService, Response, BoxError>; @@ -56,11 +58,19 @@ impl PyNext { pub struct PyMiddlewareHandler { pub name: String, pub func: PyObject, - // TODO: use `inspect.iscoroutinefunction(object)` to detect if it is coroutine? pub is_coroutine: bool, } impl PyMiddlewareHandler { + pub fn new(py: Python, func: PyObject) -> PyResult { + let func_metadata = func_metadata(py, &func)?; + Ok(Self { + name: func_metadata.name, + func, + is_coroutine: func_metadata.is_coroutine, + }) + } + pub async fn call( self, req: Request, diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs index ff40c2c89a0..ba80f324938 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs @@ -105,8 +105,11 @@ where } fn call(&mut self, req: Request) -> Self::Future { - let clone = self.inner.clone(); - let inner = std::mem::replace(&mut self.inner, clone); + let inner = { + // https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services + let clone = self.inner.clone(); + std::mem::replace(&mut self.inner, clone) + }; let handler = self.handler.clone(); let handler_name = handler.name.clone(); let next = BoxService::new(inner.map_err(|err| err.into())); diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs index cc61aea82c7..9a23844282e 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs @@ -21,7 +21,7 @@ async fn identity_middleware() -> PyResult<()> { let locals = Python::with_gil(|py| { Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) })?; - let py_handler = Python::with_gil(|py| { + let handler = Python::with_gil(|py| { let module = PyModule::from_code( py, r#" @@ -31,14 +31,9 @@ async def identity_middleware(request, next): "", "", )?; - Ok::<_, PyErr>(module.getattr("identity_middleware")?.into()) + let handler = module.getattr("identity_middleware")?.into(); + Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) })?; - let handler = PyMiddlewareHandler { - func: py_handler, - name: "identity_middleware".to_string(), - is_coroutine: true, - }; - let layer = PyMiddlewareLayer::::new(handler, locals); let (mut service, mut handle) = mock::spawn_with(|svc| { let svc = svc.map_err(|err| panic!("service failed: {err}")); @@ -76,7 +71,7 @@ async fn returning_response_from_python_middleware() -> PyResult<()> { let locals = Python::with_gil(|py| { Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) })?; - let py_handler = Python::with_gil(|py| { + let handler = Python::with_gil(|py| { let globals = [("Response", py.get_type::())].into_py_dict(py); let locals = PyDict::new(py); py.run( @@ -87,13 +82,9 @@ def middleware(request, next): Some(globals), Some(locals), )?; - Ok::<_, PyErr>(locals.get_item("middleware").unwrap().into()) + let handler = locals.get_item("middleware").unwrap().into(); + Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) })?; - let handler = PyMiddlewareHandler { - func: py_handler, - name: "middleware".to_string(), - is_coroutine: false, - }; let layer = PyMiddlewareLayer::::new(handler, locals); let (mut service, _handle) = mock::spawn_with(|svc| { @@ -120,7 +111,7 @@ async fn convert_exception_from_middleware_to_protocol_specific_response() -> Py let locals = Python::with_gil(|py| { Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) })?; - let py_handler = Python::with_gil(|py| { + let handler = Python::with_gil(|py| { let locals = PyDict::new(py); py.run( r#" @@ -130,13 +121,9 @@ def middleware(request, next): None, Some(locals), )?; - Ok::<_, PyErr>(locals.get_item("middleware").unwrap().into()) + let handler = locals.get_item("middleware").unwrap().into(); + Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) })?; - let handler = PyMiddlewareHandler { - func: py_handler, - name: "middleware".to_string(), - is_coroutine: false, - }; let layer = PyMiddlewareLayer::::new(handler, locals); let (mut service, _handle) = mock::spawn_with(|svc| { @@ -163,7 +150,7 @@ async fn uses_status_code_and_message_from_middleware_exception() -> PyResult<() let locals = Python::with_gil(|py| { Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) })?; - let py_handler = Python::with_gil(|py| { + let handler = Python::with_gil(|py| { let globals = [( "MiddlewareException", py.get_type::(), @@ -178,13 +165,9 @@ def middleware(request, next): Some(globals), Some(locals), )?; - Ok::<_, PyErr>(locals.get_item("middleware").unwrap().into()) + let handler = locals.get_item("middleware").unwrap().into(); + Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) })?; - let handler = PyMiddlewareHandler { - func: py_handler, - name: "middleware".to_string(), - is_coroutine: false, - }; let layer = PyMiddlewareLayer::::new(handler, locals); let (mut service, _handle) = mock::spawn_with(|svc| { @@ -211,7 +194,7 @@ async fn nested_middlewares() -> PyResult<()> { let locals = Python::with_gil(|py| { Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?)) })?; - let (first_py_handler, second_py_handler) = Python::with_gil(|py| { + let (first_handler, second_handler) = Python::with_gil(|py| { let globals = [("Response", py.get_type::())].into_py_dict(py); let locals = PyDict::new(py); @@ -226,21 +209,12 @@ def second_middleware(request, next): Some(globals), Some(locals), )?; - Ok::<_, PyErr>(( - locals.get_item("first_middleware").unwrap().into(), - locals.get_item("second_middleware").unwrap().into(), - )) + let first_handler = + PyMiddlewareHandler::new(py, locals.get_item("first_middleware").unwrap().into())?; + let second_handler = + PyMiddlewareHandler::new(py, locals.get_item("second_middleware").unwrap().into())?; + Ok::<_, PyErr>((first_handler, second_handler)) })?; - let first_handler = PyMiddlewareHandler { - func: first_py_handler, - name: "first_middleware".to_string(), - is_coroutine: true, - }; - let second_handler = PyMiddlewareHandler { - func: second_py_handler, - name: "second_middleware".to_string(), - is_coroutine: false, - }; let layer = Stack::new( PyMiddlewareLayer::::new(second_handler, locals.clone()), diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index 2b48e272238..2e51eb2879e 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -17,7 +17,7 @@ use signal_hook::{consts::*, iterator::Signals}; use tokio::runtime; use tower::{util::BoxCloneService, ServiceBuilder}; -use crate::PySocket; +use crate::{util::func_metadata, PySocket}; /// A Python handler function representation. /// @@ -28,6 +28,7 @@ use crate::PySocket; #[derive(Debug, Clone)] pub struct PyHandler { pub func: PyObject, + // Number of args is needed to decide whether handler accepts context as an argument pub args: usize, pub is_coroutine: bool, } @@ -258,16 +259,6 @@ event_loop.add_signal_handler(signal.SIGINT, Ok(()) } - // Check if a Python function is a coroutine. Since the function has not run yet, - // we cannot use `asyncio.iscoroutine()`, we need to use `inspect.iscoroutinefunction()`. - fn is_coroutine(&self, py: Python, func: &PyObject) -> PyResult { - let inspect = py.import("inspect")?; - // NOTE: that `asyncio.iscoroutine()` doesn't work here. - inspect - .call_method1("iscoroutinefunction", (func,))? - .extract::() - } - /// Register a Python function to be executed inside the Smithy Rust handler. /// /// There are some information needed to execute the Python code from a Rust handler, @@ -275,17 +266,11 @@ event_loop.add_signal_handler(signal.SIGINT, /// the number of arguments available, which tells us if the handler wants the state to be /// passed or not. fn register_operation(&mut self, py: Python, name: &str, func: PyObject) -> PyResult<()> { - let is_coroutine = self.is_coroutine(py, &func)?; - // Find number of expected methods (a Python implementation could not accept the context). - let inspect = py.import("inspect")?; - let func_args = inspect - .call_method1("getargs", (func.getattr(py, "__code__")?,))? - .getattr("args")? - .extract::>()?; + let func_metadata = func_metadata(py, &func)?; let handler = PyHandler { func, - is_coroutine, - args: func_args.len(), + is_coroutine: func_metadata.is_coroutine, + args: func_metadata.num_args, }; tracing::info!( "Registering handler function `{name}`, coroutine: {}, arguments: {}", diff --git a/rust-runtime/aws-smithy-http-server-python/src/util.rs b/rust-runtime/aws-smithy-http-server-python/src/util.rs new file mode 100644 index 00000000000..118c9cac58b --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/util.rs @@ -0,0 +1,80 @@ +use pyo3::prelude::*; + +// Captures some information about a Python function. +#[derive(Debug, PartialEq)] +pub struct FuncMetadata { + pub name: String, + pub is_coroutine: bool, + pub num_args: usize, +} + +// Returns `FuncMetadata` for given `func`. +pub fn func_metadata(py: Python, func: &PyObject) -> PyResult { + let name = func.getattr(py, "__name__")?.extract::(py)?; + let is_coroutine = is_coroutine(py, func)?; + let inspect = py.import("inspect")?; + let args = inspect + .call_method1("getargs", (func.getattr(py, "__code__")?,))? + .getattr("args")? + .extract::>()?; + Ok(FuncMetadata { + name, + is_coroutine, + num_args: args.len(), + }) +} + +// Check if a Python function is a coroutine. Since the function has not run yet, +// we cannot use `asyncio.iscoroutine()`, we need to use `inspect.iscoroutinefunction()`. +fn is_coroutine(py: Python, func: &PyObject) -> PyResult { + let inspect = py.import("inspect")?; + // NOTE: that `asyncio.iscoroutine()` doesn't work here. + inspect + .call_method1("iscoroutinefunction", (func,))? + .extract::() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn function_metadata() -> PyResult<()> { + Python::with_gil(|py| { + let module = PyModule::from_code( + py, + r#" +def regular_func(first_arg, second_arg): + pass + +async def async_func(): + pass +"#, + "", + "", + )?; + + let regular_func = module.getattr("regular_func")?.into_py(py); + assert_eq!( + FuncMetadata { + name: "regular_func".to_string(), + is_coroutine: false, + num_args: 2, + }, + func_metadata(py, ®ular_func)? + ); + + let async_func = module.getattr("async_func")?.into_py(py); + assert_eq!( + FuncMetadata { + name: "async_func".to_string(), + is_coroutine: true, + num_args: 0, + }, + func_metadata(py, &async_func)? + ); + + Ok(()) + }) + } +}