23
23
from contextlib import asynccontextmanager as acm
24
24
from dataclasses import dataclass
25
25
import inspect
26
- import traceback
27
26
from typing import (
28
27
Any ,
29
28
Callable ,
@@ -63,6 +62,7 @@ class LinkedTaskChannel(trio.abc.Channel):
63
62
64
63
_trio_cs : trio .CancelScope
65
64
_aio_task_complete : trio .Event
65
+ _trio_exited : bool = False
66
66
67
67
# set after ``asyncio.create_task()``
68
68
_aio_task : Optional [asyncio .Task ] = None
@@ -73,7 +73,13 @@ async def aclose(self) -> None:
73
73
await self ._from_aio .aclose ()
74
74
75
75
async def receive (self ) -> Any :
76
- async with translate_aio_errors (self ):
76
+ async with translate_aio_errors (
77
+ self ,
78
+
79
+ # XXX: obviously this will deadlock if an on-going stream is
80
+ # being procesed.
81
+ # wait_on_aio_task=False,
82
+ ):
77
83
78
84
# TODO: do we need this to guarantee asyncio code get's
79
85
# cancelled in the case where the trio side somehow creates
@@ -210,10 +216,8 @@ async def wait_on_coro_final_result(
210
216
orig = result = id (coro )
211
217
try :
212
218
result = await coro
213
- except GeneratorExit :
214
- # no need to relay error
215
- raise
216
219
except BaseException as aio_err :
220
+ log .exception ('asyncio task errored' )
217
221
chan ._aio_err = aio_err
218
222
raise
219
223
@@ -237,6 +241,7 @@ async def wait_on_coro_final_result(
237
241
to_trio .close ()
238
242
239
243
aio_task_complete .set ()
244
+ log .runtime (f'`asyncio` task: { task .get_name ()} is complete' )
240
245
241
246
# start the asyncio task we submitted from trio
242
247
if not inspect .isawaitable (coro ):
@@ -291,10 +296,12 @@ def cancel_trio(task: asyncio.Task) -> None:
291
296
elif task_err is None :
292
297
assert aio_err
293
298
aio_err .with_traceback (aio_err .__traceback__ )
294
- msg = '' .join (traceback .format_exception (type (aio_err )))
295
- log .error (
296
- f'infected task errorred:\n { msg } '
297
- )
299
+ log .error ('infected task errorred' )
300
+
301
+ # XXX: alway cancel the scope on error
302
+ # in case the trio task is blocking
303
+ # on a checkpoint.
304
+ cancel_scope .cancel ()
298
305
299
306
# raise any ``asyncio`` side error.
300
307
raise aio_err
@@ -307,6 +314,7 @@ def cancel_trio(task: asyncio.Task) -> None:
307
314
async def translate_aio_errors (
308
315
309
316
chan : LinkedTaskChannel ,
317
+ wait_on_aio_task : bool = False ,
310
318
311
319
) -> AsyncIterator [None ]:
312
320
'''
@@ -318,6 +326,7 @@ async def translate_aio_errors(
318
326
319
327
aio_err : Optional [BaseException ] = None
320
328
329
+ # TODO: make thisi a channel method?
321
330
def maybe_raise_aio_err (
322
331
err : Optional [Exception ] = None
323
332
) -> None :
@@ -367,13 +376,30 @@ def maybe_raise_aio_err(
367
376
raise
368
377
369
378
finally :
370
- # always cancel the ``asyncio`` task if we've made it this far
371
- # and it's not done.
372
- if not task .done () and aio_err :
379
+ if (
380
+ # NOTE: always cancel the ``asyncio`` task if we've made it
381
+ # this far and it's not done.
382
+ not task .done () and aio_err
383
+
384
+ # or the trio side has exited it's surrounding cancel scope
385
+ # indicating the lifetime of the ``asyncio``-side task
386
+ # should also be terminated.
387
+ or chan ._trio_exited
388
+ ):
389
+ log .runtime (
390
+ f'Cancelling `asyncio`-task: { task .get_name ()} '
391
+ )
373
392
# assert not aio_err, 'WTF how did asyncio do this?!'
374
393
task .cancel ()
375
394
376
- # if any ``asyncio`` error was caught, raise it here inline
395
+ # Required to sync with the far end ``asyncio``-task to ensure
396
+ # any error is captured (via monkeypatching the
397
+ # ``channel._aio_err``) before calling ``maybe_raise_aio_err()``
398
+ # below!
399
+ if wait_on_aio_task :
400
+ await chan ._aio_task_complete .wait ()
401
+
402
+ # NOTE: if any ``asyncio`` error was caught, raise it here inline
377
403
# here in the ``trio`` task
378
404
maybe_raise_aio_err ()
379
405
@@ -398,7 +424,10 @@ async def run_task(
398
424
** kwargs ,
399
425
)
400
426
with chan ._from_aio :
401
- async with translate_aio_errors (chan ):
427
+ async with translate_aio_errors (
428
+ chan ,
429
+ wait_on_aio_task = True ,
430
+ ):
402
431
# return single value that is the output from the
403
432
# ``asyncio`` function-as-task. Expect the mem chan api to
404
433
# do the job of handling cross-framework cancellations
@@ -426,13 +455,21 @@ async def open_channel_from(
426
455
** kwargs ,
427
456
)
428
457
async with chan ._from_aio :
429
- async with translate_aio_errors (chan ):
458
+ async with translate_aio_errors (
459
+ chan ,
460
+ wait_on_aio_task = True ,
461
+ ):
430
462
# sync to a "started()"-like first delivered value from the
431
463
# ``asyncio`` task.
432
464
first = await chan .receive ()
433
465
434
466
# deliver stream handle upward
435
- yield first , chan
467
+ try :
468
+ with chan ._trio_cs :
469
+ yield first , chan
470
+ finally :
471
+ chan ._trio_exited = True
472
+ chan ._to_trio .close ()
436
473
437
474
438
475
def run_as_asyncio_guest (
@@ -482,7 +519,7 @@ def trio_done_callback(main_outcome):
482
519
main_outcome .unwrap ()
483
520
else :
484
521
trio_done_fut .set_result (main_outcome )
485
- print (f"trio_main finished: { main_outcome !r} " )
522
+ log . runtime (f"trio_main finished: { main_outcome !r} " )
486
523
487
524
# start the infection: run trio on the asyncio loop in "guest mode"
488
525
log .info (f"Infecting asyncio process with { trio_main } " )
0 commit comments