Skip to content

Commit 1bf4f4c

Browse files
committed
feat: add coroutine::await_in_coroutine to await awaitables in coroutine context
1 parent f51c087 commit 1bf4f4c

File tree

13 files changed

+730
-217
lines changed

13 files changed

+730
-217
lines changed

guide/src/SUMMARY.md

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,36 @@
66

77
- [Getting started](getting-started.md)
88
- [Using Rust from Python](rust-from-python.md)
9-
- [Python modules](module.md)
10-
- [Python functions](function.md)
11-
- [Function signatures](function/signature.md)
12-
- [Error handling](function/error-handling.md)
13-
- [Python classes](class.md)
14-
- [Class customizations](class/protocols.md)
15-
- [Basic object customization](class/object.md)
16-
- [Emulating numeric types](class/numeric.md)
17-
- [Emulating callable objects](class/call.md)
9+
- [Python modules](module.md)
10+
- [Python functions](function.md)
11+
- [Function signatures](function/signature.md)
12+
- [Error handling](function/error-handling.md)
13+
- [Python classes](class.md)
14+
- [Class customizations](class/protocols.md)
15+
- [Basic object customization](class/object.md)
16+
- [Emulating numeric types](class/numeric.md)
17+
- [Emulating callable objects](class/call.md)
1818
- [Calling Python from Rust](python-from-rust.md)
19-
- [Python object types](types.md)
20-
- [Python exceptions](exception.md)
21-
- [Calling Python functions](python-from-rust/function-calls.md)
22-
- [Executing existing Python code](python-from-rust/calling-existing-code.md)
19+
- [Python object types](types.md)
20+
- [Python exceptions](exception.md)
21+
- [Calling Python functions](python-from-rust/function-calls.md)
22+
- [Executing existing Python code](python-from-rust/calling-existing-code.md)
2323
- [Type conversions](conversions.md)
24-
- [Mapping of Rust types to Python types](conversions/tables.md)
25-
- [Conversion traits](conversions/traits.md)
24+
- [Mapping of Rust types to Python types](conversions/tables.md)
25+
- [Conversion traits](conversions/traits.md)
2626
- [Using `async` and `await`](async-await.md)
27+
- [Awaiting Python awaitables](async-await/awaiting_python_awaitables)
2728
- [Parallelism](parallelism.md)
2829
- [Debugging](debugging.md)
2930
- [Features reference](features.md)
3031
- [Memory management](memory.md)
3132
- [Performance](performance.md)
3233
- [Advanced topics](advanced.md)
3334
- [Building and distribution](building-and-distribution.md)
34-
- [Supporting multiple Python versions](building-and-distribution/multiple-python-versions.md)
35+
- [Supporting multiple Python versions](building-and-distribution/multiple-python-versions.md)
3536
- [Useful crates](ecosystem.md)
36-
- [Logging](ecosystem/logging.md)
37-
- [Using `async` and `await`](ecosystem/async-await.md)
37+
- [Logging](ecosystem/logging.md)
38+
- [Using `async` and `await`](ecosystem/async-await.md)
3839
- [FAQ and troubleshooting](faq.md)
3940

