Skip to content

Commit

Permalink
Delay awaiting async SchedulerPlugin.{add|remove}_worker hooks in o…
Browse files Browse the repository at this point in the history
…rder to immediately execute all sync ones (#7799)
  • Loading branch information
hendrikmakait authored Apr 27, 2023
1 parent 57ae3e7 commit 839f4a9
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 5 deletions.
28 changes: 26 additions & 2 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,36 @@ def transition(
"""

def add_worker(self, scheduler: Scheduler, worker: str) -> None | Awaitable[None]:
"""Run when a new worker enters the cluster"""
"""Run when a new worker enters the cluster
If this method is synchronous, it is immediately and synchronously executed
without ``Scheduler.add_worker`` ever yielding to the event loop.
If it is asynchronous, it will be awaited after all synchronous
``SchedulerPlugin.add_worker`` hooks have executed.
.. warning::
There are no guarantees about the execution order between individual
``SchedulerPlugin.add_worker`` hooks and the ordering may be subject
to change without deprecation cycle.
"""

def remove_worker(
self, scheduler: Scheduler, worker: str
) -> None | Awaitable[None]:
"""Run when a worker leaves the cluster"""
"""Run when a worker leaves the cluster
If this method is synchronous, it is immediately and synchronously executed
without ``Scheduler.remove_worker`` ever yielding to the event loop.
If it is asynchronous, it will be awaited after all synchronous
``SchedulerPlugin.remove_worker`` hooks have executed.
.. warning::
There are no guarantees about the execution order between individual
``SchedulerPlugin.remove_worker`` hooks and the ordering may be subject
to change without deprecation cycle.
"""

def add_client(self, scheduler: Scheduler, client: str) -> None:
"""Run when a new client connects"""
Expand Down
158 changes: 158 additions & 0 deletions distributed/diagnostics/tests/test_scheduler_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import asyncio

import pytest

from distributed import Scheduler, SchedulerPlugin, Worker, get_worker
Expand Down Expand Up @@ -101,6 +103,7 @@ async def remove_worker(self, worker, scheduler):
async with Worker(s.address) as b:
pass

assert len(events) == 4
assert set(events) == {
("add_worker", a.address),
("add_worker", b.address),
Expand All @@ -115,6 +118,161 @@ async def remove_worker(self, worker, scheduler):
assert events == []


@gen_cluster(nthreads=[])
async def test_async_and_sync_add_remove_worker(s):
events = []

class MyAsyncPlugin(SchedulerPlugin):
name: str

def __init__(self, name: str) -> None:
super().__init__()
self.name = name
self.in_remove_worker = asyncio.Event()
self.block_remove_worker = asyncio.Event()

async def add_worker(self, scheduler, worker):
assert scheduler is s

await asyncio.sleep(0)
events.append((self.name, "add_worker", worker))

async def remove_worker(self, scheduler, worker):
assert scheduler is s
self.in_remove_worker.set()
await self.block_remove_worker.wait()
events.append((self.name, "remove_worker", worker))

class MySyncPlugin(SchedulerPlugin):
name: str

def __init__(self, name: str):
self.name = name

def add_worker(self, worker, scheduler):
assert scheduler is s
events.append((self.name, "add_worker", worker))

def remove_worker(self, worker, scheduler):
assert scheduler is s
events.append((self.name, "remove_worker", worker))

sync_plugin_before = MySyncPlugin(name="before")
s.add_plugin(sync_plugin_before)

async_plugin = MyAsyncPlugin(name="async")
s.add_plugin(async_plugin)

sync_plugin_after = MySyncPlugin(name="after")
s.add_plugin(sync_plugin_after)
assert events == []

async with Worker(s.address) as a:
assert len(events) == 3

# No ordering guarantees between these
assert set(events[:2]) == {
("before", "add_worker", a.address),
("after", "add_worker", a.address),
}
# Async add_worker happens after sync add_worker,
# but before code within the context manager is run
assert events[2] == ("async", "add_worker", a.address)
events.clear()

async with Worker(s.address) as b:
assert len(events) == 3
# No ordering guarantees between these
assert set(events[:2]) == {
("before", "add_worker", b.address),
("after", "add_worker", b.address),
}
# Async add_worker happens after sync add_worker,
# but before code within the context manager is run
assert events[2] == ("async", "add_worker", b.address)
events.clear()

await async_plugin.in_remove_worker.wait()
assert len(events) == 2
# No ordering guarantees between these
assert set(events) == {
("before", "remove_worker", b.address),
("after", "remove_worker", b.address),
}
events.clear()

assert len(events) == 2
# No ordering guarantees between these
assert set(events) == {
("before", "remove_worker", a.address),
("after", "remove_worker", a.address),
}
events.clear()

async_plugin.block_remove_worker.set()
await asyncio.sleep(0)
# No ordering guarantees between these
assert len(events) == 2
assert set(events) == {
("async", "remove_worker", a.address),
("async", "remove_worker", b.address),
}
events.clear()


@gen_cluster(nthreads=[])
async def test_failing_async_add_remove_worker(s):
class MyAsyncPlugin(SchedulerPlugin):
name = "MyAsyncPlugin"

def __init__(self) -> None:
super().__init__()

async def add_worker(self, scheduler, worker):
assert scheduler is s
await asyncio.sleep(0)
raise RuntimeError("Async add_worker failed")

async def remove_worker(self, scheduler, worker):
assert scheduler is s
await asyncio.sleep(0)
raise RuntimeError("Async remove_worker failed")

plugin = MyAsyncPlugin()
s.add_plugin(plugin)
with captured_logger("distributed.scheduler") as logger:
async with Worker(s.address):
while "add_worker failed" not in logger.getvalue():
await asyncio.sleep(0)
pass
while "remove_worker failed" not in logger.getvalue():
await asyncio.sleep(0)


@gen_cluster(nthreads=[])
async def test_failing_sync_add_remove_worker(s):
class MySyncPlugin(SchedulerPlugin):
name = "MySyncPlugin"

def __init__(self) -> None:
super().__init__()

def add_worker(self, scheduler, worker):
assert scheduler is s
raise RuntimeError("Async add_worker failed")

def remove_worker(self, scheduler, worker):
assert scheduler is s
raise RuntimeError("Async remove_worker failed")

plugin = MySyncPlugin()
s.add_plugin(plugin)
with captured_logger("distributed.scheduler") as logger:
async with Worker(s.address):
assert "add_worker failed" in logger.getvalue()
assert "remove_worker failed" in logger.getvalue()


@gen_test()
async def test_lifecycle():
class LifeCycle(SchedulerPlugin):
Expand Down
16 changes: 14 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4207,14 +4207,20 @@ async def add_worker(

self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop)

awaitables = []
for plugin in list(self.plugins.values()):
try:
result = plugin.add_worker(scheduler=self, worker=address)
if result is not None and inspect.isawaitable(result):
await result
awaitables.append(result)
except Exception as e:
logger.exception(e)

plugin_msgs = await asyncio.gather(*awaitables, return_exceptions=True)
plugins_exceptions = [msg for msg in plugin_msgs if isinstance(msg, Exception)]
for exc in plugins_exceptions:
logger.exception(exc, exc_info=exc)

if ws.status == Status.running:
self.transitions(
self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id
Expand Down Expand Up @@ -4982,14 +4988,20 @@ async def remove_worker(

self.transitions(recommendations, stimulus_id=stimulus_id)

awaitables = []
for plugin in list(self.plugins.values()):
try:
result = plugin.remove_worker(scheduler=self, worker=address)
if inspect.isawaitable(result):
await result
awaitables.append(result)
except Exception as e:
logger.exception(e)

plugin_msgs = await asyncio.gather(*awaitables, return_exceptions=True)
plugins_exceptions = [msg for msg in plugin_msgs if isinstance(msg, Exception)]
for exc in plugins_exceptions:
logger.exception(exc, exc_info=exc)

if not self.workers:
logger.info("Lost all workers")

Expand Down
95 changes: 94 additions & 1 deletion distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import shutil
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from contextlib import AsyncExitStack
from itertools import count
from typing import Any, Mapping
from unittest import mock
Expand All @@ -22,11 +23,15 @@
from dask.utils import stringify

from distributed.client import Client
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.scheduler import Scheduler
from distributed.scheduler import TaskState as SchedulerTaskState
from distributed.shuffle._arrow import serialize_table
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._scheduler_extension import get_worker_for_range_sharding
from distributed.shuffle._scheduler_extension import (
ShuffleSchedulerExtension,
get_worker_for_range_sharding,
)
from distributed.shuffle._shuffle import P2PShuffleLayer, ShuffleId, barrier_key
from distributed.shuffle._worker_extension import (
DataFrameShuffleRun,
Expand Down Expand Up @@ -1733,3 +1738,91 @@ async def test_replace_stale_shuffle(c, s, a, b):
await clean_worker(a)
await clean_worker(b)
await clean_scheduler(s)


class BlockedRemoveWorkerSchedulerPlugin(SchedulerPlugin):
def __init__(self, scheduler: Scheduler, *args: Any, **kwargs: Any):
self.scheduler = scheduler
super().__init__(*args, **kwargs)
self.in_remove_worker = asyncio.Event()
self.block_remove_worker = asyncio.Event()
self.scheduler.add_plugin(self)

async def remove_worker(self, *args: Any, **kwargs: Any) -> None:
self.in_remove_worker.set()
await self.block_remove_worker.wait()


class BlockedBarrierSchedulerExtension(ShuffleSchedulerExtension):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.in_barrier = asyncio.Event()
self.block_barrier = asyncio.Event()

async def barrier(self, *args: Any, **kwargs: Any) -> None:
self.in_barrier.set()
await self.block_barrier.wait()
await super().barrier(*args, **kwargs)


@gen_cluster(
client=True,
nthreads=[],
scheduler_kwargs={
"extensions": {
"blocking": BlockedRemoveWorkerSchedulerPlugin,
"shuffle": BlockedBarrierSchedulerExtension,
}
},
)
async def test_closed_worker_returns_before_barrier(c, s):
async with AsyncExitStack() as stack:
workers = [await stack.enter_async_context(Worker(s.address)) for _ in range(2)]

df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-10",
dtypes={"x": float, "y": float},
freq="10 s",
)
out = dd.shuffle.shuffle(df, "x", shuffle="p2p")
out = out.persist()
shuffle_id = await wait_until_new_shuffle_is_initialized(s)
key = barrier_key(shuffle_id)
await wait_for_state(key, "processing", s)
scheduler_extension = s.extensions["shuffle"]
await scheduler_extension.in_barrier.wait()

flushes = [
w.extensions["shuffle"].shuffles[shuffle_id]._flush_comm() for w in workers
]
await asyncio.gather(*flushes)

ts = s.tasks[key]
to_close = None
for worker in workers:
if ts.processing_on.address != worker.address:
to_close = worker
break
assert to_close
closed_port = to_close.port
await to_close.close()

blocking_extension = s.extensions["blocking"]
assert blocking_extension.in_remove_worker.is_set()

workers.append(
await stack.enter_async_context(Worker(s.address, port=closed_port))
)

scheduler_extension.block_barrier.set()

with pytest.raises(
RuntimeError, match=f"shuffle_barrier failed .* {shuffle_id}"
):
await c.compute(out.x.size)

blocking_extension.block_remove_worker.set()
await c.close()
await asyncio.gather(*[clean_worker(w) for w in workers])
await clean_scheduler(s)

0 comments on commit 839f4a9

Please sign in to comment.