Skip to content

Commit

Permalink
Make Lambda related types feature gated
Browse files Browse the repository at this point in the history
  • Loading branch information
unexge committed Dec 9, 2022
1 parent e2c3b31 commit 152b8e2
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 115 deletions.
93 changes: 17 additions & 76 deletions rust-runtime/aws-smithy-http-server-python/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@

//! Python context definition.
use std::collections::HashSet;

use pyo3::{types::PyDict, PyObject, PyResult, Python, ToPyObject};

use crate::{rich_py_err, util::is_optional_of};

use super::lambda::PyLambdaContext;
use http::Extensions;
use pyo3::{PyObject, PyResult, Python, ToPyObject};

#[cfg(feature = "aws-lambda")]
mod lambda;
pub mod layer;
#[cfg(test)]
mod testing;
Expand All @@ -38,38 +35,23 @@ pub struct PyContext {
// We could introduce a registry to keep track of every injectable type but I'm not sure that is the best way to do it,
// so until we found a good way to achive that, I didn't want to introduce any abstraction here and
// keep it simple because we only have one field that is injectable.
lambda_ctx_fields: HashSet<String>,
#[cfg(feature = "aws-lambda")]
lambda_ctx: lambda::PyContextLambda,
}

impl PyContext {
pub fn new(inner: PyObject) -> PyResult<Self> {
let lambda_ctx_fields = Python::with_gil(|py| get_lambda_ctx_fields(py, &inner))?;
Ok(Self {
#[cfg(feature = "aws-lambda")]
lambda_ctx: lambda::PyContextLambda::new(inner.clone())?,
inner,
lambda_ctx_fields,
})
}

/// Returns true if custom context class provided by the user injects [PyLambdaContext].
pub fn has_lambda_context_fields(&self) -> bool {
!self.lambda_ctx_fields.is_empty()
}

/// Sets given `lambda_ctx` to user provided context class.
pub fn set_lambda_context(&self, lambda_ctx: Option<PyLambdaContext>) {
if !self.has_lambda_context_fields() {
// Return early without acquiring GIL
return;
}

let inner = &self.inner;
Python::with_gil(|py| {
for field in self.lambda_ctx_fields.iter() {
if let Err(err) = inner.setattr(py, field.as_str(), lambda_ctx.clone()) {
tracing::warn!(field = ?field, error = ?rich_py_err(err), "could not inject `LambdaContext` to context")
}
}
});
pub fn populate_from_extensions(&self, _ext: &Extensions) {
#[cfg(feature = "aws-lambda")]
self.lambda_ctx
.populate_from_extensions(self.inner.clone(), _ext);
}
}

Expand All @@ -79,47 +61,22 @@ impl ToPyObject for PyContext {
}
}

// Inspects the given `PyObject` to detect fields that type-hinted `PyLambdaContext`.
fn get_lambda_ctx_fields(py: Python, ctx: &PyObject) -> PyResult<HashSet<String>> {
let typing = py.import("typing")?;
let hints = match typing
.call_method1("get_type_hints", (ctx,))
.and_then(|res| res.extract::<&PyDict>())
{
Ok(hints) => hints,
Err(_) => {
// `get_type_hints` could fail if `ctx` is `None`, which is the default value
// for the context if user does not provide a custom class.
// In that case, this is not really an error and we should just return an empty set.
return Ok(HashSet::new());
}
};

let mut fields = HashSet::new();
for (key, value) in hints {
if is_optional_of::<PyLambdaContext>(py, value)? {
fields.insert(key.to_string());
}
}
Ok(fields)
}

