Skip to content

Commit

Permalink
Introduce FuncMetadata to represent some information about a Python…
Browse files Browse the repository at this point in the history
… function
  • Loading branch information
unexge committed Oct 21, 2022
1 parent 67630d5 commit 8417633
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ class PythonApplicationGenerator(
renderAppStruct(writer)
renderAppDefault(writer)
renderAppClone(writer)
renderAppImpl(writer)
renderPyAppTrait(writer)
renderPyMethods(writer)
}
Expand Down Expand Up @@ -154,30 +153,6 @@ class PythonApplicationGenerator(
)
}

private fun renderAppImpl(writer: RustWriter) {
writer.rustBlockTemplate(
"""
impl App
""",
*codegenScope,
) {
rustTemplate(
"""
// Check if a Python function is a coroutine. Since the function has not run yet,
// we cannot use `asyncio.iscoroutine()`, we need to use `inspect.iscoroutinefunction()`.
fn is_coroutine(&self, py: #{pyo3}::Python, func: &#{pyo3}::PyObject) -> #{pyo3}::PyResult<bool> {
let inspect = py.import("inspect")?;
// NOTE: that `asyncio.iscoroutine()` doesn't work here.
inspect
.call_method1("iscoroutinefunction", (func,))?
.extract::<bool>()
}
""",
*codegenScope,
)
}
}

