Skip to content

Commit

Permalink
Propagate any spawned asyncio task error upwards
Browse files Browse the repository at this point in the history
This should mostly maintain top level SC principles for any task spawned
using `tractor.to_asyncio.run()`. When the `asyncio` task completes make
sure to cancel the pertaining `trio` cancel scope and raise any error
that may have resulted.

Resolves #120
  • Loading branch information
goodboy committed Sep 18, 2021
1 parent 6559f98 commit 760aa40
Showing 1 changed file with 99 additions and 0 deletions.
99 changes: 99 additions & 0 deletions tractor/to_asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
Infection apis for ``asyncio`` loops running ``trio`` using guest mode.
"""
import asyncio
import inspect
from typing import (
Any,
Callable,
AsyncGenerator,
Awaitable,
Union,
)

import trio

from ._state import current_actor


__all__ = ['run']


async def _invoke(
from_trio: trio.abc.ReceiveChannel,
to_trio: asyncio.Queue,
coro: Awaitable,
) -> Union[AsyncGenerator, Awaitable]:
"""Await or stream awaiable object based on type into
``trio`` memory channel.
"""
async def stream_from_gen(c):
async for item in c:
to_trio.send_nowait(item)

async def just_return(c):
to_trio.send_nowait(await c)

if inspect.isasyncgen(coro):
return await stream_from_gen(coro)
elif inspect.iscoroutine(coro):
return await coro


async def run(
func: Callable,
qsize: int = 2**10,
**kwargs,
) -> Any:
"""Run an ``asyncio`` async function or generator in a task, return
or stream the result back to ``trio``.
"""
assert current_actor()._infected_aio

# ITC (inter task comms)
from_trio = asyncio.Queue(qsize)
to_trio, from_aio = trio.open_memory_channel(qsize)

# allow target func to accept/stream results manually
kwargs['to_trio'] = to_trio
kwargs['from_trio'] = to_trio

coro = func(**kwargs)

cancel_scope = trio.CancelScope()

# start the asyncio task we submitted from trio
# TODO: try out ``anyio`` asyncio based tg here
task = asyncio.create_task(_invoke(from_trio, to_trio, coro))
err = None

# XXX: I'm not sure this actually does anything...
def cancel_trio(task):
"""Cancel the calling ``trio`` task on error.
"""
nonlocal err
err = task.exception()
cancel_scope.cancel()

task.add_done_callback(cancel_trio)

# determine return type async func vs. gen
if inspect.isasyncgen(coro):
# simple async func
async def result():
with cancel_scope:
return await from_aio.get()
if cancel_scope.cancelled_caught and err:
raise err

elif inspect.iscoroutine(coro):
# asycn gen
async def result():
with cancel_scope:
async with from_aio:
async for item in from_aio:
yield item
if cancel_scope.cancelled_caught and err:
raise err

return result()

0 comments on commit 760aa40

Please sign in to comment.