Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotation-less P2P shuffling #7801

Merged
merged 28 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5b19534
WIP annotation-less shuffling
hendrikmakait Apr 25, 2023
edc8a8e
Remove comments
hendrikmakait Apr 25, 2023
0359078
XFAIL
hendrikmakait Apr 25, 2023
de75ab0
Rechunk
hendrikmakait Apr 25, 2023
7c603dd
Comment
hendrikmakait Apr 25, 2023
35186d3
Tests
hendrikmakait Apr 25, 2023
0937e19
Minor
hendrikmakait Apr 26, 2023
bf27d64
XFAIL
hendrikmakait Apr 26, 2023
0c2ead1
Merge branch 'main' into annotation-less-shuffling
hendrikmakait Apr 28, 2023
c60b60a
Unset restrictions
hendrikmakait May 2, 2023
46bfc71
Recover from lost annotations
hendrikmakait May 2, 2023
e69c985
Fix tests
hendrikmakait May 2, 2023
fb93688
P2P now works with fuse
hendrikmakait May 2, 2023
673b03c
Refactor
hendrikmakait May 2, 2023
6ca596a
Fix tests
hendrikmakait May 2, 2023
502993e
Keep annotations
hendrikmakait May 2, 2023
309d4d5
Test with lost annotations
hendrikmakait May 3, 2023
30f9b91
minor
hendrikmakait May 3, 2023
09d1436
Riase informative error if barrier task unknown
hendrikmakait May 3, 2023
fd119c1
Link with #7816
hendrikmakait May 3, 2023
930693c
Update distributed/shuffle/_scheduler_extension.py
hendrikmakait May 9, 2023
e5c4022
Document assumption
hendrikmakait May 9, 2023
7a1229d
Docs
hendrikmakait May 9, 2023
c57af86
Increase range
hendrikmakait May 9, 2023
dc8055d
Rename module
hendrikmakait May 11, 2023
d5aa8a0
Ensure failing worker worked on task
hendrikmakait May 11, 2023
642007a
Additional asserts
hendrikmakait May 11, 2023
b276e79
Merge branch 'main' into annotation-less-shuffling
hendrikmakait May 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions distributed/reschedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations


class Reschedule(Exception):
hendrikmakait marked this conversation as resolved.
Show resolved Hide resolved
"""Reschedule this task

Raising this exception will stop the current execution of the task and ask
the scheduler to reschedule this task, possibly on a different machine.

This does not guarantee that the task will move onto a different machine.
The scheduler will proceed through its normal heuristics to determine the
optimal machine to accept this task. The machine will likely change if the
load across the cluster has significantly changed since first scheduling
the task.
"""
4 changes: 4 additions & 0 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,15 @@ def merge_transfer(
id: ShuffleId,
input_partition: int,
npartitions: int,
parts_out: set[int],
):
return shuffle_transfer(
input=input,
id=id,
input_partition=input_partition,
npartitions=npartitions,
column=_HASH_COLUMN_NAME,
parts_out=parts_out,
)


Expand Down Expand Up @@ -340,6 +342,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
token_left,
i,
self.npartitions,
self.parts_out,
)
for i in range(self.n_partitions_right):
transfer_keys_right.append((name_right, i))
Expand All @@ -349,6 +352,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
token_right,
i,
self.npartitions,
self.parts_out,
)

_barrier_key_left = barrier_key(ShuffleId(token_left))
Expand Down
9 changes: 3 additions & 6 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer

from distributed.reschedule import Reschedule
from distributed.shuffle._shuffle import (
ShuffleId,
ShuffleType,
Expand Down Expand Up @@ -57,6 +58,8 @@ def rechunk_unpack(
return _get_worker_extension().get_output_partition(
id, barrier_run_id, output_chunk
)
except Reschedule as e:
raise e
except Exception as e:
raise RuntimeError(f"rechunk_unpack failed during shuffle {id}") from e

Expand All @@ -70,12 +73,6 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array:
# Special case for empty array, as the algorithm below does not behave correctly
return da.empty(x.shape, chunks=chunks, dtype=x.dtype)

if dask.config.get("optimization.fuse.active") is not False:
raise RuntimeError(
"P2P rechunking requires the fuse optimization to be turned off. "
"Set the 'optimization.fuse.active' config to False to deactivate."
)

dsk: dict = {}
token = tokenize(x, chunks)
_barrier_key = barrier_key(ShuffleId(token))
Expand Down
157 changes: 108 additions & 49 deletions distributed/shuffle/_scheduler_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar
from functools import partial
from itertools import product
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Sequence

from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.shuffle._rechunk import ChunkedAxes, NIndex
Expand Down Expand Up @@ -105,6 +107,7 @@ def __init__(self, scheduler: Scheduler):
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
Expand All @@ -122,14 +125,22 @@ async def barrier(self, id: ShuffleId, run_id: int) -> None:
msg=msg, workers=list(shuffle.participating_workers)
)

def restrict_task(self, id: ShuffleId, run_id: int, key: str, worker: str) -> dict:
shuffle = self.states[id]
if shuffle.run_id != run_id:
return {"status": "error", "message": "Stale shuffle"}
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}