4041
---
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Awaiting Python awaitables
2+
3+
Python awaitable can be awaited on Rust side
4+
using [`await_in_coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/function.await_in_coroutine).
5+
6+
```rust
7+
# # ![allow(dead_code)]
8+
# #[cfg(feature = "experimental-async")] {
9+
use pyo3::{prelude::*, coroutine::await_in_coroutine};
10+
11+
#[pyfunction]
12+
async fn wrap_awaitable(awaitable: PyObject) -> PyResult<PyObject> {
13+
Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil)))?.await
14+
}
15+
# }
16+
```
17+
18+
Behind the scene, `await_in_coroutine` calls the `__await__` method of the Python awaitable (or `__iter__` for
19+
generator-based coroutine).
20+
21+
## Restrictions
22+
23+
As the name suggests, `await_in_coroutine` resulting future can only be awaited in coroutine context. Otherwise, it
24+
panics.
25+
26+
```rust
27+
# # ![allow(dead_code)]
28+
# #[cfg(feature = "experimental-async")] {
29+
use pyo3::{prelude::*, coroutine::await_in_coroutine};
30+
31+
#[pyfunction]
32+
fn block_on(awaitable: PyObject) -> PyResult<PyObject> {
33+
let future = Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil)))?;
34+
futures::executor::block_on(future) // ERROR: PyFuture must be awaited in coroutine context
35+
}
36+
# }
37+
```
38+
39+
The future must also be the only one to be awaited at a time; it means that it's forbidden to await it in a `select!`.
40+
Otherwise, it panics.
41+
42+
```rust
43+
# # ![allow(dead_code)]
44+
# #[cfg(feature = "experimental-async")] {
45+
use futures::FutureExt;
46+
use pyo3::{prelude::*, coroutine::await_in_coroutine};
47+
48+
#[pyfunction]
49+
async fn select(awaitable: PyObject) -> PyResult<PyObject> {
50+
let future = Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil)))?;
51+
futures::select_biased! {
52+
_ = std::future::pending::<()>().fuse() => unreachable!(),
53+
res = future.fuse() => res, // ERROR: Python awaitable mixed with Rust future
54+
}
55+
}
56+
# }
57+
```
58+
59+
These restrictions exist because awaiting a `await_in_coroutine` future strongly binds it to the
60+
enclosing coroutine. The coroutine will then delegate its `send`/`throw`/`close` methods to the
61+
awaited future. If it was awaited in a `select!`, `Coroutine::send` would no able to know if
62+
the value passed would have to be delegated or not.

newsfragments/3611.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `coroutine::await_in_coroutine` to await awaitables in coroutine context

pyo3-ffi/src/abstract_.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,11 @@ extern "C" {
129129
pub fn PyIter_Next(arg1: *mut PyObject) -> *mut PyObject;
130130
#[cfg(all(not(PyPy), Py_3_10))]
131131
#[cfg_attr(PyPy, link_name = "PyPyIter_Send")]
132-
pub fn PyIter_Send(iter: *mut PyObject, arg: *mut PyObject, presult: *mut *mut PyObject);
132+
pub fn PyIter_Send(
133+
iter: *mut PyObject,
134+
arg: *mut PyObject,
135+
presult: *mut *mut PyObject,
136+
) -> c_int;
133137

134138
#[cfg_attr(PyPy, link_name = "PyPyNumber_Check")]
135139
pub fn PyNumber_Check(o: *mut PyObject) -> c_int;

src/coroutine.rs

Lines changed: 60 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,30 @@ use std::{
1111
use pyo3_macros::{pyclass, pymethods};
1212

1313
use crate::{
14-
coroutine::{cancel::ThrowCallback, waker::AsyncioWaker},
15-
exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
14+
coroutine::waker::CoroutineWaker,
15+
exceptions::{PyAttributeError, PyGeneratorExit, PyRuntimeError, PyStopIteration},
16+
marker::Ungil,
1617
panic::PanicException,
17-
types::{string::PyStringMethods, PyIterator, PyString},
18-
Bound, IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
18+
types::{string::PyStringMethods, PyString},
19+
IntoPy, Py, PyErr, PyObject, PyResult, Python,
1920
};
2021

21-
pub(crate) mod cancel;
22+
mod asyncio;
23+
mod awaitable;
24+
mod cancel;
2225
mod waker;
2326

24-
use crate::marker::Ungil;
25-
pub use cancel::CancelHandle;
27+
pub use awaitable::await_in_coroutine;
28+
pub use cancel::{CancelHandle, ThrowCallback};
2629

2730
const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";
2831

32+
pub(crate) enum CoroOp {
33+
Send(PyObject),
34+
Throw(PyObject),
35+
Close,
36+
}
37+
2938
trait CoroutineFuture: Send {
3039
fn poll(self: Pin<&mut Self>, py: Python<'_>, waker: &Waker) -> Poll<PyResult<PyObject>>;
3140
}
@@ -69,7 +78,7 @@ pub struct Coroutine {
6978
qualname_prefix: Option<&'static str>,
7079
throw_callback: Option<ThrowCallback>,
7180
future: Option<Pin<Box<dyn CoroutineFuture>>>,
72-
waker: Option<Arc<AsyncioWaker>>,
81+
waker: Option<Arc<CoroutineWaker>>,
7382
}
7483

7584
impl Coroutine {
@@ -104,58 +113,55 @@ impl Coroutine {
104113
}
105114
}
106115

107-
fn poll(&mut self, py: Python<'_>, throw: Option<PyObject>) -> PyResult<PyObject> {
116+
fn poll_inner(&mut self, py: Python<'_>, mut op: CoroOp) -> PyResult<PyObject> {
108117
// raise if the coroutine has already been run to completion
109118
let future_rs = match self.future {
110119
Some(ref mut fut) => fut,
111120
None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)),
112121
};
113-
// reraise thrown exception it
114-
match (throw, &self.throw_callback) {
115-
(Some(exc), Some(cb)) => cb.throw(exc),
116-
(Some(exc), None) => {
117-
self.close();
118-
return Err(PyErr::from_value_bound(exc.into_bound(py)));
119-
}
120-
(None, _) => {}
122+
// if the future is not pending on a Python awaitable,
123+
// execute throw callback or complete on close
124+
if !matches!(self.waker, Some(ref w) if w.is_delegated(py)) {
125+
match op {
126+
send @ CoroOp::Send(_) => op = send,
127+
CoroOp::Throw(exc) => match &self.throw_callback {
128+
Some(cb) => {
129+
cb.throw(exc.clone_ref(py));
130+
op = CoroOp::Send(py.None());
131+
}
132+
None => return Err(PyErr::from_value_bound(exc.into_bound(py))),
133+
},
134+
CoroOp::Close => return Err(PyGeneratorExit::new_err(py.None())),
135+
};
121136
}
122137
// create a new waker, or try to reset it in place
123138
if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
124-
waker.reset();
139+
waker.reset(op);
125140
} else {
126-
self.waker = Some(Arc::new(AsyncioWaker::new()));
141+
self.waker = Some(Arc::new(CoroutineWaker::new(op)));
127142
}
128-
// poll the future and forward its results if ready
143+
// poll the future and forward its results if ready; otherwise, yield from waker
129144
// polling is UnwindSafe because the future is dropped in case of panic
130145
let waker = Waker::from(self.waker.clone().unwrap());
131146
let poll = || future_rs.as_mut().poll(py, &waker);
132147
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
133-
Ok(Poll::Ready(res)) => {
134-
self.close();
135-
return Err(PyStopIteration::new_err(res?));
136-
}
137-
Err(err) => {
138-
self.close();
139-
return Err(PanicException::from_panic_payload(err));
140-
}
141-
_ => {}
148+
Err(err) => Err(PanicException::from_panic_payload(err)),
149+
Ok(Poll::Ready(res)) => Err(PyStopIteration::new_err(res?)),
150+
Ok(Poll::Pending) => match self.waker.as_ref().unwrap().yield_(py) {
151+
Ok(to_yield) => Ok(to_yield),
152+
Err(err) => Err(err),
153+
},
142154
}
143-
// otherwise, initialize the waker `asyncio.Future`
144-
if let Some(future) = self.waker.as_ref().unwrap().initialize_future(py)? {
145-
// `asyncio.Future` must be awaited; fortunately, it implements `__iter__ = __await__`
146-
// and will yield itself if its result has not been set in polling above
147-
if let Some(future) = PyIterator::from_bound_object(&future.as_borrowed())
148-
.unwrap()
149-
.next()
150-
{
151-
// future has not been leaked into Python for now, and Rust code can only call
152-
// `set_result(None)` in `Wake` implementation, so it's safe to unwrap
153-
return Ok(future.unwrap().into());
154-
}
155+
}
156+
157+
fn poll(&mut self, py: Python<'_>, op: CoroOp) -> PyResult<PyObject> {
158+
let result = self.poll_inner(py, op);
159+
if result.is_err() {
160+
// the Rust future is dropped, and the field set to `None`
161+
// to indicate the coroutine has been run to completion
162+
drop(self.future.take());
155163
}
156-
// if waker has been waken during future polling, this is roughly equivalent to
157-
// `await asyncio.sleep(0)`, so just yield `None`.
158-
Ok(py.None().into_py(py))
164+
result
159165
}
160166
}
161167

@@ -180,25 +186,27 @@ impl Coroutine {
180186
}
181187
}
182188

183-
fn send(&mut self, py: Python<'_>, _value: &Bound<'_, PyAny>) -> PyResult<PyObject> {
184-
self.poll(py, None)
189+
fn send(&mut self, py: Python<'_>, value: PyObject) -> PyResult<PyObject> {
190+
self.poll(py, CoroOp::Send(value))
185191
}
186192

187193
fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult<PyObject> {
188-
self.poll(py, Some(exc))
194+
self.poll(py, CoroOp::Throw(exc))
189195
}
190196

191-
fn close(&mut self) {
192-
// the Rust future is dropped, and the field set to `None`
193-
// to indicate the coroutine has been run to completion
194-
drop(self.future.take());
197+
fn close(&mut self, py: Python<'_>) -> PyResult<()> {
198+
match self.poll(py, CoroOp::Close) {
199+
Ok(_) => Ok(()),
200+
Err(err) if err.is_instance_of::<PyGeneratorExit>(py) => Ok(()),
201+
Err(err) => Err(err),
202+
}
195203
}
196204

197205
fn __await__(self_: Py<Self>) -> Py<Self> {
198206
self_
199207
}
200208

201209
fn __next__(&mut self, py: Python<'_>) -> PyResult<PyObject> {
202-
self.poll(py, None)
210+
self.poll(py, CoroOp::Send(py.None()))
203211
}
204212
}

0 commit comments

Comments
 (0)