Skip to content

Commit

Permalink
Implemented voluntary cancellation in worker threads (#629)
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm authored Nov 22, 2023
1 parent c360b99 commit 3186fb9
Show file tree
Hide file tree
Showing 12 changed files with 260 additions and 45 deletions.
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Running asynchronous code from other threads

.. autofunction:: anyio.from_thread.run
.. autofunction:: anyio.from_thread.run_sync
.. autofunction:: anyio.from_thread.check_cancelled
.. autofunction:: anyio.from_thread.start_blocking_portal

.. autoclass:: anyio.from_thread.BlockingPortal
Expand Down
20 changes: 20 additions & 0 deletions docs/threads.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,23 @@ maximum of 40 threads to be spawned. You can adjust this limit like this::

.. note:: AnyIO's default thread pool limiter does not affect the default thread pool
executor on :mod:`asyncio`.

Reacting to cancellation in worker threads
------------------------------------------

While there is no mechanism in Python to cancel code running in a thread, AnyIO provides a
mechanism that allows user code to voluntarily check if the host task's scope has been cancelled,
and if it has, raise a cancellation exception. This can be done by simply calling
:func:`from_thread.check_cancelled`::

from anyio import to_thread, from_thread

def sync_function():
while True:
from_thread.check_cancelled()
print("Not cancelled yet")
sleep(1)

async def foo():
with move_on_after(3):
await to_thread.run_sync(sync_function)
5 changes: 5 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Call ``trio.to_thread.run_sync()`` using the ``abandon_on_cancel`` keyword argument
instead of ``cancellable``
- Removed a checkpoint when exiting a task group
- Renamed the ``cancellable`` argument in ``anyio.to_thread.run_sync()`` to
``abandon_on_cancel`` (and deprecated the old parameter name)
- Bumped minimum version of Trio to v0.23
- Added support for voluntary thread cancellation via
``anyio.from_thread.check_cancelled()``
- Bumped minimum version of trio to v0.23
- Exposed the ``ResourceGuard`` class in the public API
- Fixed ``RuntimeError: Runner is closed`` when running higher-scoped async generator
Expand Down
47 changes: 40 additions & 7 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
import sniffio

from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
from .._core._eventloop import claim_worker_thread
from .._core._eventloop import claim_worker_thread, threadlocals
from .._core._exceptions import (
BrokenResourceError,
BusyResourceError,
Expand Down Expand Up @@ -783,7 +783,7 @@ def __init__(
self.idle_workers = idle_workers
self.loop = root_task._loop
self.queue: Queue[
tuple[Context, Callable, tuple, asyncio.Future] | None
tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
] = Queue(2)
self.idle_since = AsyncIOBackend.current_time()
self.stopping = False
Expand Down Expand Up @@ -814,14 +814,17 @@ def run(self) -> None:
# Shutdown command received
return

context, func, args, future = item
context, func, args, future, cancel_scope = item
if not future.cancelled():
result = None
exception: BaseException | None = None
threadlocals.current_cancel_scope = cancel_scope
try:
result = context.run(func, *args)
except BaseException as exc:
exception = exc
finally:
del threadlocals.current_cancel_scope

if not self.loop.is_closed():
self.loop.call_soon_threadsafe(
Expand Down Expand Up @@ -2045,7 +2048,7 @@ async def run_sync_in_worker_thread(
cls,
func: Callable[..., T_Retval],
args: tuple[Any, ...],
cancellable: bool = False,
abandon_on_cancel: bool = False,
limiter: abc.CapacityLimiter | None = None,
) -> T_Retval:
await cls.checkpoint()
Expand All @@ -2062,7 +2065,7 @@ async def run_sync_in_worker_thread(
_threadpool_workers.set(workers)

async with limiter or cls.current_default_thread_limiter():
with CancelScope(shield=not cancellable):
with CancelScope(shield=not abandon_on_cancel) as scope:
future: asyncio.Future = asyncio.Future()
root_task = find_root_task()
if not idle_workers:
Expand Down Expand Up @@ -2091,21 +2094,51 @@ async def run_sync_in_worker_thread(

context = copy_context()
context.run(sniffio.current_async_library_cvar.set, None)
worker.queue.put_nowait((context, func, args, future))
if abandon_on_cancel or scope._parent_scope is None:
worker_scope = scope
else:
worker_scope = scope._parent_scope

worker.queue.put_nowait((context, func, args, future, worker_scope))
return await future

@classmethod
def check_cancelled(cls) -> None:
scope: CancelScope | None = threadlocals.current_cancel_scope
while scope is not None:
if scope.cancel_called:
raise CancelledError(f"Cancelled by cancel scope {id(scope):x}")

if scope.shield:
return

scope = scope._parent_scope

@classmethod
def run_async_from_thread(
cls,
func: Callable[..., Awaitable[T_Retval]],
args: tuple[Any, ...],
token: object,
) -> T_Retval:
async def task_wrapper(scope: CancelScope) -> T_Retval:
__tracebackhide__ = True
task = cast(asyncio.Task, current_task())
_task_states[task] = TaskState(None, scope)
scope._tasks.add(task)
try:
return await func(*args)
except CancelledError as exc:
raise concurrent.futures.CancelledError(str(exc)) from None
finally:
scope._tasks.discard(task)

loop = cast(AbstractEventLoop, token)
context = copy_context()
context.run(sniffio.current_async_library_cvar.set, "asyncio")
wrapper = task_wrapper(threadlocals.current_cancel_scope)
f: concurrent.futures.Future[T_Retval] = context.run(
asyncio.run_coroutine_threadsafe, func(*args), loop
asyncio.run_coroutine_threadsafe, wrapper, loop
)
return f.result()

Expand Down
13 changes: 8 additions & 5 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import trio.lowlevel
from outcome import Error, Outcome, Value
from trio.lowlevel import (
TrioToken,
current_root_task,
current_task,
wait_readable,
Expand Down Expand Up @@ -869,7 +868,7 @@ async def run_sync_in_worker_thread(
cls,
func: Callable[..., T_Retval],
args: tuple[Any, ...],
cancellable: bool = False,
abandon_on_cancel: bool = False,
limiter: abc.CapacityLimiter | None = None,
) -> T_Retval:
def wrapper() -> T_Retval:
Expand All @@ -879,24 +878,28 @@ def wrapper() -> T_Retval:
token = TrioBackend.current_token()
return await run_sync(
wrapper,
abandon_on_cancel=cancellable,
abandon_on_cancel=abandon_on_cancel,
limiter=cast(trio.CapacityLimiter, limiter),
)

@classmethod
def check_cancelled(cls) -> None:
trio.from_thread.check_cancelled()

@classmethod
def run_async_from_thread(
cls,
func: Callable[..., Awaitable[T_Retval]],
args: tuple[Any, ...],
token: object,
) -> T_Retval:
return trio.from_thread.run(func, *args, trio_token=cast(TrioToken, token))
return trio.from_thread.run(func, *args)

@classmethod
def run_sync_from_thread(
cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object
) -> T_Retval:
return trio.from_thread.run_sync(func, *args, trio_token=cast(TrioToken, token))
return trio.from_thread.run_sync(func, *args)

@classmethod
def create_blocking_portal(cls) -> abc.BlockingPortal:
Expand Down
44 changes: 27 additions & 17 deletions src/anyio/_core/_fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ class _PathIterator(AsyncIterator["Path"]):
iterator: Iterator[PathLike[str]]

async def __anext__(self) -> Path:
nextval = await to_thread.run_sync(next, self.iterator, None, cancellable=True)
nextval = await to_thread.run_sync(
next, self.iterator, None, abandon_on_cancel=True
)
if nextval is None:
raise StopAsyncIteration from None

Expand Down Expand Up @@ -386,17 +388,19 @@ async def cwd(cls) -> Path:
return cls(path)

async def exists(self) -> bool:
return await to_thread.run_sync(self._path.exists, cancellable=True)
return await to_thread.run_sync(self._path.exists, abandon_on_cancel=True)

async def expanduser(self) -> Path:
return Path(await to_thread.run_sync(self._path.expanduser, cancellable=True))
return Path(
await to_thread.run_sync(self._path.expanduser, abandon_on_cancel=True)
)

def glob(self, pattern: str) -> AsyncIterator[Path]:
gen = self._path.glob(pattern)
return _PathIterator(gen)

async def group(self) -> str:
return await to_thread.run_sync(self._path.group, cancellable=True)
return await to_thread.run_sync(self._path.group, abandon_on_cancel=True)

async def hardlink_to(self, target: str | pathlib.Path | Path) -> None:
if isinstance(target, Path):
Expand All @@ -413,31 +417,37 @@ def is_absolute(self) -> bool:
return self._path.is_absolute()

async def is_block_device(self) -> bool:
return await to_thread.run_sync(self._path.is_block_device, cancellable=True)
return await to_thread.run_sync(
self._path.is_block_device, abandon_on_cancel=True
)

async def is_char_device(self) -> bool:
return await to_thread.run_sync(self._path.is_char_device, cancellable=True)
return await to_thread.run_sync(
self._path.is_char_device, abandon_on_cancel=True
)

async def is_dir(self) -> bool:
return await to_thread.run_sync(self._path.is_dir, cancellable=True)
return await to_thread.run_sync(self._path.is_dir, abandon_on_cancel=True)

async def is_fifo(self) -> bool:
return await to_thread.run_sync(self._path.is_fifo, cancellable=True)
return await to_thread.run_sync(self._path.is_fifo, abandon_on_cancel=True)

async def is_file(self) -> bool:
return await to_thread.run_sync(self._path.is_file, cancellable=True)
return await to_thread.run_sync(self._path.is_file, abandon_on_cancel=True)

async def is_mount(self) -> bool:
return await to_thread.run_sync(os.path.ismount, self._path, cancellable=True)
return await to_thread.run_sync(
os.path.ismount, self._path, abandon_on_cancel=True
)

def is_reserved(self) -> bool:
return self._path.is_reserved()

async def is_socket(self) -> bool:
return await to_thread.run_sync(self._path.is_socket, cancellable=True)
return await to_thread.run_sync(self._path.is_socket, abandon_on_cancel=True)

async def is_symlink(self) -> bool:
return await to_thread.run_sync(self._path.is_symlink, cancellable=True)
return await to_thread.run_sync(self._path.is_symlink, abandon_on_cancel=True)

def iterdir(self) -> AsyncIterator[Path]:
gen = self._path.iterdir()
Expand All @@ -450,7 +460,7 @@ async def lchmod(self, mode: int) -> None:
await to_thread.run_sync(self._path.lchmod, mode)

async def lstat(self) -> os.stat_result:
return await to_thread.run_sync(self._path.lstat, cancellable=True)
return await to_thread.run_sync(self._path.lstat, abandon_on_cancel=True)

async def mkdir(
self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False
Expand Down Expand Up @@ -493,7 +503,7 @@ async def open(
return AsyncFile(fp)

async def owner(self) -> str:
return await to_thread.run_sync(self._path.owner, cancellable=True)
return await to_thread.run_sync(self._path.owner, abandon_on_cancel=True)

async def read_bytes(self) -> bytes:
return await to_thread.run_sync(self._path.read_bytes)
Expand Down Expand Up @@ -526,7 +536,7 @@ async def replace(self, target: str | pathlib.PurePath | Path) -> Path:

async def resolve(self, strict: bool = False) -> Path:
func = partial(self._path.resolve, strict=strict)
return Path(await to_thread.run_sync(func, cancellable=True))
return Path(await to_thread.run_sync(func, abandon_on_cancel=True))

def rglob(self, pattern: str) -> AsyncIterator[Path]:
gen = self._path.rglob(pattern)
Expand All @@ -542,12 +552,12 @@ async def samefile(
other_path = other_path._path

return await to_thread.run_sync(
self._path.samefile, other_path, cancellable=True
self._path.samefile, other_path, abandon_on_cancel=True
)

async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result:
func = partial(os.stat, follow_symlinks=follow_symlinks)
return await to_thread.run_sync(func, self._path, cancellable=True)
return await to_thread.run_sync(func, self._path, abandon_on_cancel=True)

async def symlink_to(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,9 +693,9 @@ async def setup_unix_local_socket(

if path_str is not None:
try:
await to_thread.run_sync(raw_socket.bind, path_str, cancellable=True)
await to_thread.run_sync(raw_socket.bind, path_str, abandon_on_cancel=True)
if mode is not None:
await to_thread.run_sync(chmod, path_str, mode, cancellable=True)
await to_thread.run_sync(chmod, path_str, mode, abandon_on_cancel=True)
except BaseException:
raw_socket.close()
raise
Expand Down
7 changes: 6 additions & 1 deletion src/anyio/abc/_eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,16 @@ async def run_sync_in_worker_thread(
cls,
func: Callable[..., T_Retval],
args: tuple[Any, ...],
cancellable: bool = False,
abandon_on_cancel: bool = False,
limiter: CapacityLimiter | None = None,
) -> T_Retval:
pass

@classmethod
@abstractmethod
def check_cancelled(cls) -> None:
pass

@classmethod
@abstractmethod
def run_async_from_thread(
Expand Down
Loading

0 comments on commit 3186fb9

Please sign in to comment.