Skip to content

Commit

Permalink
add the ability to BYO your own worker, i.e. without it needing to mo…
Browse files Browse the repository at this point in the history
…nkey path the pydra.engine.workers.WORKERS dict
  • Loading branch information
tclose committed Feb 25, 2024
1 parent 1720ba6 commit 2186e06
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 19 deletions.
26 changes: 17 additions & 9 deletions pydra/engine/submitter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Handle execution backends."""

import asyncio
import typing as ty
import pickle
from uuid import uuid4
from .workers import WORKERS
from .workers import Worker, WORKERS
from .core import is_workflow
from .helpers import get_open_loop, load_and_run_async

Expand All @@ -16,24 +17,31 @@
class Submitter:
"""Send a task to the execution backend."""

def __init__(self, plugin="cf", **kwargs):
def __init__(self, plugin: ty.Union[str, Worker] = "cf", **kwargs):
"""
Initialize task submission.
Parameters
----------
plugin : :obj:`str`
The identifier of the execution backend.
plugin : :obj:`str` or :obj:`pydra.engine.core.Worker`
Either the identifier of the execution backend or the backend itself.
Default is ``cf`` (Concurrent Futures).
"""
self.loop = get_open_loop()
self._own_loop = not self.loop.is_running()
self.plugin = plugin
try:
self.worker = WORKERS[self.plugin](**kwargs)
except KeyError:
raise NotImplementedError(f"No worker for {self.plugin}")
if isinstance(plugin, Worker):
try:
self.plugin = plugin.plugin_name
except AttributeError:
raise ValueError("Worker must have a 'plugin_name' str attribute")

Check warning on line 37 in pydra/engine/submitter.py

View check run for this annotation

Codecov / codecov/patch

pydra/engine/submitter.py#L36-L37

Added lines #L36 - L37 were not covered by tests
self.worker = plugin
else:
self.plugin = plugin
try:
self.worker = WORKERS[self.plugin](**kwargs)
except KeyError:
raise NotImplementedError(f"No worker for {self.plugin}")

Check warning on line 44 in pydra/engine/submitter.py

View check run for this annotation

Codecov / codecov/patch

pydra/engine/submitter.py#L44

Added line #L44 was not covered by tests
self.worker.loop = self.loop

def __call__(self, runnable, cache_locations=None, rerun=False, environment=None):
Expand Down
46 changes: 45 additions & 1 deletion pydra/engine/tests/test_submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import re
import subprocess as sp
import time
import os
from unittest.mock import patch

import pytest

Expand All @@ -12,8 +14,9 @@
gen_basic_wf_with_threadcount,
gen_basic_wf_with_threadcount_concurrent,
)
from ..core import Workflow
from ..core import Workflow, TaskBase
from ..submitter import Submitter
from ..workers import SerialWorker
from ... import mark
from pathlib import Path
from datetime import datetime
Expand Down Expand Up @@ -612,3 +615,44 @@ def alter_input(x):
@mark.task
def to_tuple(x, y):
return (x, y)


class BYOAdd10Worker(SerialWorker):
"""A dummy worker that adds 1 to the output of the task"""

plugin_name = "byo_add_env_var"

async def exec_serial(self, runnable, rerun=False, environment=None):
if isinstance(runnable, TaskBase):
with patch.dict(os.environ, {"BYO_ADD_VAR": "10"}):
result = runnable._run(rerun, environment=environment)
return result
else: # it could be tuple that includes pickle files with tasks and inputs
return super().exec_serial(runnable, rerun, environment)


@mark.task
def add_env_var_task(x: int) -> int:
try:
var = int(os.environ["BYO_ADD_VAR"])
except KeyError:
var = 0
return x + var


def test_byo_worker():

task1 = add_env_var_task(x=1)

with Submitter(plugin=BYOAdd10Worker()) as sub:
assert sub.plugin == "byo_add_env_var"
result = task1(submitter=sub)

assert result.output.out == 11

task2 = add_env_var_task(x=2)

with Submitter(plugin="serial") as sub:
result = task2(submitter=sub)

assert result.output.out == 2
48 changes: 39 additions & 9 deletions pydra/engine/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ async def fetch_finished(self, futures):
class SerialWorker(Worker):
"""A worker to execute linearly."""

plugin_name = "serial"

def __init__(self, **kwargs):
"""Initialize worker."""
logger.debug("Initialize SerialWorker")
Expand Down Expand Up @@ -157,6 +159,8 @@ async def fetch_finished(self, futures):
class ConcurrentFuturesWorker(Worker):
"""A worker to execute in parallel using Python's concurrent futures."""

plugin_name = "cf"

def __init__(self, n_procs=None):
"""Initialize Worker."""
super().__init__()
Expand Down Expand Up @@ -192,6 +196,7 @@ def close(self):
class SlurmWorker(DistributedWorker):
"""A worker to execute tasks on SLURM systems."""

plugin_name = "slurm"
_cmd = "sbatch"
_sacct_re = re.compile(
"(?P<jobid>\\d*) +(?P<status>\\w*)\\+? +" "(?P<exit_code>\\d+):\\d+"
Expand Down Expand Up @@ -367,6 +372,8 @@ async def _verify_exit_code(self, jobid):
class SGEWorker(DistributedWorker):
"""A worker to execute tasks on SLURM systems."""

plugin_name = "sge"

_cmd = "qsub"
_sacct_re = re.compile(
"(?P<jobid>\\d*) +(?P<status>\\w*)\\+? +" "(?P<exit_code>\\d+):\\d+"
Expand Down Expand Up @@ -860,6 +867,8 @@ class DaskWorker(Worker):
This is an experimental implementation with limited testing.
"""

plugin_name = "dask"

def __init__(self, **kwargs):
"""Initialize Worker."""
super().__init__()
Expand Down Expand Up @@ -1039,14 +1048,35 @@ def close(self):
pass


class PsijLocalWorker(PsijWorker):
"""A worker to execute tasks using PSI/J on the local machine."""

plugin_name = "psij-local"

def __init__(self, **kwargs):
"""Initialize PsijLocalWorker."""
super().__init__(subtype="local", **kwargs)


class PsijSlurmWorker(PsijWorker):
"""A worker to execute tasks using PSI/J using SLURM."""

plugin_name = "psij-slurm"

def __init__(self, **kwargs):
"""Initialize PsijSlurmWorker."""
super().__init__(subtype="local", **kwargs)

Check warning on line 1068 in pydra/engine/workers.py

View check run for this annotation

Codecov / codecov/patch

pydra/engine/workers.py#L1068

Added line #L1068 was not covered by tests


WORKERS = {
"serial": SerialWorker,
"cf": ConcurrentFuturesWorker,
"slurm": SlurmWorker,
"dask": DaskWorker,
"sge": SGEWorker,
**{
"psij-" + subtype: lambda subtype=subtype: PsijWorker(subtype=subtype)
for subtype in ["local", "slurm"]
},
w.plugin_name: w
for w in (
SerialWorker,
ConcurrentFuturesWorker,
SlurmWorker,
DaskWorker,
SGEWorker,
PsijLocalWorker,
PsijSlurmWorker,
)
}

0 comments on commit 2186e06

Please sign in to comment.