Skip to content

Commit

Permalink
Annotation-less P2P shuffling (#7801)
Browse files Browse the repository at this point in the history
Co-authored-by: Florian Jetter <[email protected]>
  • Loading branch information
hendrikmakait and fjetter authored May 11, 2023
1 parent b68d71d commit 21b70be
Show file tree
Hide file tree
Showing 12 changed files with 328 additions and 238 deletions.
15 changes: 15 additions & 0 deletions distributed/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations


class Reschedule(Exception):
"""Reschedule this task
Raising this exception will stop the current execution of the task and ask
the scheduler to reschedule this task, possibly on a different machine.
This does not guarantee that the task will move onto a different machine.
The scheduler will proceed through its normal heuristics to determine the
optimal machine to accept this task. The machine will likely change if the
load across the cluster has significantly changed since first scheduling
the task.
"""
4 changes: 4 additions & 0 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,15 @@ def merge_transfer(
id: ShuffleId,
input_partition: int,
npartitions: int,
parts_out: set[int],
):
return shuffle_transfer(
input=input,
id=id,
input_partition=input_partition,
npartitions=npartitions,
column=_HASH_COLUMN_NAME,
parts_out=parts_out,
)


Expand Down Expand Up @@ -340,6 +342,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
token_left,
i,
self.npartitions,
self.parts_out,
)
for i in range(self.n_partitions_right):
transfer_keys_right.append((name_right, i))
Expand All @@ -349,6 +352,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
token_right,
i,
self.npartitions,
self.parts_out,
)

_barrier_key_left = barrier_key(ShuffleId(token_left))
Expand Down
9 changes: 3 additions & 6 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer

from distributed.exceptions import Reschedule
from distributed.shuffle._shuffle import (
ShuffleId,
ShuffleType,
Expand Down Expand Up @@ -57,6 +58,8 @@ def rechunk_unpack(
return _get_worker_extension().get_output_partition(
id, barrier_run_id, output_chunk
)
except Reschedule as e:
raise e
except Exception as e:
raise RuntimeError(f"rechunk_unpack failed during shuffle {id}") from e

Expand All @@ -70,12 +73,6 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array:
# Special case for empty array, as the algorithm below does not behave correctly
return da.empty(x.shape, chunks=chunks, dtype=x.dtype)

if dask.config.get("optimization.fuse.active") is not False:
raise RuntimeError(
"P2P rechunking requires the fuse optimization to be turned off. "
"Set the 'optimization.fuse.active' config to False to deactivate."
)

dsk: dict = {}
token = tokenize(x, chunks)
_barrier_key = barrier_key(ShuffleId(token))
Expand Down
157 changes: 108 additions & 49 deletions distributed/shuffle/_scheduler_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar
from functools import partial
from itertools import product
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Sequence

from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.shuffle._rechunk import ChunkedAxes, NIndex
Expand Down Expand Up @@ -105,6 +107,7 @@ def __init__(self, scheduler: Scheduler):
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
Expand All @@ -122,14 +125,22 @@ async def barrier(self, id: ShuffleId, run_id: int) -> None:
msg=msg, workers=list(shuffle.participating_workers)
)

def restrict_task(self, id: ShuffleId, run_id: int, key: str, worker: str) -> dict:
shuffle = self.states[id]
if shuffle.run_id != run_id:
return {"status": "error", "message": "Stale shuffle"}
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}

def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)

