diff --git a/pydra/engine/submitter.py b/pydra/engine/submitter.py index 3906955b2..23c8f50b0 100644 --- a/pydra/engine/submitter.py +++ b/pydra/engine/submitter.py @@ -1,9 +1,10 @@ """Handle execution backends.""" import asyncio +import typing as ty import pickle from uuid import uuid4 -from .workers import WORKERS +from .workers import Worker, WORKERS from .core import is_workflow from .helpers import get_open_loop, load_and_run_async @@ -16,24 +17,34 @@ class Submitter: """Send a task to the execution backend.""" - def __init__(self, plugin="cf", **kwargs): + def __init__(self, plugin: ty.Union[str, ty.Type[Worker]] = "cf", **kwargs): """ Initialize task submission. Parameters ---------- - plugin : :obj:`str` - The identifier of the execution backend. + plugin : :obj:`str` or :obj:`ty.Type[pydra.engine.core.Worker]` + Either the identifier of the execution backend or the worker class itself. Default is ``cf`` (Concurrent Futures). + **kwargs + Additional keyword arguments to pass to the worker. """ self.loop = get_open_loop() self._own_loop = not self.loop.is_running() - self.plugin = plugin - try: - self.worker = WORKERS[self.plugin](**kwargs) - except KeyError: - raise NotImplementedError(f"No worker for {self.plugin}") + if isinstance(plugin, str): + self.plugin = plugin + try: + worker_cls = WORKERS[self.plugin] + except KeyError: + raise NotImplementedError(f"No worker for '{self.plugin}' plugin") + else: + try: + self.plugin = plugin.plugin_name + except AttributeError: + raise ValueError("Worker class must have a 'plugin_name' str attribute") + worker_cls = plugin + self.worker = worker_cls(**kwargs) self.worker.loop = self.loop def __call__(self, runnable, cache_locations=None, rerun=False, environment=None): diff --git a/pydra/engine/tests/test_submitter.py b/pydra/engine/tests/test_submitter.py index d65247e96..a3219521a 100644 --- a/pydra/engine/tests/test_submitter.py +++ b/pydra/engine/tests/test_submitter.py @@ -2,6 +2,8 @@ import re import subprocess as sp import time +import os +from unittest.mock import patch import pytest @@ -12,8 +14,9 @@ gen_basic_wf_with_threadcount, gen_basic_wf_with_threadcount_concurrent, ) -from ..core import Workflow +from ..core import Workflow, TaskBase from ..submitter import Submitter +from ..workers import SerialWorker from ... import mark from pathlib import Path from datetime import datetime @@ -612,3 +615,61 @@ def alter_input(x): @mark.task def to_tuple(x, y): return (x, y) + + +class BYOAddVarWorker(SerialWorker): + """A dummy worker that adds 1 to the output of the task""" + + plugin_name = "byo_add_env_var" + + def __init__(self, add_var, **kwargs): + super().__init__(**kwargs) + self.add_var = add_var + + async def exec_serial(self, runnable, rerun=False, environment=None): + if isinstance(runnable, TaskBase): + with patch.dict(os.environ, {"BYO_ADD_VAR": str(self.add_var)}): + result = runnable._run(rerun, environment=environment) + return result + else: # it could be tuple that includes pickle files with tasks and inputs + return super().exec_serial(runnable, rerun, environment) + + +@mark.task +def add_env_var_task(x: int) -> int: + return x + int(os.environ.get("BYO_ADD_VAR", 0)) + + +def test_byo_worker(): + + task1 = add_env_var_task(x=1) + + with Submitter(plugin=BYOAddVarWorker, add_var=10) as sub: + assert sub.plugin == "byo_add_env_var" + result = task1(submitter=sub) + + assert result.output.out == 11 + + task2 = add_env_var_task(x=2) + + with Submitter(plugin="serial") as sub: + result = task2(submitter=sub) + + assert result.output.out == 2 + + +def test_bad_builtin_worker(): + + with pytest.raises(NotImplementedError, match="No worker for 'bad-worker' plugin"): + Submitter(plugin="bad-worker") + + +def test_bad_byo_worker(): + + class BadWorker: + pass + + with pytest.raises( + ValueError, match="Worker class must have a 'plugin_name' str attribute" + ): + Submitter(plugin=BadWorker) diff --git a/pydra/engine/workers.py b/pydra/engine/workers.py index 155a2800d..eaa40beb0 100644 --- a/pydra/engine/workers.py +++ b/pydra/engine/workers.py @@ -128,6 +128,8 @@ async def fetch_finished(self, futures): class SerialWorker(Worker): """A worker to execute linearly.""" + plugin_name = "serial" + def __init__(self, **kwargs): """Initialize worker.""" logger.debug("Initialize SerialWorker") @@ -157,6 +159,8 @@ async def fetch_finished(self, futures): class ConcurrentFuturesWorker(Worker): """A worker to execute in parallel using Python's concurrent futures.""" + plugin_name = "cf" + def __init__(self, n_procs=None): """Initialize Worker.""" super().__init__() @@ -192,6 +196,7 @@ def close(self): class SlurmWorker(DistributedWorker): """A worker to execute tasks on SLURM systems.""" + plugin_name = "slurm" _cmd = "sbatch" _sacct_re = re.compile( "(?P\\d*) +(?P\\w*)\\+? +" "(?P\\d+):\\d+" @@ -367,6 +372,8 @@ async def _verify_exit_code(self, jobid): class SGEWorker(DistributedWorker): """A worker to execute tasks on SLURM systems.""" + plugin_name = "sge" + _cmd = "qsub" _sacct_re = re.compile( "(?P\\d*) +(?P\\w*)\\+? +" "(?P\\d+):\\d+" @@ -860,6 +867,8 @@ class DaskWorker(Worker): This is an experimental implementation with limited testing. """ + plugin_name = "dask" + def __init__(self, **kwargs): """Initialize Worker.""" super().__init__() @@ -898,7 +907,7 @@ def close(self): class PsijWorker(Worker): """A worker to execute tasks using PSI/J.""" - def __init__(self, subtype, **kwargs): + def __init__(self, **kwargs): """ Initialize PsijWorker. @@ -915,15 +924,6 @@ def __init__(self, subtype, **kwargs): logger.debug("Initialize PsijWorker") self.psij = psij - # Check if the provided subtype is valid - valid_subtypes = ["local", "slurm"] - if subtype not in valid_subtypes: - raise ValueError( - f"Invalid 'subtype' provided. Available options: {', '.join(valid_subtypes)}" - ) - - self.subtype = subtype - def run_el(self, interface, rerun=False, **kwargs): """Run a task.""" return self.exec_psij(interface, rerun=rerun) @@ -1039,14 +1039,29 @@ def close(self): pass +class PsijLocalWorker(PsijWorker): + """A worker to execute tasks using PSI/J on the local machine.""" + + subtype = "local" + plugin_name = f"psij-{subtype}" + + +class PsijSlurmWorker(PsijWorker): + """A worker to execute tasks using PSI/J using SLURM.""" + + subtype = "slurm" + plugin_name = f"psij-{subtype}" + + WORKERS = { - "serial": SerialWorker, - "cf": ConcurrentFuturesWorker, - "slurm": SlurmWorker, - "dask": DaskWorker, - "sge": SGEWorker, - **{ - "psij-" + subtype: lambda subtype=subtype: PsijWorker(subtype=subtype) - for subtype in ["local", "slurm"] - }, + w.plugin_name: w + for w in ( + SerialWorker, + ConcurrentFuturesWorker, + SlurmWorker, + DaskWorker, + SGEWorker, + PsijLocalWorker, + PsijSlurmWorker, + ) }