Skip to content

Commit 4902e18

Browse files
authored
Merge pull request #318 from goodboy/aio_error_propagation
Add context test that opens an inter-task-channel that errors
2 parents 80121ed + 05790a2 commit 4902e18

File tree

4 files changed

+160
-20
lines changed

4 files changed

+160
-20
lines changed

examples/infected_asyncio_echo_server.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
async def aio_echo_server(
1414
to_trio: trio.MemorySendChannel,
1515
from_trio: asyncio.Queue,
16+
1617
) -> None:
1718

1819
# a first message must be sent **from** this ``asyncio``

nooz/318.bug.rst

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
Fix a previously undetected ``trio``-``asyncio`` task lifetime linking
2+
issue with the ``to_asyncio.open_channel_from()`` api where both sides
3+
where not properly waiting/signalling termination and it was possible
4+
for ``asyncio``-side errors to not propagate due to a race condition.
5+
6+
The implementation fix summary is:
7+
- add state to signal the end of the ``trio`` side task to be
8+
read by the ``asyncio`` side and always cancel any ongoing
9+
task in such cases.
10+
- always wait on the ``asyncio`` task termination from the ``trio``
11+
side on error before maybe raising said error.
12+
- always close the ``trio`` mem chan on exit to ensure the other
13+
side can detect it and follow.

tests/test_infected_asyncio.py

+92-3
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,25 @@
1111
import pytest
1212
import trio
1313
import tractor
14-
from tractor import to_asyncio
15-
from tractor import RemoteActorError
14+
from tractor import (
15+
to_asyncio,
16+
RemoteActorError,
17+
)
1618
from tractor.trionics import BroadcastReceiver
1719

1820

19-
async def sleep_and_err(sleep_for: float = 0.1):
21+
async def sleep_and_err(
22+
sleep_for: float = 0.1,
23+
24+
# just signature placeholders for compat with
25+
# ``to_asyncio.open_channel_from()``
26+
to_trio: Optional[trio.MemorySendChannel] = None,
27+
from_trio: Optional[asyncio.Queue] = None,
28+
29+
):
30+
if to_trio:
31+
to_trio.send_nowait('start')
32+
2033
await asyncio.sleep(sleep_for)
2134
assert 0
2235

@@ -146,6 +159,80 @@ async def main():
146159
trio.run(main)
147160

148161