def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)

def get(self, id: ShuffleId, worker: str) -> dict[str, Any]:
if exception := self.erred_shuffles.get(id):
return {"status": "ERROR", "message": str(exception)}
return {"status": "error", "message": str(exception)}
state = self.states[id]
state.participating_workers.add(worker)
return state.to_msg()
Expand All @@ -144,6 +155,11 @@ def get_or_create(
try:
return self.get(id, worker)
except KeyError:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(id)

state: ShuffleState
if type == ShuffleType.DATAFRAME:
state = self._create_dataframe_shuffle_state(id, spec)
Expand All @@ -155,33 +171,33 @@ def get_or_create(
state.participating_workers.add(worker)
return state.to_msg()

def _raise_if_barrier_unknown(self, id: ShuffleId) -> None:
key = barrier_key(id)
try:
self.scheduler.tasks[key]
except KeyError:
raise RuntimeError(
f"Barrier task with key {key!r} does not exist. This may be caused by "
"task fusion during graph generation. Please let us know that you ran "
"into this by leaving a comment at distributed#7816."
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XREF: #7816

)

def _create_dataframe_shuffle_state(
self, id: ShuffleId, spec: dict[str, Any]
) -> DataFrameShuffleState:
schema = spec["schema"]
column = spec["column"]
npartitions = spec["npartitions"]
parts_out = spec["parts_out"]
assert schema is not None
assert column is not None
assert npartitions is not None
assert parts_out is not None

workers = list(self.scheduler.workers)
output_workers = set()

name = barrier_key(id)
mapping = {}
pick_worker = partial(get_worker_for_range_sharding, npartitions)

for ts in self.scheduler.tasks[name].dependents:
part = get_partition_id(ts)
if ts.worker_restrictions:
output_worker = list(ts.worker_restrictions)[0]
else:
output_worker = get_worker_for_range_sharding(
part, workers, npartitions
)
mapping[part] = output_worker
output_workers.add(output_worker)
self.scheduler.set_restrictions({ts.key: {output_worker}})
mapping = self._pin_output_workers(id, parts_out, pick_worker)
output_workers = set(mapping.values())

return DataFrameShuffleState(
id=id,
Expand All @@ -193,6 +209,52 @@ def _create_dataframe_shuffle_state(
participating_workers=output_workers.copy(),
)

def _pin_output_workers(
self,
id: ShuffleId,
output_partitions: Iterable[Any],
pick: Callable[[Any, Sequence[str]], str],
) -> dict[Any, str]:
"""Pin the outputs of a P2P shuffle to specific workers.

Parameters
----------
id: ID of the shuffle to pin
output_partitions: Output partition IDs to pin
pick: Function that picks a worker given a partition ID and sequence of worker

.. note:
This function assumes that the barrier task and the output tasks share
the same worker restrictions.
"""
mapping = {}
barrier = self.scheduler.tasks[barrier_key(id)]

if barrier.worker_restrictions:
workers = list(barrier.worker_restrictions)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implicit assumption here is that if the barrier and output tasks are guaranteed to have the same restrictions. I suggest to document this because it is a non-trivial conclusion and depending on how future versions of fusion work this may not even be true indefinitely.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a docstring.

else:
workers = list(self.scheduler.workers)

for partition in output_partitions:
worker = pick(partition, workers)
mapping[partition] = worker

for dt in barrier.dependents:
try:
partition = dt.annotations["shuffle"]
except KeyError:
continue
hendrikmakait marked this conversation as resolved.
Show resolved Hide resolved

if dt.worker_restrictions:
worker = pick(partition, list(dt.worker_restrictions))
mapping[partition] = worker
else:
worker = mapping[partition]

self._set_restriction(dt, worker)

return mapping

def _create_array_rechunk_state(
self, id: ShuffleId, spec: dict[str, Any]
) -> ArrayRechunkState:
Expand All @@ -201,21 +263,9 @@ def _create_array_rechunk_state(
assert old is not None
assert new is not None

workers = list(self.scheduler.workers)
output_workers = set()

name = barrier_key(id)
mapping = {}

for ts in self.scheduler.tasks[name].dependents:
part = get_partition_id(ts)
if ts.worker_restrictions:
output_worker = list(ts.worker_restrictions)[0]
else:
output_worker = get_worker_for_hash_sharding(part, workers)
mapping[part] = output_worker
output_workers.add(output_worker)
self.scheduler.set_restrictions({ts.key: {output_worker}})
parts_out = product(*(range(len(c)) for c in new))
mapping = self._pin_output_workers(id, parts_out, get_worker_for_hash_sharding)
output_workers = set(mapping.values())

return ArrayRechunkState(
id=id,
Expand All @@ -227,6 +277,22 @@ def _create_array_rechunk_state(
participating_workers=output_workers.copy(),
)

def _set_restriction(self, ts: TaskState, worker: str) -> None:
if "shuffle_original_restrictions" in ts.annotations:
# This may occur if multiple barriers share the same output task,
# e.g. in a hash join.
return
ts.annotations["shuffle_original_restrictions"] = ts.worker_restrictions.copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dummy question, sorry. The problem that you're fixing is tasks occasionally losing annotations. AFAIK this is applicable to all annotations, not just shuffle annotations, and your ShuffleAnnotationChaosPlugin is only killingshuffle annotation. Is there a possibility that in a real-world scenario, this new shuffle_original_restrictions will be lost too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on what we've seen, the current assumption is that annotations get lost before the tasks make it to the scheduler.

self.scheduler.set_restrictions({ts.key: {worker}})

def _unset_restriction(self, ts: TaskState) -> None:
# shuffle_original_restrictions is only set if the task was first scheduled
# on the wrong worker
if "shuffle_original_restrictions" not in ts.annotations:
return
original_restrictions = ts.annotations.pop("shuffle_original_restrictions")
self.scheduler.set_restrictions({ts.key: original_restrictions})

def remove_worker(self, scheduler: Scheduler, worker: str) -> None:
from time import time

Expand All @@ -247,7 +313,7 @@ def remove_worker(self, scheduler: Scheduler, worker: str) -> None:
for dt in barrier_task.dependents:
if worker not in dt.worker_restrictions:
continue
dt.worker_restrictions.clear()
self._unset_restriction(dt)
recs.update({dt.key: "waiting"})
# TODO: Do we need to handle other states?

Expand Down Expand Up @@ -293,34 +359,27 @@ def _clean_on_scheduler(self, id: ShuffleId) -> None:
with contextlib.suppress(KeyError):
del self.heartbeats[id]

barrier_task = self.scheduler.tasks[barrier_key(id)]
for dt in barrier_task.dependents:
self._unset_restriction(dt)

def restart(self, scheduler: Scheduler) -> None:
self.states.clear()
self.heartbeats.clear()
self.erred_shuffles.clear()


def get_partition_id(ts: TaskState) -> Any:
"""Get the output partition ID of this task state."""
try:
return ts.annotations["shuffle"]
except KeyError:
raise RuntimeError(
f"{ts} has lost its ``shuffle`` annotation. This may be caused by "
"unintended optimization during graph generation. "
"Please report this problem on GitHub and link it to "
"the tracking issue at https://github.com/dask/distributed/issues/7716."
)


def get_worker_for_range_sharding(
output_partition: int, workers: list[str], npartitions: int
npartitions: int, output_partition: int, workers: Sequence[str]
) -> str:
"""Get address of target worker for this output partition using range sharding"""
i = len(workers) * output_partition // npartitions
return workers[i]


def get_worker_for_hash_sharding(output_partition: NIndex, workers: list[str]) -> str:
def get_worker_for_hash_sharding(
output_partition: NIndex, workers: Sequence[str]
) -> str:
"""Get address of target worker for this output partition using hash sharding"""
i = hash(output_partition) % len(workers)
return workers[i]
Loading