diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4bd84851f7c..8b6af14a8d0 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2848,7 +2848,7 @@ def transition_processing_memory( { "op": "cancel-compute", "key": key, - "reason": "Finished on different worker", + "stimulus_id": f"processing-memory-{time()}", } ] @@ -7629,12 +7629,10 @@ async def get_profile( async def get_profile_metadata( self, - comm=None, - workers=None, - merge_workers=True, - start=None, - stop=None, - profile_cycle_interval=None, + workers: "Iterable[str] | None" = None, + start: float = 0, + stop: "float | None" = None, + profile_cycle_interval: "str | float | None" = None, ): parent: SchedulerState = cast(SchedulerState, self) dt = profile_cycle_interval or dask.config.get( @@ -7652,16 +7650,19 @@ async def get_profile_metadata( ) results = [r for r in results if not isinstance(r, Exception)] - counts = [v["counts"] for v in results] - counts = itertools.groupby(merge_sorted(*counts), lambda t: t[0] // dt * dt) - counts = [(time, sum(pluck(1, group))) for time, group in counts] - - keys = set() - for v in results: - for t, d in v["keys"]: - for k in d: - keys.add(k) - keys = {k: [] for k in keys} + counts = [ + (time, sum(pluck(1, group))) + for time, group in itertools.groupby( + merge_sorted( + *(v["counts"] for v in results), + ), + lambda t: t[0] // dt * dt, + ) + ] + + keys: dict[str, list[list]] = { + k: [] for v in results for t, d in v["keys"] for k in d + } groups1 = [v["keys"] for v in results] groups2 = list(merge_sorted(*groups1, key=first)) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 4612d9179d6..5137d56fcc0 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -859,10 +859,10 @@ async def test_missing_data_heals(c, s, a, b): # Secretly delete y's key if y.key in a.data: del a.data[y.key] - a.release_key(y.key) + a.release_key(y.key, stimulus_id="test") if y.key in b.data: del b.data[y.key] - b.release_key(y.key) + b.release_key(y.key, stimulus_id="test") await asyncio.sleep(0) w = c.submit(add, y, z) @@ -884,7 +884,7 @@ async def test_gather_robust_to_missing_data(c, s, a, b): if f.key in w.data: del w.data[f.key] await asyncio.sleep(0) - w.release_key(f.key) + w.release_key(f.key, stimulus_id="test") xx, yy, zz = await c.gather([x, y, z]) assert (xx, yy, zz) == (1, 2, 3) @@ -907,7 +907,7 @@ async def test_gather_robust_to_nested_missing_data(c, s, a, b): if datum.key in worker.data: del worker.data[datum.key] await asyncio.sleep(0) - worker.release_key(datum.key) + worker.release_key(datum.key, stimulus_id="test") result = await c.gather([z]) diff --git a/distributed/worker.py b/distributed/worker.py index 443b0364bca..eb0d07c6505 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -12,7 +12,7 @@ import threading import warnings import weakref -from collections import defaultdict, deque, namedtuple +from collections import defaultdict, deque from collections.abc import ( Callable, Collection, @@ -27,13 +27,7 @@ from datetime import timedelta from inspect import isawaitable from pickle import PicklingError -from typing import TYPE_CHECKING, Any, ClassVar, Literal - -if TYPE_CHECKING: - from .diagnostics.plugin import WorkerPlugin - from .actor import Actor - from .client import Client - from .nanny import Nanny +from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict, cast from tlz import first, keymap, merge, pluck # noqa: F401 from tornado.ioloop import IOLoop, PeriodicCallback @@ -101,6 +95,19 @@ from .utils_perf import ThrottledGC, disable_gc_diagnosis, enable_gc_diagnosis from .versions import get_versions +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from .actor import Actor + from .client import Client + from .diagnostics.plugin import WorkerPlugin + from .nanny import Nanny + + # {TaskState -> finish: str | (finish: str, *args)} + Recs: TypeAlias = "dict[TaskState, str | tuple]" + Smsgs: TypeAlias = "list[dict[str, Any]]" + + logger = logging.getLogger(__name__) LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") @@ -129,7 +136,19 @@ dask.config.get("distributed.scheduler.default-data-size") ) -SerializedTask = namedtuple("SerializedTask", ["function", "args", "kwargs", "task"]) + +class SerializedTask(NamedTuple): + function: Callable + args: tuple + kwargs: dict[str, Any] + task: object # distributed.scheduler.TaskState.run_spec + + +class StartStop(TypedDict, total=False): + action: str + start: float + stop: float + source: str # optional class InvalidTransition(Exception): @@ -194,24 +213,24 @@ class TaskState: """ key: str - run_spec: object + run_spec: SerializedTask | None dependencies: set[TaskState] dependents: set[TaskState] duration: float | None priority: tuple[int, ...] | None state: str who_has: set[str] - coming_from: str + coming_from: str | None waiting_for_data: set[TaskState] waiters: set[TaskState] - resource_restrictions: dict + resource_restrictions: dict[str, float] exception: Exception | None exception_text: str | None traceback: object | None traceback_text: str | None type: type | None suspicious_count: int - startstops: list[dict] + startstops: list[StartStop] start_time: float | None stop_time: float | None metadata: dict @@ -221,7 +240,7 @@ class TaskState: _previous: str | None _next: str | None - def __init__(self, key, run_spec=None): + def __init__(self, key: str, run_spec: SerializedTask | None = None): assert key is not None self.key = key self.run_spec = run_spec @@ -251,7 +270,7 @@ def __init__(self, key, run_spec=None): self._previous = None self._next = None - def __repr__(self): + def __repr__(self) -> str: return f"" def get_nbytes(self) -> int: @@ -543,13 +562,13 @@ class Worker(ServerNode): profile_recent: dict[str, Any] profile_history: deque[tuple[float, dict[str, Any]]] generation: int - ready: list[str] + ready: list[tuple[tuple[int, ...], str]] # heapq [(priority, key), ...] constrained: deque[str] _executing: set[TaskState] _in_flight_tasks: set[TaskState] executed_count: int - long_running: set[TaskState] - log: deque[tuple] + long_running: set[str] + log: deque[tuple] # [(..., stimulus_id: str | None, timestamp: float), ...] incoming_transfer_log: deque[dict[str, Any]] outgoing_transfer_log: deque[dict[str, Any]] target_message_size: int @@ -1032,7 +1051,8 @@ def __init__( ) self.periodic_callbacks["keep-alive"] = pc - pc = PeriodicCallback(self.find_missing, 1000) + # FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117 + pc = PeriodicCallback(self.find_missing, 1000) # type: ignore self.periodic_callbacks["find-missing"] = pc self._address = contact_address @@ -1044,7 +1064,8 @@ def __init__( if self.memory_limit: assert self.memory_monitor_interval is not None pc = PeriodicCallback( - self.memory_monitor, self.memory_monitor_interval * 1000 + self.memory_monitor, # type: ignore + self.memory_monitor_interval * 1000, ) self.periodic_callbacks["memory"] = pc @@ -1452,10 +1473,10 @@ def func(data): return {"status": "OK", "nbytes": len(data)} - def keys(self): + def keys(self) -> list[str]: return list(self.data) - async def gather(self, who_has: dict[str, list[str]]): + async def gather(self, who_has: dict[str, list[str]]) -> dict[str, Any]: who_has = { k: [coerce_to_address(addr) for addr in v] for k, v in who_has.items() @@ -1476,7 +1497,9 @@ async def gather(self, who_has: dict[str, list[str]]): else: return {"status": "OK"} - def get_monitor_info(self, recent=False, start=0): + def get_monitor_info( + self, recent: bool = False, start: float = 0 + ) -> dict[str, Any]: result = dict( range_query=( self.monitor.recent() @@ -1751,7 +1774,7 @@ async def close_gracefully(self, restart=None): ) await self.close(safe=True, nanny=not restart) - async def terminate(self, report: bool = True, **kwargs): + async def terminate(self, report: bool = True, **kwargs) -> str: await self.close(report=report, **kwargs) return "OK" @@ -1881,10 +1904,10 @@ def update_data( data: dict[str, object], report: bool = True, stimulus_id: str = None, - ): + ) -> dict[str, Any]: if stimulus_id is None: stimulus_id = f"update-data-{time()}" - recommendations: dict[TaskState, tuple] = {} + recommendations: Recs = {} scheduler_messages = [] for key, value in data.items(): try: @@ -1913,7 +1936,7 @@ def update_data( self.batched_stream.send(msg) return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} - def handle_free_keys(self, keys=None, stimulus_id=None): + def handle_free_keys(self, keys: list[str], stimulus_id: str) -> None: """ Handler to be called by the scheduler. @@ -1925,7 +1948,7 @@ def handle_free_keys(self, keys=None, stimulus_id=None): upstream dependency. """ self.log.append(("free-keys", keys, stimulus_id, time())) - recommendations = {} + recommendations: Recs = {} for key in keys: ts = self.tasks.get(key) if ts: @@ -1933,7 +1956,7 @@ def handle_free_keys(self, keys=None, stimulus_id=None): self.transitions(recommendations, stimulus_id=stimulus_id) - def handle_remove_replicas(self, keys, stimulus_id): + def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str: """Stream handler notifying the worker that it might be holding unreferenced, superfluous data. @@ -1953,7 +1976,7 @@ def handle_remove_replicas(self, keys, stimulus_id): For stronger guarantees, see handler free_keys """ self.log.append(("remove-replicas", keys, stimulus_id, time())) - recommendations = {} + recommendations: Recs = {} rejected = [] for key in keys: @@ -1978,7 +2001,7 @@ def handle_remove_replicas(self, keys, stimulus_id): return "OK" - async def set_resources(self, **resources): + async def set_resources(self, **resources) -> None: for r, quantity in resources.items(): if r in self.total_resources: self.available_resources[r] += quantity - self.total_resources[r] @@ -1996,7 +2019,7 @@ async def set_resources(self, **resources): # Task Management # ################### - def handle_cancel_compute(self, key, reason): + def handle_cancel_compute(self, key: str, stimulus_id: str) -> None: """ Cancel a task on a best effort basis. This is only possible while a task is in state `waiting` or `ready`. @@ -2004,12 +2027,12 @@ def handle_cancel_compute(self, key, reason): """ ts = self.tasks.get(key) if ts and ts.state in READY | {"waiting"}: - self.log.append((key, "cancel-compute", reason, time())) + self.log.append((key, "cancel-compute", stimulus_id, time())) # All possible dependents of TS should not be in state Processing on # scheduler side and therefore should not be assigned to a worker, # yet. assert not ts.dependents - self.transition(ts, "released", stimulus_id=reason) + self.transition(ts, "released", stimulus_id=stimulus_id) def handle_acquire_replicas( self, @@ -2017,8 +2040,8 @@ def handle_acquire_replicas( keys: Collection[str], who_has: dict[str, Collection[str]], stimulus_id: str, - ): - recommendations = {} + ) -> None: + recommendations: Recs = {} for key in keys: ts = self.ensure_task_exists( key=key, @@ -2060,13 +2083,13 @@ def handle_compute_task( function=None, args=None, kwargs=None, - task=no_value, + task=no_value, # distributed.scheduler.TaskState.run_spec nbytes: dict[str, int] | None = None, - resource_restrictions=None, + resource_restrictions: dict[str, float] | None = None, actor: bool = False, - annotations=None, + annotations: dict | None = None, stimulus_id: str, - ): + ) -> None: self.log.append((key, "compute-task", stimulus_id, time())) try: ts = self.tasks[key] @@ -2096,8 +2119,8 @@ def handle_compute_task( ts.resource_restrictions = resource_restrictions ts.annotations = annotations - recommendations = {} - scheduler_msgs = [] + recommendations: Recs = {} + scheduler_msgs: Smsgs = [] for dependency in who_has: dep_ts = self.ensure_task_exists( key=dependency, @@ -2136,7 +2159,9 @@ def handle_compute_task( for key, value in nbytes.items(): self.tasks[key].nbytes = value - def transition_missing_fetch(self, ts, *, stimulus_id): + def transition_missing_fetch( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: if self.validate: assert ts.state == "missing" assert ts.priority is not None @@ -2147,22 +2172,26 @@ def transition_missing_fetch(self, ts, *, stimulus_id): self.data_needed.push(ts) return {}, [] - def transition_missing_released(self, ts, *, stimulus_id): + def transition_missing_released( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: self._missing_dep_flight.discard(ts) - recommendations, smsgs = self.transition_generic_released( - ts, stimulus_id=stimulus_id - ) + recs, smsgs = self.transition_generic_released(ts, stimulus_id=stimulus_id) assert ts.key in self.tasks - return recommendations, smsgs + return recs, smsgs - def transition_flight_missing(self, ts, *, stimulus_id): + def transition_flight_missing( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: assert ts.done ts.state = "missing" self._missing_dep_flight.add(ts) ts.done = False return {}, [] - def transition_released_fetch(self, ts, *, stimulus_id): + def transition_released_fetch( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: if self.validate: assert ts.state == "released" assert ts.priority is not None @@ -2173,9 +2202,11 @@ def transition_released_fetch(self, ts, *, stimulus_id): self.data_needed.push(ts) return {}, [] - def transition_generic_released(self, ts, *, stimulus_id): - self.release_key(ts.key, reason=stimulus_id) - recs = {} + def transition_generic_released( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: + self.release_key(ts.key, stimulus_id=stimulus_id) + recs: Recs = {} for dependency in ts.dependencies: if ( not dependency.waiters @@ -2188,12 +2219,14 @@ def transition_generic_released(self, ts, *, stimulus_id): return recs, [] - def transition_released_waiting(self, ts, *, stimulus_id): + def transition_released_waiting( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: if self.validate: assert ts.state == "released" assert all(d.key in self.tasks for d in ts.dependencies) - recommendations = {} + recommendations: Recs = {} ts.waiting_for_data.clear() for dep_ts in ts.dependencies: if not dep_ts.state == "memory": @@ -2212,7 +2245,9 @@ def transition_released_waiting(self, ts, *, stimulus_id): ts.state = "waiting" return recommendations, [] - def transition_fetch_flight(self, ts, worker, *, stimulus_id): + def transition_fetch_flight( + self, ts: TaskState, worker, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: if self.validate: assert ts.state == "fetch" assert ts.who_has @@ -2223,12 +2258,16 @@ def transition_fetch_flight(self, ts, worker, *, stimulus_id): self._in_flight_tasks.add(ts) return {}, [] - def transition_memory_released(self, ts, *, stimulus_id): + def transition_memory_released( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: recs, smsgs = self.transition_generic_released(ts, stimulus_id=stimulus_id) smsgs.append({"op": "release-worker-data", "key": ts.key}) return recs, smsgs - def transition_waiting_constrained(self, ts, *, stimulus_id): + def transition_waiting_constrained( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: if self.validate: assert ts.state == "waiting" assert not ts.waiting_for_data @@ -2242,41 +2281,55 @@ def transition_waiting_constrained(self, ts, *, stimulus_id): self.constrained.append(ts.key) return {}, [] - def transition_long_running_rescheduled(self, ts, *, stimulus_id): - recs = {ts: "released"} + def transition_long_running_rescheduled( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: + recs: Recs = {ts: "released"} smsgs = [{"op": "reschedule", "key": ts.key, "worker": self.address}] return recs, smsgs - def transition_executing_rescheduled(self, ts, *, stimulus_id): + def transition_executing_rescheduled( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity self._executing.discard(ts) - recs = {ts: "released"} - smsgs = [{"op": "reschedule", "key": ts.key, "worker": self.address}] + recs: Recs = {ts: "released"} + smsgs: Smsgs = [{"op": "reschedule", "key": ts.key, "worker": self.address}] return recs, smsgs - def transition_waiting_ready(self, ts, *, stimulus_id): + def transition_waiting_ready( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: if self.validate: assert ts.state == "waiting" assert ts.key not in self.ready assert not ts.waiting_for_data - assert ts.priority is not None for dep in ts.dependencies: assert dep.key in self.data or dep.key in self.actors assert dep.state == "memory" ts.state = "ready" + assert ts.priority is not None heapq.heappush(self.ready, (ts.priority, ts.key)) return {}, [] def transition_cancelled_error( - self, ts, exception, traceback, exception_text, traceback_text, *, stimulus_id - ): - recs, msgs = {}, [] + self, + ts: TaskState, + exception, + traceback, + exception_text, + traceback_text, + *, + stimulus_id: str, + ) -> tuple[Recs, Smsgs]: + recs: Recs = {} + smsgs: Smsgs = [] if ts._previous == "executing": - recs, msgs = self.transition_executing_error( + recs, smsgs = self.transition_executing_error( ts, exception, traceback, @@ -2285,7 +2338,7 @@ def transition_cancelled_error( stimulus_id=stimulus_id, ) elif ts._previous == "flight": - recs, msgs = self.transition_flight_error( + recs, smsgs = self.transition_flight_error( ts, exception, traceback, @@ -2295,11 +2348,18 @@ def transition_cancelled_error( ) if ts._next: recs[ts] = ts._next - return recs, msgs + return recs, smsgs def transition_generic_error( - self, ts, exception, traceback, exception_text, traceback_text, *, stimulus_id - ): + self, + ts: TaskState, + exception, + traceback, + exception_text, + traceback_text, + *, + stimulus_id: str, + ) -> tuple[Recs, Smsgs]: ts.exception = exception ts.traceback = traceback ts.exception_text = exception_text @@ -2322,8 +2382,15 @@ def transition_generic_error( return {}, [smsg] def transition_executing_error( - self, ts, exception, traceback, exception_text, traceback_text, *, stimulus_id - ): + self, + ts: TaskState, + exception, + traceback, + exception_text, + traceback_text, + *, + stimulus_id: str, + ) -> tuple[Recs, Smsgs]: for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity self._executing.discard(ts) @@ -2336,7 +2403,9 @@ def transition_executing_error( stimulus_id=stimulus_id, ) - def _transition_from_resumed(self, ts, finish, *, stimulus_id): + def _transition_from_resumed( + self, ts: TaskState, finish: str, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: """`resumed` is an intermediate degenerate state which splits further up into two states depending on what the last signal / next state is intended to be. There are only two viable choices depending on whether @@ -2356,7 +2425,8 @@ def _transition_from_resumed(self, ts, finish, *, stimulus_id): See also `transition_resumed_waiting` """ - recs, smsgs = {}, [] + recs: Recs = {} + smsgs: Smsgs = [] if ts.done: next_state = ts._next # if the next state is already intended to be waiting or if the @@ -2365,30 +2435,37 @@ def _transition_from_resumed(self, ts, finish, *, stimulus_id): recs, smsgs = self.transition_generic_released( ts, stimulus_id=stimulus_id ) + assert next_state recs[ts] = next_state else: ts._next = finish return recs, smsgs - def transition_resumed_fetch(self, ts, *, stimulus_id): + def transition_resumed_fetch( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: """ See Worker._transition_from_resumed """ return self._transition_from_resumed(ts, "fetch", stimulus_id=stimulus_id) - def transition_resumed_missing(self, ts, *, stimulus_id): + def transition_resumed_missing( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: """ See Worker._transition_from_resumed """ return self._transition_from_resumed(ts, "missing", stimulus_id=stimulus_id) - def transition_resumed_waiting(self, ts, *, stimulus_id): + def transition_resumed_waiting(self, ts: TaskState, *, stimulus_id: str): """ See Worker._transition_from_resumed """ return self._transition_from_resumed(ts, "waiting", stimulus_id=stimulus_id) - def transition_cancelled_fetch(self, ts, *, stimulus_id): + def transition_cancelled_fetch( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: if ts.done: return {ts: "released"}, [] elif ts._previous == "flight": @@ -2398,12 +2475,16 @@ def transition_cancelled_fetch(self, ts, *, stimulus_id): assert ts._previous == "executing" return {ts: ("resumed", "fetch")}, [] - def transition_cancelled_resumed(self, ts, next, *, stimulus_id): + def transition_cancelled_resumed( + self, ts: TaskState, next: str, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: ts._next = next ts.state = "resumed" return {}, [] - def transition_cancelled_waiting(self, ts, *, stimulus_id): + def transition_cancelled_waiting( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: if ts.done: return {ts: "released"}, [] elif ts._previous == "executing": @@ -2413,30 +2494,35 @@ def transition_cancelled_waiting(self, ts, *, stimulus_id): assert ts._previous == "flight" return {ts: ("resumed", "waiting")}, [] - def transition_cancelled_forgotten(self, ts, *, stimulus_id): + def transition_cancelled_forgotten( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: ts._next = "forgotten" if not ts.done: return {}, [] return {ts: "released"}, [] - def transition_cancelled_released(self, ts, *, stimulus_id): + def transition_cancelled_released( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: if not ts.done: ts._next = "released" return {}, [] next_state = ts._next + assert next_state self._executing.discard(ts) self._in_flight_tasks.discard(ts) for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity - recommendations, smsgs = self.transition_generic_released( - ts, stimulus_id=stimulus_id - ) + recs, smsgs = self.transition_generic_released(ts, stimulus_id=stimulus_id) if next_state != "released": - recommendations[ts] = next_state - return recommendations, smsgs + recs[ts] = next_state + return recs, smsgs - def transition_executing_released(self, ts, *, stimulus_id): + def transition_executing_released( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: ts._previous = ts.state ts._next = "released" # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 @@ -2444,11 +2530,15 @@ def transition_executing_released(self, ts, *, stimulus_id): ts.done = False return {}, [] - def transition_long_running_memory(self, ts, value=no_value, *, stimulus_id): + def transition_long_running_memory( + self, ts: TaskState, value=no_value, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: self.executed_count += 1 return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) - def transition_generic_memory(self, ts, value=no_value, *, stimulus_id): + def transition_generic_memory( + self, ts: TaskState, value=no_value, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: if value is no_value and ts.key not in self.data: raise RuntimeError( f"Tried to transition task {ts} to `memory` without data available" @@ -2471,7 +2561,9 @@ def transition_generic_memory(self, ts, value=no_value, *, stimulus_id): smsgs = [self._get_task_finished_msg(ts)] return recs, smsgs - def transition_executing_memory(self, ts, value=no_value, *, stimulus_id): + def transition_executing_memory( + self, ts: TaskState, value=no_value, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: if self.validate: assert ts.state == "executing" or ts.key in self.long_running assert not ts.waiting_for_data @@ -2481,7 +2573,9 @@ def transition_executing_memory(self, ts, value=no_value, *, stimulus_id): self.executed_count += 1 return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) - def transition_constrained_executing(self, ts, *, stimulus_id): + def transition_constrained_executing( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: if self.validate: assert not ts.waiting_for_data assert ts.key not in self.data @@ -2497,7 +2591,9 @@ def transition_constrained_executing(self, ts, *, stimulus_id): self.loop.add_callback(self.execute, ts.key, stimulus_id=stimulus_id) return {}, [] - def transition_ready_executing(self, ts, *, stimulus_id): + def transition_ready_executing( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: if self.validate: assert not ts.waiting_for_data assert ts.key not in self.data @@ -2513,28 +2609,37 @@ def transition_ready_executing(self, ts, *, stimulus_id): self.loop.add_callback(self.execute, ts.key, stimulus_id=stimulus_id) return {}, [] - def transition_flight_fetch(self, ts, *, stimulus_id): + def transition_flight_fetch( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: # If this transition is called after the flight coroutine has finished, # we can reset the task and transition to fetch again. If it is not yet # finished, this should be a no-op - if ts.done: - recommendations = {} - ts.state = "fetch" - ts.coming_from = None - ts.done = False - if not ts.who_has: - recommendations[ts] = "missing" - else: - self.data_needed.push(ts) - for w in ts.who_has: - self.pending_data_per_worker[w].push(ts) - return recommendations, [] - else: + if not ts.done: return {}, [] + recommendations: Recs = {} + ts.state = "fetch" + ts.coming_from = None + ts.done = False + if not ts.who_has: + recommendations[ts] = "missing" + else: + self.data_needed.push(ts) + for w in ts.who_has: + self.pending_data_per_worker[w].push(ts) + return recommendations, [] + def transition_flight_error( - self, ts, exception, traceback, exception_text, traceback_text, *, stimulus_id - ): + self, + ts: TaskState, + exception, + traceback, + exception_text, + traceback_text, + *, + stimulus_id: str, + ) -> tuple[Recs, Smsgs]: self._in_flight_tasks.discard(ts) ts.coming_from = None return self.transition_generic_error( @@ -2546,7 +2651,9 @@ def transition_flight_error( stimulus_id=stimulus_id, ) - def transition_flight_released(self, ts, *, stimulus_id): + def transition_flight_released( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: if ts.done: # FIXME: Is this even possible? Would an assert instead be more # sensible? @@ -2558,10 +2665,15 @@ def transition_flight_released(self, ts, *, stimulus_id): ts.state = "cancelled" return {}, [] - def transition_cancelled_memory(self, ts, value, *, stimulus_id): + def transition_cancelled_memory( + self, ts: TaskState, value, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: + assert ts._next return {ts: ts._next}, [] - def transition_executing_long_running(self, ts, compute_duration, *, stimulus_id): + def transition_executing_long_running( + self, ts: TaskState, compute_duration, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: ts.state = "long-running" self._executing.discard(ts) self.long_running.add(ts.key) @@ -2576,48 +2688,50 @@ def transition_executing_long_running(self, ts, compute_duration, *, stimulus_id self.io_loop.add_callback(self.ensure_computing) return {}, smsgs - def transition_released_memory(self, ts, value, *, stimulus_id): - recommendations = {} + def transition_released_memory( + self, ts: TaskState, value, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: + recs: Recs = {} try: - recommendations = self._put_key_in_memory( - ts, value, stimulus_id=stimulus_id - ) + recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) except Exception as e: msg = error_message(e) - recommendations[ts] = ( + recs[ts] = ( "error", msg["exception"], msg["traceback"], msg["exception_text"], msg["traceback_text"], ) - return recommendations, [] + return recs, [] smsgs = [{"op": "add-keys", "keys": [ts.key], "stimulus_id": stimulus_id}] - return recommendations, smsgs + return recs, smsgs - def transition_flight_memory(self, ts, value, *, stimulus_id): + def transition_flight_memory( + self, ts: TaskState, value, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: self._in_flight_tasks.discard(ts) ts.coming_from = None - recommendations = {} + recs: Recs = {} try: - recommendations = self._put_key_in_memory( - ts, value, stimulus_id=stimulus_id - ) + recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) except Exception as e: msg = error_message(e) - recommendations[ts] = ( + recs[ts] = ( "error", msg["exception"], msg["traceback"], msg["exception_text"], msg["traceback_text"], ) - return recommendations, [] + return recs, [] smsgs = [{"op": "add-keys", "keys": [ts.key], "stimulus_id": stimulus_id}] - return recommendations, smsgs + return recs, smsgs - def transition_released_forgotten(self, ts, *, stimulus_id): - recommendations = {} + def transition_released_forgotten( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Recs, Smsgs]: + recommendations: Recs = {} # Dependents _should_ be released by the scheduler before this if self.validate: assert not any(d.state != "forgotten" for d in ts.dependents) @@ -2631,17 +2745,19 @@ def transition_released_forgotten(self, ts, *, stimulus_id): self.tasks.pop(ts.key, None) return recommendations, [] - def _transition(self, ts, finish, *args, stimulus_id, **kwargs): + def _transition( + self, ts: TaskState, finish: str | tuple, *args, stimulus_id: str, **kwargs + ) -> tuple[Recs, Smsgs]: if isinstance(finish, tuple): # the concatenated transition path might need to access the tuple assert not args - finish, *args = finish + finish, *args = finish # type: ignore if ts is None or ts.state == finish: return {}, [] start = ts.state - func = self._transitions_table.get((start, finish)) + func = self._transitions_table.get((start, cast(str, finish))) if func is not None: self._transition_counter += 1 @@ -2653,6 +2769,8 @@ def _transition(self, ts, finish, *args, stimulus_id, **kwargs): try: recs, smsgs = self._transition(ts, "released", stimulus_id=stimulus_id) v = recs.get(ts, (finish, *args)) + v_state: str + v_args: list | tuple if isinstance(v, tuple): v_state, *v_args = v else: @@ -2690,7 +2808,9 @@ def _transition(self, ts, finish, *args, stimulus_id, **kwargs): ) return recs, smsgs - def transition(self, ts, finish: str, *, stimulus_id, **kwargs): + def transition( + self, ts: TaskState, finish: str, *, stimulus_id: str, **kwargs + ) -> None: """Transition a key from its current state to the finish state Examples @@ -2711,7 +2831,7 @@ def transition(self, ts, finish: str, *, stimulus_id, **kwargs): self.batched_stream.send(msg) self.transitions(recs, stimulus_id=stimulus_id) - def transitions(self, recommendations: dict, *, stimulus_id): + def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: """Process transitions until none are left This includes feedback from previous transitions and continues until we @@ -2743,7 +2863,9 @@ def transitions(self, recommendations: dict, *, stimulus_id): len(smsgs), ) - def maybe_transition_long_running(self, ts, *, stimulus_id, compute_duration=None): + def maybe_transition_long_running( + self, ts: TaskState, *, stimulus_id: str, compute_duration=None + ): if ts.state == "executing": self.transition( ts, @@ -2753,7 +2875,7 @@ def maybe_transition_long_running(self, ts, *, stimulus_id, compute_duration=Non ) assert ts.state == "long-running" - def stateof(self, key): + def stateof(self, key: str) -> dict[str, Any]: ts = self.tasks[key] return { "executing": ts.state == "executing", @@ -2762,8 +2884,8 @@ def stateof(self, key): "data": key in self.data, } - def story(self, *keys): - keys = [key.key if isinstance(key, TaskState) else key for key in keys] + def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]: + keys = [e.key if isinstance(e, TaskState) else e for e in keys_or_tasks] return [ msg for msg in self.log @@ -2776,7 +2898,7 @@ def story(self, *keys): ) ] - def ensure_communicating(self): + def ensure_communicating(self) -> None: stimulus_id = f"ensure-communicating-{time()}" skipped_worker_in_flight = [] @@ -2818,7 +2940,9 @@ def ensure_communicating(self): self.comm_nbytes += total_nbytes self.in_flight_workers[worker] = to_gather - recommendations = {self.tasks[d]: ("flight", worker) for d in to_gather} + recommendations: Recs = { + self.tasks[d]: ("flight", worker) for d in to_gather + } self.transitions(recommendations, stimulus_id=stimulus_id) self.loop.add_callback( @@ -2832,7 +2956,7 @@ def ensure_communicating(self): for el in skipped_worker_in_flight: self.data_needed.push(el) - def _get_task_finished_msg(self, ts): + def _get_task_finished_msg(self, ts: TaskState) -> dict[str, Any]: if ts.key not in self.data and ts.key not in self.actors: raise RuntimeError(f"Task {ts} not ready") typ = ts.type @@ -2864,7 +2988,7 @@ def _get_task_finished_msg(self, ts): d["startstops"] = ts.startstops return d - def _put_key_in_memory(self, ts, value, *, stimulus_id): + def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: """ Put a key into memory and set data related task state attributes. On success, generate recommendations for dependents. @@ -2889,7 +3013,7 @@ def _put_key_in_memory(self, ts, value, *, stimulus_id): ts.state = "memory" return {} - recommendations = {} + recommendations: Recs = {} if ts.key in self.actors: self.actors[ts.key] = value else: @@ -3047,8 +3171,8 @@ async def gather_dep( to_gather: Iterable[str], total_nbytes: int, *, - stimulus_id, - ): + stimulus_id: str, + ) -> None: """Gather dependencies for a task from a worker who has them Parameters @@ -3065,7 +3189,7 @@ async def gather_dep( if self.status not in Status.ANY_RUNNING: # type: ignore return - recommendations: dict[TaskState, str | tuple] = {} + recommendations: Recs = {} with log_errors(): response = {} to_gather_keys: set[str] = set() @@ -3182,7 +3306,7 @@ async def gather_dep( self.ensure_communicating() - async def find_missing(self): + async def find_missing(self) -> None: with log_errors(): if not self._missing_dep_flight: return @@ -3198,7 +3322,7 @@ async def find_missing(self): ) who_has = {k: v for k, v in who_has.items() if v} self.update_who_has(who_has) - recommendations = {} + recommendations: Recs = {} for ts in self._missing_dep_flight: if ts.who_has: recommendations[ts] = "fetch" @@ -3247,7 +3371,7 @@ def update_who_has(self, who_has: dict[str, Collection[str]]) -> None: pdb.set_trace() raise - def handle_steal_request(self, key, stimulus_id): + def handle_steal_request(self, key: str, stimulus_id: str) -> None: # There may be a race condition between stealing and releasing a task. # In this case the self.tasks is already cleared. The `None` will be # registered as `already-computing` on the other end @@ -3263,6 +3387,7 @@ def handle_steal_request(self, key, stimulus_id): self.batched_stream.send(response) if state in READY | {"waiting"}: + assert ts # If task is marked as "constrained" we haven't yet assigned it an # `available_resources` to run on, that happens in # `transition_constrained_executing` @@ -3288,8 +3413,9 @@ def release_key( self, key: str, cause: TaskState | None = None, - reason: str | None = None, report: bool = True, + *, + stimulus_id: str, ) -> None: try: if self.validate: @@ -3300,12 +3426,15 @@ def release_key( ts.state = "released" logger.debug( - "Release key %s", {"key": key, "cause": cause, "reason": reason} + "Release key %s", + {"key": key, "cause": cause, "stimulus_id": stimulus_id}, ) if cause: - self.log.append((key, "release-key", {"cause": cause}, reason, time())) + self.log.append( + (key, "release-key", {"cause": cause}, stimulus_id, time()) + ) else: - self.log.append((key, "release-key", reason, time())) + self.log.append((key, "release-key", stimulus_id, time())) if key in self.data: try: del self.data[key] @@ -3340,7 +3469,7 @@ def release_key( self._in_flight_tasks.discard(ts) self._notify_plugins( - "release_key", key, state_before, cause, reason, report + "release_key", key, state_before, cause, stimulus_id, report ) except CommClosedError: # Batched stream send might raise if it was already closed @@ -3363,10 +3492,17 @@ def run(self, comm, function, args=(), wait=True, kwargs=None): def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True): return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) - async def plugin_add(self, plugin=None, name=None, catch_errors=True): + async def plugin_add( + self, + plugin: WorkerPlugin | bytes, + name: str | None = None, + catch_errors: bool = True, + ) -> dict[str, Any]: with log_errors(pdb=False): if isinstance(plugin, bytes): - plugin = pickle.loads(plugin) + # Note: historically we have accepted duck-typed classes that don't + # inherit from WorkerPlugin. Don't do `assert isinstance`. + plugin = cast("WorkerPlugin", pickle.loads(plugin)) if name is None: name = _get_plugin_name(plugin) @@ -3392,7 +3528,7 @@ async def plugin_add(self, plugin=None, name=None, catch_errors=True): return {"status": "OK"} - async def plugin_remove(self, name=None): + async def plugin_remove(self, name: str) -> dict[str, Any]: with log_errors(pdb=False): logger.info(f"Removing Worker plugin {name}") try: @@ -3413,7 +3549,7 @@ async def actor_execute( function=None, args=(), kwargs: dict | None = None, - ): + ) -> dict[str, Any]: kwargs = kwargs or {} separate_thread = kwargs.pop("separate_thread", True) key = actor @@ -3442,7 +3578,7 @@ async def actor_execute( except Exception as ex: return {"status": "error", "exception": to_serialize(ex)} - def actor_attribute(self, actor=None, attribute=None): + def actor_attribute(self, actor=None, attribute=None) -> dict[str, Any]: try: value = getattr(self.actors[actor], attribute) return {"status": "OK", "result": to_serialize(value)} @@ -3459,9 +3595,11 @@ def meets_resource_constraints(self, key: str) -> bool: return True - async def _maybe_deserialize_task(self, ts, *, stimulus_id): - if not isinstance(ts.run_spec, SerializedTask): - return ts.run_spec + async def _maybe_deserialize_task( + self, ts: TaskState, *, stimulus_id: str + ) -> tuple[Callable, tuple, dict[str, Any]] | None: + if ts.run_spec is None: + return None try: start = time() # Offload deserializing large tasks @@ -3489,7 +3627,7 @@ async def _maybe_deserialize_task(self, ts, *, stimulus_id): ) raise - def ensure_computing(self): + def ensure_computing(self) -> None: if self.status in (Status.paused, Status.closing_gracefully): return try: @@ -3526,7 +3664,7 @@ def ensure_computing(self): pdb.set_trace() raise - async def execute(self, key, *, stimulus_id): + async def execute(self, key: str, *, stimulus_id: str) -> None: if self.status in {Status.closing, Status.closed, Status.closing_gracefully}: return if key not in self.tasks: @@ -3549,7 +3687,7 @@ async def execute(self, key, *, stimulus_id): assert ts.state == "executing" assert ts.run_spec is not None - function, args, kwargs = await self._maybe_deserialize_task( + function, args, kwargs = await self._maybe_deserialize_task( # type: ignore ts, stimulus_id=stimulus_id ) @@ -3600,9 +3738,9 @@ async def execute(self, key, *, stimulus_id): self.active_keys.discard(ts.key) key = ts.key - # key *must* be still in tasks. Releasing it direclty is forbidden + # key *must* be still in tasks. Releasing it directly is forbidden # without going through cancelled - ts = self.tasks.get(key) + ts = self.tasks.get(key) # type: ignore assert ts, self.story(key) ts.done = True result["key"] = ts.key @@ -3611,7 +3749,7 @@ async def execute(self, key, *, stimulus_id): {"action": "compute", "start": result["start"], "stop": result["stop"]} ) self.threads[ts.key] = result["thread"] - recommendations = {} + recommendations: Recs = {} if result["op"] == "task-finished": ts.nbytes = result["nbytes"] ts.type = result["type"] @@ -3665,7 +3803,9 @@ async def execute(self, key, *, stimulus_id): self.ensure_computing() self.ensure_communicating() - def _prepare_args_for_execution(self, ts, args, kwargs): + def _prepare_args_for_execution( + self, ts: TaskState, args: tuple, kwargs: dict[str, Any] + ) -> tuple[tuple, dict[str, Any]]: start = time() data = {} for dep in ts.dependencies: @@ -3689,7 +3829,7 @@ def _prepare_args_for_execution(self, ts, args, kwargs): # Administrative # ################## - async def memory_monitor(self): + async def memory_monitor(self) -> None: """Track this process's memory usage and act accordingly If we rise above 70% memory use, start dumping data to disk. @@ -3699,6 +3839,7 @@ async def memory_monitor(self): if self._memory_monitoring: return self._memory_monitoring = True + assert self.memory_limit total = 0 memory = self.monitor.get_process_memory() @@ -3738,6 +3879,10 @@ def check_pause(memory): check_pause(memory) # Dump data to disk if above 70% if self.memory_spill_fraction and frac > self.memory_spill_fraction: + from .spill import SpillBuffer + + assert isinstance(self.data, SpillBuffer) + logger.debug( "Worker is at %.0f%% memory usage. Start spilling data to disk.", frac * 100, @@ -3791,9 +3936,8 @@ def check_pause(memory): ) self._memory_monitoring = False - return total - def cycle_profile(self): + def cycle_profile(self) -> None: now = time() + self.scheduler_delay prof, self.profile_recent = self.profile_recent, profile.create() self.profile_history.append((now, prof)) @@ -3801,7 +3945,7 @@ def cycle_profile(self): self.profile_keys_history.append((now, dict(self.profile_keys))) self.profile_keys.clear() - def trigger_profile(self): + def trigger_profile(self) -> None: """ Get a frame from all actively computing threads @@ -3834,7 +3978,13 @@ def trigger_profile(self): if self.digests is not None: self.digests["profile-duration"].add(stop - start) - async def get_profile(self, start=None, stop=None, key=None, server=False): + async def get_profile( + self, + start=None, + stop=None, + key=None, + server: bool = False, + ): now = time() + self.scheduler_delay if server: history = self.io_loop.profile @@ -3875,11 +4025,12 @@ async def get_profile(self, start=None, stop=None, key=None, server=False): return prof - async def get_profile_metadata(self, start=0, stop=None): + async def get_profile_metadata( + self, start: float = 0, stop: float | None = None + ) -> dict[str, Any]: add_recent = stop is None now = time() + self.scheduler_delay stop = stop or now - start = start or 0 result = { "counts": [ (t, d["count"]) for t, d in self.profile_history if start < t < stop @@ -3897,16 +4048,14 @@ async def get_profile_metadata(self, start=0, stop=None): ) return result - def get_call_stack(self, keys=None): + def get_call_stack(self, keys: Collection[str] | None = None) -> dict[str, Any]: with self.active_threads_lock: - frames = sys._current_frames() - active_threads = self.active_threads.copy() - frames = {k: frames[ident] for ident, k in active_threads.items()} + sys_frames = sys._current_frames() + frames = {key: sys_frames[tid] for tid, key in self.active_threads.items()} if keys is not None: - frames = {k: frame for k, frame in frames.items() if k in keys} + frames = {key: frames[key] for key in keys if key in frames} - result = {k: profile.call_stack(frame) for k, frame in frames.items()} - return result + return {key: profile.call_stack(frame) for key, frame in frames.items()} def _notify_plugins(self, method_name, *args, **kwargs): for name, plugin in self.plugins.items():