Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import threading
import traceback
import uuid
import warnings
import weakref
from collections import defaultdict
from collections.abc import Container
from contextlib import suppress
from enum import Enum
from functools import partial
from typing import ClassVar
from typing import Callable, ClassVar

import tblib
from tlz import merge
Expand Down Expand Up @@ -98,6 +99,22 @@ def _raise(*args, **kwargs):
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")


def _expects_comm(func: Callable) -> bool:
sig = inspect.signature(func)
params = list(sig.parameters)
if params and params[0] == "comm":
return True
if params and params[0] == "stream":
warnings.warn(
"Calling the first arugment of a RPC handler `stream` is "
"deprecated. Defining this argument is optional. Either remove the "
f"arugment or rename it to `comm` in {func}.",
FutureWarning,
)
return True
return False


class Server:
"""Dask Distributed Server

Expand Down Expand Up @@ -379,7 +396,7 @@ def port(self):
_, self._port = get_address_host_port(self.address)
return self._port

def identity(self, comm=None) -> dict[str, str]:
def identity(self) -> dict[str, str]:
return {"type": type(self).__name__, "id": self.id}

def _to_dict(
Expand All @@ -404,7 +421,7 @@ def _to_dict(
info = {k: v for k, v in info.items() if k not in exclude}
return recursive_to_dict(info, exclude=exclude)

def echo(self, comm=None, data=None):
def echo(self, data=None):
return data

async def listen(self, port_or_addr=None, allow_offload=True, **kwargs):
Expand Down Expand Up @@ -514,7 +531,10 @@ async def handle_comm(self, comm):

logger.debug("Calling into handler %s", handler.__name__)
try:
result = handler(comm, **msg)
if _expects_comm(handler):
result = handler(comm, **msg)
else:
result = handler(**msg)
if inspect.isawaitable(result):
result = asyncio.ensure_future(result)
self._ongoing_coroutines.add(result)
Expand Down
8 changes: 4 additions & 4 deletions distributed/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, scheduler):

self.scheduler.extensions["events"] = self

async def event_wait(self, comm=None, name=None, timeout=None):
async def event_wait(self, name=None, timeout=None):
"""Wait until the event is set to true.
Returns false, when this did not happen in the given time
and true otherwise.
Expand Down Expand Up @@ -89,7 +89,7 @@ async def event_wait(self, comm=None, name=None, timeout=None):

return True

def event_set(self, comm=None, name=None):
def event_set(self, name=None):
"""Set the event with the given name to true.

All waiters on this event will be notified.
Expand All @@ -100,7 +100,7 @@ def event_set(self, comm=None, name=None):
# we set the event to true
self._events[name].set()

def event_clear(self, comm=None, name=None):
def event_clear(self, name=None):
"""Set the event with the given name to false."""
with log_errors():
name = self._normalize_name(name)
Expand All @@ -121,7 +121,7 @@ def event_clear(self, comm=None, name=None):
event = self._events[name]
event.clear()

def event_is_set(self, comm=None, name=None):
def event_is_set(self, name=None):
with log_errors():
name = self._normalize_name(name)
# the default flag value is false
Expand Down
4 changes: 2 additions & 2 deletions distributed/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, scheduler):

self.scheduler.extensions["locks"] = self

async def acquire(self, comm=None, name=None, id=None, timeout=None):
async def acquire(self, name=None, id=None, timeout=None):
with log_errors():
if isinstance(name, list):
name = tuple(name)
Expand Down Expand Up @@ -60,7 +60,7 @@ async def acquire(self, comm=None, name=None, id=None, timeout=None):
self.ids[name] = id
return result

def release(self, comm=None, name=None, id=None):
def release(self, name=None, id=None):
with log_errors():
if isinstance(name, list):
name = tuple(name)
Expand Down
6 changes: 2 additions & 4 deletions distributed/multi_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ def _refain_locks(self, locks, id):
for waiter in waiters_ready:
self.scheduler.loop.add_callback(self.events[waiter].set)

async def acquire(
self, comm=None, locks=None, id=None, timeout=None, num_locks=None
):
async def acquire(self, locks=None, id=None, timeout=None, num_locks=None):
with log_errors():
if not self._request_locks(locks, id, num_locks):
assert id not in self.events
Expand All @@ -134,7 +132,7 @@ async def acquire(
assert self.requests_left[id] == 0
return True

def release(self, comm=None, id=None):
def release(self, id=None):
with log_errors():
self._refain_locks(self.requests[id], id)

Expand Down
16 changes: 8 additions & 8 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ async def start(self):

return self

async def kill(self, comm=None, timeout=2):
async def kill(self, timeout=2):
"""Kill the local worker process

