diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ba399a9..4297698 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,15 +35,16 @@ repos: rev: 0.7.1 hooks: - id: nbstripout -- repo: https://github.com/executablebooks/mdformat - rev: 0.7.17 - hooks: - - id: mdformat - additional_dependencies: [ - mdformat-gfm, - mdformat-black, - ] - args: [--wrap, "88"] +# Conflicts with admonitions. +# - repo: https://github.com/executablebooks/mdformat +# rev: 0.7.17 +# hooks: +# - id: mdformat +# additional_dependencies: [ +# mdformat-gfm, +# mdformat-black, +# ] +# args: [--wrap, "88"] - repo: https://github.com/codespell-project/codespell rev: v2.2.6 hooks: diff --git a/README.md b/README.md index 74b38a2..ee44b20 100644 --- a/README.md +++ b/README.md @@ -47,15 +47,15 @@ $ pytask --n-workers 2 $ pytask -n auto ``` -Using processes to parallelize the execution of tasks is useful for CPU bound tasks such +Using processes to parallelize the execution of tasks is useful for CPU-bound tasks such as numerical computations. ([Here](https://stackoverflow.com/a/868577/7523785) is an -explanation on what CPU or IO bound means.) +explanation of what CPU- or IO-bound means.) -For IO bound tasks, tasks where the limiting factor are network responses, access to +For IO-bound tasks, tasks where the limiting factor is network latency and access to files, you can parallelize via threads. ```console -$ pytask --parallel-backend threads +pytask --parallel-backend threads ``` You can also set the options in a `pyproject.toml`. @@ -68,6 +68,65 @@ n_workers = 1 parallel_backend = "processes" # or loky or threads ``` +## Custom Executor + +> [!NOTE] +> +> The interface for custom executors is rudimentary right now and there is not a lot of +> support by public functions. Please, give some feedback if you are trying or managed +> to use a custom backend. +> +> Also, please contribute your custom executors if you consider them useful to others. + +pytask-parallel allows you to use your parallel backend as long as it follows the +interface defined by +[`concurrent.futures.Executor`](https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Executor). + +In some cases, adding a new backend can be as easy as registering a builder function +that receives some arguments (currently only `n_workers`) and returns the instantiated +executor. + +```python +from concurrent.futures import Executor +from my_project.executor import CustomExecutor + +from pytask_parallel import ParallelBackend, registry + + +def build_custom_executor(n_workers: int) -> Executor: + return CustomExecutor(max_workers=n_workers) + + +registry.register_parallel_backend(ParallelBackend.CUSTOM, build_custom_executor) +``` + +Now, build the project requesting your custom backend. + +```console +pytask --parallel-backend custom +``` + +Realistically, it is not the only necessary adjustment for a nice user experience. There +are two other important things. pytask-parallel does not implement them by default since +it seems more tightly coupled to your backend. + +1. A wrapper for the executed function that captures warnings, catches exceptions and + saves products of the task (within the child process!). + + As an example, see + [`def _execute_task()`](https://github.com/pytask-dev/pytask-parallel/blob/c441dbb75fa6ab3ab17d8ad5061840c802dc1c41/src/pytask_parallel/processes.py#L91-L155) + that does all that for the processes and loky backend. + +1. To apply the wrapper, you need to write a custom hook implementation for + `def pytask_execute_task()`. See + [`def pytask_execute_task()`](https://github.com/pytask-dev/pytask-parallel/blob/c441dbb75fa6ab3ab17d8ad5061840c802dc1c41/src/pytask_parallel/processes.py#L41-L65) + for an example. Use the + [`hook_module`](https://pytask-dev.readthedocs.io/en/stable/how_to_guides/extending_pytask.html#using-hook-module-and-hook-module) + configuration value to register your implementation. + +Another example of an implementation can be found as a +[test](https://github.com/pytask-dev/pytask-parallel/blob/c441dbb75fa6ab3ab17d8ad5061840c802dc1c41/tests/test_backends.py#L35-L78). + ## Some implementation details ### Parallelization and Debugging @@ -92,12 +151,12 @@ Consult the [release notes](CHANGES.md) to find out about what is new. - `pytask-parallel` does not call the `pytask_execute_task_protocol` hook specification/entry-point because `pytask_execute_task_setup` and `pytask_execute_task` need to be separated from `pytask_execute_task_teardown`. Thus, - plugins which change this hook specification may not interact well with the + plugins that change this hook specification may not interact well with the parallelization. -- There are two PRs for CPython which try to re-enable setting custom reducers which - should have been working, but does not. Here are the references. +- Two PRs for CPython try to re-enable setting custom reducers which should have been + working but does not. Here are the references. - > - - > - - > - + - https://bugs.python.org/issue28053 + - https://github.com/python/cpython/pull/9959 + - https://github.com/python/cpython/pull/15058 diff --git a/environment.yml b/environment.yml index 27498f2..4def0f2 100644 --- a/environment.yml +++ b/environment.yml @@ -1,7 +1,6 @@ name: pytask-parallel channels: - - conda-forge/label/pytask_rc - conda-forge - nodefaults @@ -21,7 +20,6 @@ dependencies: - tox - ipywidgets - nbmake - - pre-commit - pytest-cov - pip: diff --git a/src/pytask_parallel/__init__.py b/src/pytask_parallel/__init__.py index c937a1f..812a622 100644 --- a/src/pytask_parallel/__init__.py +++ b/src/pytask_parallel/__init__.py @@ -2,6 +2,9 @@ from __future__ import annotations +from pytask_parallel.backends import ParallelBackend +from pytask_parallel.backends import registry + try: from ._version import version as __version__ except ImportError: @@ -10,4 +13,4 @@ __version__ = "unknown" -__all__ = ["__version__"] +__all__ = ["ParallelBackend", "__version__", "registry"] diff --git a/src/pytask_parallel/backends.py b/src/pytask_parallel/backends.py index 31a1c6a..8b53f5c 100644 --- a/src/pytask_parallel/backends.py +++ b/src/pytask_parallel/backends.py @@ -2,25 +2,29 @@ from __future__ import annotations +from concurrent.futures import Executor from concurrent.futures import Future from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ThreadPoolExecutor from enum import Enum from typing import Any from typing import Callable +from typing import ClassVar import cloudpickle from loky import get_reusable_executor +__all__ = ["ParallelBackend", "ParallelBackendRegistry", "registry"] -def deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any: + +def _deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any: """Deserialize and execute a function and keyword arguments.""" deserialized_fn = cloudpickle.loads(fn) deserialized_kwargs = cloudpickle.loads(kwargs) return deserialized_fn(**deserialized_kwargs) -class CloudpickleProcessPoolExecutor(ProcessPoolExecutor): +class _CloudpickleProcessPoolExecutor(ProcessPoolExecutor): """Patches the standard executor to serialize functions with cloudpickle.""" # The type signature is wrong for version above Py3.7. Fix when 3.7 is deprecated. @@ -32,7 +36,7 @@ def submit( # type: ignore[override] ) -> Future[Any]: """Submit a new task.""" return super().submit( - deserialize_and_run_with_cloudpickle, + _deserialize_and_run_with_cloudpickle, fn=cloudpickle.dumps(fn), kwargs=cloudpickle.dumps(kwargs), ) @@ -41,13 +45,46 @@ def submit( # type: ignore[override] class ParallelBackend(Enum): """Choices for parallel backends.""" + CUSTOM = "custom" + LOKY = "loky" PROCESSES = "processes" THREADS = "threads" - LOKY = "loky" -PARALLEL_BACKEND_BUILDER = { - ParallelBackend.PROCESSES: lambda: CloudpickleProcessPoolExecutor, - ParallelBackend.THREADS: lambda: ThreadPoolExecutor, - ParallelBackend.LOKY: lambda: get_reusable_executor, -} +class ParallelBackendRegistry: + """Registry for parallel backends.""" + + registry: ClassVar[dict[ParallelBackend, Callable[..., Executor]]] = {} + + def register_parallel_backend( + self, kind: ParallelBackend, builder: Callable[..., Executor] + ) -> None: + """Register a parallel backend.""" + self.registry[kind] = builder + + def get_parallel_backend(self, kind: ParallelBackend, n_workers: int) -> Executor: + """Get a parallel backend.""" + __tracebackhide__ = True + try: + return self.registry[kind](n_workers=n_workers) + except KeyError: + msg = f"No registered parallel backend found for kind {kind}." + raise ValueError(msg) from None + except Exception as e: # noqa: BLE001 + msg = f"Could not instantiate parallel backend {kind.value}." + raise ValueError(msg) from e + + +registry = ParallelBackendRegistry() + + +registry.register_parallel_backend( + ParallelBackend.PROCESSES, + lambda n_workers: _CloudpickleProcessPoolExecutor(max_workers=n_workers), +) +registry.register_parallel_backend( + ParallelBackend.THREADS, lambda n_workers: ThreadPoolExecutor(max_workers=n_workers) +) +registry.register_parallel_backend( + ParallelBackend.LOKY, lambda n_workers: get_reusable_executor(max_workers=n_workers) +) diff --git a/src/pytask_parallel/config.py b/src/pytask_parallel/config.py index 271662c..8e58f55 100644 --- a/src/pytask_parallel/config.py +++ b/src/pytask_parallel/config.py @@ -7,6 +7,7 @@ from pytask import hookimpl +from pytask_parallel import custom from pytask_parallel import execute from pytask_parallel import processes from pytask_parallel import threads @@ -33,15 +34,22 @@ def pytask_parse_config(config: dict[str, Any]) -> None: config["delay"] = 0.1 -@hookimpl +@hookimpl(trylast=True) def pytask_post_parse(config: dict[str, Any]) -> None: """Register the parallel backend if debugging is not enabled.""" if config["pdb"] or config["trace"] or config["dry_run"]: config["n_workers"] = 1 - if config["n_workers"] > 1: + # Register parallel execute hook. + if config["n_workers"] > 1 or config["parallel_backend"] == ParallelBackend.CUSTOM: config["pm"].register(execute) + + # Register parallel backends. + if config["n_workers"] > 1: if config["parallel_backend"] == ParallelBackend.THREADS: config["pm"].register(threads) else: config["pm"].register(processes) + + if config["parallel_backend"] == ParallelBackend.CUSTOM: + config["pm"].register(custom) diff --git a/src/pytask_parallel/custom.py b/src/pytask_parallel/custom.py new file mode 100644 index 0000000..3ed2377 --- /dev/null +++ b/src/pytask_parallel/custom.py @@ -0,0 +1,27 @@ +"""Contains functions for the threads backend.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Any + +from pytask import PTask +from pytask import Session +from pytask import hookimpl + +from pytask_parallel.utils import create_kwargs_for_task + +if TYPE_CHECKING: + from concurrent.futures import Future + + +@hookimpl +def pytask_execute_task(session: Session, task: PTask) -> Future[Any]: + """Execute a task. + + Since threads have shared memory, it is not necessary to pickle and unpickle the + task. + + """ + kwargs = create_kwargs_for_task(task) + return session.config["_parallel_executor"].submit(task.function, **kwargs) diff --git a/src/pytask_parallel/execute.py b/src/pytask_parallel/execute.py index 7145da4..e6cf659 100644 --- a/src/pytask_parallel/execute.py +++ b/src/pytask_parallel/execute.py @@ -11,17 +11,19 @@ from attrs import field from pytask import ExecutionReport from pytask import PNode +from pytask import PTask from pytask import PythonNode from pytask import Session from pytask import hookimpl +from pytask.tree_util import PyTree from pytask.tree_util import tree_map +from pytask.tree_util import tree_structure -from pytask_parallel.backends import PARALLEL_BACKEND_BUILDER -from pytask_parallel.backends import ParallelBackend +from pytask_parallel.backends import registry +from pytask_parallel.utils import parse_future_result if TYPE_CHECKING: from concurrent.futures import Future - from types import TracebackType @hookimpl @@ -41,9 +43,11 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 reports = session.execution_reports running_tasks: dict[str, Future[Any]] = {} - parallel_backend = PARALLEL_BACKEND_BUILDER[session.config["parallel_backend"]]() + parallel_backend = registry.get_parallel_backend( + session.config["parallel_backend"], n_workers=session.config["n_workers"] + ) - with parallel_backend(max_workers=session.config["n_workers"]) as executor: + with parallel_backend as executor: session.config["_parallel_executor"] = executor sleeper = _Sleeper() @@ -84,21 +88,12 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 for task_name in list(running_tasks): future = running_tasks[task_name] + if future.done(): - # An exception was thrown before the task was executed. - if future.exception() is not None: - exc_info = _parse_future_exception(future.exception()) - warning_reports = [] - # A task raised an exception. - else: - (python_nodes, warning_reports, task_exception) = ( - future.result() - ) - session.warnings.extend(warning_reports) - exc_info = ( - _parse_future_exception(future.exception()) - or task_exception - ) + python_nodes, warnings_reports, exc_info = parse_future_result( + future + ) + session.warnings.extend(warnings_reports) if exc_info is not None: task = session.dag.nodes[task_name]["task"] @@ -109,16 +104,7 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 session.scheduler.done(task_name) else: task = session.dag.nodes[task_name]["task"] - - # Update PythonNodes with the values from the future if - # not threads. - if ( - session.config["parallel_backend"] - != ParallelBackend.THREADS - ): - task.produces = tree_map( - _update_python_node, task.produces, python_nodes - ) + _update_python_nodes(task, python_nodes) try: session.hook.pytask_execute_task_teardown( @@ -134,8 +120,6 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 running_tasks.pop(task_name) newly_collected_reports.append(report) session.scheduler.done(task_name) - else: - pass for report in newly_collected_reports: session.hook.pytask_execute_task_process_report( @@ -156,17 +140,21 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 return True -def _update_python_node(x: PNode, y: PythonNode | None) -> PNode: - if y: - x.save(y.load()) - return x +def _update_python_nodes( + task: PTask, python_nodes: dict[str, PyTree[PythonNode | None]] | None +) -> None: + """Update the python nodes of a task with the python nodes from the future.""" + def _update_python_node(x: PNode, y: PythonNode | None) -> PNode: + if y: + x.save(y.load()) + return x -def _parse_future_exception( - exc: BaseException | None, -) -> tuple[type[BaseException], BaseException, TracebackType] | None: - """Parse a future exception into the format of ``sys.exc_info``.""" - return None if exc is None else (type(exc), exc, exc.__traceback__) + structure_python_nodes = tree_structure(python_nodes) + structure_produces = tree_structure(task.produces) + # strict must be false when none is leaf. + if structure_produces.is_prefix(structure_python_nodes, strict=False): + task.produces = tree_map(_update_python_node, task.produces, python_nodes) # type: ignore[assignment] @define(kw_only=True) diff --git a/src/pytask_parallel/processes.py b/src/pytask_parallel/processes.py index c6de65a..b9541ba 100644 --- a/src/pytask_parallel/processes.py +++ b/src/pytask_parallel/processes.py @@ -1,180 +1,187 @@ -"""Contains functions related to processes and loky.""" - -from __future__ import annotations - -import inspect -import sys -import warnings -from functools import partial -from typing import TYPE_CHECKING -from typing import Any -from typing import Callable - -import cloudpickle -from pytask import Mark -from pytask import PTask -from pytask import PythonNode -from pytask import Session -from pytask import WarningReport -from pytask import console -from pytask import get_marks -from pytask import hookimpl -from pytask import parse_warning_filter -from pytask import remove_internal_traceback_frames_from_exc_info -from pytask import warning_record_to_str -from pytask.tree_util import PyTree -from pytask.tree_util import tree_map -from rich.traceback import Traceback - -from pytask_parallel.utils import create_kwargs_for_task -from pytask_parallel.utils import handle_task_function_return - -if TYPE_CHECKING: - from concurrent.futures import Future - from pathlib import Path - from types import ModuleType - from types import TracebackType - - from rich.console import ConsoleOptions - - -@hookimpl -def pytask_execute_task(session: Session, task: PTask) -> Future[Any]: - """Execute a task. - - Take a task, pickle it and send the bytes over to another process. - - """ - kwargs = create_kwargs_for_task(task) - - # Task modules are dynamically loaded and added to `sys.modules`. Thus, cloudpickle - # believes the module of the task function is also importable in the child process. - # We have to register the module as dynamic again, so that cloudpickle will pickle - # it with the function. See cloudpickle#417, pytask#373 and pytask#374. - task_module = _get_module(task.function, getattr(task, "path", None)) - cloudpickle.register_pickle_by_value(task_module) - - return session.config["_parallel_executor"].submit( - _execute_task, - task=task, - kwargs=kwargs, - show_locals=session.config["show_locals"], - console_options=console.options, - session_filterwarnings=session.config["filterwarnings"], - task_filterwarnings=get_marks(task, "filterwarnings"), - ) - - -def _raise_exception_on_breakpoint(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 - msg = ( - "You cannot use 'breakpoint()' or 'pdb.set_trace()' while parallelizing the " - "execution of tasks with pytask-parallel. Please, remove the breakpoint or run " - "the task without parallelization to debug it." - ) - raise RuntimeError(msg) - - -def _patch_set_trace_and_breakpoint() -> None: - """Patch :func:`pdb.set_trace` and :func:`breakpoint`. - - Patch sys.breakpointhook to intercept any call of breakpoint() and pdb.set_trace in - a subprocess and print a better exception message. - - """ - import pdb # noqa: T100 - import sys - - pdb.set_trace = _raise_exception_on_breakpoint - sys.breakpointhook = _raise_exception_on_breakpoint - - -def _execute_task( # noqa: PLR0913 - task: PTask, - kwargs: dict[str, Any], - show_locals: bool, # noqa: FBT001 - console_options: ConsoleOptions, - session_filterwarnings: tuple[str, ...], - task_filterwarnings: tuple[Mark, ...], -) -> tuple[ - PyTree[PythonNode | None], - list[WarningReport], - tuple[type[BaseException], BaseException, str] | None, -]: - """Unserialize and execute task. - - This function receives bytes and unpickles them to a task which is them execute in a - spawned process or thread. - - """ - __tracebackhide__ = True - _patch_set_trace_and_breakpoint() - - with warnings.catch_warnings(record=True) as log: - for arg in session_filterwarnings: - warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) - - # apply filters from "filterwarnings" marks - for mark in task_filterwarnings: - for arg in mark.args: - warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) - - try: - out = task.execute(**kwargs) - except Exception: # noqa: BLE001 - exc_info = sys.exc_info() - processed_exc_info = _process_exception( - exc_info, show_locals, console_options - ) - else: - handle_task_function_return(task, out) - processed_exc_info = None - - task_display_name = getattr(task, "display_name", task.name) - warning_reports = [] - for warning_message in log: - fs_location = warning_message.filename, warning_message.lineno - warning_reports.append( - WarningReport( - message=warning_record_to_str(warning_message), - fs_location=fs_location, - id_=task_display_name, - ) - ) - - python_nodes = tree_map( - lambda x: x if isinstance(x, PythonNode) else None, task.produces - ) - - return python_nodes, warning_reports, processed_exc_info - - -def _process_exception( - exc_info: tuple[type[BaseException], BaseException, TracebackType | None], - show_locals: bool, # noqa: FBT001 - console_options: ConsoleOptions, -) -> tuple[type[BaseException], BaseException, str]: - """Process the exception and convert the traceback to a string.""" - exc_info = remove_internal_traceback_frames_from_exc_info(exc_info) - traceback = Traceback.from_exception(*exc_info, show_locals=show_locals) - segments = console.render(traceback, options=console_options) - text = "".join(segment.text for segment in segments) - return (*exc_info[:2], text) - - -def _get_module(func: Callable[..., Any], path: Path | None) -> ModuleType: - """Get the module of a python function. - - ``functools.partial`` obfuscates the module of the function and - ``inspect.getmodule`` returns :mod`functools`. Therefore, we recover the original - function. - - We use the path from the task module to aid the search although it is not clear - whether it helps. - - """ - if isinstance(func, partial): - func = func.func - - if path: - return inspect.getmodule(func, path.as_posix()) - return inspect.getmodule(func) +"""Contains functions related to processes and loky.""" + +from __future__ import annotations + +import inspect +import sys +import warnings +from functools import partial +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable + +import cloudpickle +from pytask import Mark +from pytask import PTask +from pytask import PythonNode +from pytask import Session +from pytask import WarningReport +from pytask import console +from pytask import get_marks +from pytask import hookimpl +from pytask import parse_warning_filter +from pytask import remove_internal_traceback_frames_from_exc_info +from pytask import warning_record_to_str +from pytask.tree_util import PyTree +from pytask.tree_util import tree_map +from rich.traceback import Traceback + +from pytask_parallel.utils import create_kwargs_for_task +from pytask_parallel.utils import handle_task_function_return + +if TYPE_CHECKING: + from concurrent.futures import Future + from pathlib import Path + from types import ModuleType + from types import TracebackType + + from rich.console import ConsoleOptions + + +@hookimpl +def pytask_execute_task(session: Session, task: PTask) -> Future[Any]: + """Execute a task. + + Take a task, pickle it and send the bytes over to another process. + + """ + kwargs = create_kwargs_for_task(task) + + # Task modules are dynamically loaded and added to `sys.modules`. Thus, cloudpickle + # believes the module of the task function is also importable in the child process. + # We have to register the module as dynamic again, so that cloudpickle will pickle + # it with the function. See cloudpickle#417, pytask#373 and pytask#374. + task_module = _get_module(task.function, getattr(task, "path", None)) + cloudpickle.register_pickle_by_value(task_module) + + return session.config["_parallel_executor"].submit( + _execute_task, + task=task, + kwargs=kwargs, + show_locals=session.config["show_locals"], + console_options=console.options, + session_filterwarnings=session.config["filterwarnings"], + task_filterwarnings=get_marks(task, "filterwarnings"), + ) + + +def _raise_exception_on_breakpoint(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 + msg = ( + "You cannot use 'breakpoint()' or 'pdb.set_trace()' while parallelizing the " + "execution of tasks with pytask-parallel. Please, remove the breakpoint or run " + "the task without parallelization to debug it." + ) + raise RuntimeError(msg) + + +def _patch_set_trace_and_breakpoint() -> None: + """Patch :func:`pdb.set_trace` and :func:`breakpoint`. + + Patch sys.breakpointhook to intercept any call of breakpoint() and pdb.set_trace in + a subprocess and print a better exception message. + + """ + import pdb # noqa: T100 + import sys + + pdb.set_trace = _raise_exception_on_breakpoint + sys.breakpointhook = _raise_exception_on_breakpoint + + +def _execute_task( # noqa: PLR0913 + task: PTask, + kwargs: dict[str, Any], + show_locals: bool, # noqa: FBT001 + console_options: ConsoleOptions, + session_filterwarnings: tuple[str, ...], + task_filterwarnings: tuple[Mark, ...], +) -> tuple[ + PyTree[PythonNode | None], + list[WarningReport], + tuple[type[BaseException], BaseException, str] | None, +]: + """Unserialize and execute task. + + This function receives bytes and unpickles them to a task which is them execute in a + spawned process or thread. + + """ + # Hide this function from tracebacks. + __tracebackhide__ = True + + # Patch set_trace and breakpoint to show a better error message. + _patch_set_trace_and_breakpoint() + + # Catch warnings and store them in a list. + with warnings.catch_warnings(record=True) as log: + # Apply global filterwarnings. + for arg in session_filterwarnings: + warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) + + # Apply filters from "filterwarnings" marks + for mark in task_filterwarnings: + for arg in mark.args: + warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) + + try: + out = task.execute(**kwargs) + except Exception: # noqa: BLE001 + exc_info = sys.exc_info() + processed_exc_info = _process_exception( + exc_info, show_locals, console_options + ) + else: + # Save products. + handle_task_function_return(task, out) + processed_exc_info = None + + task_display_name = getattr(task, "display_name", task.name) + warning_reports = [] + for warning_message in log: + fs_location = warning_message.filename, warning_message.lineno + warning_reports.append( + WarningReport( + message=warning_record_to_str(warning_message), + fs_location=fs_location, + id_=task_display_name, + ) + ) + + # Collect all PythonNodes that are products to pass values back to the main process. + python_nodes = tree_map( + lambda x: x if isinstance(x, PythonNode) else None, task.produces + ) + + return python_nodes, warning_reports, processed_exc_info + + +def _process_exception( + exc_info: tuple[type[BaseException], BaseException, TracebackType | None], + show_locals: bool, # noqa: FBT001 + console_options: ConsoleOptions, +) -> tuple[type[BaseException], BaseException, str]: + """Process the exception and convert the traceback to a string.""" + exc_info = remove_internal_traceback_frames_from_exc_info(exc_info) + traceback = Traceback.from_exception(*exc_info, show_locals=show_locals) + segments = console.render(traceback, options=console_options) + text = "".join(segment.text for segment in segments) + return (*exc_info[:2], text) + + +def _get_module(func: Callable[..., Any], path: Path | None) -> ModuleType: + """Get the module of a python function. + + ``functools.partial`` obfuscates the module of the function and + ``inspect.getmodule`` returns :mod`functools`. Therefore, we recover the original + function. + + We use the path from the task module to aid the search although it is not clear + whether it helps. + + """ + if isinstance(func, partial): + func = func.func + + if path: + return inspect.getmodule(func, path.as_posix()) + return inspect.getmodule(func) diff --git a/src/pytask_parallel/threads.py b/src/pytask_parallel/threads.py index 4c2db4a..2e3c356 100644 --- a/src/pytask_parallel/threads.py +++ b/src/pytask_parallel/threads.py @@ -1,55 +1,55 @@ -"""Contains functions for the threads backend.""" - -from __future__ import annotations - -import sys -from typing import TYPE_CHECKING -from typing import Any - -from pytask import PTask -from pytask import Session -from pytask import hookimpl - -from pytask_parallel.utils import create_kwargs_for_task -from pytask_parallel.utils import handle_task_function_return - -if TYPE_CHECKING: - from concurrent.futures import Future - from types import TracebackType - - -@hookimpl -def pytask_execute_task(session: Session, task: PTask) -> Future[Any]: - """Execute a task. - - Since threads have shared memory, it is not necessary to pickle and unpickle the - task. - - """ - kwargs = create_kwargs_for_task(task) - return session.config["_parallel_executor"].submit( - _mock_processes_for_threads, task=task, **kwargs - ) - - -def _mock_processes_for_threads( - task: PTask, **kwargs: Any -) -> tuple[ - None, list[Any], tuple[type[BaseException], BaseException, TracebackType] | None -]: - """Mock execution function such that it returns the same as for processes. - - The function for processes returns ``warning_reports`` and an ``exception``. With - threads, these object are collected by the main and not the subprocess. So, we just - return placeholders. - - """ - __tracebackhide__ = True - try: - out = task.function(**kwargs) - except Exception: # noqa: BLE001 - exc_info = sys.exc_info() - else: - handle_task_function_return(task, out) - exc_info = None - return None, [], exc_info +"""Contains functions for the threads backend.""" + +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING +from typing import Any + +from pytask import PTask +from pytask import Session +from pytask import hookimpl + +from pytask_parallel.utils import create_kwargs_for_task +from pytask_parallel.utils import handle_task_function_return + +if TYPE_CHECKING: + from concurrent.futures import Future + from types import TracebackType + + +@hookimpl +def pytask_execute_task(session: Session, task: PTask) -> Future[Any]: + """Execute a task. + + Since threads have shared memory, it is not necessary to pickle and unpickle the + task. + + """ + kwargs = create_kwargs_for_task(task) + return session.config["_parallel_executor"].submit( + _mock_processes_for_threads, task=task, **kwargs + ) + + +def _mock_processes_for_threads( + task: PTask, **kwargs: Any +) -> tuple[ + None, list[Any], tuple[type[BaseException], BaseException, TracebackType] | None +]: + """Mock execution function such that it returns the same as for processes. + + The function for processes returns ``warning_reports`` and an ``exception``. With + threads, these object are collected by the main and not the subprocess. So, we just + return placeholders. + + """ + __tracebackhide__ = True + try: + out = task.function(**kwargs) + except Exception: # noqa: BLE001 + exc_info = sys.exc_info() + else: + handle_task_function_return(task, out) + exc_info = None + return None, [], exc_info diff --git a/src/pytask_parallel/utils.py b/src/pytask_parallel/utils.py index 4472c5e..11cc19f 100644 --- a/src/pytask_parallel/utils.py +++ b/src/pytask_parallel/utils.py @@ -1,52 +1,94 @@ -"""Contains utility functions.""" - -from __future__ import annotations - -import inspect -from typing import TYPE_CHECKING -from typing import Any - -from pytask.tree_util import PyTree -from pytask.tree_util import tree_leaves -from pytask.tree_util import tree_map -from pytask.tree_util import tree_structure - -if TYPE_CHECKING: - from pytask import PTask - - -def handle_task_function_return(task: PTask, out: Any) -> None: - """Handle the return value of a task function.""" - if "return" not in task.produces: - return - - structure_out = tree_structure(out) - structure_return = tree_structure(task.produces["return"]) - # strict must be false when none is leaf. - if not structure_return.is_prefix(structure_out, strict=False): - msg = ( - "The structure of the return annotation is not a subtree of " - "the structure of the function return.\n\nFunction return: " - f"{structure_out}\n\nReturn annotation: {structure_return}" - ) - raise ValueError(msg) - - nodes = tree_leaves(task.produces["return"]) - values = structure_return.flatten_up_to(out) - for node, value in zip(nodes, values): - node.save(value) - - -def create_kwargs_for_task(task: PTask) -> dict[str, PyTree[Any]]: - """Create kwargs for task function.""" - parameters = inspect.signature(task.function).parameters - - kwargs = {} - for name, value in task.depends_on.items(): - kwargs[name] = tree_map(lambda x: x.load(), value) - - for name, value in task.produces.items(): - if name in parameters: - kwargs[name] = tree_map(lambda x: x.load(), value) - - return kwargs +"""Contains utility functions.""" + +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING +from typing import Any + +from pytask.tree_util import PyTree +from pytask.tree_util import tree_leaves +from pytask.tree_util import tree_map +from pytask.tree_util import tree_structure + +if TYPE_CHECKING: + from concurrent.futures import Future + from types import TracebackType + + from pytask import PTask + from pytask import PythonNode + from pytask import WarningReport + + +def parse_future_result( + future: Future[Any], +) -> tuple[ + dict[str, PyTree[PythonNode | None]] | None, + list[WarningReport], + tuple[type[BaseException], BaseException, TracebackType] | None, +]: + """Parse the result of a future.""" + # An exception was raised before the task was executed. + future_exception = future.exception() + if future_exception is not None: + exc_info = _parse_future_exception(future_exception) + return None, [], exc_info + + out = future.result() + if isinstance(out, tuple) and len(out) == 3: # noqa: PLR2004 + return out + + if out is None: + return None, [], None + + # What to do when the output does not match? + msg = ( + "The task function returns an unknown output format. Either return a tuple " + "with three elements, python nodes, warning reports and exception or only " + "return." + ) + raise Exception(msg) # noqa: TRY002 + + +def handle_task_function_return(task: PTask, out: Any) -> None: + """Handle the return value of a task function.""" + if "return" not in task.produces: + return + + structure_out = tree_structure(out) + structure_return = tree_structure(task.produces["return"]) + # strict must be false when none is leaf. + if not structure_return.is_prefix(structure_out, strict=False): + msg = ( + "The structure of the return annotation is not a subtree of " + "the structure of the function return.\n\nFunction return: " + f"{structure_out}\n\nReturn annotation: {structure_return}" + ) + raise ValueError(msg) + + nodes = tree_leaves(task.produces["return"]) + values = structure_return.flatten_up_to(out) + for node, value in zip(nodes, values): + node.save(value) + + +def create_kwargs_for_task(task: PTask) -> dict[str, PyTree[Any]]: + """Create kwargs for task function.""" + parameters = inspect.signature(task.function).parameters + + kwargs = {} + for name, value in task.depends_on.items(): + kwargs[name] = tree_map(lambda x: x.load(), value) + + for name, value in task.produces.items(): + if name in parameters: + kwargs[name] = tree_map(lambda x: x.load(), value) + + return kwargs + + +def _parse_future_exception( + exc: BaseException | None, +) -> tuple[type[BaseException], BaseException, TracebackType] | None: + """Parse a future exception into the format of ``sys.exc_info``.""" + return None if exc is None else (type(exc), exc, exc.__traceback__) diff --git a/tests/test_backends.py b/tests/test_backends.py new file mode 100644 index 0000000..b722ffd --- /dev/null +++ b/tests/test_backends.py @@ -0,0 +1,78 @@ +import textwrap + +import pytest +from pytask import ExitCode +from pytask import cli + + +@pytest.mark.end_to_end() +def test_error_requesting_custom_backend_without_registration(runner, tmp_path): + tmp_path.joinpath("task_example.py").write_text("def task_example(): pass") + result = runner.invoke(cli, [tmp_path.as_posix(), "--parallel-backend", "custom"]) + assert result.exit_code == ExitCode.FAILED + assert "No registered parallel backend found" in result.output + + +@pytest.mark.end_to_end() +def test_error_while_instantiating_custom_backend(runner, tmp_path): + hook_source = """ + from pytask_parallel import ParallelBackend, registry + + def custom_builder(n_workers): + raise Exception("ERROR") + + registry.register_parallel_backend(ParallelBackend.CUSTOM, custom_builder) + + def task_example(): pass + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(hook_source)) + result = runner.invoke(cli, [tmp_path.as_posix(), "--parallel-backend", "custom"]) + assert result.exit_code == ExitCode.FAILED + assert "ERROR" in result.output + assert "Could not instantiate parallel backend custom." in result.output + + +@pytest.mark.end_to_end() +def test_register_custom_backend(runner, tmp_path): + hook_source = """ + import cloudpickle + from loky import get_reusable_executor + from pytask import hookimpl + from pytask_parallel import ParallelBackend + from pytask_parallel import registry + from pytask_parallel.processes import _get_module + from pytask_parallel.utils import create_kwargs_for_task + + + @hookimpl(tryfirst=True) + def pytask_execute_task(session, task): + kwargs = create_kwargs_for_task(task) + + task_module = _get_module(task.function, getattr(task, "path", None)) + cloudpickle.register_pickle_by_value(task_module) + + return session.config["_parallel_executor"].submit(task.function, **kwargs) + + + def custom_builder(n_workers): + print("Build custom executor.") + return get_reusable_executor(max_workers=n_workers) + + + registry.register_parallel_backend(ParallelBackend.CUSTOM, custom_builder) + """ + tmp_path.joinpath("hook.py").write_text(textwrap.dedent(hook_source)) + tmp_path.joinpath("task_example.py").write_text("def task_example(): pass") + result = runner.invoke( + cli, + [ + tmp_path.as_posix(), + "--parallel-backend", + "custom", + "--hook-module", + tmp_path.joinpath("hook.py").as_posix(), + ], + ) + assert result.exit_code == ExitCode.OK + assert "Build custom executor." in result.output + assert "1 Succeeded" in result.output diff --git a/tests/test_execute.py b/tests/test_execute.py index 6010f7a..a99a0de 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -12,13 +12,11 @@ from tests.conftest import restore_sys_path_and_module_after_test_execution - -class Session: - pass +_IMPLEMENTED_BACKENDS = [p for p in ParallelBackend if p != ParallelBackend.CUSTOM] @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_parallel_execution(tmp_path, parallel_backend): source = """ from pytask import Product @@ -40,7 +38,7 @@ def task_2(path: Annotated[Path, Product] = Path("out_2.txt")): @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_parallel_execution_w_cli(runner, tmp_path, parallel_backend): source = """ from pytask import Product @@ -70,7 +68,7 @@ def task_2(path: Annotated[Path, Product] = Path("out_2.txt")): @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_stop_execution_when_max_failures_is_reached(tmp_path, parallel_backend): source = """ import time @@ -98,7 +96,7 @@ def task_3(): time.sleep(3) @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_task_priorities(tmp_path, parallel_backend): source = """ import pytask @@ -139,7 +137,7 @@ def task_5(): @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) @pytest.mark.parametrize("show_locals", [True, False]) def test_rendering_of_tracebacks_with_rich( runner, tmp_path, parallel_backend, show_locals @@ -221,7 +219,7 @@ def test_sleeper(): @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_task_that_return(runner, tmp_path, parallel_backend): source = """ from pathlib import Path @@ -241,7 +239,7 @@ def task_example() -> Annotated[str, Path("file.txt")]: @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_task_without_path_that_return(runner, tmp_path, parallel_backend): source = """ from pathlib import Path @@ -263,7 +261,7 @@ def test_task_without_path_that_return(runner, tmp_path, parallel_backend): @pytest.mark.end_to_end() @pytest.mark.parametrize("flag", ["--pdb", "--trace", "--dry-run"]) -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_parallel_execution_is_deactivated(runner, tmp_path, flag, parallel_backend): tmp_path.joinpath("task_example.py").write_text("def task_example(): pass") result = runner.invoke( @@ -277,7 +275,12 @@ def test_parallel_execution_is_deactivated(runner, tmp_path, flag, parallel_back @pytest.mark.end_to_end() @pytest.mark.parametrize("code", ["breakpoint()", "import pdb; pdb.set_trace()"]) @pytest.mark.parametrize( - "parallel_backend", [i for i in ParallelBackend if i != ParallelBackend.THREADS] + "parallel_backend", + [ + i + for i in ParallelBackend + if i not in (ParallelBackend.THREADS, ParallelBackend.CUSTOM) + ], ) def test_raise_error_on_breakpoint(runner, tmp_path, code, parallel_backend): tmp_path.joinpath("task_example.py").write_text(f"def task_example(): {code}") @@ -289,7 +292,7 @@ def test_raise_error_on_breakpoint(runner, tmp_path, code, parallel_backend): @pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", ParallelBackend) +@pytest.mark.parametrize("parallel_backend", _IMPLEMENTED_BACKENDS) def test_task_partialed(runner, tmp_path, parallel_backend): source = """ from pathlib import Path