diff --git a/distributed/core.py b/distributed/core.py index c75e1451935..fdad4cf1439 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -7,13 +7,14 @@ import threading import traceback import uuid +import warnings import weakref from collections import defaultdict from collections.abc import Container from contextlib import suppress from enum import Enum from functools import partial -from typing import ClassVar +from typing import Callable, ClassVar import tblib from tlz import merge @@ -98,6 +99,22 @@ def _raise(*args, **kwargs): LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") +def _expects_comm(func: Callable) -> bool: + sig = inspect.signature(func) + params = list(sig.parameters) + if params and params[0] == "comm": + return True + if params and params[0] == "stream": + warnings.warn( + "Calling the first arugment of a RPC handler `stream` is " + "deprecated. Defining this argument is optional. Either remove the " + f"arugment or rename it to `comm` in {func}.", + FutureWarning, + ) + return True + return False + + class Server: """Dask Distributed Server @@ -379,7 +396,7 @@ def port(self): _, self._port = get_address_host_port(self.address) return self._port - def identity(self, comm=None) -> dict[str, str]: + def identity(self) -> dict[str, str]: return {"type": type(self).__name__, "id": self.id} def _to_dict( @@ -404,7 +421,7 @@ def _to_dict( info = {k: v for k, v in info.items() if k not in exclude} return recursive_to_dict(info, exclude=exclude) - def echo(self, comm=None, data=None): + def echo(self, data=None): return data async def listen(self, port_or_addr=None, allow_offload=True, **kwargs): @@ -514,7 +531,10 @@ async def handle_comm(self, comm): logger.debug("Calling into handler %s", handler.__name__) try: - result = handler(comm, **msg) + if _expects_comm(handler): + result = handler(comm, **msg) + else: + result = handler(**msg) if inspect.isawaitable(result): result = asyncio.ensure_future(result) self._ongoing_coroutines.add(result) diff --git a/distributed/event.py b/distributed/event.py index 882281692c4..2f9a7787af7 100644 --- a/distributed/event.py +++ b/distributed/event.py @@ -60,7 +60,7 @@ def __init__(self, scheduler): self.scheduler.extensions["events"] = self - async def event_wait(self, comm=None, name=None, timeout=None): + async def event_wait(self, name=None, timeout=None): """Wait until the event is set to true. Returns false, when this did not happen in the given time and true otherwise. @@ -89,7 +89,7 @@ async def event_wait(self, comm=None, name=None, timeout=None): return True - def event_set(self, comm=None, name=None): + def event_set(self, name=None): """Set the event with the given name to true. All waiters on this event will be notified. @@ -100,7 +100,7 @@ def event_set(self, comm=None, name=None): # we set the event to true self._events[name].set() - def event_clear(self, comm=None, name=None): + def event_clear(self, name=None): """Set the event with the given name to false.""" with log_errors(): name = self._normalize_name(name) @@ -121,7 +121,7 @@ def event_clear(self, comm=None, name=None): event = self._events[name] event.clear() - def event_is_set(self, comm=None, name=None): + def event_is_set(self, name=None): with log_errors(): name = self._normalize_name(name) # the default flag value is false diff --git a/distributed/lock.py b/distributed/lock.py index d8b50ac8d03..a4ed4926325 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -32,7 +32,7 @@ def __init__(self, scheduler): self.scheduler.extensions["locks"] = self - async def acquire(self, comm=None, name=None, id=None, timeout=None): + async def acquire(self, name=None, id=None, timeout=None): with log_errors(): if isinstance(name, list): name = tuple(name) @@ -60,7 +60,7 @@ async def acquire(self, comm=None, name=None, id=None, timeout=None): self.ids[name] = id return result - def release(self, comm=None, name=None, id=None): + def release(self, name=None, id=None): with log_errors(): if isinstance(name, list): name = tuple(name) diff --git a/distributed/multi_lock.py b/distributed/multi_lock.py index 5d26653cd03..7141a300252 100644 --- a/distributed/multi_lock.py +++ b/distributed/multi_lock.py @@ -112,9 +112,7 @@ def _refain_locks(self, locks, id): for waiter in waiters_ready: self.scheduler.loop.add_callback(self.events[waiter].set) - async def acquire( - self, comm=None, locks=None, id=None, timeout=None, num_locks=None - ): + async def acquire(self, locks=None, id=None, timeout=None, num_locks=None): with log_errors(): if not self._request_locks(locks, id, num_locks): assert id not in self.events @@ -134,7 +132,7 @@ async def acquire( assert self.requests_left[id] == 0 return True - def release(self, comm=None, id=None): + def release(self, id=None): with log_errors(): self._refain_locks(self.requests[id], id) diff --git a/distributed/nanny.py b/distributed/nanny.py index 314583212a3..55c7838d8f3 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -342,7 +342,7 @@ async def start(self): return self - async def kill(self, comm=None, timeout=2): + async def kill(self, timeout=2): """Kill the local worker process Blocks until both the process is down and the scheduler is properly @@ -355,7 +355,7 @@ async def kill(self, comm=None, timeout=2): deadline = time() + timeout await self.process.kill(timeout=0.8 * (deadline - time())) - async def instantiate(self, comm=None) -> Status: + async def instantiate(self) -> Status: """Start a local worker process Blocks until the process is up and the scheduler is properly informed @@ -420,7 +420,7 @@ async def instantiate(self, comm=None) -> Status: raise return result - async def plugin_add(self, comm=None, plugin=None, name=None): + async def plugin_add(self, plugin=None, name=None): with log_errors(pdb=False): if isinstance(plugin, bytes): plugin = pickle.loads(plugin) @@ -446,7 +446,7 @@ async def plugin_add(self, comm=None, plugin=None, name=None): return {"status": "OK"} - async def plugin_remove(self, comm=None, name=None): + async def plugin_remove(self, name=None): with log_errors(pdb=False): logger.info(f"Removing Nanny plugin {name}") try: @@ -461,7 +461,7 @@ async def plugin_remove(self, comm=None, name=None): return {"status": "OK"} - async def restart(self, comm=None, timeout=30, executor_wait=True): + async def restart(self, timeout=30, executor_wait=True): async def _(): if self.process is not None: await self.kill() @@ -515,8 +515,8 @@ def memory_monitor(self): def is_alive(self): return self.process is not None and self.process.is_alive() - def run(self, *args, **kwargs): - return run(self, *args, **kwargs) + def run(self, comm, *args, **kwargs): + return run(self, comm, *args, **kwargs) def _on_exit_sync(self, exitcode): self.loop.add_callback(self._on_exit, exitcode) @@ -560,7 +560,7 @@ def _close(self, *args, **kwargs): warnings.warn("Worker._close has moved to Worker.close", stacklevel=2) return self.close(*args, **kwargs) - def close_gracefully(self, comm=None): + def close_gracefully(self): """ A signal that we shouldn't try to restart workers if they go away diff --git a/distributed/node.py b/distributed/node.py index 8e59bafb22c..b7b5b02031c 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -25,7 +25,7 @@ class ServerNode(Server): # XXX avoid inheriting from Server? there is some large potential for confusion # between base and derived attribute namespaces... - def versions(self, comm=None, packages=None): + def versions(self, packages=None): return get_versions(packages=packages) def start_services(self, default_listen_ip): @@ -87,7 +87,7 @@ def _setup_logging(self, logger): logger.addHandler(self._deque_handler) weakref.finalize(self, logger.removeHandler, self._deque_handler) - def get_logs(self, comm=None, start=None, n=None, timestamps=False): + def get_logs(self, start=0, n=None, timestamps=False): """ Fetch log entries for this node @@ -105,8 +105,7 @@ def get_logs(self, comm=None, start=None, n=None, timestamps=False): List of tuples containing the log level, message, and (optional) timestamp for each filtered entry """ deque_handler = self._deque_handler - if start is None: - start = -1 + L = [] for count, msg in enumerate(deque_handler.deque): if n and count >= n or msg.created < start: diff --git a/distributed/publish.py b/distributed/publish.py index 066a073f598..0310c826a1d 100644 --- a/distributed/publish.py +++ b/distributed/publish.py @@ -28,9 +28,7 @@ def __init__(self, scheduler): self.scheduler.handlers.update(handlers) self.scheduler.extensions["publish"] = self - def put( - self, comm=None, keys=None, data=None, name=None, override=False, client=None - ): + def put(self, keys=None, data=None, name=None, override=False, client=None): with log_errors(): if not override and name in self.datasets: raise KeyError("Dataset %s already exists" % name) @@ -38,7 +36,7 @@ def put( self.datasets[name] = {"data": data, "keys": keys} return {"status": "OK", "name": name} - def delete(self, comm=None, name=None): + def delete(self, name=None): with log_errors(): out = self.datasets.pop(name, {"keys": []}) self.scheduler.client_releases_keys( @@ -49,7 +47,7 @@ def list(self, *args): with log_errors(): return list(sorted(self.datasets.keys(), key=str)) - def get(self, stream, name=None, client=None): + def get(self, name=None, client=None): with log_errors(): return self.datasets.get(name, None) diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 20822e145ca..6d5ab7924ed 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -36,7 +36,7 @@ def __init__(self, scheduler): self.scheduler.extensions["pubsub"] = self - def add_publisher(self, comm=None, name=None, worker=None): + def add_publisher(self, name=None, worker=None): logger.debug("Add publisher: %s %s", name, worker) self.publishers[name].add(worker) return { @@ -45,7 +45,7 @@ def add_publisher(self, comm=None, name=None, worker=None): and len(self.client_subscribers[name]) > 0, } - def add_subscriber(self, comm=None, name=None, worker=None, client=None): + def add_subscriber(self, name=None, worker=None, client=None): if worker: logger.debug("Add worker subscriber: %s %s", name, worker) self.subscribers[name].add(worker) @@ -63,7 +63,7 @@ def add_subscriber(self, comm=None, name=None, worker=None, client=None): ) self.client_subscribers[name].add(client) - def remove_publisher(self, comm=None, name=None, worker=None): + def remove_publisher(self, name=None, worker=None): if worker in self.publishers[name]: logger.debug("Remove publisher: %s %s", name, worker) self.publishers[name].remove(worker) @@ -72,7 +72,7 @@ def remove_publisher(self, comm=None, name=None, worker=None): del self.subscribers[name] del self.publishers[name] - def remove_subscriber(self, comm=None, name=None, worker=None, client=None): + def remove_subscriber(self, name=None, worker=None, client=None): if worker: logger.debug("Remove worker subscriber: %s %s", name, worker) self.subscribers[name].remove(worker) diff --git a/distributed/queues.py b/distributed/queues.py index 052d9791e72..b7eb37cd93e 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -44,7 +44,7 @@ def __init__(self, scheduler): self.scheduler.extensions["queues"] = self - def create(self, comm=None, name=None, client=None, maxsize=0): + def create(self, name=None, client=None, maxsize=0): logger.debug(f"Queue name: {name}") if name not in self.queues: self.queues[name] = asyncio.Queue(maxsize=maxsize) @@ -52,7 +52,7 @@ def create(self, comm=None, name=None, client=None, maxsize=0): else: self.client_refcount[name] += 1 - def release(self, comm=None, name=None, client=None): + def release(self, name=None, client=None): if name not in self.queues: return @@ -65,9 +65,7 @@ def release(self, comm=None, name=None, client=None): if keys: self.scheduler.client_releases_keys(keys=keys, client="queue-%s" % name) - async def put( - self, comm=None, name=None, key=None, data=None, client=None, timeout=None - ): + async def put(self, name=None, key=None, data=None, client=None, timeout=None): if key is not None: record = {"type": "Future", "value": key} self.future_refcount[name, key] += 1 @@ -82,7 +80,7 @@ def future_release(self, name=None, key=None, client=None): self.scheduler.client_releases_keys(keys=[key], client="queue-%s" % name) del self.future_refcount[name, key] - async def get(self, comm=None, name=None, client=None, timeout=None, batch=False): + async def get(self, name=None, client=None, timeout=None, batch=False): def process(record): """Add task status if known""" if record["type"] == "Future": @@ -122,7 +120,7 @@ def process(record): record = process(record) return record - def qsize(self, comm=None, name=None, client=None): + def qsize(self, name=None, client=None): return self.queues[name].qsize() diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 59c1d4a41cb..4bd84851f7c 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4037,7 +4037,7 @@ def _repr_html_(self): tasks=parent._tasks, ) - def identity(self, comm=None): + def identity(self): """Basic information about ourselves and our cluster""" parent: SchedulerState = cast(SchedulerState, self) d = { @@ -4173,7 +4173,7 @@ def del_scheduler_file(): setproctitle(f"dask-scheduler [{self.address}]") return self - async def close(self, comm=None, fast=False, close_workers=False): + async def close(self, fast=False, close_workers=False): """Send cleanup signal to all coroutines then wait until finished See Also @@ -4243,7 +4243,7 @@ async def close(self, comm=None, fast=False, close_workers=False): setproctitle("dask-scheduler [closed]") disable_gc_diagnosis() - async def close_worker(self, comm=None, worker=None, safe=None): + async def close_worker(self, worker: str, safe: bool = False): """Remove a worker from the cluster This both removes the worker from our local state and also sends a @@ -4975,7 +4975,7 @@ def stimulus_task_erred( **kwargs, ) - def stimulus_retry(self, comm=None, keys=None, client=None): + def stimulus_retry(self, keys, client=None): parent: SchedulerState = cast(SchedulerState, self) logger.info("Client %s requests to retry %d keys", client, len(keys)) if client: @@ -5005,7 +5005,7 @@ def stimulus_retry(self, comm=None, keys=None, client=None): return tuple(seen) - async def remove_worker(self, comm=None, address=None, safe=False, close=True): + async def remove_worker(self, address, safe=False, close=True): """ Remove worker from cluster @@ -5555,7 +5555,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): else: self.transitions({key: "forgotten"}) - def release_worker_data(self, comm=None, key=None, worker=None): + def release_worker_data(self, key, worker): parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState = parent._workers_dv.get(worker) ts: TaskState = parent._tasks.get(key) @@ -5771,9 +5771,7 @@ def remove_plugin( f"Could not find plugin {name!r} among the current scheduler plugins" ) - async def register_scheduler_plugin( - self, comm=None, plugin=None, name=None, idempotent=None - ): + async def register_scheduler_plugin(self, plugin, name=None, idempotent=None): """Register a plugin on the scheduler.""" if not dask.config.get("distributed.scheduler.pickle"): raise ValueError( @@ -5910,7 +5908,7 @@ async def scatter( ) return keys - async def gather(self, comm=None, keys=None, serializers=None): + async def gather(self, keys, serializers=None): """Collect data from workers to the scheduler""" parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState @@ -7027,7 +7025,7 @@ async def _track_retire_worker( logger.info("Retired worker %s", ws._address) return ws._address, ws.identity() - def add_keys(self, comm=None, worker=None, keys=(), stimulus_id=None): + def add_keys(self, worker=None, keys=(), stimulus_id=None): """ Learn that a worker has certain keys @@ -7063,12 +7061,10 @@ def add_keys(self, comm=None, worker=None, keys=(), stimulus_id=None): def update_data( self, - comm=None, *, who_has: dict, nbytes: dict, client=None, - serializers=None, ): """ Learn that new data has entered the network from an external source @@ -7176,7 +7172,7 @@ def subscribe_worker_status(self, comm=None): del v["last_seen"] return ident - def get_processing(self, comm=None, workers=None): + def get_processing(self, workers=None): parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState ts: TaskState @@ -7191,7 +7187,7 @@ def get_processing(self, comm=None, workers=None): for w, ws in parent._workers_dv.items() } - def get_who_has(self, comm=None, keys=None): + def get_who_has(self, keys=None): parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState ts: TaskState @@ -7208,7 +7204,7 @@ def get_who_has(self, comm=None, keys=None): for key, ts in parent._tasks.items() } - def get_has_what(self, comm=None, workers=None): + def get_has_what(self, workers=None): parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState ts: TaskState @@ -7226,7 +7222,7 @@ def get_has_what(self, comm=None, workers=None): for w, ws in parent._workers_dv.items() } - def get_ncores(self, comm=None, workers=None): + def get_ncores(self, workers=None): parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState if workers is not None: @@ -7239,7 +7235,7 @@ def get_ncores(self, comm=None, workers=None): else: return {w: ws._nthreads for w, ws in parent._workers_dv.items()} - def get_ncores_running(self, comm=None, workers=None): + def get_ncores_running(self, workers=None): parent: SchedulerState = cast(SchedulerState, self) ncores = self.get_ncores(workers=workers) return { @@ -7248,7 +7244,7 @@ def get_ncores_running(self, comm=None, workers=None): if parent._workers_dv[w].status == Status.running } - async def get_call_stack(self, comm=None, keys=None): + async def get_call_stack(self, keys=None): parent: SchedulerState = cast(SchedulerState, self) ts: TaskState dts: TaskState @@ -7279,7 +7275,7 @@ async def get_call_stack(self, comm=None, keys=None): response = {w: r for w, r in zip(workers, results) if r} return response - def get_nbytes(self, comm=None, keys=None, summary=True): + def get_nbytes(self, keys=None, summary=True): parent: SchedulerState = cast(SchedulerState, self) ts: TaskState with log_errors(): @@ -7298,7 +7294,7 @@ def get_nbytes(self, comm=None, keys=None, summary=True): return result - def run_function(self, stream, function, args=(), kwargs={}, wait=True): + def run_function(self, comm, function, args=(), kwargs=None, wait=True): """Run a function within this process See Also @@ -7313,11 +7309,11 @@ def run_function(self, stream, function, args=(), kwargs={}, wait=True): "deserializing arbitrary bytestrings using pickle via the " "'distributed.scheduler.pickle' configuration setting." ) - + kwargs = kwargs or {} self.log_event("all", {"action": "run-function", "function": function}) - return run(self, stream, function=function, args=args, kwargs=kwargs, wait=wait) + return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) - def set_metadata(self, comm=None, keys=None, value=None): + def set_metadata(self, keys=None, value=None): parent: SchedulerState = cast(SchedulerState, self) metadata = parent._task_metadata for key in keys[:-1]: @@ -7326,7 +7322,7 @@ def set_metadata(self, comm=None, keys=None, value=None): metadata = metadata[key] metadata[keys[-1]] = value - def get_metadata(self, comm=None, keys=None, default=no_default): + def get_metadata(self, keys, default=no_default): parent: SchedulerState = cast(SchedulerState, self) metadata = parent._task_metadata for key in keys[:-1]: @@ -7339,7 +7335,7 @@ def get_metadata(self, comm=None, keys=None, default=no_default): else: raise - def set_restrictions(self, comm=None, worker=None): + def set_restrictions(self, worker: "dict[str, Collection[str] | str]"): ts: TaskState for key, restrictions in worker.items(): ts = self.tasks[key] @@ -7347,7 +7343,7 @@ def set_restrictions(self, comm=None, worker=None): restrictions = {restrictions} ts._worker_restrictions = set(restrictions) - def get_task_prefix_states(self, comm=None): + def get_task_prefix_states(self): with log_errors(): state = {} @@ -7367,14 +7363,14 @@ def get_task_prefix_states(self, comm=None): return state - def get_task_status(self, comm=None, keys=None): + def get_task_status(self, keys=None): parent: SchedulerState = cast(SchedulerState, self) return { key: (parent._tasks[key].state if key in parent._tasks else None) for key in keys } - def get_task_stream(self, comm=None, start=None, stop=None, count=None): + def get_task_stream(self, start=None, stop=None, count=None): from distributed.diagnostics.task_stream import TaskStreamPlugin if TaskStreamPlugin.name not in self.plugins: @@ -7384,11 +7380,11 @@ def get_task_stream(self, comm=None, start=None, stop=None, count=None): return plugin.collect(start=start, stop=stop, count=count) - def start_task_metadata(self, comm=None, name=None): + def start_task_metadata(self, name=None): plugin = CollectTaskMetaDataPlugin(scheduler=self, name=name) self.add_plugin(plugin) - def stop_task_metadata(self, comm=None, name=None): + def stop_task_metadata(self, name=None): plugins = [ p for p in list(self.plugins.values()) @@ -7517,7 +7513,7 @@ def reschedule(self, key=None, worker=None): # Utility functions # ##################### - def add_resources(self, comm=None, worker=None, resources=None): + def add_resources(self, worker: str, resources=None): parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState = parent._workers_dv[worker] if resources: @@ -7582,7 +7578,7 @@ def workers_list(self, workers): out.update({ww for ww in parent._workers if w in ww}) # TODO: quadratic return list(out) - def start_ipython(self, comm=None): + def start_ipython(self): """Start an IPython kernel Returns Jupyter connection info dictionary. @@ -7683,7 +7679,7 @@ async def get_profile_metadata( return {"counts": counts, "keys": keys} async def performance_report( - self, comm=None, start=None, last_count=None, code="", mode=None + self, start: float, last_count: int, code="", mode=None ): parent: SchedulerState = cast(SchedulerState, self) stop = time() @@ -7709,9 +7705,9 @@ def profile_to_figure(state): # Task stream task_stream = self.get_task_stream(start=start) total_tasks = len(task_stream) - timespent = defaultdict(int) + timespent: "defaultdict[str, float]" = defaultdict(float) for d in task_stream: - for x in d.get("startstops", []): + for x in d["startstops"]: timespent[x["action"]] += x["stop"] - x["start"] tasks_timings = "" for k in sorted(timespent.keys()): @@ -7850,7 +7846,7 @@ def profile_to_figure(state): return data - async def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False): + async def get_worker_logs(self, n=None, workers=None, nanny=False): results = await self.broadcast( msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny ) @@ -7885,7 +7881,7 @@ def subscribe_topic(self, topic, client): def unsubscribe_topic(self, topic, client): self.event_subscriber[topic].discard(client) - def get_events(self, comm=None, topic=None): + def get_events(self, topic=None): if topic is not None: return tuple(self.events[topic]) else: @@ -7990,7 +7986,7 @@ def check_idle(self): ) self.loop.add_callback(self.close) - def adaptive_target(self, comm=None, target_duration=None): + def adaptive_target(self, target_duration=None): """Desired number of workers based on the current workload This looks at the current running tasks and memory use, and returns a diff --git a/distributed/semaphore.py b/distributed/semaphore.py index e3245848248..f0d255000ce 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -92,11 +92,11 @@ def __init__(self, scheduler): dask.config.get("distributed.scheduler.locks.lease-timeout"), default="s" ) - async def get_value(self, comm=None, name=None): + async def get_value(self, name=None): return len(self.leases[name]) # `comm` here is required by the handler interface - def create(self, comm=None, name=None, max_leases=None): + def create(self, name=None, max_leases=None): # We use `self.max_leases` as the point of truth to find out if a semaphore with a specific # `name` has been created. if name not in self.max_leases: @@ -109,7 +109,7 @@ def create(self, comm=None, name=None, max_leases=None): % (max_leases, self.max_leases[name]) ) - def refresh_leases(self, comm=None, name=None, lease_ids=None): + def refresh_leases(self, name=None, lease_ids=None): with log_errors(): now = time() logger.debug( @@ -145,7 +145,7 @@ def _semaphore_exists(self, name): return False return True - async def acquire(self, comm=None, name=None, timeout=None, lease_id=None): + async def acquire(self, name=None, timeout=None, lease_id=None): with log_errors(): if not self._semaphore_exists(name): raise RuntimeError(f"Semaphore `{name}` not known or already closed.") @@ -195,7 +195,7 @@ async def acquire(self, comm=None, name=None, timeout=None, lease_id=None): return result - def release(self, comm=None, name=None, lease_id=None): + def release(self, name=None, lease_id=None): with log_errors(): if not self._semaphore_exists(name): logger.warning( @@ -242,7 +242,7 @@ def _check_lease_timeout(self): ) self._release_value(name=name, lease_id=_id) - def close(self, comm=None, name=None): + def close(self, name=None): """Hard close the semaphore without warning clients which still hold a lease.""" with log_errors(): if not self._semaphore_exists(name): diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 5567b9bd688..2cae8e4ea39 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -13,6 +13,7 @@ ConnectionPool, Server, Status, + _expects_comm, clean_exception, coerce_to_address, connect, @@ -835,7 +836,7 @@ async def test_deserialize_error(): comm = await connect(server.address, deserialize=False) with pytest.raises(Exception) as info: - await send_recv(comm, op="throws") + await send_recv(comm, op="throws", x="foo") assert type(info.value) == Exception for c in str(info.value): @@ -994,3 +995,58 @@ async def long_handler(comm, delay=10): await asyncio.wait_for(fut, 0.5) await comm.close() await server.close() + + +def test_expects_comm(): + class A: + def empty(self): + ... + + def one_arg(self, arg): + ... + + def comm_arg(self, comm): + ... + + def stream_arg(self, stream): + ... + + def two_arg(self, arg, other): + ... + + def comm_arg_other(self, comm, other): + ... + + def stream_arg_other(self, stream, other): + ... + + def arg_kwarg(self, arg, other=None): + ... + + def comm_posarg_only(self, comm, /, other): + ... + + def comm_not_leading_position(self, other, comm): + ... + + def stream_not_leading_position(self, other, stream): + ... + + expected_warning = "first arugment of a RPC handler `stream` is deprecated" + + instance = A() + + assert not _expects_comm(instance.empty) + assert not _expects_comm(instance.one_arg) + assert _expects_comm(instance.comm_arg) + with pytest.warns(FutureWarning, match=expected_warning): + assert _expects_comm(instance.stream_arg) + assert not _expects_comm(instance.two_arg) + assert _expects_comm(instance.comm_arg_other) + with pytest.warns(FutureWarning, match=expected_warning): + assert _expects_comm(instance.stream_arg_other) + assert not _expects_comm(instance.arg_kwarg) + assert _expects_comm(instance.comm_posarg_only) + assert not _expects_comm(instance.comm_not_leading_position) + + assert not _expects_comm(instance.stream_not_leading_position) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 0f735dece79..24b7f790c68 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -99,7 +99,7 @@ async def test_str(s, a, b): @gen_cluster(nthreads=[]) async def test_identity(s): async with Worker(s.address) as w: - ident = w.identity(None) + ident = w.identity() assert "Worker" in ident["type"] assert ident["scheduler"] == s.address assert isinstance(ident["nthreads"], int) diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 1fdc0e5af07..4acea7e8b10 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -118,7 +118,7 @@ def __repr__(self): _round_robin_counter = [0] -async def scatter_to_workers(nthreads, data, rpc=rpc, report=True, serializers=None): +async def scatter_to_workers(nthreads, data, rpc=rpc, report=True): """Scatter data directly to workers This distributes data in a round-robin fashion to a set of workers based on @@ -145,7 +145,8 @@ async def scatter_to_workers(nthreads, data, rpc=rpc, report=True, serializers=N out = await All( [ rpcs[address].update_data( - data=v, report=report, serializers=serializers + data=v, + report=report, ) for address, v in d.items() ] diff --git a/distributed/variable.py b/distributed/variable.py index 13a5bb5e8a3..5206dc6254f 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -42,7 +42,7 @@ def __init__(self, scheduler): self.scheduler.extensions["variables"] = self - async def set(self, comm=None, name=None, key=None, data=None, client=None): + async def set(self, name=None, key=None, data=None, client=None): if key is not None: record = {"type": "Future", "value": key} self.scheduler.client_desires_keys(keys=[key], client="variable-%s" % name) @@ -74,7 +74,7 @@ async def future_release(self, name=None, key=None, token=None, client=None): async with self.waiting_conditions[name]: self.waiting_conditions[name].notify_all() - async def get(self, comm=None, name=None, client=None, timeout=None): + async def get(self, name=None, client=None, timeout=None): start = time() while name not in self.variables: if timeout is not None: @@ -107,7 +107,7 @@ async def _(): # Python 3.6 is odd and requires special help here self.waiting[key, name].add(token) return record - async def delete(self, comm=None, name=None, client=None): + async def delete(self, name=None, client=None): with log_errors(): try: old = self.variables[name] diff --git a/distributed/worker.py b/distributed/worker.py index 6ed8cb5c5dd..5ed20c226cd 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1208,7 +1208,7 @@ async def get_startup_information(self): return result - def identity(self, comm=None): + def identity(self): return { "type": type(self).__name__, "id": self.id, @@ -1452,10 +1452,10 @@ def func(data): return {"status": "OK", "nbytes": len(data)} - def keys(self, comm=None): + def keys(self): return list(self.data) - async def gather(self, comm=None, who_has=None): + async def gather(self, who_has: dict[str, list[str]]): who_has = { k: [coerce_to_address(addr) for addr in v] for k, v in who_has.items() @@ -1476,7 +1476,7 @@ async def gather(self, comm=None, who_has=None): else: return {"status": "OK"} - def get_monitor_info(self, comm=None, recent=False, start=0): + def get_monitor_info(self, recent=False, start=0): result = dict( range_query=( self.monitor.recent() @@ -1751,7 +1751,7 @@ async def close_gracefully(self, restart=None): ) await self.close(safe=True, nanny=not restart) - async def terminate(self, comm=None, report=True, **kwargs): + async def terminate(self, report: bool = True, **kwargs): await self.close(report=report, **kwargs) return "OK" @@ -1877,11 +1877,14 @@ async def get_data( ################### def update_data( - self, comm=None, data=None, report=True, serializers=None, stimulus_id=None + self, + data: dict[str, object], + report: bool = True, + stimulus_id: str = None, ): if stimulus_id is None: stimulus_id = f"update-data-{time()}" - recommendations = {} + recommendations: dict[TaskState, tuple] = {} scheduler_messages = [] for key, value in data.items(): try: @@ -1910,7 +1913,7 @@ def update_data( self.batched_stream.send(msg) return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} - def handle_free_keys(self, comm=None, keys=None, stimulus_id=None): + def handle_free_keys(self, keys=None, stimulus_id=None): """ Handler to be called by the scheduler. @@ -2010,7 +2013,6 @@ def handle_cancel_compute(self, key, reason): def handle_acquire_replicas( self, - comm=None, *, keys: Collection[str], who_has: dict[str, Collection[str]], @@ -3361,7 +3363,7 @@ def run(self, comm, function, args=(), wait=True, kwargs=None): def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True): return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) - async def plugin_add(self, comm=None, plugin=None, name=None, catch_errors=True): + async def plugin_add(self, plugin=None, name=None, catch_errors=True): with log_errors(pdb=False): if isinstance(plugin, bytes): plugin = pickle.loads(plugin) @@ -3372,7 +3374,7 @@ async def plugin_add(self, comm=None, plugin=None, name=None, catch_errors=True) assert name if name in self.plugins: - await self.plugin_remove(comm=comm, name=name) + await self.plugin_remove(name=name) self.plugins[name] = plugin @@ -3390,7 +3392,7 @@ async def plugin_add(self, comm=None, plugin=None, name=None, catch_errors=True) return {"status": "OK"} - async def plugin_remove(self, comm=None, name=None): + async def plugin_remove(self, name=None): with log_errors(pdb=False): logger.info(f"Removing Worker plugin {name}") try: @@ -3407,7 +3409,6 @@ async def plugin_remove(self, comm=None, name=None): async def actor_execute( self, - comm=None, actor=None, function=None, args=(), @@ -3441,7 +3442,7 @@ async def actor_execute( except Exception as ex: return {"status": "error", "exception": to_serialize(ex)} - def actor_attribute(self, comm=None, actor=None, attribute=None): + def actor_attribute(self, actor=None, attribute=None): try: value = getattr(self.actors[actor], attribute) return {"status": "OK", "result": to_serialize(value)} @@ -3832,9 +3833,7 @@ def trigger_profile(self): if self.digests is not None: self.digests["profile-duration"].add(stop - start) - async def get_profile( - self, comm=None, start=None, stop=None, key=None, server=False - ): + async def get_profile(self, start=None, stop=None, key=None, server=False): now = time() + self.scheduler_delay if server: history = self.io_loop.profile @@ -3875,7 +3874,7 @@ async def get_profile( return prof - async def get_profile_metadata(self, comm=None, start=0, stop=None): + async def get_profile_metadata(self, start=0, stop=None): add_recent = stop is None now = time() + self.scheduler_delay stop = stop or now @@ -3897,7 +3896,7 @@ async def get_profile_metadata(self, comm=None, start=0, stop=None): ) return result - def get_call_stack(self, comm=None, keys=None): + def get_call_stack(self, keys=None): with self.active_threads_lock: frames = sys._current_frames() active_threads = self.active_threads.copy()