def get(self, id: ShuffleId, worker: str) -> dict[str, Any]:
if exception := self.erred_shuffles.get(id):
return {"status": "ERROR", "message": str(exception)}
return {"status": "error", "message": str(exception)}
state = self.states[id]
state.participating_workers.add(worker)
return state.to_msg()
Expand All @@ -144,6 +155,11 @@ def get_or_create(
try:
return self.get(id, worker)
except KeyError:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(id)

state: ShuffleState
if type == ShuffleType.DATAFRAME:
state = self._create_dataframe_shuffle_state(id, spec)
Expand All @@ -155,33 +171,33 @@ def get_or_create(
state.participating_workers.add(worker)
return state.to_msg()

def _raise_if_barrier_unknown(self, id: ShuffleId) -> None:
key = barrier_key(id)
try:
self.scheduler.tasks[key]
except KeyError:
raise RuntimeError(
f"Barrier task with key {key!r} does not exist. This may be caused by "
"task fusion during graph generation. Please let us know that you ran "
"into this by leaving a comment at distributed#7816."
)

def _create_dataframe_shuffle_state(
self, id: ShuffleId, spec: dict[str, Any]
) -> DataFrameShuffleState:
schema = spec["schema"]
column = spec["column"]
npartitions = spec["npartitions"]
parts_out = spec["parts_out"]
assert schema is not None
assert column is not None
assert npartitions is not None
assert parts_out is not None

workers = list(self.scheduler.workers)
output_workers = set()

name = barrier_key(id)
mapping = {}
pick_worker = partial(get_worker_for_range_sharding, npartitions)

for ts in self.scheduler.tasks[name].dependents:
part = get_partition_id(ts)
if ts.worker_restrictions:
output_worker = list(ts.worker_restrictions)[0]
else:
output_worker = get_worker_for_range_sharding(
part, workers, npartitions
)
mapping[part] = output_worker
output_workers.add(output_worker)
self.scheduler.set_restrictions({ts.key: {output_worker}})
mapping = self._pin_output_workers(id, parts_out, pick_worker)
output_workers = set(mapping.values())

return DataFrameShuffleState(
id=id,
Expand All @@ -193,6 +209,52 @@ def _create_dataframe_shuffle_state(
participating_workers=output_workers.copy(),
)

def _pin_output_workers(
self,
id: ShuffleId,
output_partitions: Iterable[Any],
pick: Callable[[Any, Sequence[str]], str],
) -> dict[Any, str]:
"""Pin the outputs of a P2P shuffle to specific workers.
Parameters
----------
id: ID of the shuffle to pin
output_partitions: Output partition IDs to pin
pick: Function that picks a worker given a partition ID and sequence of worker
.. note:
This function assumes that the barrier task and the output tasks share
the same worker restrictions.
"""
mapping = {}
barrier = self.scheduler.tasks[barrier_key(id)]

if barrier.worker_restrictions:
workers = list(barrier.worker_restrictions)
else:
workers = list(self.scheduler.workers)

for partition in output_partitions:
worker = pick(partition, workers)
mapping[partition] = worker

for dt in barrier.dependents:
try:
partition = dt.annotations["shuffle"]
except KeyError:
continue

if dt.worker_restrictions:
worker = pick(partition, list(dt.worker_restrictions))
mapping[partition] = worker
else:
worker = mapping[partition]

self._set_restriction(dt, worker)

return mapping

def _create_array_rechunk_state(
self, id: ShuffleId, spec: dict[str, Any]
) -> ArrayRechunkState:
Expand All @@ -201,21 +263,9 @@ def _create_array_rechunk_state(
assert old is not None
assert new is not None

workers = list(self.scheduler.workers)
output_workers = set()

name = barrier_key(id)
mapping = {}

for ts in self.scheduler.tasks[name].dependents:
part = get_partition_id(ts)
if ts.worker_restrictions:
output_worker = list(ts.worker_restrictions)[0]
else:
output_worker = get_worker_for_hash_sharding(part, workers)
mapping[part] = output_worker
output_workers.add(output_worker)
self.scheduler.set_restrictions({ts.key: {output_worker}})
parts_out = product(*(range(len(c)) for c in new))
mapping = self._pin_output_workers(id, parts_out, get_worker_for_hash_sharding)
output_workers = set(mapping.values())

return ArrayRechunkState(
id=id,
Expand All @@ -227,6 +277,22 @@ def _create_array_rechunk_state(
participating_workers=output_workers.copy(),
)

def _set_restriction(self, ts: TaskState, worker: str) -> None:
if "shuffle_original_restrictions" in ts.annotations:
# This may occur if multiple barriers share the same output task,
# e.g. in a hash join.
return
ts.annotations["shuffle_original_restrictions"] = ts.worker_restrictions.copy()
self.scheduler.set_restrictions({ts.key: {worker}})

def _unset_restriction(self, ts: TaskState) -> None:
# shuffle_original_restrictions is only set if the task was first scheduled
# on the wrong worker
if "shuffle_original_restrictions" not in ts.annotations:
return
original_restrictions = ts.annotations.pop("shuffle_original_restrictions")
self.scheduler.set_restrictions({ts.key: original_restrictions})

def remove_worker(self, scheduler: Scheduler, worker: str) -> None:
from time import time

Expand All @@ -247,7 +313,7 @@ def remove_worker(self, scheduler: Scheduler, worker: str) -> None:
for dt in barrier_task.dependents:
if worker not in dt.worker_restrictions:
continue
dt.worker_restrictions.clear()
self._unset_restriction(dt)
recs.update({dt.key: "waiting"})
# TODO: Do we need to handle other states?

Expand Down Expand Up @@ -293,34 +359,27 @@ def _clean_on_scheduler(self, id: ShuffleId) -> None:
with contextlib.suppress(KeyError):
del self.heartbeats[id]

barrier_task = self.scheduler.tasks[barrier_key(id)]
for dt in barrier_task.dependents:
self._unset_restriction(dt)

def restart(self, scheduler: Scheduler) -> None:
self.states.clear()
self.heartbeats.clear()
self.erred_shuffles.clear()


def get_partition_id(ts: TaskState) -> Any:
"""Get the output partition ID of this task state."""
try:
return ts.annotations["shuffle"]
except KeyError:
raise RuntimeError(
f"{ts} has lost its ``shuffle`` annotation. This may be caused by "
"unintended optimization during graph generation. "
"Please report this problem on GitHub and link it to "
"the tracking issue at https://github.com/dask/distributed/issues/7716."
)


def get_worker_for_range_sharding(
output_partition: int, workers: list[str], npartitions: int
npartitions: int, output_partition: int, workers: Sequence[str]
) -> str:
"""Get address of target worker for this output partition using range sharding"""
i = len(workers) * output_partition // npartitions
return workers[i]


def get_worker_for_hash_sharding(output_partition: NIndex, workers: list[str]) -> str:
def get_worker_for_hash_sharding(
output_partition: NIndex, workers: Sequence[str]
) -> str:
"""Get address of target worker for this output partition using hash sharding"""
i = hash(output_partition) % len(workers)
return workers[i]
Loading

0 comments on commit 21b70be

Please sign in to comment.