diff --git a/distributed/client.py b/distributed/client.py index d3c07b59795..a0fb7b78852 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -74,6 +74,7 @@ from distributed.diagnostics.plugin import ( ForwardLoggingPlugin, NannyPlugin, + SchedulerUploadFile, UploadFile, WorkerPlugin, _get_plugin_name, @@ -3734,10 +3735,18 @@ def upload_file(self, filename, **kwargs): >>> from mylibrary import myfunc # doctest: +SKIP >>> L = client.map(myfunc, seq) # doctest: +SKIP """ - return self.register_worker_plugin( - UploadFile(filename), - name=filename + str(uuid.uuid4()), - ) + name = filename + str(uuid.uuid4()) + + async def _(): + results = await asyncio.gather( + self.register_scheduler_plugin( + SchedulerUploadFile(filename), name=name + ), + self.register_worker_plugin(UploadFile(filename), name=name), + ) + return results[1] # Results from workers upload + + return self.sync(_) async def _rebalance(self, futures=None, workers=None): if futures is not None: diff --git a/distributed/core.py b/distributed/core.py index 40372cb8fa8..75243e409ea 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -3,7 +3,9 @@ import asyncio import inspect import logging +import os import sys +import tempfile import threading import traceback import types @@ -24,6 +26,7 @@ from dask.utils import parse_timedelta from distributed import profile, protocol +from distributed.collections import LRU from distributed.comm import ( Comm, CommClosedError, @@ -35,16 +38,21 @@ ) from distributed.compatibility import PeriodicCallback from distributed.counter import Counter +from distributed.diskutils import WorkDir, WorkSpace from distributed.metrics import context_meter, time +from distributed.protocol import pickle from distributed.system_monitor import SystemMonitor from distributed.utils import ( NoOpAwaitable, get_traceback, has_keyword, + import_file, iscoroutinefunction, + offload, recursive_to_dict, truncate_exception, wait_for, + warn_on_duration, ) if TYPE_CHECKING: @@ -56,6 +64,21 @@ Coro = Coroutine[Any, Any, T] +cache_loads = LRU(maxsize=100) + + +def loads_function(bytes_object): + """Load a function from bytes, cache bytes""" + if len(bytes_object) < 100000: + try: + result = cache_loads[bytes_object] + except KeyError: + result = pickle.loads(bytes_object) + cache_loads[bytes_object] = result + return result + return pickle.loads(bytes_object) + + class Status(Enum): """ This Enum contains the various states a cluster, worker, scheduler and nanny can be @@ -303,6 +326,9 @@ class Server: default_ip = "" default_port = 0 + local_directory: str + _workspace: WorkSpace + _workdir: None | WorkDir def __init__( self, @@ -316,7 +342,39 @@ def __init__( connection_args=None, timeout=None, io_loop=None, + local_directory=None, + needs_workdir=True, ): + if local_directory is None: + local_directory = ( + dask.config.get("temporary-directory") or tempfile.gettempdir() + ) + + if "dask-scratch-space" not in str(local_directory): + local_directory = os.path.join(local_directory, "dask-scratch-space") + + self._original_local_dir = local_directory + + with warn_on_duration( + "1s", + "Creating scratch directories is taking a surprisingly long time. ({duration:.2f}s) " + "This is often due to running workers on a network file system. " + "Consider specifying a local-directory to point workers to write " + "scratch data to a local disk.", + ): + self._workspace = WorkSpace(local_directory) + + if not needs_workdir: # eg. Nanny will not need a WorkDir + self._workdir = None + self.local_directory = self._workspace.base_dir + else: + name = type(self).__name__.lower() + self._workdir = self._workspace.new_work_dir(prefix=f"{name}-") + self.local_directory = self._workdir.dir_path + + if self.local_directory not in sys.path: + sys.path.insert(0, self.local_directory) + if io_loop is not None: warnings.warn( "The io_loop kwarg to Server is ignored and will be deprecated", @@ -437,6 +495,35 @@ def set_thread_ident(): self.__stopped = False + async def upload_file( + self, filename: str, data: str | bytes, load: bool = True + ) -> dict[str, Any]: + out_filename = os.path.join(self.local_directory, filename) + + def func(data): + if isinstance(data, str): + data = data.encode() + with open(out_filename, "wb") as f: + f.write(data) + f.flush() + os.fsync(f.fileno()) + return data + + if len(data) < 10000: + data = func(data) + else: + data = await offload(func, data) + + if load: + try: + import_file(out_filename) + cache_loads.data.clear() + except Exception as e: + logger.exception(e) + raise e + + return {"status": "OK", "nbytes": len(data)} + def _shift_counters(self): for counter in self.counters.values(): counter.shift() @@ -573,6 +660,9 @@ def stop(self): if self.__stopped: return + if self._workdir is not None: + self._workdir.release() + self.monitor.close() self.__stopped = True diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 0ac63597ae7..cbd7f195162 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -311,6 +311,21 @@ def _get_plugin_name(plugin: SchedulerPlugin | WorkerPlugin | NannyPlugin) -> st return funcname(type(plugin)) + "-" + str(uuid.uuid4()) +class SchedulerUploadFile(SchedulerPlugin): + name = "upload_file" + + def __init__(self, filepath): + """ + Initialize the plugin by reading in the data from the given file. + """ + self.filename = os.path.basename(filepath) + with open(filepath, "rb") as f: + self.data = f.read() + + async def start(self, scheduler: Scheduler) -> None: + await scheduler.upload_file(self.filename, self.data) + + class PackageInstall(WorkerPlugin, abc.ABC): """Abstract parent class for a worker plugin to install a set of packages diff --git a/distributed/nanny.py b/distributed/nanny.py index 7238c171d42..5bc120c3ee7 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -8,7 +8,6 @@ import multiprocessing import os import shutil -import tempfile import threading import uuid import warnings @@ -39,7 +38,6 @@ error_message, ) from distributed.diagnostics.plugin import _get_plugin_name -from distributed.diskutils import WorkSpace from distributed.metrics import time from distributed.node import ServerNode from distributed.process import AsyncProcess @@ -175,19 +173,6 @@ def __init__( # type: ignore[no-untyped-def] assert isinstance(self.security, Security) self.connection_args = self.security.get_connection_args("worker") - if local_directory is None: - local_directory = ( - dask.config.get("temporary-directory") or tempfile.gettempdir() - ) - self._original_local_dir = local_directory - local_directory = os.path.join(local_directory, "dask-worker-space") - else: - self._original_local_dir = local_directory - - # Create directory if it doesn't exist and test for write access. - # In case of PermissionError, change the name. - self.local_directory = WorkSpace(local_directory).base_dir - self.preload = preload if self.preload is None: self.preload = dask.config.get("distributed.worker.preload") @@ -200,6 +185,25 @@ def __init__( # type: ignore[no-untyped-def] if preload_nanny_argv is None: preload_nanny_argv = dask.config.get("distributed.nanny.preload-argv") + handlers = { + "instantiate": self.instantiate, + "kill": self.kill, + "restart": self.restart, + "get_logs": self.get_logs, + # cannot call it 'close' on the rpc side for naming conflict + "terminate": self.close, + "close_gracefully": self.close_gracefully, + "run": self.run, + "plugin_add": self.plugin_add, + "plugin_remove": self.plugin_remove, + } + super().__init__( + handlers=handlers, + connection_args=self.connection_args, + local_directory=local_directory, + needs_workdir=False, + ) + self.preloads = preloading.process_preloads( self, preload_nanny, preload_nanny_argv, file_dir=self.local_directory ) @@ -221,7 +225,7 @@ def __init__( # type: ignore[no-untyped-def] protocol = protocol_address[0] self._given_worker_port = worker_port - self.nthreads = nthreads or CPU_COUNT + self.nthreads: int = nthreads or CPU_COUNT self.reconnect = reconnect self.validate = validate self.resources = resources @@ -256,23 +260,7 @@ def __init__( # type: ignore[no-untyped-def] stack.enter_context(silence_logging_cmgr(level=silence_logs)) self.silence_logs = silence_logs - handlers = { - "instantiate": self.instantiate, - "kill": self.kill, - "restart": self.restart, - "get_logs": self.get_logs, - # cannot call it 'close' on the rpc side for naming conflict - "terminate": self.close, - "close_gracefully": self.close_gracefully, - "run": self.run, - "plugin_add": self.plugin_add, - "plugin_remove": self.plugin_remove, - } - self.plugins: dict[str, NannyPlugin] = {} - - super().__init__(handlers=handlers, connection_args=self.connection_args) - self.scheduler = self.rpc(self.scheduler_addr) self.memory_manager = NannyMemoryManager(self, memory_limit=memory_limit) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index b37cea6e810..276f5ff621a 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1486,9 +1486,16 @@ def g(): with save_sys_modules(): for value in [123, 456]: - with tmp_text("myfile.py", f"def f():\n return {value}") as fn: + code = f"def f():\n return {value}" + with tmp_text("myfile.py", code) as fn: await c.upload_file(fn) + # Confirm workers _and_ scheduler got the file + for server in [s, a, b]: + file = pathlib.Path(server.local_directory).joinpath("myfile.py") + assert file.is_file() + assert file.read_text() == code + x = c.submit(g, pure=False) result = await x assert result == value diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index a2c37f331c8..76b28a303a0 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -363,23 +363,23 @@ async def test_local_directory(s): with dask.config.set(temporary_directory=fn): async with Nanny(s.address) as n: assert n.local_directory.startswith(fn) - assert "dask-worker-space" in n.local_directory - assert n.process.worker_dir.count("dask-worker-space") == 1 + assert "dask-scratch-space" in n.local_directory + assert n.process.worker_dir.count("dask-scratch-space") == 1 @pytest.mark.skipif(WINDOWS, reason="Need POSIX filesystem permissions and UIDs") @gen_cluster(nthreads=[]) async def test_unwriteable_dask_worker_space(s, tmp_path): - os.mkdir(f"{tmp_path}/dask-worker-space", mode=0o500) + os.mkdir(f"{tmp_path}/dask-scratch-space", mode=0o500) with pytest.raises(PermissionError): - open(f"{tmp_path}/dask-worker-space/tryme", "w") + open(f"{tmp_path}/dask-scratch-space/tryme", "w") with dask.config.set(temporary_directory=tmp_path): async with Nanny(s.address) as n: assert n.local_directory == os.path.join( - tmp_path, f"dask-worker-space-{os.getuid()}" + tmp_path, f"dask-scratch-space-{os.getuid()}" ) - assert n.process.worker_dir.count(f"dask-worker-space-{os.getuid()}") == 1 + assert n.process.worker_dir.count(f"dask-scratch-space-{os.getuid()}") == 1 def _noop(x): diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index e85fd0b34ef..d40d7fb9041 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -217,8 +217,9 @@ def g(): result = await future assert result == 123 - await c.close() await s.close() + await c.close() + assert not os.path.exists(os.path.join(a.local_directory, "foobar.py")) @@ -978,7 +979,7 @@ async def test_worker_dir(c, s, a, b): @gen_cluster(client=True, nthreads=[], config={"temporary-directory": None}) async def test_default_worker_dir(c, s): - expect = os.path.join(tempfile.gettempdir(), "dask-worker-space") + expect = os.path.join(tempfile.gettempdir(), "dask-scratch-space") async with Worker(s.address) as w: assert os.path.dirname(w.local_directory) == expect @@ -1385,7 +1386,7 @@ async def test_local_directory(s, tmp_path): with dask.config.set(temporary_directory=str(tmp_path)): async with Worker(s.address) as w: assert w.local_directory.startswith(str(tmp_path)) - assert "dask-worker-space" in w.local_directory + assert "dask-scratch-space" in w.local_directory @gen_cluster(nthreads=[]) @@ -1393,7 +1394,7 @@ async def test_local_directory_make_new_directory(s, tmp_path): async with Worker(s.address, local_directory=str(tmp_path / "foo" / "bar")) as w: assert w.local_directory.startswith(str(tmp_path)) assert "foo" in w.local_directory - assert "dask-worker-space" in w.local_directory + assert "dask-scratch-space" in w.local_directory @pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost") diff --git a/distributed/worker.py b/distributed/worker.py index 6818b7232f6..9c6de71fed8 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -11,7 +11,6 @@ import pathlib import random import sys -import tempfile import threading import warnings import weakref @@ -74,13 +73,14 @@ coerce_to_address, context_meter_to_server_digest, error_message, + loads_function, pingpong, ) from distributed.core import rpc as RPCType from distributed.core import send_recv from distributed.diagnostics import nvml, rmm from distributed.diagnostics.plugin import _get_plugin_name -from distributed.diskutils import WorkDir, WorkSpace +from distributed.diskutils import WorkSpace from distributed.http import get_handlers from distributed.metrics import context_meter, thread_time, time from distributed.node import ServerNode @@ -97,7 +97,6 @@ _maybe_complex, get_ip, has_arg, - import_file, in_async_call, is_python_shutting_down, iscoroutinefunction, @@ -111,7 +110,6 @@ silence_logging_cmgr, thread_state, wait_for, - warn_on_duration, ) from distributed.utils_comm import gather_from_workers, pack_data, retry_operation from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis @@ -434,8 +432,6 @@ class Worker(BaseWorker, ServerNode): latency: float profile_cycle_interval: float workspace: WorkSpace - _workdir: WorkDir - local_directory: str _client: Client | None bandwidth_workers: defaultdict[str, tuple[float, int]] bandwidth_types: defaultdict[type, tuple[float, int]] @@ -598,51 +594,9 @@ def __init__( self._setup_logging(logger) - if not local_directory: - local_directory = ( - dask.config.get("temporary-directory") or tempfile.gettempdir() - ) - local_directory = os.path.join(local_directory, "dask-worker-space") - - with warn_on_duration( - "1s", - "Creating scratch directories is taking a surprisingly long time. ({duration:.2f}s) " - "This is often due to running workers on a network file system. " - "Consider specifying a local-directory to point workers to write " - "scratch data to a local disk.", - ): - self._workspace = WorkSpace(local_directory) - self._workdir = self._workspace.new_work_dir(prefix="worker-") - self.local_directory = self._workdir.dir_path - - if not preload: - preload = dask.config.get("distributed.worker.preload") - if not preload_argv: - preload_argv = dask.config.get("distributed.worker.preload-argv") - assert preload is not None - assert preload_argv is not None - self.preloads = preloading.process_preloads( - self, preload, preload_argv, file_dir=self.local_directory - ) - self.death_timeout = parse_timedelta(death_timeout) - if scheduler_file: - cfg = json_load_robust(scheduler_file, timeout=self.death_timeout) - scheduler_addr = cfg["address"] - elif scheduler_ip is None and dask.config.get("scheduler-address", None): - scheduler_addr = dask.config.get("scheduler-address") - elif scheduler_port is None: - scheduler_addr = coerce_to_address(scheduler_ip) - else: - scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port)) self.contact_address = contact_address - if protocol is None: - protocol_address = scheduler_addr.split("://") - if len(protocol_address) == 2: - protocol = protocol_address[0] - assert protocol - self._start_port = port self._start_host = host if host: @@ -654,7 +608,6 @@ def __init__( f"got {host_address}" ) self._interface = interface - self._protocol = protocol nthreads = nthreads or CPU_COUNT if resources is None: @@ -702,9 +655,6 @@ def __init__( self.scheduler_delay = 0 self.stream_comms = {} - if self.local_directory not in sys.path: - sys.path.insert(0, self.local_directory) - self.plugins = {} self._pending_plugins = plugins @@ -769,8 +719,38 @@ def __init__( handlers=handlers, stream_handlers=stream_handlers, connection_args=self.connection_args, + local_directory=local_directory, **kwargs, ) + + if not preload: + preload = dask.config.get("distributed.worker.preload") + if not preload_argv: + preload_argv = dask.config.get("distributed.worker.preload-argv") + assert preload is not None + assert preload_argv is not None + + self.preloads = preloading.process_preloads( + self, preload, preload_argv, file_dir=self.local_directory + ) + + if scheduler_file: + cfg = json_load_robust(scheduler_file, timeout=self.death_timeout) + scheduler_addr = cfg["address"] + elif scheduler_ip is None and dask.config.get("scheduler-address", None): + scheduler_addr = dask.config.get("scheduler-address") + elif scheduler_port is None: + scheduler_addr = coerce_to_address(scheduler_ip) + else: + scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port)) + + if protocol is None: + protocol_address = scheduler_addr.split("://") + if len(protocol_address) == 2: + protocol = protocol_address[0] + assert protocol + self._protocol = protocol + self.memory_manager = WorkerMemoryManager( self, data=data, @@ -1284,35 +1264,6 @@ async def handle_scheduler(self, comm: Comm) -> None: finally: await self.close(reason="worker-handle-scheduler-connection-broken") - async def upload_file( - self, filename: str, data: str | bytes, load: bool = True - ) -> dict[str, Any]: - out_filename = os.path.join(self.local_directory, filename) - - def func(data): - if isinstance(data, str): - data = data.encode() - with open(out_filename, "wb") as f: - f.write(data) - f.flush() - os.fsync(f.fileno()) - return data - - if len(data) < 10000: - data = func(data) - else: - data = await offload(func, data) - - if load: - try: - import_file(out_filename) - cache_loads.data.clear() - except Exception as e: - logger.exception(e) - raise e - - return {"status": "OK", "nbytes": len(data)} - def keys(self) -> list[str]: return list(self.data) @@ -1598,7 +1549,6 @@ async def close( # type: ignore c.close() await self.scheduler.close_rpc() - self._workdir.release() self.stop_services() @@ -2916,21 +2866,6 @@ async def get_data_from_worker( job_counter = [0] -cache_loads = LRU(maxsize=100) - - -def loads_function(bytes_object): - """Load a function from bytes, cache bytes""" - if len(bytes_object) < 100000: - try: - result = cache_loads[bytes_object] - except KeyError: - result = pickle.loads(bytes_object) - cache_loads[bytes_object] = result - return result - return pickle.loads(bytes_object) - - @context_meter.meter("deserialize") def _deserialize(function=None, args=None, kwargs=None, task=NO_VALUE): """Deserialize task inputs and regularize to func, args, kwargs"""