Blocks until both the process is down and the scheduler is properly
Expand All @@ -355,7 +355,7 @@ async def kill(self, comm=None, timeout=2):
deadline = time() + timeout
await self.process.kill(timeout=0.8 * (deadline - time()))

async def instantiate(self, comm=None) -> Status:
async def instantiate(self) -> Status:
"""Start a local worker process

Blocks until the process is up and the scheduler is properly informed
Expand Down Expand Up @@ -420,7 +420,7 @@ async def instantiate(self, comm=None) -> Status:
raise
return result

async def plugin_add(self, comm=None, plugin=None, name=None):
async def plugin_add(self, plugin=None, name=None):
with log_errors(pdb=False):
if isinstance(plugin, bytes):
plugin = pickle.loads(plugin)
Expand All @@ -446,7 +446,7 @@ async def plugin_add(self, comm=None, plugin=None, name=None):

return {"status": "OK"}

async def plugin_remove(self, comm=None, name=None):
async def plugin_remove(self, name=None):
with log_errors(pdb=False):
logger.info(f"Removing Nanny plugin {name}")
try:
Expand All @@ -461,7 +461,7 @@ async def plugin_remove(self, comm=None, name=None):

return {"status": "OK"}

async def restart(self, comm=None, timeout=30, executor_wait=True):
async def restart(self, timeout=30, executor_wait=True):
async def _():
if self.process is not None:
await self.kill()
Expand Down Expand Up @@ -515,8 +515,8 @@ def memory_monitor(self):
def is_alive(self):
return self.process is not None and self.process.is_alive()

def run(self, *args, **kwargs):
return run(self, *args, **kwargs)
def run(self, comm, *args, **kwargs):
return run(self, comm, *args, **kwargs)

def _on_exit_sync(self, exitcode):
self.loop.add_callback(self._on_exit, exitcode)
Expand Down Expand Up @@ -560,7 +560,7 @@ def _close(self, *args, **kwargs):
warnings.warn("Worker._close has moved to Worker.close", stacklevel=2)
return self.close(*args, **kwargs)

def close_gracefully(self, comm=None):
def close_gracefully(self):
"""
A signal that we shouldn't try to restart workers if they go away

Expand Down
7 changes: 3 additions & 4 deletions distributed/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class ServerNode(Server):
# XXX avoid inheriting from Server? there is some large potential for confusion
# between base and derived attribute namespaces...

def versions(self, comm=None, packages=None):
def versions(self, packages=None):
return get_versions(packages=packages)

def start_services(self, default_listen_ip):
Expand Down Expand Up @@ -87,7 +87,7 @@ def _setup_logging(self, logger):
logger.addHandler(self._deque_handler)
weakref.finalize(self, logger.removeHandler, self._deque_handler)

def get_logs(self, comm=None, start=None, n=None, timestamps=False):
def get_logs(self, start=0, n=None, timestamps=False):
"""
Fetch log entries for this node

Expand All @@ -105,8 +105,7 @@ def get_logs(self, comm=None, start=None, n=None, timestamps=False):
List of tuples containing the log level, message, and (optional) timestamp for each filtered entry
"""
deque_handler = self._deque_handler
if start is None:
start = -1

L = []
for count, msg in enumerate(deque_handler.deque):
if n and count >= n or msg.created < start:
Expand Down
8 changes: 3 additions & 5 deletions distributed/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,15 @@ def __init__(self, scheduler):
self.scheduler.handlers.update(handlers)
self.scheduler.extensions["publish"] = self

