-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Propagate any spawned
asyncio
task error upwards
This should mostly maintain top level SC principles for any task spawned using `tractor.to_asyncio.run()`. When the `asyncio` task completes make sure to cancel the pertaining `trio` cancel scope and raise any error that may have resulted. Resolves #120
Showing
1 changed file
with
99 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
""" | ||
Infection apis for ``asyncio`` loops running ``trio`` using guest mode. | ||
""" | ||
import asyncio | ||
import inspect | ||
from typing import ( | ||
Any, | ||
Callable, | ||
AsyncGenerator, | ||
Awaitable, | ||
Union, | ||
) | ||
|
||
import trio | ||
|
||
from ._state import current_actor | ||
|
||
|
||
__all__ = ['run'] | ||
|
||
|
||
async def _invoke( | ||
from_trio: trio.abc.ReceiveChannel, | ||
to_trio: asyncio.Queue, | ||
coro: Awaitable, | ||
) -> Union[AsyncGenerator, Awaitable]: | ||
"""Await or stream awaiable object based on type into | ||
``trio`` memory channel. | ||
""" | ||
async def stream_from_gen(c): | ||
async for item in c: | ||
to_trio.send_nowait(item) | ||
|
||
async def just_return(c): | ||
to_trio.send_nowait(await c) | ||
|
||
if inspect.isasyncgen(coro): | ||
return await stream_from_gen(coro) | ||
elif inspect.iscoroutine(coro): | ||
return await coro | ||
|
||
|
||
async def run( | ||
func: Callable, | ||
qsize: int = 2**10, | ||
**kwargs, | ||
) -> Any: | ||
"""Run an ``asyncio`` async function or generator in a task, return | ||
or stream the result back to ``trio``. | ||
""" | ||
assert current_actor()._infected_aio | ||
|
||
# ITC (inter task comms) | ||
from_trio = asyncio.Queue(qsize) | ||
to_trio, from_aio = trio.open_memory_channel(qsize) | ||
|
||
# allow target func to accept/stream results manually | ||
kwargs['to_trio'] = to_trio | ||
kwargs['from_trio'] = to_trio | ||
|
||
coro = func(**kwargs) | ||
|
||
cancel_scope = trio.CancelScope() | ||
|
||
# start the asyncio task we submitted from trio | ||
# TODO: try out ``anyio`` asyncio based tg here | ||
task = asyncio.create_task(_invoke(from_trio, to_trio, coro)) | ||
err = None | ||
|
||
# XXX: I'm not sure this actually does anything... | ||
def cancel_trio(task): | ||
"""Cancel the calling ``trio`` task on error. | ||
""" | ||
nonlocal err | ||
err = task.exception() | ||
cancel_scope.cancel() | ||
|
||
task.add_done_callback(cancel_trio) | ||
|
||
# determine return type async func vs. gen | ||
if inspect.isasyncgen(coro): | ||
# simple async func | ||
async def result(): | ||
with cancel_scope: | ||
return await from_aio.get() | ||
if cancel_scope.cancelled_caught and err: | ||
raise err | ||
|
||
elif inspect.iscoroutine(coro): | ||
# asycn gen | ||
async def result(): | ||
with cancel_scope: | ||
async with from_aio: | ||
async for item in from_aio: | ||
yield item | ||
if cancel_scope.cancelled_caught and err: | ||
raise err | ||
|
||
return result() |