Skip to content

Commit 3186fb9

Browse files
authored
Implemented voluntary cancellation in worker threads (#629)
1 parent c360b99 commit 3186fb9

12 files changed

+260
-45
lines changed

docs/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ Running asynchronous code from other threads
6161

6262
.. autofunction:: anyio.from_thread.run
6363
.. autofunction:: anyio.from_thread.run_sync
64+
.. autofunction:: anyio.from_thread.check_cancelled
6465
.. autofunction:: anyio.from_thread.start_blocking_portal
6566

6667
.. autoclass:: anyio.from_thread.BlockingPortal

docs/threads.rst

+20
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,23 @@ maximum of 40 threads to be spawned. You can adjust this limit like this::
205205

206206
.. note:: AnyIO's default thread pool limiter does not affect the default thread pool
207207
executor on :mod:`asyncio`.
208+
209+
Reacting to cancellation in worker threads
210+
------------------------------------------
211+
212+
While there is no mechanism in Python to cancel code running in a thread, AnyIO provides a
213+
mechanism that allows user code to voluntarily check if the host task's scope has been cancelled,
214+
and if it has, raise a cancellation exception. This can be done by simply calling
215+
:func:`from_thread.check_cancelled`::
216+
217+
from anyio import to_thread, from_thread
218+
219+
def sync_function():
220+
while True:
221+
from_thread.check_cancelled()
222+
print("Not cancelled yet")
223+
sleep(1)
224+
225+
async def foo():
226+
with move_on_after(3):
227+
await to_thread.run_sync(sync_function)

docs/versionhistory.rst

+5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
1010
- Call ``trio.to_thread.run_sync()`` using the ``abandon_on_cancel`` keyword argument
1111
instead of ``cancellable``
1212
- Removed a checkpoint when exiting a task group
13+
- Renamed the ``cancellable`` argument in ``anyio.to_thread.run_sync()`` to
14+
``abandon_on_cancel`` (and deprecated the old parameter name)
15+
- Bumped minimum version of Trio to v0.23
16+
- Added support for voluntary thread cancellation via
17+
``anyio.from_thread.check_cancelled()``
1318
- Bumped minimum version of trio to v0.23
1419
- Exposed the ``ResourceGuard`` class in the public API
1520
- Fixed ``RuntimeError: Runner is closed`` when running higher-scoped async generator

src/anyio/_backends/_asyncio.py

+40-7
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
import sniffio
6060

6161
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
62-
from .._core._eventloop import claim_worker_thread
62+
from .._core._eventloop import claim_worker_thread, threadlocals
6363
from .._core._exceptions import (
6464
BrokenResourceError,
6565
BusyResourceError,
@@ -783,7 +783,7 @@ def __init__(
783783
self.idle_workers = idle_workers
784784
self.loop = root_task._loop
785785
self.queue: Queue[
786-
tuple[Context, Callable, tuple, asyncio.Future] | None
786+
tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
787787
] = Queue(2)
788788
self.idle_since = AsyncIOBackend.current_time()
789789
self.stopping = False
@@ -814,14 +814,17 @@ def run(self) -> None:
814814
# Shutdown command received
815815
return
816816

817-
context, func, args, future = item
817+
context, func, args, future, cancel_scope = item
818818
if not future.cancelled():
819819
result = None
820820
exception: BaseException | None = None
821+
threadlocals.current_cancel_scope = cancel_scope
821822
try:
822823
result = context.run(func, *args)
823824
except BaseException as exc:
824825
exception = exc
826+
finally:
827+
del threadlocals.current_cancel_scope
825828

826829
if not self.loop.is_closed():
827830
self.loop.call_soon_threadsafe(
@@ -2045,7 +2048,7 @@ async def run_sync_in_worker_thread(
20452048
cls,
20462049
func: Callable[..., T_Retval],
20472050
args: tuple[Any, ...],
2048-
cancellable: bool = False,
2051+
abandon_on_cancel: bool = False,
20492052
limiter: abc.CapacityLimiter | None = None,
20502053
) -> T_Retval:
20512054
await cls.checkpoint()
@@ -2062,7 +2065,7 @@ async def run_sync_in_worker_thread(
20622065
_threadpool_workers.set(workers)
20632066

20642067
async with limiter or cls.current_default_thread_limiter():
2065-
with CancelScope(shield=not cancellable):
2068+
with CancelScope(shield=not abandon_on_cancel) as scope:
20662069
future: asyncio.Future = asyncio.Future()
20672070
root_task = find_root_task()
20682071
if not idle_workers:
@@ -2091,21 +2094,51 @@ async def run_sync_in_worker_thread(
20912094

20922095
context = copy_context()
20932096
context.run(sniffio.current_async_library_cvar.set, None)
2094-
worker.queue.put_nowait((context, func, args, future))
2097+
if abandon_on_cancel or scope._parent_scope is None:
2098+
worker_scope = scope
2099+
else:
2100+
worker_scope = scope._parent_scope
2101+
2102+
worker.queue.put_nowait((context, func, args, future, worker_scope))
20952103
return await future
20962104

2105+
@classmethod
2106+
def check_cancelled(cls) -> None:
2107+
scope: CancelScope | None = threadlocals.current_cancel_scope
2108+
while scope is not None:
2109+
if scope.cancel_called:
2110+
raise CancelledError(f"Cancelled by cancel scope {id(scope):x}")
2111+
2112+
if scope.shield:
2113+
return
2114+
2115+
scope = scope._parent_scope
2116+
20972117
@classmethod
20982118
def run_async_from_thread(
20992119
cls,
21002120
func: Callable[..., Awaitable[T_Retval]],
21012121
args: tuple[Any, ...],
21022122
token: object,
21032123
) -> T_Retval:
2124+
async def task_wrapper(scope: CancelScope) -> T_Retval:
2125+
__tracebackhide__ = True
2126+
task = cast(asyncio.Task, current_task())
2127+
_task_states[task] = TaskState(None, scope)
2128+
scope._tasks.add(task)
2129+
try:
2130+
return await func(*args)
2131+
except CancelledError as exc:
2132+
raise concurrent.futures.CancelledError(str(exc)) from None
2133+
finally:
2134+
scope._tasks.discard(task)
2135+
21042136
loop = cast(AbstractEventLoop, token)
21052137
context = copy_context()
21062138
context.run(sniffio.current_async_library_cvar.set, "asyncio")
2139+
wrapper = task_wrapper(threadlocals.current_cancel_scope)
21072140
f: concurrent.futures.Future[T_Retval] = context.run(
2108-
asyncio.run_coroutine_threadsafe, func(*args), loop
2141+
asyncio.run_coroutine_threadsafe, wrapper, loop
21092142
)
21102143
return f.result()
21112144

src/anyio/_backends/_trio.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import trio.lowlevel
3737
from outcome import Error, Outcome, Value
3838
from trio.lowlevel import (
39-
TrioToken,
4039
current_root_task,
4140
current_task,
4241
wait_readable,
@@ -869,7 +868,7 @@ async def run_sync_in_worker_thread(
869868
cls,
870869
func: Callable[..., T_Retval],
871870
args: tuple[Any, ...],
872-
cancellable: bool = False,
871+
abandon_on_cancel: bool = False,
873872
limiter: abc.CapacityLimiter | None = None,
874873
) -> T_Retval:
875874
def wrapper() -> T_Retval:
@@ -879,24 +878,28 @@ def wrapper() -> T_Retval:
879878
token = TrioBackend.current_token()
880879
return await run_sync(
881880
wrapper,
882-
abandon_on_cancel=cancellable,
881+
abandon_on_cancel=abandon_on_cancel,
883882
limiter=cast(trio.CapacityLimiter, limiter),
884883
)
885884

885+
@classmethod
886+
def check_cancelled(cls) -> None:
887+
trio.from_thread.check_cancelled()
888+
886889
@classmethod
887890
def run_async_from_thread(
888891
cls,
889892
func: Callable[..., Awaitable[T_Retval]],
890893
args: tuple[Any, ...],
891894
token: object,
892895
) -> T_Retval:
893-
return trio.from_thread.run(func, *args, trio_token=cast(TrioToken, token))
896+
return trio.from_thread.run(func, *args)
894897

895898
@classmethod
896899
def run_sync_from_thread(
897900
cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object
898901
) -> T_Retval:
899-
return trio.from_thread.run_sync(func, *args, trio_token=cast(TrioToken, token))
902+
return trio.from_thread.run_sync(func, *args)
900903

901904
@classmethod
902905
def create_blocking_portal(cls) -> abc.BlockingPortal:

src/anyio/_core/_fileio.py

+27-17
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ class _PathIterator(AsyncIterator["Path"]):
205205
iterator: Iterator[PathLike[str]]
206206

207207
async def __anext__(self) -> Path:
208-
nextval = await to_thread.run_sync(next, self.iterator, None, cancellable=True)
208+
nextval = await to_thread.run_sync(
209+
next, self.iterator, None, abandon_on_cancel=True
210+
)
209211
if nextval is None:
210212
raise StopAsyncIteration from None
211213

@@ -386,17 +388,19 @@ async def cwd(cls) -> Path:
386388
return cls(path)
387389

388390
async def exists(self) -> bool:
389-
return await to_thread.run_sync(self._path.exists, cancellable=True)
391+
return await to_thread.run_sync(self._path.exists, abandon_on_cancel=True)
390392

391393
async def expanduser(self) -> Path:
392-
return Path(await to_thread.run_sync(self._path.expanduser, cancellable=True))
394+
return Path(
395+
await to_thread.run_sync(self._path.expanduser, abandon_on_cancel=True)
396+
)
393397

394398
def glob(self, pattern: str) -> AsyncIterator[Path]:
395399
gen = self._path.glob(pattern)
396400
return _PathIterator(gen)
397401

398402
async def group(self) -> str:
399-
return await to_thread.run_sync(self._path.group, cancellable=True)
403+
return await to_thread.run_sync(self._path.group, abandon_on_cancel=True)
400404

401405
async def hardlink_to(self, target: str | pathlib.Path | Path) -> None:
402406
if isinstance(target, Path):
@@ -413,31 +417,37 @@ def is_absolute(self) -> bool:
413417
return self._path.is_absolute()
414418

415419
async def is_block_device(self) -> bool:
416-
return await to_thread.run_sync(self._path.is_block_device, cancellable=True)
420+
return await to_thread.run_sync(
421+
self._path.is_block_device, abandon_on_cancel=True
422+
)
417423

418424
async def is_char_device(self) -> bool:
419-
return await to_thread.run_sync(self._path.is_char_device, cancellable=True)
425+
return await to_thread.run_sync(
426+
self._path.is_char_device, abandon_on_cancel=True
427+
)
420428

421429
async def is_dir(self) -> bool:
422-
return await to_thread.run_sync(self._path.is_dir, cancellable=True)
430+
return await to_thread.run_sync(self._path.is_dir, abandon_on_cancel=True)
423431

424432
async def is_fifo(self) -> bool:
425-
return await to_thread.run_sync(self._path.is_fifo, cancellable=True)
433+
return await to_thread.run_sync(self._path.is_fifo, abandon_on_cancel=True)
426434

427435
async def is_file(self) -> bool:
428-
return await to_thread.run_sync(self._path.is_file, cancellable=True)
436+
return await to_thread.run_sync(self._path.is_file, abandon_on_cancel=True)
429437

430438
async def is_mount(self) -> bool:
431-
return await to_thread.run_sync(os.path.ismount, self._path, cancellable=True)
439+
return await to_thread.run_sync(
440+
os.path.ismount, self._path, abandon_on_cancel=True
441+
)
432442

433443
def is_reserved(self) -> bool:
434444
return self._path.is_reserved()
435445

436446
async def is_socket(self) -> bool:
437-
return await to_thread.run_sync(self._path.is_socket, cancellable=True)
447+
return await to_thread.run_sync(self._path.is_socket, abandon_on_cancel=True)
438448

439449
async def is_symlink(self) -> bool:
440-
return await to_thread.run_sync(self._path.is_symlink, cancellable=True)
450+
return await to_thread.run_sync(self._path.is_symlink, abandon_on_cancel=True)
441451

442452
def iterdir(self) -> AsyncIterator[Path]:
443453
gen = self._path.iterdir()
@@ -450,7 +460,7 @@ async def lchmod(self, mode: int) -> None:
450460
await to_thread.run_sync(self._path.lchmod, mode)
451461

452462
async def lstat(self) -> os.stat_result:
453-
return await to_thread.run_sync(self._path.lstat, cancellable=True)
463+
return await to_thread.run_sync(self._path.lstat, abandon_on_cancel=True)
454464

455465
async def mkdir(
456466
self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False
@@ -493,7 +503,7 @@ async def open(
493503
return AsyncFile(fp)
494504

495505
async def owner(self) -> str:
496-
return await to_thread.run_sync(self._path.owner, cancellable=True)
506+
return await to_thread.run_sync(self._path.owner, abandon_on_cancel=True)
497507

498508
async def read_bytes(self) -> bytes:
499509
return await to_thread.run_sync(self._path.read_bytes)
@@ -526,7 +536,7 @@ async def replace(self, target: str | pathlib.PurePath | Path) -> Path:
526536

527537
async def resolve(self, strict: bool = False) -> Path:
528538
func = partial(self._path.resolve, strict=strict)
529-
return Path(await to_thread.run_sync(func, cancellable=True))
539+
return Path(await to_thread.run_sync(func, abandon_on_cancel=True))
530540

531541
def rglob(self, pattern: str) -> AsyncIterator[Path]:
532542
gen = self._path.rglob(pattern)
@@ -542,12 +552,12 @@ async def samefile(
542552
other_path = other_path._path
543553

544554
return await to_thread.run_sync(
545-
self._path.samefile, other_path, cancellable=True
555+
self._path.samefile, other_path, abandon_on_cancel=True
546556
)
547557

548558
async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result:
549559
func = partial(os.stat, follow_symlinks=follow_symlinks)
550-
return await to_thread.run_sync(func, self._path, cancellable=True)
560+
return await to_thread.run_sync(func, self._path, abandon_on_cancel=True)
551561

552562
async def symlink_to(
553563
self,

src/anyio/_core/_sockets.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -693,9 +693,9 @@ async def setup_unix_local_socket(
693693

694694
if path_str is not None:
695695
try:
696-
await to_thread.run_sync(raw_socket.bind, path_str, cancellable=True)
696+
await to_thread.run_sync(raw_socket.bind, path_str, abandon_on_cancel=True)
697697
if mode is not None:
698-
await to_thread.run_sync(chmod, path_str, mode, cancellable=True)
698+
await to_thread.run_sync(chmod, path_str, mode, abandon_on_cancel=True)
699699
except BaseException:
700700
raw_socket.close()
701701
raise

src/anyio/abc/_eventloop.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,16 @@ async def run_sync_in_worker_thread(
171171
cls,
172172
func: Callable[..., T_Retval],
173173
args: tuple[Any, ...],
174-
cancellable: bool = False,
174+
abandon_on_cancel: bool = False,
175175
limiter: CapacityLimiter | None = None,
176176
) -> T_Retval:
177177
pass
178178

179+
@classmethod
180+
@abstractmethod
181+
def check_cancelled(cls) -> None:
182+
pass
183+
179184
@classmethod
180185
@abstractmethod
181186
def run_async_from_thread(

0 commit comments

Comments
 (0)