From e9db7a3f6c0ac42d8beea0f7293f3588128a5280 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 30 Mar 2024 09:34:32 +0100 Subject: [PATCH 1/9] TEMP COMMIT. --- README.md | 34 ++++ src/pytask_parallel/__init__.py | 5 +- src/pytask_parallel/backends.py | 55 +++++- src/pytask_parallel/config.py | 36 ++-- src/pytask_parallel/execute.py | 278 +------------------------------ src/pytask_parallel/plugin.py | 2 - src/pytask_parallel/processes.py | 181 ++++++++++++++++++++ src/pytask_parallel/threads.py | 56 +++++++ src/pytask_parallel/utils.py | 61 +++++++ tests/test_backends.py | 38 +++++ tests/test_execute.py | 29 ++-- 11 files changed, 466 insertions(+), 309 deletions(-) create mode 100644 src/pytask_parallel/processes.py create mode 100644 src/pytask_parallel/threads.py create mode 100644 src/pytask_parallel/utils.py create mode 100644 tests/test_backends.py diff --git a/README.md b/README.md index 74b38a2..325c007 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,40 @@ n_workers = 1 parallel_backend = "processes" # or loky or threads ``` +## Custom Executor + +pytask-parallel allows you to use your parallel backend. The only requirement is that +you provide an executor that implements the interface of +[`concurrent.futures.Executor`](https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Executor). + +To register your backend, go to a module that is imported by pytask when building the +project, for example, the `config.py`. Register a builder function for your custom +backend. + +```python +from concurrent.futures import Executor +from concurrent.futures import ProcessPoolExecutor + +from pytask_parallel import ParallelBackend, registry + + +def build_custom_executor(n_workers: int) -> Executor: + return ProcessPoolExecutor(max_workers=n_workers) + + +registry.register_parallel_backend(ParallelBackend.CUSTOM, build_custom_executor) +``` + +Now, build the project requesting your custom backend. + +```console +pytask --parallel-backend custom +``` + +> \[!NOTE\] +> +> When you request the custom backend, it is even used when `n_workers` is set to 1. + ## Some implementation details ### Parallelization and Debugging diff --git a/src/pytask_parallel/__init__.py b/src/pytask_parallel/__init__.py index c937a1f..812a622 100644 --- a/src/pytask_parallel/__init__.py +++ b/src/pytask_parallel/__init__.py @@ -2,6 +2,9 @@ from __future__ import annotations +from pytask_parallel.backends import ParallelBackend +from pytask_parallel.backends import registry + try: from ._version import version as __version__ except ImportError: @@ -10,4 +13,4 @@ __version__ = "unknown" -__all__ = ["__version__"] +__all__ = ["ParallelBackend", "__version__", "registry"] diff --git a/src/pytask_parallel/backends.py b/src/pytask_parallel/backends.py index 31a1c6a..f2172ab 100644 --- a/src/pytask_parallel/backends.py +++ b/src/pytask_parallel/backends.py @@ -2,25 +2,29 @@ from __future__ import annotations +from concurrent.futures import Executor from concurrent.futures import Future from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ThreadPoolExecutor from enum import Enum from typing import Any from typing import Callable +from typing import ClassVar import cloudpickle from loky import get_reusable_executor +__all__ = ["ParallelBackend", "ParallelBackendRegistry", "registry"] -def deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any: + +def _deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any: """Deserialize and execute a function and keyword arguments.""" deserialized_fn = cloudpickle.loads(fn) deserialized_kwargs = cloudpickle.loads(kwargs) return deserialized_fn(**deserialized_kwargs) -class CloudpickleProcessPoolExecutor(ProcessPoolExecutor): +class _CloudpickleProcessPoolExecutor(ProcessPoolExecutor): """Patches the standard executor to serialize functions with cloudpickle.""" # The type signature is wrong for version above Py3.7. Fix when 3.7 is deprecated. @@ -32,7 +36,7 @@ def submit( # type: ignore[override] ) -> Future[Any]: """Submit a new task.""" return super().submit( - deserialize_and_run_with_cloudpickle, + _deserialize_and_run_with_cloudpickle, fn=cloudpickle.dumps(fn), kwargs=cloudpickle.dumps(kwargs), ) @@ -41,13 +45,46 @@ def submit( # type: ignore[override] class ParallelBackend(Enum): """Choices for parallel backends.""" + CUSTOM = "custom" + LOKY = "loky" PROCESSES = "processes" THREADS = "threads" - LOKY = "loky" -PARALLEL_BACKEND_BUILDER = { - ParallelBackend.PROCESSES: lambda: CloudpickleProcessPoolExecutor, - ParallelBackend.THREADS: lambda: ThreadPoolExecutor, - ParallelBackend.LOKY: lambda: get_reusable_executor, -} +class ParallelBackendRegistry: + """Registry for parallel backends.""" + + registry: ClassVar[dict[ParallelBackend, Callable[..., Executor]]] = {} + + def register_parallel_backend( + self, kind: ParallelBackend, builder: Callable[..., Executor] + ) -> None: + """Register a parallel backend.""" + self.registry[kind] = builder + + def get_parallel_backend(self, kind: ParallelBackend, n_workers: int) -> Executor: + """Get a parallel backend.""" + __tracebackhide__ = True + try: + return self.registry[kind](n_workers=n_workers) + except KeyError: + msg = f"No registered parallel backend found for kind {kind}." + raise ValueError(msg) from None + except Exception as e: # noqa: BLE001 + msg = f"Could not instantiate parallel backend {kind}." + raise ValueError(msg) from e + + +registry = ParallelBackendRegistry() + + +registry.register_parallel_backend( + ParallelBackend.PROCESSES, + lambda n_workers: _CloudpickleProcessPoolExecutor(max_workers=n_workers), +) +registry.register_parallel_backend( + ParallelBackend.THREADS, lambda n_workers: ThreadPoolExecutor(max_workers=n_workers) +) +registry.register_parallel_backend( + ParallelBackend.LOKY, lambda n_workers: get_reusable_executor(max_workers=n_workers) +) diff --git a/src/pytask_parallel/config.py b/src/pytask_parallel/config.py index 7e4dd2e..d6e63e7 100644 --- a/src/pytask_parallel/config.py +++ b/src/pytask_parallel/config.py @@ -2,7 +2,6 @@ from __future__ import annotations -import enum import os from typing import Any @@ -17,25 +16,36 @@ def pytask_parse_config(config: dict[str, Any]) -> None: if config["n_workers"] == "auto": config["n_workers"] = max(os.cpu_count() - 1, 1) - if ( - isinstance(config["parallel_backend"], str) - and config["parallel_backend"] in ParallelBackend._value2member_map_ # noqa: SLF001 - ): + try: config["parallel_backend"] = ParallelBackend(config["parallel_backend"]) - elif ( - isinstance(config["parallel_backend"], enum.Enum) - and config["parallel_backend"] in ParallelBackend - ): - pass - else: + except ValueError: msg = f"Invalid value for 'parallel_backend'. Got {config['parallel_backend']}." - raise ValueError(msg) + raise ValueError(msg) from None config["delay"] = 0.1 -@hookimpl +@hookimpl(trylast=True) def pytask_post_parse(config: dict[str, Any]) -> None: """Disable parallelization if debugging is enabled.""" if config["pdb"] or config["trace"] or config["dry_run"]: config["n_workers"] = 1 + + if config["n_workers"] > 1: + if config["parallel_backend"] == ParallelBackend.THREADS: + from pytask_parallel import threads + + config["pm"].register(threads) + + elif config["parallel_backend"] in ( + ParallelBackend.LOKY, + ParallelBackend.PROCESSES, + ): + from pytask_parallel import processes + + config["pm"].register(processes) + + if config["n_workers"] > 1 or config["parallel_backend"] == ParallelBackend.CUSTOM: + from pytask_parallel import execute + + config["pm"].register(execute) diff --git a/src/pytask_parallel/execute.py b/src/pytask_parallel/execute.py index 837b666..433bbc7 100644 --- a/src/pytask_parallel/execute.py +++ b/src/pytask_parallel/execute.py @@ -2,64 +2,28 @@ from __future__ import annotations -import inspect import sys import time -import warnings -from functools import partial from typing import TYPE_CHECKING from typing import Any -from typing import Callable -import cloudpickle from attrs import define from attrs import field from pytask import ExecutionReport -from pytask import Mark from pytask import PNode -from pytask import PTask from pytask import PythonNode from pytask import Session -from pytask import Task -from pytask import WarningReport -from pytask import console -from pytask import get_marks from pytask import hookimpl -from pytask import parse_warning_filter -from pytask import remove_internal_traceback_frames_from_exc_info -from pytask import warning_record_to_str -from pytask.tree_util import PyTree -from pytask.tree_util import tree_leaves from pytask.tree_util import tree_map -from pytask.tree_util import tree_structure -from rich.traceback import Traceback -from pytask_parallel.backends import PARALLEL_BACKEND_BUILDER from pytask_parallel.backends import ParallelBackend +from pytask_parallel.backends import registry +from pytask_parallel.utils import is_parallelized if TYPE_CHECKING: from concurrent.futures import Future - from pathlib import Path - from types import ModuleType from types import TracebackType - from rich.console import ConsoleOptions - - -@hookimpl -def pytask_post_parse(config: dict[str, Any]) -> None: - """Register the parallel backend.""" - if config["parallel_backend"] == ParallelBackend.THREADS: - config["pm"].register(DefaultBackendNameSpace) - else: - config["pm"].register(ProcessesNameSpace) - - if PARALLEL_BACKEND_BUILDER[config["parallel_backend"]] is None: - raise - config["_parallel_executor"] = PARALLEL_BACKEND_BUILDER[ - config["parallel_backend"] - ]() - @hookimpl(tryfirst=True) def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR0915 @@ -75,15 +39,15 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 """ __tracebackhide__ = True - if session.config["n_workers"] > 1: + if is_parallelized(session.config["n_workers"], session.config["parallel_backend"]): reports = session.execution_reports running_tasks: dict[str, Future[Any]] = {} - parallel_backend = PARALLEL_BACKEND_BUILDER[ - session.config["parallel_backend"] - ]() + parallel_backend = registry.get_parallel_backend( + session.config["parallel_backend"], n_workers=session.config["n_workers"] + ) - with parallel_backend(max_workers=session.config["n_workers"]) as executor: + with parallel_backend as executor: session.config["_parallel_executor"] = executor sleeper = _Sleeper() @@ -216,215 +180,6 @@ def _parse_future_exception( return None if exc is None else (type(exc), exc, exc.__traceback__) -class ProcessesNameSpace: - """The name space for hooks related to processes.""" - - @staticmethod - @hookimpl(tryfirst=True) - def pytask_execute_task(session: Session, task: PTask) -> Future[Any] | None: - """Execute a task. - - Take a task, pickle it and send the bytes over to another process. - - """ - if session.config["n_workers"] > 1: - kwargs = _create_kwargs_for_task(task) - - # Task modules are dynamically loaded and added to `sys.modules`. Thus, - # cloudpickle believes the module of the task function is also importable in - # the child process. We have to register the module as dynamic again, so - # that cloudpickle will pickle it with the function. See cloudpickle#417, - # pytask#373 and pytask#374. - task_module = _get_module(task.function, getattr(task, "path", None)) - cloudpickle.register_pickle_by_value(task_module) - - return session.config["_parallel_executor"].submit( - _execute_task, - task=task, - kwargs=kwargs, - show_locals=session.config["show_locals"], - console_options=console.options, - session_filterwarnings=session.config["filterwarnings"], - task_filterwarnings=get_marks(task, "filterwarnings"), - ) - return None - - -def _raise_exception_on_breakpoint(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 - msg = ( - "You cannot use 'breakpoint()' or 'pdb.set_trace()' while parallelizing the " - "execution of tasks with pytask-parallel. Please, remove the breakpoint or run " - "the task without parallelization to debug it." - ) - raise RuntimeError(msg) - - -def _patch_set_trace_and_breakpoint() -> None: - """Patch :func:`pdb.set_trace` and :func:`breakpoint`. - - Patch sys.breakpointhook to intercept any call of breakpoint() and pdb.set_trace in - a subprocess and print a better exception message. - - """ - import pdb # noqa: T100 - import sys - - pdb.set_trace = _raise_exception_on_breakpoint - sys.breakpointhook = _raise_exception_on_breakpoint - - -def _execute_task( # noqa: PLR0913 - task: PTask, - kwargs: dict[str, Any], - show_locals: bool, # noqa: FBT001 - console_options: ConsoleOptions, - session_filterwarnings: tuple[str, ...], - task_filterwarnings: tuple[Mark, ...], -) -> tuple[ - PyTree[PythonNode | None], - list[WarningReport], - tuple[type[BaseException], BaseException, str] | None, -]: - """Unserialize and execute task. - - This function receives bytes and unpickles them to a task which is them execute in a - spawned process or thread. - - """ - __tracebackhide__ = True - _patch_set_trace_and_breakpoint() - - with warnings.catch_warnings(record=True) as log: - for arg in session_filterwarnings: - warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) - - # apply filters from "filterwarnings" marks - for mark in task_filterwarnings: - for arg in mark.args: - warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) - - try: - out = task.execute(**kwargs) - except Exception: # noqa: BLE001 - exc_info = sys.exc_info() - processed_exc_info = _process_exception( - exc_info, show_locals, console_options - ) - else: - _handle_task_function_return(task, out) - processed_exc_info = None - - task_display_name = getattr(task, "display_name", task.name) - warning_reports = [] - for warning_message in log: - fs_location = warning_message.filename, warning_message.lineno - warning_reports.append( - WarningReport( - message=warning_record_to_str(warning_message), - fs_location=fs_location, - id_=task_display_name, - ) - ) - - python_nodes = tree_map( - lambda x: x if isinstance(x, PythonNode) else None, task.produces - ) - - return python_nodes, warning_reports, processed_exc_info - - -def _process_exception( - exc_info: tuple[type[BaseException], BaseException, TracebackType | None], - show_locals: bool, # noqa: FBT001 - console_options: ConsoleOptions, -) -> tuple[type[BaseException], BaseException, str]: - """Process the exception and convert the traceback to a string.""" - exc_info = remove_internal_traceback_frames_from_exc_info(exc_info) - traceback = Traceback.from_exception(*exc_info, show_locals=show_locals) - segments = console.render(traceback, options=console_options) - text = "".join(segment.text for segment in segments) - return (*exc_info[:2], text) - - -def _handle_task_function_return(task: PTask, out: Any) -> None: - if "return" not in task.produces: - return - - structure_out = tree_structure(out) - structure_return = tree_structure(task.produces["return"]) - # strict must be false when none is leaf. - if not structure_return.is_prefix(structure_out, strict=False): - msg = ( - "The structure of the return annotation is not a subtree of " - "the structure of the function return.\n\nFunction return: " - f"{structure_out}\n\nReturn annotation: {structure_return}" - ) - raise ValueError(msg) - - nodes = tree_leaves(task.produces["return"]) - values = structure_return.flatten_up_to(out) - for node, value in zip(nodes, values): - node.save(value) - - -class DefaultBackendNameSpace: - """The name space for hooks related to threads.""" - - @staticmethod - @hookimpl(tryfirst=True) - def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None: - """Execute a task. - - Since threads have shared memory, it is not necessary to pickle and unpickle the - task. - - """ - if session.config["n_workers"] > 1: - kwargs = _create_kwargs_for_task(task) - return session.config["_parallel_executor"].submit( - _mock_processes_for_threads, task=task, **kwargs - ) - return None - - -def _mock_processes_for_threads( - task: PTask, **kwargs: Any -) -> tuple[ - None, list[Any], tuple[type[BaseException], BaseException, TracebackType] | None -]: - """Mock execution function such that it returns the same as for processes. - - The function for processes returns ``warning_reports`` and an ``exception``. With - threads, these object are collected by the main and not the subprocess. So, we just - return placeholders. - - """ - __tracebackhide__ = True - try: - out = task.function(**kwargs) - except Exception: # noqa: BLE001 - exc_info = sys.exc_info() - else: - _handle_task_function_return(task, out) - exc_info = None - return None, [], exc_info - - -def _create_kwargs_for_task(task: PTask) -> dict[str, PyTree[Any]]: - """Create kwargs for task function.""" - parameters = inspect.signature(task.function).parameters - - kwargs = {} - for name, value in task.depends_on.items(): - kwargs[name] = tree_map(lambda x: x.load(), value) - - for name, value in task.produces.items(): - if name in parameters: - kwargs[name] = tree_map(lambda x: x.load(), value) - - return kwargs - - @define(kw_only=True) class _Sleeper: """A sleeper that always sleeps a bit and up to 1 second if you don't wake it up. @@ -446,22 +201,3 @@ def increment(self) -> None: def sleep(self) -> None: time.sleep(self.timings[self.timing_idx]) - - -def _get_module(func: Callable[..., Any], path: Path | None) -> ModuleType: - """Get the module of a python function. - - ``functools.partial`` obfuscates the module of the function and - ``inspect.getmodule`` returns :mod`functools`. Therefore, we recover the original - function. - - We use the path from the task module to aid the search although it is not clear - whether it helps. - - """ - if isinstance(func, partial): - func = func.func - - if path: - return inspect.getmodule(func, path.as_posix()) - return inspect.getmodule(func) diff --git a/src/pytask_parallel/plugin.py b/src/pytask_parallel/plugin.py index 8ccf789..353cdef 100644 --- a/src/pytask_parallel/plugin.py +++ b/src/pytask_parallel/plugin.py @@ -8,7 +8,6 @@ from pytask_parallel import build from pytask_parallel import config -from pytask_parallel import execute from pytask_parallel import logging if TYPE_CHECKING: @@ -20,5 +19,4 @@ def pytask_add_hooks(pm: PluginManager) -> None: """Register plugins.""" pm.register(build) pm.register(config) - pm.register(execute) pm.register(logging) diff --git a/src/pytask_parallel/processes.py b/src/pytask_parallel/processes.py new file mode 100644 index 0000000..2421952 --- /dev/null +++ b/src/pytask_parallel/processes.py @@ -0,0 +1,181 @@ +"""Contains hooks for parallel execution of tasks with processes/loky.""" + +from __future__ import annotations + +import inspect +import sys +import warnings +from functools import partial +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable + +import cloudpickle +from pytask import Mark +from pytask import PTask +from pytask import PythonNode +from pytask import Session +from pytask import WarningReport +from pytask import console +from pytask import get_marks +from pytask import hookimpl +from pytask import parse_warning_filter +from pytask import remove_internal_traceback_frames_from_exc_info +from pytask import warning_record_to_str +from pytask.tree_util import PyTree +from pytask.tree_util import tree_map +from rich.traceback import Traceback + +from pytask_parallel.utils import create_kwargs_for_task +from pytask_parallel.utils import handle_task_function_return + +if TYPE_CHECKING: + from concurrent.futures import Future + from pathlib import Path + from types import ModuleType + from types import TracebackType + + from rich.console import ConsoleOptions + + +@hookimpl +def pytask_execute_task(session: Session, task: PTask) -> Future[Any]: + """Execute a task. + + Take a task, pickle it and send the bytes over to another process. + + """ + kwargs = create_kwargs_for_task(task) + + # Task modules are dynamically loaded and added to `sys.modules`. Thus, + # cloudpickle believes the module of the task function is also importable in the + # child process. We have to register the module as dynamic again, so that + # cloudpickle will pickle it with the function. See cloudpickle#417, pytask#373 + # and pytask#374. + task_module = _get_module(task.function, getattr(task, "path", None)) + cloudpickle.register_pickle_by_value(task_module) + + return session.config["_parallel_executor"].submit( + _execute_task, + task=task, + kwargs=kwargs, + show_locals=session.config["show_locals"], + console_options=console.options, + session_filterwarnings=session.config["filterwarnings"], + task_filterwarnings=get_marks(task, "filterwarnings"), + ) + + +def _raise_exception_on_breakpoint(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 + msg = ( + "You cannot use 'breakpoint()' or 'pdb.set_trace()' while parallelizing the " + "execution of tasks with pytask-parallel. Please, remove the breakpoint or run " + "the task without parallelization to debug it." + ) + raise RuntimeError(msg) + + +def _patch_set_trace_and_breakpoint() -> None: + """Patch :func:`pdb.set_trace` and :func:`breakpoint`. + + Patch sys.breakpointhook to intercept any call of breakpoint() and pdb.set_trace in + a subprocess and print a better exception message. + + """ + import pdb # noqa: T100 + import sys + + pdb.set_trace = _raise_exception_on_breakpoint + sys.breakpointhook = _raise_exception_on_breakpoint + + +def _process_exception( + exc_info: tuple[type[BaseException], BaseException, TracebackType | None], + show_locals: bool, # noqa: FBT001 + console_options: ConsoleOptions, +) -> tuple[type[BaseException], BaseException, str]: + """Process the exception and convert the traceback to a string.""" + exc_info = remove_internal_traceback_frames_from_exc_info(exc_info) + traceback = Traceback.from_exception(*exc_info, show_locals=show_locals) + segments = console.render(traceback, options=console_options) + text = "".join(segment.text for segment in segments) + return (*exc_info[:2], text) + + +def _execute_task( # noqa: PLR0913 + task: PTask, + kwargs: dict[str, Any], + show_locals: bool, # noqa: FBT001 + console_options: ConsoleOptions, + session_filterwarnings: tuple[str, ...], + task_filterwarnings: tuple[Mark, ...], +) -> tuple[ + PyTree[PythonNode | None], + list[WarningReport], + tuple[type[BaseException], BaseException, str] | None, +]: + """Unserialize and execute task. + + This function receives bytes and unpickles them to a task which is them execute in a + spawned process or thread. + + """ + __tracebackhide__ = True + _patch_set_trace_and_breakpoint() + + with warnings.catch_warnings(record=True) as log: + for arg in session_filterwarnings: + warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) + + # apply filters from "filterwarnings" marks + for mark in task_filterwarnings: + for arg in mark.args: + warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) + + try: + out = task.execute(**kwargs) + except Exception: # noqa: BLE001 + exc_info = sys.exc_info() + processed_exc_info = _process_exception( + exc_info, show_locals, console_options + ) + else: + handle_task_function_return(task, out) + processed_exc_info = None + + task_display_name = getattr(task, "display_name", task.name) + warning_reports = [] + for warning_message in log: + fs_location = warning_message.filename, warning_message.lineno + warning_reports.append( + WarningReport( + message=warning_record_to_str(warning_message), + fs_location=fs_location, + id_=task_display_name, + ) + ) + + python_nodes = tree_map( + lambda x: x if isinstance(x, PythonNode) else None, task.produces + ) + + return python_nodes, warning_reports, processed_exc_info + + +def _get_module(func: Callable[..., Any], path: Path | None) -> ModuleType: + """Get the module of a python function. + + ``functools.partial`` obfuscates the module of the function and + ``inspect.getmodule`` returns :mod`functools`. Therefore, we recover the original + function. + + We use the path from the task module to aid the search although it is not clear + whether it helps. + + """ + if isinstance(func, partial): + func = func.func + + if path: + return inspect.getmodule(func, path.as_posix()) + return inspect.getmodule(func) diff --git a/src/pytask_parallel/threads.py b/src/pytask_parallel/threads.py new file mode 100644 index 0000000..351e495 --- /dev/null +++ b/src/pytask_parallel/threads.py @@ -0,0 +1,56 @@ +"""Contains functions for parallel execution of tasks with threads.""" + +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING +from typing import Any + +from pytask import PTask +from pytask import Session +from pytask import Task +from pytask import hookimpl + +from pytask_parallel.utils import create_kwargs_for_task +from pytask_parallel.utils import handle_task_function_return + +if TYPE_CHECKING: + from concurrent.futures import Future + from types import TracebackType + + +@hookimpl +def pytask_execute_task(session: Session, task: Task) -> Future[Any]: + """Execute a task. + + Since threads have shared memory, it is not necessary to pickle and unpickle the + task. + + """ + kwargs = create_kwargs_for_task(task) + return session.config["_parallel_executor"].submit( + _mock_processes_for_threads, task=task, **kwargs + ) + + +def _mock_processes_for_threads( + task: PTask, **kwargs: Any +) -> tuple[ + None, list[Any], tuple[type[BaseException], BaseException, TracebackType] | None +]: + """Mock execution function such that it returns the same as for processes. + + The function for processes returns ``warning_reports`` and an ``exception``. With + threads, these object are collected by the main and not the subprocess. So, we just + return placeholders. + + """ + __tracebackhide__ = True + try: + out = task.function(**kwargs) + except Exception: # noqa: BLE001 + exc_info = sys.exc_info() + else: + handle_task_function_return(task, out) + exc_info = None + return None, [], exc_info diff --git a/src/pytask_parallel/utils.py b/src/pytask_parallel/utils.py new file mode 100644 index 0000000..024243d --- /dev/null +++ b/src/pytask_parallel/utils.py @@ -0,0 +1,61 @@ +"""Contains utility functions.""" + +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING +from typing import Any + +from pytask.tree_util import PyTree +from pytask.tree_util import tree_leaves +from pytask.tree_util import tree_map +from pytask.tree_util import tree_structure + +from pytask_parallel.backends import ParallelBackend + +if TYPE_CHECKING: + from pytask import PTask + + +def create_kwargs_for_task(task: PTask) -> dict[str, PyTree[Any]]: + """Create kwargs for task function.""" + parameters = inspect.signature(task.function).parameters + + kwargs = {} + for name, value in task.depends_on.items(): + kwargs[name] = tree_map(lambda x: x.load(), value) + + for name, value in task.produces.items(): + if name in parameters: + kwargs[name] = tree_map(lambda x: x.load(), value) + + return kwargs + + +def handle_task_function_return(task: PTask, out: Any) -> None: + """Handle the return value of a task function.""" + if "return" not in task.produces: + return + + structure_out = tree_structure(out) + structure_return = tree_structure(task.produces["return"]) + # strict must be false when none is leaf. + if not structure_return.is_prefix(structure_out, strict=False): + msg = ( + "The structure of the return annotation is not a subtree of " + "the structure of the function return.\n\nFunction return: " + f"{structure_out}\n\nReturn annotation: {structure_return}" + ) + raise ValueError(msg) + + nodes = tree_leaves(task.produces["return"]) + values = structure_return.flatten_up_to(out) + for node, value in zip(nodes, values): + node.save(value) + + +def is_parallelized(n_workers: int, parallel_backend: ParallelBackend) -> bool: + """Check if the execution is parallelized.""" + return parallel_backend == ParallelBackend.CUSTOM or ( + n_workers > 1 and parallel_backend != ParallelBackend.CUSTOM + ) diff --git a/tests/test_backends.py b/tests/test_backends.py new file mode 100644 index 0000000..78c283e --- /dev/null +++ b/tests/test_backends.py @@ -0,0 +1,38 @@ +import textwrap + +import pytest +from pytask import ExitCode +from pytask import cli + + +@pytest.mark.end_to_end() +def test_error_requesting_custom_backend_without_registration(runner, tmp_path): + tmp_path.joinpath("task_example.py").write_text("def task_example(): pass") + result = runner.invoke(cli, [tmp_path.as_posix(), "--parallel-backend", "custom"]) + assert result.exit_code == ExitCode.FAILED + assert "No registered parallel backend found" in result.output + + +@pytest.mark.end_to_end() +def test_register_custom_backend(runner, tmp_path): + source = """ + from pytask_parallel import registry, ParallelBackend + from concurrent.futures import ProcessPoolExecutor + + def custom_builder(n_workers): + print("Build custom executor.") + return ProcessPoolExecutor(max_workers=n_workers) + + registry.register_parallel_backend(ParallelBackend.CUSTOM, custom_builder) + + + def task_example(): pass + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + result = runner.invoke( + cli, + [tmp_path.as_posix(), "--parallel-backend", "custom"], + ) + assert result.exit_code == ExitCode.OK + assert "Build custom executor." in result.output + assert "1 Succeeded" in result.output diff --git a/tests/test_execute.py b/tests/test_execute.py index 6010f7a..a99a0de 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -12,13 +12,11 @@ from tests.conftest import restore_sys_path_and_module_after_test_execution - -class Session: - pass +_IMPLEMENTED_BACKENDS = [p for p in ParallelBackend if p != ParallelBackend.CUSTOM] @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_parallel_execution(tmp_path, parallel_backend): source = """ from pytask import Product @@ -40,7 +38,7 @@ def task_2(path: Annotated[Path, Product] = Path("out_2.txt")): @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_parallel_execution_w_cli(runner, tmp_path, parallel_backend): source = """ from pytask import Product @@ -70,7 +68,7 @@ def task_2(path: Annotated[Path, Product] = Path("out_2.txt")): @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_stop_execution_when_max_failures_is_reached(tmp_path, parallel_backend): source = """ import time @@ -98,7 +96,7 @@ def task_3(): time.sleep(3) @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_task_priorities(tmp_path, parallel_backend): source = """ import pytask @@ -139,7 +137,7 @@ def task_5(): @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) @pytest.mark.parametrize("show_locals", [True, False]) def test_rendering_of_tracebacks_with_rich( runner, tmp_path, parallel_backend, show_locals @@ -221,7 +219,7 @@ def test_sleeper(): @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_task_that_return(runner, tmp_path, parallel_backend): source = """ from pathlib import Path @@ -241,7 +239,7 @@ def task_example() -> Annotated[str, Path("file.txt")]: @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_task_without_path_that_return(runner, tmp_path, parallel_backend): source = """ from pathlib import Path @@ -263,7 +261,7 @@ def test_task_without_path_that_return(runner, tmp_path, parallel_backend): @pytest.mark.end_to_end() @pytest.mark.parametrize("flag", ["--pdb", "--trace", "--dry-run"]) -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_parallel_execution_is_deactivated(runner, tmp_path, flag, parallel_backend): tmp_path.joinpath("task_example.py").write_text("def task_example(): pass") result = runner.invoke( @@ -277,7 +275,12 @@ def test_parallel_execution_is_deactivated(runner, tmp_path, flag, parallel_back @pytest.mark.end_to_end() @pytest.mark.parametrize("code", ["breakpoint()", "import pdb; pdb.set_trace()"]) @pytest.mark.parametrize( - "parallel_backend", [i for i in ParallelBackend if i != ParallelBackend.THREADS] + "parallel_backend", + [ + i + for i in ParallelBackend + if i not in (ParallelBackend.THREADS, ParallelBackend.CUSTOM) + ], ) def test_raise_error_on_breakpoint(runner, tmp_path, code, parallel_backend): tmp_path.joinpath("task_example.py").write_text(f"def task_example(): {code}") @@ -289,7 +292,7 @@ def test_raise_error_on_breakpoint(runner, tmp_path, code, parallel_backend): @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_task_partialed(runner, tmp_path, parallel_backend): source = """ from pathlib import Path From c109c6ce3028bd86ee32e65c35d1d8cb7c9e3c81 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 30 Mar 2024 10:17:46 +0100 Subject: [PATCH 2/9] Remove delayed imports. --- src/pytask_parallel/config.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/pytask_parallel/config.py b/src/pytask_parallel/config.py index d6e63e7..11774f7 100644 --- a/src/pytask_parallel/config.py +++ b/src/pytask_parallel/config.py @@ -7,6 +7,8 @@ from pytask import hookimpl +from pytask_parallel import processes +from pytask_parallel import threads from pytask_parallel.backends import ParallelBackend @@ -33,16 +35,12 @@ def pytask_post_parse(config: dict[str, Any]) -> None: if config["n_workers"] > 1: if config["parallel_backend"] == ParallelBackend.THREADS: - from pytask_parallel import threads - config["pm"].register(threads) elif config["parallel_backend"] in ( ParallelBackend.LOKY, ParallelBackend.PROCESSES, ): - from pytask_parallel import processes - config["pm"].register(processes) if config["n_workers"] > 1 or config["parallel_backend"] == ParallelBackend.CUSTOM: From f6eafa6443975552ade6e6195e8dacb73c5270b2 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 30 Mar 2024 13:56:05 +0100 Subject: [PATCH 3/9] Make parsing more robust. --- src/pytask_parallel/config.py | 10 +++++- src/pytask_parallel/custom.py | 27 +++++++++++++++ src/pytask_parallel/execute.py | 60 +++++++++++++--------------------- src/pytask_parallel/utils.py | 39 ++++++++++++++++++++++ tests/test_backends.py | 22 +++++++++++-- 5 files changed, 117 insertions(+), 41 deletions(-) create mode 100644 src/pytask_parallel/custom.py diff --git a/src/pytask_parallel/config.py b/src/pytask_parallel/config.py index 9385d20..8e58f55 100644 --- a/src/pytask_parallel/config.py +++ b/src/pytask_parallel/config.py @@ -7,6 +7,7 @@ from pytask import hookimpl +from pytask_parallel import custom from pytask_parallel import execute from pytask_parallel import processes from pytask_parallel import threads @@ -39,9 +40,16 @@ def pytask_post_parse(config: dict[str, Any]) -> None: if config["pdb"] or config["trace"] or config["dry_run"]: config["n_workers"] = 1 - if config["n_workers"] > 1: + # Register parallel execute hook. + if config["n_workers"] > 1 or config["parallel_backend"] == ParallelBackend.CUSTOM: config["pm"].register(execute) + + # Register parallel backends. + if config["n_workers"] > 1: if config["parallel_backend"] == ParallelBackend.THREADS: config["pm"].register(threads) else: config["pm"].register(processes) + + if config["parallel_backend"] == ParallelBackend.CUSTOM: + config["pm"].register(custom) diff --git a/src/pytask_parallel/custom.py b/src/pytask_parallel/custom.py new file mode 100644 index 0000000..3ed2377 --- /dev/null +++ b/src/pytask_parallel/custom.py @@ -0,0 +1,27 @@ +"""Contains functions for the threads backend.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Any + +from pytask import PTask +from pytask import Session +from pytask import hookimpl + +from pytask_parallel.utils import create_kwargs_for_task + +if TYPE_CHECKING: + from concurrent.futures import Future + + +@hookimpl +def pytask_execute_task(session: Session, task: PTask) -> Future[Any]: + """Execute a task. + + Since threads have shared memory, it is not necessary to pickle and unpickle the + task. + + """ + kwargs = create_kwargs_for_task(task) + return session.config["_parallel_executor"].submit(task.function, **kwargs) diff --git a/src/pytask_parallel/execute.py b/src/pytask_parallel/execute.py index ceedfc8..e6cf659 100644 --- a/src/pytask_parallel/execute.py +++ b/src/pytask_parallel/execute.py @@ -11,17 +11,19 @@ from attrs import field from pytask import ExecutionReport from pytask import PNode +from pytask import PTask from pytask import PythonNode from pytask import Session from pytask import hookimpl +from pytask.tree_util import PyTree from pytask.tree_util import tree_map +from pytask.tree_util import tree_structure -from pytask_parallel.backends import ParallelBackend from pytask_parallel.backends import registry +from pytask_parallel.utils import parse_future_result if TYPE_CHECKING: from concurrent.futures import Future - from types import TracebackType @hookimpl @@ -86,21 +88,12 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 for task_name in list(running_tasks): future = running_tasks[task_name] + if future.done(): - # An exception was thrown before the task was executed. - if future.exception() is not None: - exc_info = _parse_future_exception(future.exception()) - warning_reports = [] - # A task raised an exception. - else: - (python_nodes, warning_reports, task_exception) = ( - future.result() - ) - session.warnings.extend(warning_reports) - exc_info = ( - _parse_future_exception(future.exception()) - or task_exception - ) + python_nodes, warnings_reports, exc_info = parse_future_result( + future + ) + session.warnings.extend(warnings_reports) if exc_info is not None: task = session.dag.nodes[task_name]["task"] @@ -111,16 +104,7 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 session.scheduler.done(task_name) else: task = session.dag.nodes[task_name]["task"] - - # Update PythonNodes with the values from the future if - # not threads. - if ( - session.config["parallel_backend"] - != ParallelBackend.THREADS - ): - task.produces = tree_map( - _update_python_node, task.produces, python_nodes - ) + _update_python_nodes(task, python_nodes) try: session.hook.pytask_execute_task_teardown( @@ -136,8 +120,6 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 running_tasks.pop(task_name) newly_collected_reports.append(report) session.scheduler.done(task_name) - else: - pass for report in newly_collected_reports: session.hook.pytask_execute_task_process_report( @@ -158,17 +140,21 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 return True -def _update_python_node(x: PNode, y: PythonNode | None) -> PNode: - if y: - x.save(y.load()) - return x +def _update_python_nodes( + task: PTask, python_nodes: dict[str, PyTree[PythonNode | None]] | None +) -> None: + """Update the python nodes of a task with the python nodes from the future.""" + def _update_python_node(x: PNode, y: PythonNode | None) -> PNode: + if y: + x.save(y.load()) + return x -def _parse_future_exception( - exc: BaseException | None, -) -> tuple[type[BaseException], BaseException, TracebackType] | None: - """Parse a future exception into the format of ``sys.exc_info``.""" - return None if exc is None else (type(exc), exc, exc.__traceback__) + structure_python_nodes = tree_structure(python_nodes) + structure_produces = tree_structure(task.produces) + # strict must be false when none is leaf. + if structure_produces.is_prefix(structure_python_nodes, strict=False): + task.produces = tree_map(_update_python_node, task.produces, python_nodes) # type: ignore[assignment] @define(kw_only=True) diff --git a/src/pytask_parallel/utils.py b/src/pytask_parallel/utils.py index 5fd1b50..122d620 100644 --- a/src/pytask_parallel/utils.py +++ b/src/pytask_parallel/utils.py @@ -12,7 +12,39 @@ from pytask.tree_util import tree_structure if TYPE_CHECKING: + from concurrent.futures import Future + from types import TracebackType + from pytask import PTask + from pytask import PythonNode + from pytask import WarningReport + + +def parse_future_result( + future: Future[Any], +) -> tuple[ + dict[str, PyTree[PythonNode | None]] | None, + list[WarningReport], + tuple[type[BaseException], BaseException, TracebackType] | None, +]: + """Parse the result of a future.""" + # An exception was raised before the task was executed. + future_exception = future.exception() + if future_exception is not None: + exc_info = _parse_future_exception(future_exception) + return None, [], exc_info + + out = future.result() + if isinstance(out, tuple) and len(out) == 3: # noqa: PLR2004 + return out + + # What to do when the output does not match? + msg = ( + "The task function returns an unknown output format. Either return a tuple " + "with three elements, python nodes, warning reports and exception or only " + "return." + ) + raise Exception(msg) # noqa: TRY002 def handle_task_function_return(task: PTask, out: Any) -> None: @@ -50,3 +82,10 @@ def create_kwargs_for_task(task: PTask) -> dict[str, PyTree[Any]]: kwargs[name] = tree_map(lambda x: x.load(), value) return kwargs + + +def _parse_future_exception( + exc: BaseException | None, +) -> tuple[type[BaseException], BaseException, TracebackType] | None: + """Parse a future exception into the format of ``sys.exc_info``.""" + return None if exc is None else (type(exc), exc, exc.__traceback__) diff --git a/tests/test_backends.py b/tests/test_backends.py index 78c283e..fd65ed7 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -16,16 +16,32 @@ def test_error_requesting_custom_backend_without_registration(runner, tmp_path): @pytest.mark.end_to_end() def test_register_custom_backend(runner, tmp_path): source = """ - from pytask_parallel import registry, ParallelBackend + import cloudpickle + from concurrent.futures import ProcessPoolExecutor + from loky import get_reusable_executor + from pytask_parallel import registry, ParallelBackend + + def _deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes): + deserialized_fn = cloudpickle.loads(fn) + deserialized_kwargs = cloudpickle.loads(kwargs) + return None, [], deserialized_fn(**deserialized_kwargs) + + class _CloudpickleProcessPoolExecutor(ProcessPoolExecutor): + + def submit(self, fn, *args, **kwargs): + return super().submit( + _deserialize_and_run_with_cloudpickle, + fn=cloudpickle.dumps(fn), + kwargs=cloudpickle.dumps(kwargs), + ) def custom_builder(n_workers): print("Build custom executor.") - return ProcessPoolExecutor(max_workers=n_workers) + return _CloudpickleProcessPoolExecutor(max_workers=n_workers) registry.register_parallel_backend(ParallelBackend.CUSTOM, custom_builder) - def task_example(): pass """ tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) From ab0c4b16f2c7e618c3c3e4e8751f2d4b01f8c185 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 30 Mar 2024 18:44:56 +0100 Subject: [PATCH 4/9] Add more documentation. --- README.md | 51 +++++++++++++++++++------------- src/pytask_parallel/processes.py | 9 +++++- 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 325c007..8860ec6 100644 --- a/README.md +++ b/README.md @@ -47,15 +47,15 @@ $ pytask --n-workers 2 $ pytask -n auto ``` -Using processes to parallelize the execution of tasks is useful for CPU bound tasks such +Using processes to parallelize the execution of tasks is useful for CPU-bound tasks such as numerical computations. ([Here](https://stackoverflow.com/a/868577/7523785) is an -explanation on what CPU or IO bound means.) +explanation of what CPU- or IO-bound means.) -For IO bound tasks, tasks where the limiting factor are network responses, access to +For IO-bound tasks, tasks where the limiting factor is network latency and access to files, you can parallelize via threads. ```console -$ pytask --parallel-backend threads +pytask --parallel-backend threads ``` You can also set the options in a `pyproject.toml`. @@ -70,23 +70,23 @@ parallel_backend = "processes" # or loky or threads ## Custom Executor -pytask-parallel allows you to use your parallel backend. The only requirement is that -you provide an executor that implements the interface of +pytask-parallel allows you to use your parallel backend as long as it follows the +interface defined by [`concurrent.futures.Executor`](https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Executor). -To register your backend, go to a module that is imported by pytask when building the -project, for example, the `config.py`. Register a builder function for your custom -backend. +In some cases, adding a new backend can be as easy as registering a builder function +that receives some arguments (currently only `n_workers`) and returns the instantiated +executor. ```python from concurrent.futures import Executor -from concurrent.futures import ProcessPoolExecutor +from my_project.executor import CustomExecutor from pytask_parallel import ParallelBackend, registry def build_custom_executor(n_workers: int) -> Executor: - return ProcessPoolExecutor(max_workers=n_workers) + return CustomExecutor(max_workers=n_workers) registry.register_parallel_backend(ParallelBackend.CUSTOM, build_custom_executor) @@ -98,9 +98,20 @@ Now, build the project requesting your custom backend. pytask --parallel-backend custom ``` -> \[!NOTE\] -> -> When you request the custom backend, it is even used when `n_workers` is set to 1. +Realistically, it is not the only necessary adjustment for a nice user experience. There +are two other important things. pytask-parallel does not implement them by default since +it seems more tightly coupled to your backend. + +1. A wrapper for the executed function that captures warnings, catches exceptions and + saves products of the task (within the child process!). + + As an example, see `def _execute_task()` that does all that for the processes and + loky backend. + +1. To apply the wrapper, you need to write a custom hook implementation for + `def pytask_execute_task()`. See `def pytask_execute_task()` for an example. Use the + [`hook_module`](https://pytask-dev.readthedocs.io/en/stable/how_to_guides/extending_pytask.html#using-hook-module-and-hook-module) + configuration value to register your implementation. ## Some implementation details @@ -126,12 +137,12 @@ Consult the [release notes](CHANGES.md) to find out about what is new. - `pytask-parallel` does not call the `pytask_execute_task_protocol` hook specification/entry-point because `pytask_execute_task_setup` and `pytask_execute_task` need to be separated from `pytask_execute_task_teardown`. Thus, - plugins which change this hook specification may not interact well with the + plugins that change this hook specification may not interact well with the parallelization. -- There are two PRs for CPython which try to re-enable setting custom reducers which - should have been working, but does not. Here are the references. +- Two PRs for CPython try to re-enable setting custom reducers which should have been + working but does not. Here are the references. - > - - > - - > - + - https://bugs.python.org/issue28053 + - https://github.com/python/cpython/pull/9959 + - https://github.com/python/cpython/pull/15058 diff --git a/src/pytask_parallel/processes.py b/src/pytask_parallel/processes.py index 8c5abfe..b9541ba 100644 --- a/src/pytask_parallel/processes.py +++ b/src/pytask_parallel/processes.py @@ -106,14 +106,19 @@ def _execute_task( # noqa: PLR0913 spawned process or thread. """ + # Hide this function from tracebacks. __tracebackhide__ = True + + # Patch set_trace and breakpoint to show a better error message. _patch_set_trace_and_breakpoint() + # Catch warnings and store them in a list. with warnings.catch_warnings(record=True) as log: + # Apply global filterwarnings. for arg in session_filterwarnings: warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) - # apply filters from "filterwarnings" marks + # Apply filters from "filterwarnings" marks for mark in task_filterwarnings: for arg in mark.args: warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) @@ -126,6 +131,7 @@ def _execute_task( # noqa: PLR0913 exc_info, show_locals, console_options ) else: + # Save products. handle_task_function_return(task, out) processed_exc_info = None @@ -141,6 +147,7 @@ def _execute_task( # noqa: PLR0913 ) ) + # Collect all PythonNodes that are products to pass values back to the main process. python_nodes = tree_map( lambda x: x if isinstance(x, PythonNode) else None, task.produces ) From 36da23f4f09536fe0baaed1f1ea0cc60d7479147 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 30 Mar 2024 19:00:36 +0100 Subject: [PATCH 5/9] debug. --- tests/test_backends.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_backends.py b/tests/test_backends.py index fd65ed7..4ae8297 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -49,6 +49,7 @@ def task_example(): pass cli, [tmp_path.as_posix(), "--parallel-backend", "custom"], ) + print(result.output) # noqa: T201 assert result.exit_code == ExitCode.OK assert "Build custom executor." in result.output assert "1 Succeeded" in result.output From ab83798d03d95bde2d8772832bc58ed593d96546 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 30 Mar 2024 21:17:10 +0100 Subject: [PATCH 6/9] Fix test. --- src/pytask_parallel/utils.py | 3 +++ tests/test_backends.py | 40 ++++++++++++++++++------------------ 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src/pytask_parallel/utils.py b/src/pytask_parallel/utils.py index 122d620..11cc19f 100644 --- a/src/pytask_parallel/utils.py +++ b/src/pytask_parallel/utils.py @@ -38,6 +38,9 @@ def parse_future_result( if isinstance(out, tuple) and len(out) == 3: # noqa: PLR2004 return out + if out is None: + return None, [], None + # What to do when the output does not match? msg = ( "The task function returns an unknown output format. Either return a tuple " diff --git a/tests/test_backends.py b/tests/test_backends.py index 4ae8297..f8d4e62 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -15,39 +15,39 @@ def test_error_requesting_custom_backend_without_registration(runner, tmp_path): @pytest.mark.end_to_end() def test_register_custom_backend(runner, tmp_path): - source = """ + hook_source = """ import cloudpickle - - from concurrent.futures import ProcessPoolExecutor from loky import get_reusable_executor - from pytask_parallel import registry, ParallelBackend + from pytask import hookimpl + from pytask_parallel import ParallelBackend + from pytask_parallel import registry + from pytask_parallel.processes import _get_module + from pytask_parallel.utils import create_kwargs_for_task + - def _deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes): - deserialized_fn = cloudpickle.loads(fn) - deserialized_kwargs = cloudpickle.loads(kwargs) - return None, [], deserialized_fn(**deserialized_kwargs) + @hookimpl(tryfirst=True) + def pytask_execute_task(session, task): + kwargs = create_kwargs_for_task(task) - class _CloudpickleProcessPoolExecutor(ProcessPoolExecutor): + task_module = _get_module(task.function, getattr(task, "path", None)) + cloudpickle.register_pickle_by_value(task_module) + + return session.config["_parallel_executor"].submit(task.function, **kwargs) - def submit(self, fn, *args, **kwargs): - return super().submit( - _deserialize_and_run_with_cloudpickle, - fn=cloudpickle.dumps(fn), - kwargs=cloudpickle.dumps(kwargs), - ) def custom_builder(n_workers): print("Build custom executor.") - return _CloudpickleProcessPoolExecutor(max_workers=n_workers) + return get_reusable_executor(max_workers=n_workers) - registry.register_parallel_backend(ParallelBackend.CUSTOM, custom_builder) - def task_example(): pass + registry.register_parallel_backend(ParallelBackend.CUSTOM, custom_builder) """ - tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + tmp_path.joinpath("hook.py").write_text(textwrap.dedent(hook_source)) + + tmp_path.joinpath("task_example.py").write_text("def task_example(): pass") result = runner.invoke( cli, - [tmp_path.as_posix(), "--parallel-backend", "custom"], + [tmp_path.as_posix(), "--parallel-backend", "custom", "--hook-module", tmp_path.joinpath("hook.py").as_posix()], ) print(result.output) # noqa: T201 assert result.exit_code == ExitCode.OK From 03414860fdbb52abc7fa9371f2613a7d2f087108 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 30 Mar 2024 20:17:40 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_backends.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_backends.py b/tests/test_backends.py index f8d4e62..a6dab62 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -47,7 +47,13 @@ def custom_builder(n_workers): tmp_path.joinpath("task_example.py").write_text("def task_example(): pass") result = runner.invoke( cli, - [tmp_path.as_posix(), "--parallel-backend", "custom", "--hook-module", tmp_path.joinpath("hook.py").as_posix()], + [ + tmp_path.as_posix(), + "--parallel-backend", + "custom", + "--hook-module", + tmp_path.joinpath("hook.py").as_posix(), + ], ) print(result.output) # noqa: T201 assert result.exit_code == ExitCode.OK From c441dbb75fa6ab3ab17d8ad5061840c802dc1c41 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 30 Mar 2024 21:33:51 +0100 Subject: [PATCH 8/9] Simplify tests. --- environment.yml | 2 -- src/pytask_parallel/backends.py | 2 +- tests/test_backends.py | 21 +++++++++++++++++++-- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/environment.yml b/environment.yml index 27498f2..4def0f2 100644 --- a/environment.yml +++ b/environment.yml @@ -1,7 +1,6 @@ name: pytask-parallel channels: - - conda-forge/label/pytask_rc - conda-forge - nodefaults @@ -21,7 +20,6 @@ dependencies: - tox - ipywidgets - nbmake - - pre-commit - pytest-cov - pip: diff --git a/src/pytask_parallel/backends.py b/src/pytask_parallel/backends.py index f2172ab..8b53f5c 100644 --- a/src/pytask_parallel/backends.py +++ b/src/pytask_parallel/backends.py @@ -71,7 +71,7 @@ def get_parallel_backend(self, kind: ParallelBackend, n_workers: int) -> Executo msg = f"No registered parallel backend found for kind {kind}." raise ValueError(msg) from None except Exception as e: # noqa: BLE001 - msg = f"Could not instantiate parallel backend {kind}." + msg = f"Could not instantiate parallel backend {kind.value}." raise ValueError(msg) from e diff --git a/tests/test_backends.py b/tests/test_backends.py index a6dab62..b722ffd 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -13,6 +13,25 @@ def test_error_requesting_custom_backend_without_registration(runner, tmp_path): assert "No registered parallel backend found" in result.output +@pytest.mark.end_to_end() +def test_error_while_instantiating_custom_backend(runner, tmp_path): + hook_source = """ + from pytask_parallel import ParallelBackend, registry + + def custom_builder(n_workers): + raise Exception("ERROR") + + registry.register_parallel_backend(ParallelBackend.CUSTOM, custom_builder) + + def task_example(): pass + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(hook_source)) + result = runner.invoke(cli, [tmp_path.as_posix(), "--parallel-backend", "custom"]) + assert result.exit_code == ExitCode.FAILED + assert "ERROR" in result.output + assert "Could not instantiate parallel backend custom." in result.output + + @pytest.mark.end_to_end() def test_register_custom_backend(runner, tmp_path): hook_source = """ @@ -43,7 +62,6 @@ def custom_builder(n_workers): registry.register_parallel_backend(ParallelBackend.CUSTOM, custom_builder) """ tmp_path.joinpath("hook.py").write_text(textwrap.dedent(hook_source)) - tmp_path.joinpath("task_example.py").write_text("def task_example(): pass") result = runner.invoke( cli, @@ -55,7 +73,6 @@ def custom_builder(n_workers): tmp_path.joinpath("hook.py").as_posix(), ], ) - print(result.output) # noqa: T201 assert result.exit_code == ExitCode.OK assert "Build custom executor." in result.output assert "1 Succeeded" in result.output From 9fe989fa1c94dac5bab8bb68064712e1485de31d Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 30 Mar 2024 21:42:21 +0100 Subject: [PATCH 9/9] Better readme. --- .pre-commit-config.yaml | 19 ++++++++++--------- README.md | 20 +++++++++++++++++--- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ba399a9..4297698 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,15 +35,16 @@ repos: rev: 0.7.1 hooks: - id: nbstripout -- repo: https://github.com/executablebooks/mdformat - rev: 0.7.17 - hooks: - - id: mdformat - additional_dependencies: [ - mdformat-gfm, - mdformat-black, - ] - args: [--wrap, "88"] +# Conflicts with admonitions. +# - repo: https://github.com/executablebooks/mdformat +# rev: 0.7.17 +# hooks: +# - id: mdformat +# additional_dependencies: [ +# mdformat-gfm, +# mdformat-black, +# ] +# args: [--wrap, "88"] - repo: https://github.com/codespell-project/codespell rev: v2.2.6 hooks: diff --git a/README.md b/README.md index 8860ec6..ee44b20 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,14 @@ parallel_backend = "processes" # or loky or threads ## Custom Executor +> [!NOTE] +> +> The interface for custom executors is rudimentary right now and there is not a lot of +> support by public functions. Please, give some feedback if you are trying or managed +> to use a custom backend. +> +> Also, please contribute your custom executors if you consider them useful to others. + pytask-parallel allows you to use your parallel backend as long as it follows the interface defined by [`concurrent.futures.Executor`](https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Executor). @@ -105,14 +113,20 @@ it seems more tightly coupled to your backend. 1. A wrapper for the executed function that captures warnings, catches exceptions and saves products of the task (within the child process!). - As an example, see `def _execute_task()` that does all that for the processes and - loky backend. + As an example, see + [`def _execute_task()`](https://github.com/pytask-dev/pytask-parallel/blob/c441dbb75fa6ab3ab17d8ad5061840c802dc1c41/src/pytask_parallel/processes.py#L91-L155) + that does all that for the processes and loky backend. 1. To apply the wrapper, you need to write a custom hook implementation for - `def pytask_execute_task()`. See `def pytask_execute_task()` for an example. Use the + `def pytask_execute_task()`. See + [`def pytask_execute_task()`](https://github.com/pytask-dev/pytask-parallel/blob/c441dbb75fa6ab3ab17d8ad5061840c802dc1c41/src/pytask_parallel/processes.py#L41-L65) + for an example. Use the [`hook_module`](https://pytask-dev.readthedocs.io/en/stable/how_to_guides/extending_pytask.html#using-hook-module-and-hook-module) configuration value to register your implementation. +Another example of an implementation can be found as a +[test](https://github.com/pytask-dev/pytask-parallel/blob/c441dbb75fa6ab3ab17d8ad5061840c802dc1c41/tests/test_backends.py#L35-L78). + ## Some implementation details ### Parallelization and Debugging