def put(
self, comm=None, keys=None, data=None, name=None, override=False, client=None
):
def put(self, keys=None, data=None, name=None, override=False, client=None):
with log_errors():
if not override and name in self.datasets:
raise KeyError("Dataset %s already exists" % name)
self.scheduler.client_desires_keys(keys, f"published-{stringify(name)}")
self.datasets[name] = {"data": data, "keys": keys}
return {"status": "OK", "name": name}

def delete(self, comm=None, name=None):
def delete(self, name=None):
with log_errors():
out = self.datasets.pop(name, {"keys": []})
self.scheduler.client_releases_keys(
Expand All @@ -49,7 +47,7 @@ def list(self, *args):
with log_errors():
return list(sorted(self.datasets.keys(), key=str))

def get(self, stream, name=None, client=None):
def get(self, name=None, client=None):
with log_errors():
return self.datasets.get(name, None)

Expand Down
8 changes: 4 additions & 4 deletions distributed/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, scheduler):

self.scheduler.extensions["pubsub"] = self

def add_publisher(self, comm=None, name=None, worker=None):
def add_publisher(self, name=None, worker=None):
logger.debug("Add publisher: %s %s", name, worker)
self.publishers[name].add(worker)
return {
Expand All @@ -45,7 +45,7 @@ def add_publisher(self, comm=None, name=None, worker=None):
and len(self.client_subscribers[name]) > 0,
}

def add_subscriber(self, comm=None, name=None, worker=None, client=None):
def add_subscriber(self, name=None, worker=None, client=None):
if worker:
logger.debug("Add worker subscriber: %s %s", name, worker)
self.subscribers[name].add(worker)
Expand All @@ -63,7 +63,7 @@ def add_subscriber(self, comm=None, name=None, worker=None, client=None):
)
self.client_subscribers[name].add(client)

def remove_publisher(self, comm=None, name=None, worker=None):
def remove_publisher(self, name=None, worker=None):
if worker in self.publishers[name]:
logger.debug("Remove publisher: %s %s", name, worker)
self.publishers[name].remove(worker)
Expand All @@ -72,7 +72,7 @@ def remove_publisher(self, comm=None, name=None, worker=None):
del self.subscribers[name]
del self.publishers[name]

def remove_subscriber(self, comm=None, name=None, worker=None, client=None):
def remove_subscriber(self, name=None, worker=None, client=None):
if worker:
logger.debug("Remove worker subscriber: %s %s", name, worker)
self.subscribers[name].remove(worker)
Expand Down
12 changes: 5 additions & 7 deletions distributed/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ def __init__(self, scheduler):

self.scheduler.extensions["queues"] = self

def create(self, comm=None, name=None, client=None, maxsize=0):
def create(self, name=None, client=None, maxsize=0):
logger.debug(f"Queue name: {name}")
if name not in self.queues:
self.queues[name] = asyncio.Queue(maxsize=maxsize)
self.client_refcount[name] = 1
else:
self.client_refcount[name] += 1

def release(self, comm=None, name=None, client=None):
def release(self, name=None, client=None):
if name not in self.queues:
return

Expand All @@ -65,9 +65,7 @@ def release(self, comm=None, name=None, client=None):
if keys:
self.scheduler.client_releases_keys(keys=keys, client="queue-%s" % name)

async def put(
self, comm=None, name=None, key=None, data=None, client=None, timeout=None
):
async def put(self, name=None, key=None, data=None, client=None, timeout=None):
if key is not None:
record = {"type": "Future", "value": key}
self.future_refcount[name, key] += 1
Expand All @@ -82,7 +80,7 @@ def future_release(self, name=None, key=None, client=None):
self.scheduler.client_releases_keys(keys=[key], client="queue-%s" % name)
del self.future_refcount[name, key]

async def get(self, comm=None, name=None, client=None, timeout=None, batch=False):
async def get(self, name=None, client=None, timeout=None, batch=False):
def process(record):
"""Add task status if known"""
if record["type"] == "Future":
Expand Down Expand Up @@ -122,7 +120,7 @@ def process(record):
record = process(record)
return record

def qsize(self, comm=None, name=None, client=None):
def qsize(self, name=None, client=None):
return self.queues[name].qsize()


Expand Down
Loading