11//! Python coroutine implementation, used notably when wrapping `async fn` 
22//! with `#[pyfunction]`/`#[pymethods]`. 
3- use  std:: task:: Waker ; 
43use  std:: { 
54    future:: Future , 
65    panic, 
76    pin:: Pin , 
87    sync:: Arc , 
9-     task:: { Context ,  Poll } , 
8+     task:: { Context ,  Poll ,   Waker } , 
109} ; 
1110
1211use  pyo3_macros:: { pyclass,  pymethods} ; 
1312
1413use  crate :: { 
15-     coroutine:: waker:: AsyncioWaker , 
14+     coroutine:: waker:: CoroutineWaker , 
1615    exceptions:: { PyAttributeError ,  PyRuntimeError ,  PyStopIteration } , 
1716    pyclass:: IterNextOutput , 
18-     types:: { PyIterator ,   PyString } , 
19-     IntoPy ,  Py ,  PyAny ,   PyErr ,  PyObject ,  PyResult ,  Python , 
17+     types:: PyString , 
18+     IntoPy ,  Py ,  PyErr ,  PyObject ,  PyResult ,  Python , 
2019} ; 
2120
21+ mod  asyncio; 
2222pub ( crate )  mod  cancel; 
23- mod  waker; 
23+ pub ( crate )   mod  waker; 
2424
2525use  crate :: coroutine:: cancel:: ThrowCallback ; 
2626use  crate :: panic:: PanicException ; 
@@ -36,7 +36,7 @@ pub struct Coroutine {
3636    throw_callback :  Option < ThrowCallback > , 
3737    allow_threads :  bool , 
3838    future :  Option < Pin < Box < dyn  Future < Output  = PyResult < PyObject > >  + Send > > > , 
39-     waker :  Option < Arc < AsyncioWaker > > , 
39+     waker :  Option < Arc < CoroutineWaker > > , 
4040} 
4141
4242impl  Coroutine  { 
@@ -73,33 +73,37 @@ impl Coroutine {
7373        } 
7474    } 
7575
76-     fn  poll ( 
76+     fn  poll_inner ( 
7777        & mut  self , 
7878        py :  Python < ' _ > , 
79-         throw :  Option < PyObject > , 
79+         mut   sent_result :  Option < Result < PyObject ,   PyObject > > , 
8080    )  -> PyResult < IterNextOutput < PyObject ,  PyObject > >  { 
8181        // raise if the coroutine has already been run to completion 
8282        let  future_rs = match  self . future  { 
8383            Some ( ref  mut  fut)  => fut, 
8484            None  => return  Err ( PyRuntimeError :: new_err ( COROUTINE_REUSED_ERROR ) ) , 
8585        } ; 
86-         // reraise thrown exception it 
87-         match  ( throw,  & self . throw_callback )  { 
88-             ( Some ( exc) ,  Some ( cb) )  => cb. throw ( exc. as_ref ( py) ) , 
89-             ( Some ( exc) ,  None )  => { 
90-                 self . close ( ) ; 
91-                 return  Err ( PyErr :: from_value ( exc. as_ref ( py) ) ) ; 
86+         // if the future is not pending on a Python awaitable, 
87+         // execute throw callback or complete on close 
88+         if  !matches ! ( self . waker,  Some ( ref w)  if  w. yielded_from_awaitable( py) )  { 
89+             match  ( sent_result,  & self . throw_callback )  { 
90+                 ( res @ Some ( Ok ( _) ) ,  _)  => sent_result = res, 
91+                 ( Some ( Err ( err) ) ,  Some ( cb) )  => { 
92+                     cb. throw ( err. as_ref ( py) ) ; 
93+                     sent_result = Some ( Ok ( py. None ( ) . into ( ) ) ) ; 
94+                 } 
95+                 ( Some ( Err ( err) ) ,  None )  => return  Err ( PyErr :: from_value ( err. as_ref ( py) ) ) , 
96+                 ( None ,  _)  => return  Ok ( IterNextOutput :: Return ( py. None ( ) . into ( ) ) ) , 
9297            } 
93-             _ => { } 
9498        } 
9599        // create a new waker, or try to reset it in place 
96100        if  let  Some ( waker)  = self . waker . as_mut ( ) . and_then ( Arc :: get_mut)  { 
97-             waker. reset ( ) ; 
101+             waker. reset ( sent_result ) ; 
98102        }  else  { 
99-             self . waker  = Some ( Arc :: new ( AsyncioWaker :: new ( ) ) ) ; 
103+             self . waker  = Some ( Arc :: new ( CoroutineWaker :: new ( sent_result ) ) ) ; 
100104        } 
101105        let  waker = Waker :: from ( self . waker . clone ( ) . unwrap ( ) ) ; 
102-         // poll the Rust future and forward its results if ready 
106+         // poll the Rust future and forward its results if ready; otherwise, yield from waker  
103107        // polling is UnwindSafe because the future is dropped in case of panic 
104108        let  poll = || { 
105109            if  self . allow_threads  { 
@@ -109,29 +113,27 @@ impl Coroutine {
109113            } 
110114        } ; 
111115        match  panic:: catch_unwind ( panic:: AssertUnwindSafe ( poll) )  { 
112-             Ok ( Poll :: Ready ( res) )  => { 
113-                 self . close ( ) ; 
114-                 return  Ok ( IterNextOutput :: Return ( res?) ) ; 
115-             } 
116-             Err ( err)  => { 
117-                 self . close ( ) ; 
118-                 return  Err ( PanicException :: from_panic_payload ( err) ) ; 
119-             } 
120-             _ => { } 
116+             Err ( err)  => Err ( PanicException :: from_panic_payload ( err) ) , 
117+             Ok ( Poll :: Ready ( res) )  => Ok ( IterNextOutput :: Return ( res?) ) , 
118+             Ok ( Poll :: Pending )  => match  self . waker . as_ref ( ) . unwrap ( ) . yield_ ( py)  { 
119+                 Ok ( to_yield)  => Ok ( IterNextOutput :: Yield ( to_yield) ) , 
120+                 Err ( err)  => Err ( err) , 
121+             } , 
121122        } 
122-         // otherwise, initialize the waker `asyncio.Future` 
123-         if  let  Some ( future)  = self . waker . as_ref ( ) . unwrap ( ) . initialize_future ( py) ? { 
124-             // `asyncio.Future` must be awaited; fortunately, it implements `__iter__ = __await__` 
125-             // and will yield itself if its result has not been set in polling above 
126-             if  let  Some ( future)  = PyIterator :: from_object ( future) . unwrap ( ) . next ( )  { 
127-                 // future has not been leaked into Python for now, and Rust code can only call 
128-                 // `set_result(None)` in `Wake` implementation, so it's safe to unwrap 
129-                 return  Ok ( IterNextOutput :: Yield ( future. unwrap ( ) . into ( ) ) ) ; 
130-             } 
123+     } 
124+ 
125+     fn  poll ( 
126+         & mut  self , 
127+         py :  Python < ' _ > , 
128+         sent_result :  Option < Result < PyObject ,  PyObject > > , 
129+     )  -> PyResult < IterNextOutput < PyObject ,  PyObject > >  { 
130+         let  result = self . poll_inner ( py,  sent_result) ; 
131+         if  matches ! ( result,  Ok ( IterNextOutput :: Return ( _) )  | Err ( _) )  { 
132+             // the Rust future is dropped, and the field set to `None` 
133+             // to indicate the coroutine has been run to completion 
134+             drop ( self . future . take ( ) ) ; 
131135        } 
132-         // if waker has been waken during future polling, this is roughly equivalent to 
133-         // `await asyncio.sleep(0)`, so just yield `None`. 
134-         Ok ( IterNextOutput :: Yield ( py. None ( ) . into ( ) ) ) 
136+         result
135137    } 
136138} 
137139
@@ -163,25 +165,24 @@ impl Coroutine {
163165        } 
164166    } 
165167
166-     fn  send ( & mut  self ,  py :  Python < ' _ > ,  _value :   & PyAny )  -> PyResult < PyObject >  { 
167-         iter_result ( self . poll ( py,  None ) ?) 
168+     fn  send ( & mut  self ,  py :  Python < ' _ > ,  value :   PyObject )  -> PyResult < PyObject >  { 
169+         iter_result ( self . poll ( py,  Some ( Ok ( value ) ) ) ?) 
168170    } 
169171
170172    fn  throw ( & mut  self ,  py :  Python < ' _ > ,  exc :  PyObject )  -> PyResult < PyObject >  { 
171-         iter_result ( self . poll ( py,  Some ( exc) ) ?) 
173+         iter_result ( self . poll ( py,  Some ( Err ( exc) ) ) ?) 
172174    } 
173175
174-     fn  close ( & mut  self )  { 
175-         // the Rust future is dropped, and the field set to `None` 
176-         // to indicate the coroutine has been run to completion 
177-         drop ( self . future . take ( ) ) ; 
176+     fn  close ( & mut  self ,  py :  Python < ' _ > )  -> PyResult < ( ) >  { 
177+         self . poll ( py,  None ) ?; 
178+         Ok ( ( ) ) 
178179    } 
179180
180181    fn  __await__ ( self_ :  Py < Self > )  -> Py < Self >  { 
181182        self_
182183    } 
183184
184185    fn  __next__ ( & mut  self ,  py :  Python < ' _ > )  -> PyResult < IterNextOutput < PyObject ,  PyObject > >  { 
185-         self . poll ( py,  None ) 
186+         self . poll ( py,  Some ( Ok ( py . None ( ) . into ( ) ) ) ) 
186187    } 
187188} 
0 commit comments