-
-
Notifications
You must be signed in to change notification settings - Fork 719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Annotation-less P2P shuffling #7801
Changes from all commits
5b19534
edc8a8e
0359078
de75ab0
7c603dd
35186d3
0937e19
bf27d64
0c2ead1
c60b60a
46bfc71
e69c985
fb93688
673b03c
6ca596a
502993e
309d4d5
30f9b91
09d1436
fd119c1
930693c
e5c4022
7a1229d
c57af86
dc8055d
d5aa8a0
642007a
b276e79
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from __future__ import annotations | ||
|
||
|
||
class Reschedule(Exception): | ||
"""Reschedule this task | ||
|
||
Raising this exception will stop the current execution of the task and ask | ||
the scheduler to reschedule this task, possibly on a different machine. | ||
|
||
This does not guarantee that the task will move onto a different machine. | ||
The scheduler will proceed through its normal heuristics to determine the | ||
optimal machine to accept this task. The machine will likely change if the | ||
load across the cluster has significantly changed since first scheduling | ||
the task. | ||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,9 @@ | |
import logging | ||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Any, ClassVar | ||
from functools import partial | ||
from itertools import product | ||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Sequence | ||
|
||
from distributed.diagnostics.plugin import SchedulerPlugin | ||
from distributed.shuffle._rechunk import ChunkedAxes, NIndex | ||
|
@@ -105,6 +107,7 @@ def __init__(self, scheduler: Scheduler): | |
"shuffle_barrier": self.barrier, | ||
"shuffle_get": self.get, | ||
"shuffle_get_or_create": self.get_or_create, | ||
"shuffle_restrict_task": self.restrict_task, | ||
} | ||
) | ||
self.heartbeats = defaultdict(lambda: defaultdict(dict)) | ||
|
@@ -122,14 +125,22 @@ async def barrier(self, id: ShuffleId, run_id: int) -> None: | |
msg=msg, workers=list(shuffle.participating_workers) | ||
) | ||
|
||
def restrict_task(self, id: ShuffleId, run_id: int, key: str, worker: str) -> dict: | ||
shuffle = self.states[id] | ||
if shuffle.run_id != run_id: | ||
return {"status": "error", "message": "Stale shuffle"} | ||
ts = self.scheduler.tasks[key] | ||
self._set_restriction(ts, worker) | ||
return {"status": "OK"} | ||
|
||
def heartbeat(self, ws: WorkerState, data: dict) -> None: | ||
for shuffle_id, d in data.items(): | ||
if shuffle_id in self.shuffle_ids(): | ||
self.heartbeats[shuffle_id][ws.address].update(d) | ||
|
||
def get(self, id: ShuffleId, worker: str) -> dict[str, Any]: | ||
if exception := self.erred_shuffles.get(id): | ||
return {"status": "ERROR", "message": str(exception)} | ||
return {"status": "error", "message": str(exception)} | ||
state = self.states[id] | ||
state.participating_workers.add(worker) | ||
return state.to_msg() | ||
|
@@ -144,6 +155,11 @@ def get_or_create( | |
try: | ||
return self.get(id, worker) | ||
except KeyError: | ||
# FIXME: The current implementation relies on the barrier task to be | ||
# known by its name. If the name has been mangled, we cannot guarantee | ||
# that the shuffle works as intended and should fail instead. | ||
self._raise_if_barrier_unknown(id) | ||
|
||
state: ShuffleState | ||
if type == ShuffleType.DATAFRAME: | ||
state = self._create_dataframe_shuffle_state(id, spec) | ||
|
@@ -155,33 +171,33 @@ def get_or_create( | |
state.participating_workers.add(worker) | ||
return state.to_msg() | ||
|
||
def _raise_if_barrier_unknown(self, id: ShuffleId) -> None: | ||
key = barrier_key(id) | ||
try: | ||
self.scheduler.tasks[key] | ||
except KeyError: | ||
raise RuntimeError( | ||
f"Barrier task with key {key!r} does not exist. This may be caused by " | ||
"task fusion during graph generation. Please let us know that you ran " | ||
"into this by leaving a comment at distributed#7816." | ||
) | ||
|
||
def _create_dataframe_shuffle_state( | ||
self, id: ShuffleId, spec: dict[str, Any] | ||
) -> DataFrameShuffleState: | ||
schema = spec["schema"] | ||
column = spec["column"] | ||
npartitions = spec["npartitions"] | ||
parts_out = spec["parts_out"] | ||
assert schema is not None | ||
assert column is not None | ||
assert npartitions is not None | ||
assert parts_out is not None | ||
|
||
workers = list(self.scheduler.workers) | ||
output_workers = set() | ||
|
||
name = barrier_key(id) | ||
mapping = {} | ||
pick_worker = partial(get_worker_for_range_sharding, npartitions) | ||
|
||
for ts in self.scheduler.tasks[name].dependents: | ||
part = get_partition_id(ts) | ||
if ts.worker_restrictions: | ||
output_worker = list(ts.worker_restrictions)[0] | ||
else: | ||
output_worker = get_worker_for_range_sharding( | ||
part, workers, npartitions | ||
) | ||
mapping[part] = output_worker | ||
output_workers.add(output_worker) | ||
self.scheduler.set_restrictions({ts.key: {output_worker}}) | ||
mapping = self._pin_output_workers(id, parts_out, pick_worker) | ||
output_workers = set(mapping.values()) | ||
|
||
return DataFrameShuffleState( | ||
id=id, | ||
|
@@ -193,6 +209,52 @@ def _create_dataframe_shuffle_state( | |
participating_workers=output_workers.copy(), | ||
) | ||
|
||
def _pin_output_workers( | ||
self, | ||
id: ShuffleId, | ||
output_partitions: Iterable[Any], | ||
pick: Callable[[Any, Sequence[str]], str], | ||
) -> dict[Any, str]: | ||
"""Pin the outputs of a P2P shuffle to specific workers. | ||
Parameters | ||
---------- | ||
id: ID of the shuffle to pin | ||
output_partitions: Output partition IDs to pin | ||
pick: Function that picks a worker given a partition ID and sequence of worker | ||
.. note: | ||
This function assumes that the barrier task and the output tasks share | ||
the same worker restrictions. | ||
""" | ||
mapping = {} | ||
barrier = self.scheduler.tasks[barrier_key(id)] | ||
|
||
if barrier.worker_restrictions: | ||
workers = list(barrier.worker_restrictions) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implicit assumption here is that if the barrier and output tasks are guaranteed to have the same restrictions. I suggest to document this because it is a non-trivial conclusion and depending on how future versions of fusion work this may not even be true indefinitely. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added a docstring. |
||
else: | ||
workers = list(self.scheduler.workers) | ||
|
||
for partition in output_partitions: | ||
worker = pick(partition, workers) | ||
mapping[partition] = worker | ||
|
||
for dt in barrier.dependents: | ||
try: | ||
partition = dt.annotations["shuffle"] | ||
except KeyError: | ||
continue | ||
hendrikmakait marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if dt.worker_restrictions: | ||
worker = pick(partition, list(dt.worker_restrictions)) | ||
mapping[partition] = worker | ||
else: | ||
worker = mapping[partition] | ||
|
||
self._set_restriction(dt, worker) | ||
|
||
return mapping | ||
|
||
def _create_array_rechunk_state( | ||
self, id: ShuffleId, spec: dict[str, Any] | ||
) -> ArrayRechunkState: | ||
|
@@ -201,21 +263,9 @@ def _create_array_rechunk_state( | |
assert old is not None | ||
assert new is not None | ||
|
||
workers = list(self.scheduler.workers) | ||
output_workers = set() | ||
|
||
name = barrier_key(id) | ||
mapping = {} | ||
|
||
for ts in self.scheduler.tasks[name].dependents: | ||
part = get_partition_id(ts) | ||
if ts.worker_restrictions: | ||
output_worker = list(ts.worker_restrictions)[0] | ||
else: | ||
output_worker = get_worker_for_hash_sharding(part, workers) | ||
mapping[part] = output_worker | ||
output_workers.add(output_worker) | ||
self.scheduler.set_restrictions({ts.key: {output_worker}}) | ||
parts_out = product(*(range(len(c)) for c in new)) | ||
mapping = self._pin_output_workers(id, parts_out, get_worker_for_hash_sharding) | ||
output_workers = set(mapping.values()) | ||
|
||
return ArrayRechunkState( | ||
id=id, | ||
|
@@ -227,6 +277,22 @@ def _create_array_rechunk_state( | |
participating_workers=output_workers.copy(), | ||
) | ||
|
||
def _set_restriction(self, ts: TaskState, worker: str) -> None: | ||
if "shuffle_original_restrictions" in ts.annotations: | ||
# This may occur if multiple barriers share the same output task, | ||
# e.g. in a hash join. | ||
return | ||
ts.annotations["shuffle_original_restrictions"] = ts.worker_restrictions.copy() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dummy question, sorry. The problem that you're fixing is tasks occasionally losing annotations. AFAIK this is applicable to all annotations, not just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on what we've seen, the current assumption is that |
||
self.scheduler.set_restrictions({ts.key: {worker}}) | ||
|
||
def _unset_restriction(self, ts: TaskState) -> None: | ||
# shuffle_original_restrictions is only set if the task was first scheduled | ||
# on the wrong worker | ||
if "shuffle_original_restrictions" not in ts.annotations: | ||
return | ||
original_restrictions = ts.annotations.pop("shuffle_original_restrictions") | ||
self.scheduler.set_restrictions({ts.key: original_restrictions}) | ||
|
||
def remove_worker(self, scheduler: Scheduler, worker: str) -> None: | ||
from time import time | ||
|
||
|
@@ -247,7 +313,7 @@ def remove_worker(self, scheduler: Scheduler, worker: str) -> None: | |
for dt in barrier_task.dependents: | ||
if worker not in dt.worker_restrictions: | ||
continue | ||
dt.worker_restrictions.clear() | ||
self._unset_restriction(dt) | ||
recs.update({dt.key: "waiting"}) | ||
# TODO: Do we need to handle other states? | ||
|
||
|
@@ -293,34 +359,27 @@ def _clean_on_scheduler(self, id: ShuffleId) -> None: | |
with contextlib.suppress(KeyError): | ||
del self.heartbeats[id] | ||
|
||
barrier_task = self.scheduler.tasks[barrier_key(id)] | ||
for dt in barrier_task.dependents: | ||
self._unset_restriction(dt) | ||
|
||
def restart(self, scheduler: Scheduler) -> None: | ||
self.states.clear() | ||
self.heartbeats.clear() | ||
self.erred_shuffles.clear() | ||
|
||
|
||
def get_partition_id(ts: TaskState) -> Any: | ||
"""Get the output partition ID of this task state.""" | ||
try: | ||
return ts.annotations["shuffle"] | ||
except KeyError: | ||
raise RuntimeError( | ||
f"{ts} has lost its ``shuffle`` annotation. This may be caused by " | ||
"unintended optimization during graph generation. " | ||
"Please report this problem on GitHub and link it to " | ||
"the tracking issue at https://github.com/dask/distributed/issues/7716." | ||
) | ||
|
||
|
||
def get_worker_for_range_sharding( | ||
output_partition: int, workers: list[str], npartitions: int | ||
npartitions: int, output_partition: int, workers: Sequence[str] | ||
) -> str: | ||
"""Get address of target worker for this output partition using range sharding""" | ||
i = len(workers) * output_partition // npartitions | ||
return workers[i] | ||
|
||
|
||
def get_worker_for_hash_sharding(output_partition: NIndex, workers: list[str]) -> str: | ||
def get_worker_for_hash_sharding( | ||
output_partition: NIndex, workers: Sequence[str] | ||
) -> str: | ||
"""Get address of target worker for this output partition using hash sharding""" | ||
i = hash(output_partition) % len(workers) | ||
return workers[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XREF: #7816