#[cfg(test)]
mod tests {
use http::Extensions;
use pyo3::{prelude::*, py_run};

use crate::context::testing::{get_context, py_lambda_ctx};
use super::testing::get_context;

#[test]
fn py_context_with_lambda_context() -> PyResult<()> {
fn py_context() -> PyResult<()> {
pyo3::prepare_freethreaded_python();

let ctx = get_context(
r#"
class Context:
foo: int = 0
bar: str = 'qux'
lambda_ctx: typing.Optional[LambdaContext]
ctx = Context()
ctx.foo = 42
Expand All @@ -132,35 +89,20 @@ ctx.foo = 42
r#"
assert ctx.foo == 42
assert ctx.bar == 'qux'
assert not hasattr(ctx, 'lambda_ctx')
"#
);
});

ctx.set_lambda_context(Some(py_lambda_ctx("my-req-id", "123")));
Python::with_gil(|py| {
py_run!(
py,
ctx,
r#"
assert ctx.lambda_ctx.request_id == "my-req-id"
assert ctx.lambda_ctx.deadline == 123
# Make some modifications
ctx.foo += 1
ctx.bar = 'baz'
"#
);
});

// Assume we are getting a new request but that one doesn't have a `LambdaContext`,
// in that case we should make fields `None` and shouldn't leak the previous `LambdaContext`.
ctx.set_lambda_context(None);
ctx.populate_from_extensions(&Extensions::new());

Python::with_gil(|py| {
py_run!(
py,
ctx,
r#"
assert ctx.lambda_ctx is None
# Make sure we are preserving any modifications
assert ctx.foo == 43
assert ctx.bar == 'baz'
Expand All @@ -179,7 +121,6 @@ assert ctx.bar == 'baz'
pyo3::prepare_freethreaded_python();

let ctx = get_context("ctx = None");
ctx.set_lambda_context(Some(py_lambda_ctx("my-req-id", "123")));
Python::with_gil(|py| {
py_run!(py, ctx, "assert ctx is None");
});
Expand Down
167 changes: 167 additions & 0 deletions rust-runtime/aws-smithy-http-server-python/src/context/lambda.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

//! Support for injecting [PyLambdaContext] to [super::PyContext].
use std::collections::HashSet;

use http::Extensions;
use lambda_http::Context as LambdaContext;
use pyo3::{types::PyDict, PyObject, PyResult, Python};

use crate::{lambda::PyLambdaContext, rich_py_err, util::is_optional_of};

#[derive(Clone)]
pub struct PyContextLambda {
fields: HashSet<String>,
}

impl PyContextLambda {
pub fn new(ctx: PyObject) -> PyResult<Self> {
let fields = Python::with_gil(|py| get_lambda_ctx_fields(py, &ctx))?;
Ok(Self { fields })
}

pub fn populate_from_extensions(&self, ctx: PyObject, ext: &Extensions) {
if self.fields.is_empty() {
// Return early without acquiring GIL
return;
}

let lambda_ctx = ext
.get::<LambdaContext>()
.cloned()
.map(PyLambdaContext::new);

Python::with_gil(|py| {
for field in self.fields.iter() {
if let Err(err) = ctx.setattr(py, field.as_str(), lambda_ctx.clone()) {
tracing::warn!(field = ?field, error = ?rich_py_err(err), "could not inject `LambdaContext` to context")
}
}
});
}
}

// Inspects the given `PyObject` to detect fields that type-hinted `PyLambdaContext`.
fn get_lambda_ctx_fields(py: Python, ctx: &PyObject) -> PyResult<HashSet<String>> {
let typing = py.import("typing")?;
let hints = match typing
.call_method1("get_type_hints", (ctx,))
.and_then(|res| res.extract::<&PyDict>())
{
Ok(hints) => hints,
Err(_) => {
// `get_type_hints` could fail if `ctx` is `None`, which is the default value
// for the context if user does not provide a custom class.
// In that case, this is not really an error and we should just return an empty set.
return Ok(HashSet::new());
}
};

let mut fields = HashSet::new();
for (key, value) in hints {
if is_optional_of::<PyLambdaContext>(py, value)? {
fields.insert(key.to_string());
}
}
Ok(fields)
}

#[cfg(test)]
mod tests {
use http::Extensions;
use lambda_http::Context as LambdaContext;
use pyo3::{prelude::*, py_run};

use crate::context::testing::{get_context, lambda_ctx};

#[test]
fn py_context_with_lambda_context() -> PyResult<()> {
pyo3::prepare_freethreaded_python();

let ctx = get_context(
r#"
class Context:
foo: int = 0
bar: str = 'qux'
lambda_ctx: typing.Optional[LambdaContext]
ctx = Context()
ctx.foo = 42
"#,
);
Python::with_gil(|py| {
py_run!(
py,
ctx,
r#"
assert ctx.foo == 42
assert ctx.bar == 'qux'
assert not hasattr(ctx, 'lambda_ctx')
"#
);
});

ctx.populate_from_extensions(&extensions_with_lambda_ctx(lambda_ctx("my-req-id", "123")));
Python::with_gil(|py| {
py_run!(
py,
ctx,
r#"
assert ctx.lambda_ctx.request_id == "my-req-id"
assert ctx.lambda_ctx.deadline == 123
# Make some modifications
ctx.foo += 1
ctx.bar = 'baz'
"#
);
});

// Assume we are getting a new request but that one doesn't have a `LambdaContext`,
// in that case we should make fields `None` and shouldn't leak the previous `LambdaContext`.
ctx.populate_from_extensions(&empty_extensions());
Python::with_gil(|py| {
py_run!(
py,
ctx,
r#"
assert ctx.lambda_ctx is None
# Make sure we are preserving any modifications
assert ctx.foo == 43
assert ctx.bar == 'baz'
"#
);
});

Ok(())
}

#[test]
fn works_with_none() -> PyResult<()> {
// Users can set context to `None` by explicity or implicitly by not providing a custom context class,
// it shouldn't be fail in that case.

pyo3::prepare_freethreaded_python();

let ctx = get_context("ctx = None");
ctx.populate_from_extensions(&extensions_with_lambda_ctx(lambda_ctx("my-req-id", "123")));
Python::with_gil(|py| {
py_run!(py, ctx, "assert ctx is None");
});

Ok(())
}

fn extensions_with_lambda_ctx(ctx: LambdaContext) -> Extensions {
let mut exts = empty_extensions();
exts.insert(ctx);
exts
}

fn empty_extensions() -> Extensions {
Extensions::new()
}
}
Loading

0 comments on commit 152b8e2

Please sign in to comment.