@@ -11,21 +11,29 @@ use std::{
1111use pyo3_macros:: { pyclass, pymethods} ;
1212
1313use crate :: {
14- coroutine:: { cancel:: ThrowCallback , waker:: AsyncioWaker } ,
14+ coroutine:: { cancel:: ThrowCallback , waker:: CoroutineWaker } ,
1515 exceptions:: { PyAttributeError , PyRuntimeError , PyStopIteration } ,
1616 panic:: PanicException ,
17- types:: { string:: PyStringMethods , PyIterator , PyString } ,
18- Bound , IntoPy , Py , PyAny , PyErr , PyObject , PyResult , Python ,
17+ types:: { string:: PyStringMethods , PyString } ,
18+ IntoPy , Py , PyErr , PyObject , PyResult , Python ,
1919} ;
2020
21+ mod asyncio;
2122pub ( crate ) mod cancel;
22- mod waker;
23+ pub ( crate ) mod waker;
2324
24- use crate :: marker:: Ungil ;
2525pub use cancel:: CancelHandle ;
2626
27+ use crate :: { exceptions:: PyGeneratorExit , marker:: Ungil } ;
28+
2729const COROUTINE_REUSED_ERROR : & str = "cannot reuse already awaited coroutine" ;
2830
31+ pub ( crate ) enum CoroOp {
32+ Send ( PyObject ) ,
33+ Throw ( PyObject ) ,
34+ Close ,
35+ }
36+
2937trait CoroutineFuture : Send {
3038 fn poll ( self : Pin < & mut Self > , py : Python < ' _ > , waker : & Waker ) -> Poll < PyResult < PyObject > > ;
3139}
@@ -69,7 +77,7 @@ pub struct Coroutine {
6977 qualname_prefix : Option < & ' static str > ,
7078 throw_callback : Option < ThrowCallback > ,
7179 future : Option < Pin < Box < dyn CoroutineFuture > > > ,
72- waker : Option < Arc < AsyncioWaker > > ,
80+ waker : Option < Arc < CoroutineWaker > > ,
7381}
7482
7583impl Coroutine {
@@ -104,58 +112,55 @@ impl Coroutine {
104112 }
105113 }
106114
107- fn poll ( & mut self , py : Python < ' _ > , throw : Option < PyObject > ) -> PyResult < PyObject > {
115+ fn poll_inner ( & mut self , py : Python < ' _ > , mut op : CoroOp ) -> PyResult < PyObject > {
108116 // raise if the coroutine has already been run to completion
109117 let future_rs = match self . future {
110118 Some ( ref mut fut) => fut,
111119 None => return Err ( PyRuntimeError :: new_err ( COROUTINE_REUSED_ERROR ) ) ,
112120 } ;
113- // reraise thrown exception it
114- match ( throw, & self . throw_callback ) {
115- ( Some ( exc) , Some ( cb) ) => cb. throw ( exc) ,
116- ( Some ( exc) , None ) => {
117- self . close ( ) ;
118- return Err ( PyErr :: from_value_bound ( exc. into_bound ( py) ) ) ;
119- }
120- ( None , _) => { }
121+ // if the future is not pending on a Python awaitable,
122+ // execute throw callback or complete on close
123+ if !matches ! ( self . waker, Some ( ref w) if w. is_delegated( py) ) {
124+ match op {
125+ send @ CoroOp :: Send ( _) => op = send,
126+ CoroOp :: Throw ( exc) => match & self . throw_callback {
127+ Some ( cb) => {
128+ cb. throw ( exc. clone_ref ( py) ) ;
129+ op = CoroOp :: Send ( py. None ( ) ) ;
130+ }
131+ None => return Err ( PyErr :: from_value_bound ( exc. into_bound ( py) ) ) ,
132+ } ,
133+ CoroOp :: Close => return Err ( PyGeneratorExit :: new_err ( py. None ( ) ) ) ,
134+ } ;
121135 }
122136 // create a new waker, or try to reset it in place
123137 if let Some ( waker) = self . waker . as_mut ( ) . and_then ( Arc :: get_mut) {
124- waker. reset ( ) ;
138+ waker. reset ( op ) ;
125139 } else {
126- self . waker = Some ( Arc :: new ( AsyncioWaker :: new ( ) ) ) ;
140+ self . waker = Some ( Arc :: new ( CoroutineWaker :: new ( op ) ) ) ;
127141 }
128- // poll the future and forward its results if ready
142+ // poll the future and forward its results if ready; otherwise, yield from waker
129143 // polling is UnwindSafe because the future is dropped in case of panic
130144 let waker = Waker :: from ( self . waker . clone ( ) . unwrap ( ) ) ;
131145 let poll = || future_rs. as_mut ( ) . poll ( py, & waker) ;
132146 match panic:: catch_unwind ( panic:: AssertUnwindSafe ( poll) ) {
133- Ok ( Poll :: Ready ( res) ) => {
134- self . close ( ) ;
135- return Err ( PyStopIteration :: new_err ( res?) ) ;
136- }
137- Err ( err) => {
138- self . close ( ) ;
139- return Err ( PanicException :: from_panic_payload ( err) ) ;
140- }
141- _ => { }
147+ Err ( err) => Err ( PanicException :: from_panic_payload ( err) ) ,
148+ Ok ( Poll :: Ready ( res) ) => Err ( PyStopIteration :: new_err ( res?) ) ,
149+ Ok ( Poll :: Pending ) => match self . waker . as_ref ( ) . unwrap ( ) . yield_ ( py) {
150+ Ok ( to_yield) => Ok ( to_yield) ,
151+ Err ( err) => Err ( err) ,
152+ } ,
142153 }
143- // otherwise, initialize the waker `asyncio.Future`
144- if let Some ( future) = self . waker . as_ref ( ) . unwrap ( ) . initialize_future ( py) ? {
145- // `asyncio.Future` must be awaited; fortunately, it implements `__iter__ = __await__`
146- // and will yield itself if its result has not been set in polling above
147- if let Some ( future) = PyIterator :: from_bound_object ( & future. as_borrowed ( ) )
148- . unwrap ( )
149- . next ( )
150- {
151- // future has not been leaked into Python for now, and Rust code can only call
152- // `set_result(None)` in `Wake` implementation, so it's safe to unwrap
153- return Ok ( future. unwrap ( ) . into ( ) ) ;
154- }
154+ }
155+
156+ fn poll ( & mut self , py : Python < ' _ > , op : CoroOp ) -> PyResult < PyObject > {
157+ let result = self . poll_inner ( py, op) ;
158+ if result. is_err ( ) {
159+ // the Rust future is dropped, and the field set to `None`
160+ // to indicate the coroutine has been run to completion
161+ drop ( self . future . take ( ) ) ;
155162 }
156- // if waker has been waken during future polling, this is roughly equivalent to
157- // `await asyncio.sleep(0)`, so just yield `None`.
158- Ok ( py. None ( ) . into_py ( py) )
163+ result
159164 }
160165}
161166
@@ -180,25 +185,27 @@ impl Coroutine {
180185 }
181186 }
182187
183- fn send ( & mut self , py : Python < ' _ > , _value : & Bound < ' _ , PyAny > ) -> PyResult < PyObject > {
184- self . poll ( py, None )
188+ fn send ( & mut self , py : Python < ' _ > , value : PyObject ) -> PyResult < PyObject > {
189+ self . poll ( py, CoroOp :: Send ( value ) )
185190 }
186191
187192 fn throw ( & mut self , py : Python < ' _ > , exc : PyObject ) -> PyResult < PyObject > {
188- self . poll ( py, Some ( exc) )
193+ self . poll ( py, CoroOp :: Throw ( exc) )
189194 }
190195
191- fn close ( & mut self ) {
192- // the Rust future is dropped, and the field set to `None`
193- // to indicate the coroutine has been run to completion
194- drop ( self . future . take ( ) ) ;
196+ fn close ( & mut self , py : Python < ' _ > ) -> PyResult < ( ) > {
197+ match self . poll ( py, CoroOp :: Close ) {
198+ Ok ( _) => Ok ( ( ) ) ,
199+ Err ( err) if err. is_instance_of :: < PyGeneratorExit > ( py) => Ok ( ( ) ) ,
200+ Err ( err) => Err ( err) ,
201+ }
195202 }
196203
197204 fn __await__ ( self_ : Py < Self > ) -> Py < Self > {
198205 self_
199206 }
200207
201208 fn __next__ ( & mut self , py : Python < ' _ > ) -> PyResult < PyObject > {
202- self . poll ( py, None )
209+ self . poll ( py, CoroOp :: Send ( py . None ( ) ) )
203210 }
204211}
0 commit comments