Skip to content

Commit

Permalink
Client.upload_file send to both Workers and Scheduler and rename sc…
Browse files Browse the repository at this point in the history
…ratch directory (#7802)
  • Loading branch information
milesgranger authored May 9, 2023
1 parent 581e91d commit 09ea0f5
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 144 deletions.
17 changes: 13 additions & 4 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from distributed.diagnostics.plugin import (
ForwardLoggingPlugin,
NannyPlugin,
SchedulerUploadFile,
UploadFile,
WorkerPlugin,
_get_plugin_name,
Expand Down Expand Up @@ -3734,10 +3735,18 @@ def upload_file(self, filename, **kwargs):
>>> from mylibrary import myfunc # doctest: +SKIP
>>> L = client.map(myfunc, seq) # doctest: +SKIP
"""
return self.register_worker_plugin(
UploadFile(filename),
name=filename + str(uuid.uuid4()),
)
name = filename + str(uuid.uuid4())

async def _():
results = await asyncio.gather(
self.register_scheduler_plugin(
SchedulerUploadFile(filename), name=name
),
self.register_worker_plugin(UploadFile(filename), name=name),
)
return results[1] # Results from workers upload

return self.sync(_)

async def _rebalance(self, futures=None, workers=None):
if futures is not None:
Expand Down
90 changes: 90 additions & 0 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import asyncio
import inspect
import logging
import os
import sys
import tempfile
import threading
import traceback
import types
Expand All @@ -24,6 +26,7 @@
from dask.utils import parse_timedelta

from distributed import profile, protocol
from distributed.collections import LRU
from distributed.comm import (
Comm,
CommClosedError,
Expand All @@ -35,16 +38,21 @@
)
from distributed.compatibility import PeriodicCallback
from distributed.counter import Counter
from distributed.diskutils import WorkDir, WorkSpace
from distributed.metrics import context_meter, time
from distributed.protocol import pickle
from distributed.system_monitor import SystemMonitor
from distributed.utils import (
NoOpAwaitable,
get_traceback,
has_keyword,
import_file,
iscoroutinefunction,
offload,
recursive_to_dict,
truncate_exception,
wait_for,
warn_on_duration,
)

if TYPE_CHECKING:
Expand All @@ -56,6 +64,21 @@
Coro = Coroutine[Any, Any, T]


cache_loads = LRU(maxsize=100)


def loads_function(bytes_object):
"""Load a function from bytes, cache bytes"""
if len(bytes_object) < 100000:
try:
result = cache_loads[bytes_object]
except KeyError:
result = pickle.loads(bytes_object)
cache_loads[bytes_object] = result
return result
return pickle.loads(bytes_object)


class Status(Enum):
"""
This Enum contains the various states a cluster, worker, scheduler and nanny can be
Expand Down Expand Up @@ -303,6 +326,9 @@ class Server:

default_ip = ""
default_port = 0
local_directory: str
_workspace: WorkSpace
_workdir: None | WorkDir

def __init__(
self,
Expand All @@ -316,7 +342,39 @@ def __init__(
connection_args=None,
timeout=None,
io_loop=None,
local_directory=None,
needs_workdir=True,
):
if local_directory is None:
local_directory = (
dask.config.get("temporary-directory") or tempfile.gettempdir()
)

if "dask-scratch-space" not in str(local_directory):
local_directory = os.path.join(local_directory, "dask-scratch-space")

self._original_local_dir = local_directory

with warn_on_duration(
"1s",
"Creating scratch directories is taking a surprisingly long time. ({duration:.2f}s) "
"This is often due to running workers on a network file system. "
"Consider specifying a local-directory to point workers to write "
"scratch data to a local disk.",
):
self._workspace = WorkSpace(local_directory)

if not needs_workdir: # eg. Nanny will not need a WorkDir
self._workdir = None
self.local_directory = self._workspace.base_dir
else:
name = type(self).__name__.lower()
self._workdir = self._workspace.new_work_dir(prefix=f"{name}-")
self.local_directory = self._workdir.dir_path

if self.local_directory not in sys.path:
sys.path.insert(0, self.local_directory)

if io_loop is not None:
warnings.warn(
"The io_loop kwarg to Server is ignored and will be deprecated",
Expand Down Expand Up @@ -437,6 +495,35 @@ def set_thread_ident():

self.__stopped = False

async def upload_file(
self, filename: str, data: str | bytes, load: bool = True
) -> dict[str, Any]:
out_filename = os.path.join(self.local_directory, filename)

def func(data):
if isinstance(data, str):
data = data.encode()
with open(out_filename, "wb") as f:
f.write(data)
f.flush()
os.fsync(f.fileno())
return data

if len(data) < 10000:
data = func(data)
else:
data = await offload(func, data)

if load:
try:
import_file(out_filename)
cache_loads.data.clear()
except Exception as e:
logger.exception(e)
raise e

return {"status": "OK", "nbytes": len(data)}

def _shift_counters(self):
for counter in self.counters.values():
counter.shift()
Expand Down Expand Up @@ -573,6 +660,9 @@ def stop(self):
if self.__stopped:
return

if self._workdir is not None:
self._workdir.release()

self.monitor.close()

self.__stopped = True
Expand Down
15 changes: 15 additions & 0 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,21 @@ def _get_plugin_name(plugin: SchedulerPlugin | WorkerPlugin | NannyPlugin) -> st
return funcname(type(plugin)) + "-" + str(uuid.uuid4())


class SchedulerUploadFile(SchedulerPlugin):
name = "upload_file"

def __init__(self, filepath):
"""
Initialize the plugin by reading in the data from the given file.
"""
self.filename = os.path.basename(filepath)
with open(filepath, "rb") as f:
self.data = f.read()

async def start(self, scheduler: Scheduler) -> None:
await scheduler.upload_file(self.filename, self.data)


class PackageInstall(WorkerPlugin, abc.ABC):
"""Abstract parent class for a worker plugin to install a set of packages
Expand Down
52 changes: 20 additions & 32 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import multiprocessing
import os
import shutil
import tempfile
import threading
import uuid
import warnings
Expand Down Expand Up @@ -39,7 +38,6 @@
error_message,
)
from distributed.diagnostics.plugin import _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.metrics import time
from distributed.node import ServerNode
from distributed.process import AsyncProcess
Expand Down Expand Up @@ -175,19 +173,6 @@ def __init__( # type: ignore[no-untyped-def]
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")

if local_directory is None:
local_directory = (
dask.config.get("temporary-directory") or tempfile.gettempdir()
)
self._original_local_dir = local_directory
local_directory = os.path.join(local_directory, "dask-worker-space")
else:
self._original_local_dir = local_directory

# Create directory if it doesn't exist and test for write access.
# In case of PermissionError, change the name.
self.local_directory = WorkSpace(local_directory).base_dir

self.preload = preload
if self.preload is None:
self.preload = dask.config.get("distributed.worker.preload")
Expand All @@ -200,6 +185,25 @@ def __init__( # type: ignore[no-untyped-def]
if preload_nanny_argv is None:
preload_nanny_argv = dask.config.get("distributed.nanny.preload-argv")

handlers = {
"instantiate": self.instantiate,
"kill": self.kill,
"restart": self.restart,
"get_logs": self.get_logs,
# cannot call it 'close' on the rpc side for naming conflict
"terminate": self.close,
"close_gracefully": self.close_gracefully,
"run": self.run,
"plugin_add": self.plugin_add,
"plugin_remove": self.plugin_remove,
}
super().__init__(
handlers=handlers,
connection_args=self.connection_args,
local_directory=local_directory,
needs_workdir=False,
)

self.preloads = preloading.process_preloads(
self, preload_nanny, preload_nanny_argv, file_dir=self.local_directory
)
Expand All @@ -221,7 +225,7 @@ def __init__( # type: ignore[no-untyped-def]
protocol = protocol_address[0]

self._given_worker_port = worker_port
self.nthreads = nthreads or CPU_COUNT
self.nthreads: int = nthreads or CPU_COUNT
self.reconnect = reconnect
self.validate = validate
self.resources = resources
Expand Down Expand Up @@ -256,23 +260,7 @@ def __init__( # type: ignore[no-untyped-def]
stack.enter_context(silence_logging_cmgr(level=silence_logs))
self.silence_logs = silence_logs

handlers = {
"instantiate": self.instantiate,
"kill": self.kill,
"restart": self.restart,
"get_logs": self.get_logs,
# cannot call it 'close' on the rpc side for naming conflict
"terminate": self.close,
"close_gracefully": self.close_gracefully,
"run": self.run,
"plugin_add": self.plugin_add,
"plugin_remove": self.plugin_remove,
}

self.plugins: dict[str, NannyPlugin] = {}

super().__init__(handlers=handlers, connection_args=self.connection_args)

self.scheduler = self.rpc(self.scheduler_addr)
self.memory_manager = NannyMemoryManager(self, memory_limit=memory_limit)

Expand Down
9 changes: 8 additions & 1 deletion distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,9 +1486,16 @@ def g():

with save_sys_modules():
for value in [123, 456]:
with tmp_text("myfile.py", f"def f():\n return {value}") as fn:
code = f"def f():\n return {value}"
with tmp_text("myfile.py", code) as fn:
await c.upload_file(fn)

# Confirm workers _and_ scheduler got the file
for server in [s, a, b]:
file = pathlib.Path(server.local_directory).joinpath("myfile.py")
assert file.is_file()
assert file.read_text() == code

x = c.submit(g, pure=False)
result = await x
assert result == value
Expand Down
12 changes: 6 additions & 6 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,23 +363,23 @@ async def test_local_directory(s):
with dask.config.set(temporary_directory=fn):
async with Nanny(s.address) as n:
assert n.local_directory.startswith(fn)
assert "dask-worker-space" in n.local_directory
assert n.process.worker_dir.count("dask-worker-space") == 1
assert "dask-scratch-space" in n.local_directory
assert n.process.worker_dir.count("dask-scratch-space") == 1


@pytest.mark.skipif(WINDOWS, reason="Need POSIX filesystem permissions and UIDs")
@gen_cluster(nthreads=[])
async def test_unwriteable_dask_worker_space(s, tmp_path):
os.mkdir(f"{tmp_path}/dask-worker-space", mode=0o500)
os.mkdir(f"{tmp_path}/dask-scratch-space", mode=0o500)
with pytest.raises(PermissionError):
open(f"{tmp_path}/dask-worker-space/tryme", "w")
open(f"{tmp_path}/dask-scratch-space/tryme", "w")

with dask.config.set(temporary_directory=tmp_path):
async with Nanny(s.address) as n:
assert n.local_directory == os.path.join(
tmp_path, f"dask-worker-space-{os.getuid()}"
tmp_path, f"dask-scratch-space-{os.getuid()}"
)
assert n.process.worker_dir.count(f"dask-worker-space-{os.getuid()}") == 1
assert n.process.worker_dir.count(f"dask-scratch-space-{os.getuid()}") == 1


def _noop(x):
Expand Down
9 changes: 5 additions & 4 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,9 @@ def g():
result = await future
assert result == 123

await c.close()
await s.close()
await c.close()

assert not os.path.exists(os.path.join(a.local_directory, "foobar.py"))


Expand Down Expand Up @@ -978,7 +979,7 @@ async def test_worker_dir(c, s, a, b):

@gen_cluster(client=True, nthreads=[], config={"temporary-directory": None})
async def test_default_worker_dir(c, s):
expect = os.path.join(tempfile.gettempdir(), "dask-worker-space")
expect = os.path.join(tempfile.gettempdir(), "dask-scratch-space")

async with Worker(s.address) as w:
assert os.path.dirname(w.local_directory) == expect
Expand Down Expand Up @@ -1385,15 +1386,15 @@ async def test_local_directory(s, tmp_path):
with dask.config.set(temporary_directory=str(tmp_path)):
async with Worker(s.address) as w:
assert w.local_directory.startswith(str(tmp_path))
assert "dask-worker-space" in w.local_directory
assert "dask-scratch-space" in w.local_directory


@gen_cluster(nthreads=[])
async def test_local_directory_make_new_directory(s, tmp_path):
async with Worker(s.address, local_directory=str(tmp_path / "foo" / "bar")) as w:
assert w.local_directory.startswith(str(tmp_path))
assert "foo" in w.local_directory
assert "dask-worker-space" in w.local_directory
assert "dask-scratch-space" in w.local_directory


@pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost")
Expand Down
Loading

0 comments on commit 09ea0f5

Please sign in to comment.