@@ -11,21 +11,30 @@ use std::{
1111use pyo3_macros:: { pyclass, pymethods} ;
1212
1313use crate :: {
14- coroutine:: { cancel:: ThrowCallback , waker:: AsyncioWaker } ,
15- exceptions:: { PyAttributeError , PyRuntimeError , PyStopIteration } ,
14+ coroutine:: waker:: CoroutineWaker ,
15+ exceptions:: { PyAttributeError , PyGeneratorExit , PyRuntimeError , PyStopIteration } ,
16+ marker:: Ungil ,
1617 panic:: PanicException ,
17- types:: { string:: PyStringMethods , PyIterator , PyString } ,
18- Bound , IntoPy , Py , PyAny , PyErr , PyObject , PyResult , Python ,
18+ types:: { string:: PyStringMethods , PyString } ,
19+ IntoPy , Py , PyErr , PyObject , PyResult , Python ,
1920} ;
2021
21- pub ( crate ) mod cancel;
22+ mod asyncio;
23+ mod awaitable;
24+ mod cancel;
2225mod waker;
2326
24- use crate :: marker :: Ungil ;
25- pub use cancel:: CancelHandle ;
27+ pub use awaitable :: await_in_coroutine ;
28+ pub use cancel:: { CancelHandle , ThrowCallback } ;
2629
2730const COROUTINE_REUSED_ERROR : & str = "cannot reuse already awaited coroutine" ;
2831
32+ pub ( crate ) enum CoroOp {
33+ Send ( PyObject ) ,
34+ Throw ( PyObject ) ,
35+ Close ,
36+ }
37+
2938trait CoroutineFuture : Send {
3039 fn poll ( self : Pin < & mut Self > , py : Python < ' _ > , waker : & Waker ) -> Poll < PyResult < PyObject > > ;
3140}
@@ -69,7 +78,7 @@ pub struct Coroutine {
6978 qualname_prefix : Option < & ' static str > ,
7079 throw_callback : Option < ThrowCallback > ,
7180 future : Option < Pin < Box < dyn CoroutineFuture > > > ,
72- waker : Option < Arc < AsyncioWaker > > ,
81+ waker : Option < Arc < CoroutineWaker > > ,
7382}
7483
7584impl Coroutine {
@@ -104,58 +113,55 @@ impl Coroutine {
104113 }
105114 }
106115
107- fn poll ( & mut self , py : Python < ' _ > , throw : Option < PyObject > ) -> PyResult < PyObject > {
116+ fn poll_inner ( & mut self , py : Python < ' _ > , mut op : CoroOp ) -> PyResult < PyObject > {
108117 // raise if the coroutine has already been run to completion
109118 let future_rs = match self . future {
110119 Some ( ref mut fut) => fut,
111120 None => return Err ( PyRuntimeError :: new_err ( COROUTINE_REUSED_ERROR ) ) ,
112121 } ;
113- // reraise thrown exception it
114- match ( throw, & self . throw_callback ) {
115- ( Some ( exc) , Some ( cb) ) => cb. throw ( exc) ,
116- ( Some ( exc) , None ) => {
117- self . close ( ) ;
118- return Err ( PyErr :: from_value_bound ( exc. into_bound ( py) ) ) ;
119- }
120- ( None , _) => { }
122+ // if the future is not pending on a Python awaitable,
123+ // execute throw callback or complete on close
124+ if !matches ! ( self . waker, Some ( ref w) if w. is_delegated( py) ) {
125+ match op {
126+ send @ CoroOp :: Send ( _) => op = send,
127+ CoroOp :: Throw ( exc) => match & self . throw_callback {
128+ Some ( cb) => {
129+ cb. throw ( exc. clone_ref ( py) ) ;
130+ op = CoroOp :: Send ( py. None ( ) ) ;
131+ }
132+ None => return Err ( PyErr :: from_value_bound ( exc. into_bound ( py) ) ) ,
133+ } ,
134+ CoroOp :: Close => return Err ( PyGeneratorExit :: new_err ( py. None ( ) ) ) ,
135+ } ;
121136 }
122137 // create a new waker, or try to reset it in place
123138 if let Some ( waker) = self . waker . as_mut ( ) . and_then ( Arc :: get_mut) {
124- waker. reset ( ) ;
139+ waker. reset ( op ) ;
125140 } else {
126- self . waker = Some ( Arc :: new ( AsyncioWaker :: new ( ) ) ) ;
141+ self . waker = Some ( Arc :: new ( CoroutineWaker :: new ( op ) ) ) ;
127142 }
128- // poll the future and forward its results if ready
143+ // poll the future and forward its results if ready; otherwise, yield from waker
129144 // polling is UnwindSafe because the future is dropped in case of panic
130145 let waker = Waker :: from ( self . waker . clone ( ) . unwrap ( ) ) ;
131146 let poll = || future_rs. as_mut ( ) . poll ( py, & waker) ;
132147 match panic:: catch_unwind ( panic:: AssertUnwindSafe ( poll) ) {
133- Ok ( Poll :: Ready ( res) ) => {
134- self . close ( ) ;
135- return Err ( PyStopIteration :: new_err ( res?) ) ;
136- }
137- Err ( err) => {
138- self . close ( ) ;
139- return Err ( PanicException :: from_panic_payload ( err) ) ;
140- }
141- _ => { }
148+ Err ( err) => Err ( PanicException :: from_panic_payload ( err) ) ,
149+ Ok ( Poll :: Ready ( res) ) => Err ( PyStopIteration :: new_err ( res?) ) ,
150+ Ok ( Poll :: Pending ) => match self . waker . as_ref ( ) . unwrap ( ) . yield_ ( py) {
151+ Ok ( to_yield) => Ok ( to_yield) ,
152+ Err ( err) => Err ( err) ,
153+ } ,
142154 }
143- // otherwise, initialize the waker `asyncio.Future`
144- if let Some ( future) = self . waker . as_ref ( ) . unwrap ( ) . initialize_future ( py) ? {
145- // `asyncio.Future` must be awaited; fortunately, it implements `__iter__ = __await__`
146- // and will yield itself if its result has not been set in polling above
147- if let Some ( future) = PyIterator :: from_bound_object ( & future. as_borrowed ( ) )
148- . unwrap ( )
149- . next ( )
150- {
151- // future has not been leaked into Python for now, and Rust code can only call
152- // `set_result(None)` in `Wake` implementation, so it's safe to unwrap
153- return Ok ( future. unwrap ( ) . into ( ) ) ;
154- }
155+ }
156+
157+ fn poll ( & mut self , py : Python < ' _ > , op : CoroOp ) -> PyResult < PyObject > {
158+ let result = self . poll_inner ( py, op) ;
159+ if result. is_err ( ) {
160+ // the Rust future is dropped, and the field set to `None`
161+ // to indicate the coroutine has been run to completion
162+ drop ( self . future . take ( ) ) ;
155163 }
156- // if waker has been waken during future polling, this is roughly equivalent to
157- // `await asyncio.sleep(0)`, so just yield `None`.
158- Ok ( py. None ( ) . into_py ( py) )
164+ result
159165 }
160166}
161167
@@ -180,25 +186,27 @@ impl Coroutine {
180186 }
181187 }
182188
183- fn send ( & mut self , py : Python < ' _ > , _value : & Bound < ' _ , PyAny > ) -> PyResult < PyObject > {
184- self . poll ( py, None )
189+ fn send ( & mut self , py : Python < ' _ > , value : PyObject ) -> PyResult < PyObject > {
190+ self . poll ( py, CoroOp :: Send ( value ) )
185191 }
186192
187193 fn throw ( & mut self , py : Python < ' _ > , exc : PyObject ) -> PyResult < PyObject > {
188- self . poll ( py, Some ( exc) )
194+ self . poll ( py, CoroOp :: Throw ( exc) )
189195 }
190196
191- fn close ( & mut self ) {
192- // the Rust future is dropped, and the field set to `None`
193- // to indicate the coroutine has been run to completion
194- drop ( self . future . take ( ) ) ;
197+ fn close ( & mut self , py : Python < ' _ > ) -> PyResult < ( ) > {
198+ match self . poll ( py, CoroOp :: Close ) {
199+ Ok ( _) => Ok ( ( ) ) ,
200+ Err ( err) if err. is_instance_of :: < PyGeneratorExit > ( py) => Ok ( ( ) ) ,
201+ Err ( err) => Err ( err) ,
202+ }
195203 }
196204
197205 fn __await__ ( self_ : Py < Self > ) -> Py < Self > {
198206 self_
199207 }
200208
201209 fn __next__ ( & mut self , py : Python < ' _ > ) -> PyResult < PyObject > {
202- self . poll ( py, None )
210+ self . poll ( py, CoroOp :: Send ( py . None ( ) ) )
203211 }
204212}
0 commit comments