private fun renderPyAppTrait(writer: RustWriter) {
writer.rustBlockTemplate(
"""
Expand Down Expand Up @@ -282,13 +257,7 @@ class PythonApplicationGenerator(
/// Register a Python function to be executed inside a Tower middleware layer.
##[pyo3(text_signature = "(${'$'}self, func)")]
pub fn middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> {
let name = func.getattr(py, "__name__")?.extract::<String>(py)?;
let is_coroutine = self.is_coroutine(py, &func)?;
let handler = #{SmithyPython}::PyMiddlewareHandler {
name,
func,
is_coroutine,
};
let handler = #{SmithyPython}::PyMiddlewareHandler::new(py, func)?;
tracing::info!(
"registering middleware function `{}`, coroutine: {}",
handler.name,
Expand Down
1 change: 1 addition & 0 deletions rust-runtime/aws-smithy-http-server-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod middleware;
mod server;
mod socket;
pub mod types;
pub(crate) mod util;

#[doc(inline)]
pub use error::{PyError, PyMiddlewareException};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyFunction};
use pyo3_asyncio::TaskLocals;
use tower::{util::BoxService, BoxError, Service};

use crate::util::func_metadata;

use super::{PyMiddlewareError, PyRequest, PyResponse};

type PyNextInner = BoxService<Request<Body>, Response<BoxBody>, BoxError>;
Expand Down Expand Up @@ -56,11 +58,19 @@ impl PyNext {
pub struct PyMiddlewareHandler {
pub name: String,
pub func: PyObject,
// TODO: use `inspect.iscoroutinefunction(object)` to detect if it is coroutine?
pub is_coroutine: bool,
}

impl PyMiddlewareHandler {
pub fn new(py: Python, func: PyObject) -> PyResult<Self> {
let func_metadata = func_metadata(py, &func)?;
Ok(Self {
name: func_metadata.name,
func,
is_coroutine: func_metadata.is_coroutine,
})
}

pub async fn call(
self,
req: Request<Body>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,11 @@ where
}

fn call(&mut self, req: Request<Body>) -> Self::Future {
let clone = self.inner.clone();
let inner = std::mem::replace(&mut self.inner, clone);
let inner = {
// https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
let clone = self.inner.clone();
std::mem::replace(&mut self.inner, clone)
};
let handler = self.handler.clone();
let handler_name = handler.name.clone();
let next = BoxService::new(inner.map_err(|err| err.into()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async fn identity_middleware() -> PyResult<()> {
let locals = Python::with_gil(|py| {
Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?))
})?;
let py_handler = Python::with_gil(|py| {
let handler = Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
Expand All @@ -31,14 +31,9 @@ async def identity_middleware(request, next):
"",
"",
)?;
Ok::<_, PyErr>(module.getattr("identity_middleware")?.into())
let handler = module.getattr("identity_middleware")?.into();
Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?)
})?;
let handler = PyMiddlewareHandler {
func: py_handler,
name: "identity_middleware".to_string(),
is_coroutine: true,
};

let layer = PyMiddlewareLayer::<RestJson1>::new(handler, locals);
let (mut service, mut handle) = mock::spawn_with(|svc| {
let svc = svc.map_err(|err| panic!("service failed: {err}"));
Expand Down Expand Up @@ -76,7 +71,7 @@ async fn returning_response_from_python_middleware() -> PyResult<()> {
let locals = Python::with_gil(|py| {
Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?))
})?;
let py_handler = Python::with_gil(|py| {
let handler = Python::with_gil(|py| {
let globals = [("Response", py.get_type::<PyResponse>())].into_py_dict(py);
let locals = PyDict::new(py);
py.run(
Expand All @@ -87,13 +82,9 @@ def middleware(request, next):
Some(globals),
Some(locals),
)?;
Ok::<_, PyErr>(locals.get_item("middleware").unwrap().into())
let handler = locals.get_item("middleware").unwrap().into();
Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?)
})?;
let handler = PyMiddlewareHandler {
func: py_handler,
name: "middleware".to_string(),
is_coroutine: false,
};

let layer = PyMiddlewareLayer::<RestJson1>::new(handler, locals);
let (mut service, _handle) = mock::spawn_with(|svc| {
Expand All @@ -120,7 +111,7 @@ async fn convert_exception_from_middleware_to_protocol_specific_response() -> Py
let locals = Python::with_gil(|py| {
Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?))
})?;
let py_handler = Python::with_gil(|py| {
let handler = Python::with_gil(|py| {
let locals = PyDict::new(py);
py.run(
r#"
Expand All @@ -130,13 +121,9 @@ def middleware(request, next):
None,
Some(locals),
)?;
Ok::<_, PyErr>(locals.get_item("middleware").unwrap().into())
let handler = locals.get_item("middleware").unwrap().into();
Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?)
})?;
let handler = PyMiddlewareHandler {
func: py_handler,
name: "middleware".to_string(),
is_coroutine: false,
};

let layer = PyMiddlewareLayer::<RestJson1>::new(handler, locals);
let (mut service, _handle) = mock::spawn_with(|svc| {
Expand All @@ -163,7 +150,7 @@ async fn uses_status_code_and_message_from_middleware_exception() -> PyResult<()
let locals = Python::with_gil(|py| {
Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?))
})?;
let py_handler = Python::with_gil(|py| {
let handler = Python::with_gil(|py| {
let globals = [(
"MiddlewareException",
py.get_type::<PyMiddlewareException>(),
Expand All @@ -178,13 +165,9 @@ def middleware(request, next):
Some(globals),
Some(locals),
)?;
Ok::<_, PyErr>(locals.get_item("middleware").unwrap().into())
let handler = locals.get_item("middleware").unwrap().into();
Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?)
})?;
let handler = PyMiddlewareHandler {
func: py_handler,
name: "middleware".to_string(),
is_coroutine: false,
};

let layer = PyMiddlewareLayer::<RestJson1>::new(handler, locals);
let (mut service, _handle) = mock::spawn_with(|svc| {
Expand All @@ -211,7 +194,7 @@ async fn nested_middlewares() -> PyResult<()> {
let locals = Python::with_gil(|py| {
Ok::<_, PyErr>(TaskLocals::new(pyo3_asyncio::tokio::get_current_loop(py)?))
})?;
let (first_py_handler, second_py_handler) = Python::with_gil(|py| {
let (first_handler, second_handler) = Python::with_gil(|py| {
let globals = [("Response", py.get_type::<PyResponse>())].into_py_dict(py);
let locals = PyDict::new(py);

Expand All @@ -226,21 +209,12 @@ def second_middleware(request, next):
Some(globals),
Some(locals),
)?;
Ok::<_, PyErr>((
locals.get_item("first_middleware").unwrap().into(),
locals.get_item("second_middleware").unwrap().into(),
))
let first_handler =
PyMiddlewareHandler::new(py, locals.get_item("first_middleware").unwrap().into())?;
let second_handler =
PyMiddlewareHandler::new(py, locals.get_item("second_middleware").unwrap().into())?;
Ok::<_, PyErr>((first_handler, second_handler))
})?;
let first_handler = PyMiddlewareHandler {
func: first_py_handler,
name: "first_middleware".to_string(),
is_coroutine: true,
};
let second_handler = PyMiddlewareHandler {
func: second_py_handler,
name: "second_middleware".to_string(),
is_coroutine: false,
};

let layer = Stack::new(
PyMiddlewareLayer::<RestJson1>::new(second_handler, locals.clone()),
Expand Down
25 changes: 5 additions & 20 deletions rust-runtime/aws-smithy-http-server-python/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use signal_hook::{consts::*, iterator::Signals};
use tokio::runtime;
use tower::{util::BoxCloneService, ServiceBuilder};

use crate::PySocket;
use crate::{util::func_metadata, PySocket};

/// A Python handler function representation.
///
Expand All @@ -28,6 +28,7 @@ use crate::PySocket;
#[derive(Debug, Clone)]
pub struct PyHandler {
pub func: PyObject,
// Number of args is needed to decide whether handler accepts context as an argument
pub args: usize,
pub is_coroutine: bool,
}
Expand Down Expand Up @@ -258,34 +259,18 @@ event_loop.add_signal_handler(signal.SIGINT,
Ok(())
}

// Check if a Python function is a coroutine. Since the function has not run yet,
// we cannot use `asyncio.iscoroutine()`, we need to use `inspect.iscoroutinefunction()`.
fn is_coroutine(&self, py: Python, func: &PyObject) -> PyResult<bool> {
let inspect = py.import("inspect")?;
// NOTE: that `asyncio.iscoroutine()` doesn't work here.
inspect
.call_method1("iscoroutinefunction", (func,))?
.extract::<bool>()
}

/// Register a Python function to be executed inside the Smithy Rust handler.
///
/// There are some information needed to execute the Python code from a Rust handler,
/// such has if the registered function needs to be awaited (if it is a coroutine) and
/// the number of arguments available, which tells us if the handler wants the state to be
/// passed or not.
fn register_operation(&mut self, py: Python, name: &str, func: PyObject) -> PyResult<()> {
let is_coroutine = self.is_coroutine(py, &func)?;
// Find number of expected methods (a Python implementation could not accept the context).
let inspect = py.import("inspect")?;
let func_args = inspect
.call_method1("getargs", (func.getattr(py, "__code__")?,))?
.getattr("args")?
.extract::<Vec<String>>()?;
let func_metadata = func_metadata(py, &func)?;
let handler = PyHandler {
func,
is_coroutine,
args: func_args.len(),
is_coroutine: func_metadata.is_coroutine,
args: func_metadata.num_args,
};
tracing::info!(
"Registering handler function `{name}`, coroutine: {}, arguments: {}",
Expand Down
80 changes: 80 additions & 0 deletions rust-runtime/aws-smithy-http-server-python/src/util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use pyo3::prelude::*;

// Captures some information about a Python function.
#[derive(Debug, PartialEq)]
pub struct FuncMetadata {
pub name: String,
pub is_coroutine: bool,
pub num_args: usize,
}

// Returns `FuncMetadata` for given `func`.
pub fn func_metadata(py: Python, func: &PyObject) -> PyResult<FuncMetadata> {
let name = func.getattr(py, "__name__")?.extract::<String>(py)?;
let is_coroutine = is_coroutine(py, func)?;
let inspect = py.import("inspect")?;
let args = inspect
.call_method1("getargs", (func.getattr(py, "__code__")?,))?
.getattr("args")?
.extract::<Vec<String>>()?;
Ok(FuncMetadata {
name,
is_coroutine,
num_args: args.len(),
})
}

// Check if a Python function is a coroutine. Since the function has not run yet,
// we cannot use `asyncio.iscoroutine()`, we need to use `inspect.iscoroutinefunction()`.
fn is_coroutine(py: Python, func: &PyObject) -> PyResult<bool> {
let inspect = py.import("inspect")?;
// NOTE: that `asyncio.iscoroutine()` doesn't work here.
inspect
.call_method1("iscoroutinefunction", (func,))?
.extract::<bool>()
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn function_metadata() -> PyResult<()> {
Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
def regular_func(first_arg, second_arg):
pass
async def async_func():
pass
"#,
"",
"",
)?;

let regular_func = module.getattr("regular_func")?.into_py(py);
assert_eq!(
FuncMetadata {
name: "regular_func".to_string(),
is_coroutine: false,
num_args: 2,
},
func_metadata(py, &regular_func)?
);

let async_func = module.getattr("async_func")?.into_py(py);
assert_eq!(
FuncMetadata {
name: "async_func".to_string(),
is_coroutine: true,
num_args: 0,
},
func_metadata(py, &async_func)?
);

Ok(())
})
}
}

0 comments on commit 8417633

Please sign in to comment.