162+
@tractor.context
163+
async def trio_ctx(
164+
ctx: tractor.Context,
165+
):
166+
167+
await ctx.started('start')
168+
169+
# this will block until the ``asyncio`` task sends a "first"
170+
# message.
171+
with trio.fail_after(2):
172+
async with (
173+
tractor.to_asyncio.open_channel_from(
174+
sleep_and_err,
175+
) as (first, chan),
176+
177+
trio.open_nursery() as n,
178+
):
179+
180+
assert first == 'start'
181+
182+
# spawn another asyncio task for the cuck of it.
183+
n.start_soon(
184+
tractor.to_asyncio.run_task,
185+
sleep_forever,
186+
)
187+
await trio.sleep_forever()
188+
189+
190+
@pytest.mark.parametrize(
191+
'parent_cancels', [False, True],
192+
ids='parent_actor_cancels_child={}'.format
193+
)
194+
def test_context_spawns_aio_task_that_errors(
195+
arb_addr,
196+
parent_cancels: bool,
197+
):
198+
'''
199+
Verify that spawning a task via an intertask channel ctx mngr that
200+
errors correctly propagates the error back from the `asyncio`-side
201+
task.
202+
203+
'''
204+
async def main():
205+
206+
async with tractor.open_nursery() as n:
207+
p = await n.start_actor(
208+
'aio_daemon',
209+
enable_modules=[__name__],
210+
infect_asyncio=True,
211+
# debug_mode=True,
212+
loglevel='cancel',
213+
)
214+
async with p.open_context(
215+
trio_ctx,
216+
) as (ctx, first):
217+
218+
assert first == 'start'
219+
220+
if parent_cancels:
221+
await p.cancel_actor()
222+
223+
await trio.sleep_forever()
224+
225+
with pytest.raises(RemoteActorError) as excinfo:
226+
trio.run(main)
227+
228+
err = excinfo.value
229+
assert isinstance(err, RemoteActorError)
230+
if parent_cancels:
231+
assert err.type == trio.Cancelled
232+
else:
233+
assert err.type == AssertionError
234+
235+
149236
async def aio_cancel():
150237
''''
151238
Cancel urself boi.
@@ -385,6 +472,8 @@ async def aio_echo_server(
385472
print('breaking aio echo loop')
386473
break
387474

475+
print('exiting asyncio task')
476+
388477
async with to_asyncio.open_channel_from(
389478
aio_echo_server,
390479
) as (first, chan):

tractor/to_asyncio.py

+54-17
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from contextlib import asynccontextmanager as acm
2424
from dataclasses import dataclass
2525
import inspect
26-
import traceback
2726
from typing import (
2827
Any,
2928
Callable,
@@ -63,6 +62,7 @@ class LinkedTaskChannel(trio.abc.Channel):
6362

6463
_trio_cs: trio.CancelScope
6564
_aio_task_complete: trio.Event
65+
_trio_exited: bool = False
6666

6767
# set after ``asyncio.create_task()``
6868
_aio_task: Optional[asyncio.Task] = None
@@ -73,7 +73,13 @@ async def aclose(self) -> None:
7373
await self._from_aio.aclose()
7474

7575
async def receive(self) -> Any:
76-
async with translate_aio_errors(self):
76+
async with translate_aio_errors(
77+
self,
78+
79+
# XXX: obviously this will deadlock if an on-going stream is
80+
# being procesed.
81+
# wait_on_aio_task=False,
82+
):
7783

7884
# TODO: do we need this to guarantee asyncio code get's
7985
# cancelled in the case where the trio side somehow creates
@@ -210,10 +216,8 @@ async def wait_on_coro_final_result(
210216
orig = result = id(coro)
211217
try:
212218
result = await coro
213-
except GeneratorExit:
214-
# no need to relay error
215-
raise
216219
except BaseException as aio_err:
220+
log.exception('asyncio task errored')
217221
chan._aio_err = aio_err
218222
raise
219223

@@ -237,6 +241,7 @@ async def wait_on_coro_final_result(
237241
to_trio.close()
238242

239243
aio_task_complete.set()
244+
log.runtime(f'`asyncio` task: {task.get_name()} is complete')
240245

241246
# start the asyncio task we submitted from trio
242247
if not inspect.isawaitable(coro):
@@ -291,10 +296,12 @@ def cancel_trio(task: asyncio.Task) -> None:
291296
elif task_err is None:
292297
assert aio_err
293298
aio_err.with_traceback(aio_err.__traceback__)
294-
msg = ''.join(traceback.format_exception(type(aio_err)))
295-
log.error(
296-
f'infected task errorred:\n{msg}'
297-
)
299+
log.error('infected task errorred')
300+
301+
# XXX: alway cancel the scope on error
302+
# in case the trio task is blocking
303+
# on a checkpoint.
304+
cancel_scope.cancel()
298305

299306
# raise any ``asyncio`` side error.
300307
raise aio_err
@@ -307,6 +314,7 @@ def cancel_trio(task: asyncio.Task) -> None:
307314
async def translate_aio_errors(
308315

309316
chan: LinkedTaskChannel,
317+
wait_on_aio_task: bool = False,
310318

311319
) -> AsyncIterator[None]:
312320
'''
@@ -318,6 +326,7 @@ async def translate_aio_errors(
318326

319327
aio_err: Optional[BaseException] = None
320328

329+
# TODO: make thisi a channel method?
321330
def maybe_raise_aio_err(
322331
err: Optional[Exception] = None
323332
) -> None:
@@ -367,13 +376,30 @@ def maybe_raise_aio_err(
367376
raise
368377

369378
finally:
370-
# always cancel the ``asyncio`` task if we've made it this far
371-
# and it's not done.
372-
if not task.done() and aio_err:
379+
if (
380+
# NOTE: always cancel the ``asyncio`` task if we've made it
381+
# this far and it's not done.
382+
not task.done() and aio_err
383+
384+
# or the trio side has exited it's surrounding cancel scope
385+
# indicating the lifetime of the ``asyncio``-side task
386+
# should also be terminated.
387+
or chan._trio_exited
388+
):
389+
log.runtime(
390+
f'Cancelling `asyncio`-task: {task.get_name()}'
391+
)
373392
# assert not aio_err, 'WTF how did asyncio do this?!'
374393
task.cancel()
375394

376-
# if any ``asyncio`` error was caught, raise it here inline
395+
# Required to sync with the far end ``asyncio``-task to ensure
396+
# any error is captured (via monkeypatching the
397+
# ``channel._aio_err``) before calling ``maybe_raise_aio_err()``
398+
# below!
399+
if wait_on_aio_task:
400+
await chan._aio_task_complete.wait()
401+
402+
# NOTE: if any ``asyncio`` error was caught, raise it here inline
377403
# here in the ``trio`` task
378404
maybe_raise_aio_err()
379405

@@ -398,7 +424,10 @@ async def run_task(
398424
**kwargs,
399425
)
400426
with chan._from_aio:
401-
async with translate_aio_errors(chan):
427+
async with translate_aio_errors(
428+
chan,
429+
wait_on_aio_task=True,
430+
):
402431
# return single value that is the output from the
403432
# ``asyncio`` function-as-task. Expect the mem chan api to
404433
# do the job of handling cross-framework cancellations
@@ -426,13 +455,21 @@ async def open_channel_from(
426455
**kwargs,
427456
)
428457
async with chan._from_aio:
429-
async with translate_aio_errors(chan):
458+
async with translate_aio_errors(
459+
chan,
460+
wait_on_aio_task=True,
461+
):
430462
# sync to a "started()"-like first delivered value from the
431463
# ``asyncio`` task.
432464
first = await chan.receive()
433465

434466
# deliver stream handle upward
435-
yield first, chan
467+
try:
468+
with chan._trio_cs:
469+
yield first, chan
470+
finally:
471+
chan._trio_exited = True
472+
chan._to_trio.close()
436473

437474

438475
def run_as_asyncio_guest(
@@ -482,7 +519,7 @@ def trio_done_callback(main_outcome):
482519
main_outcome.unwrap()
483520
else:
484521
trio_done_fut.set_result(main_outcome)
485-
print(f"trio_main finished: {main_outcome!r}")
522+
log.runtime(f"trio_main finished: {main_outcome!r}")
486523

487524
# start the infection: run trio on the asyncio loop in "guest mode"
488525
log.info(f"Infecting asyncio process with {trio_main}")

0 commit comments

Comments
 (0)