Skip to content

Commit

Permalink
Cluster wait (#6700)
Browse files Browse the repository at this point in the history
* Moving wait_for_worker logic to cluster, and having client call that if it can

* Adding test for cluster.wait_for_workers

* use try and except to catch case where cluster is none or wait_for_workers is not implemented

* linting

* This test has been removed on main branch, but for some reason git merge didn't remove it from mine

* Cluster has to use scheduler_info attribute instead of scheduler.identity

* lint

* reverting

* need to use cluster.scale when client.wait_for_workers is called while using a cluster

* need to use scheduler_info. Also, using cluster.scale to emulate behaviour of client.wait_for_workers

* using scheduler_info and dont need to call scale anymore

* lint

* adding gen_test decorator

* Don't think we need to scale at start of wait_for_workers

* self.scheduler_info does not update worker status from init to running, so need to request status again

* Use Status

* Scale was fixing the nworkers test because it forced the worker status to update. Now that worker status is checked we don't need this (and shouldn't have really included it anyway)

* Refactoring

* Fixing type information

* Experimenting with creating new comm

* Create separate comm in _start and use that to update scheduler_info

* Close new comm

* initialise scheduler_info_comm

* Don't allow n_workers to be zero for cluster wait_for_workers

* Adding return type

* Change try-catch to be an explicit if-else

* Check explicitly for cluster is none, as I think it's clearer

* linting

* use scheduler_comm instead of opening new comm

* remove update_scheduler_info method

* pre-commit changes

* Reduce number of works to see if it fixes github tests

* Changing test to make it work in python 3.8
  • Loading branch information
idorrington92 authored Apr 20, 2023
1 parent 1baa5ff commit 76bbfaf
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 1 deletion.
6 changes: 5 additions & 1 deletion distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,11 @@ def wait_for_workers(
raise ValueError(
f"`n_workers` must be a positive integer. Instead got {n_workers}."
)
return self.sync(self._wait_for_workers, n_workers, timeout=timeout)

if self.cluster is None:
return self.sync(self._wait_for_workers, n_workers, timeout=timeout)

return self.cluster.wait_for_workers(n_workers, timeout)

def _heartbeat(self):
# Don't send heartbeat if scheduler comm or cluster are already closed
Expand Down
55 changes: 55 additions & 0 deletions distributed/deploy/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from distributed.compatibility import PeriodicCallback
from distributed.core import Status
from distributed.deploy.adaptive import Adaptive
from distributed.metrics import time
from distributed.objects import SchedulerInfo
from distributed.utils import (
Log,
Expand All @@ -33,6 +34,9 @@
logger = logging.getLogger(__name__)


no_default = "__no_default__"


class Cluster(SyncMethodMixin):
"""Superclass for cluster objects
Expand Down Expand Up @@ -582,6 +586,57 @@ def __eq__(self, other):
def __hash__(self):
return id(self)

async def _wait_for_workers(self, n_workers=0, timeout=None):
self.scheduler_info = SchedulerInfo(await self.scheduler_comm.identity())
if timeout:
deadline = time() + parse_timedelta(timeout)
else:
deadline = None

def running_workers(info):
return len(
[
ws
for ws in info["workers"].values()
if ws["status"] == Status.running.name
]
)

while n_workers and running_workers(self.scheduler_info) < n_workers:
if deadline and time() > deadline:
raise TimeoutError(
"Only %d/%d workers arrived after %s"
% (running_workers(self.scheduler_info), n_workers, timeout)
)
await asyncio.sleep(0.1)

self.scheduler_info = SchedulerInfo(await self.scheduler_comm.identity())

def wait_for_workers(
self, n_workers: int | str = no_default, timeout: float | None = None
) -> None:
"""Blocking call to wait for n workers before continuing
Parameters
----------
n_workers : int
The number of workers
timeout : number, optional
Time in seconds after which to raise a
``dask.distributed.TimeoutError``
"""
if n_workers is no_default:
warnings.warn(
"Please specify the `n_workers` argument when using `Client.wait_for_workers`. Not specifying `n_workers` will no longer be supported in future versions.",
FutureWarning,
)
n_workers = 0
elif not isinstance(n_workers, int) or n_workers < 1:
raise ValueError(
f"`n_workers` must be a positive integer. Instead got {n_workers}."
)
return self.sync(self._wait_for_workers, n_workers, timeout=timeout)


def _exponential_backoff(
attempt: int, multiplier: float, exponential_base: float, max_interval: float
Expand Down
16 changes: 16 additions & 0 deletions distributed/deploy/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from tornado.ioloop import IOLoop

from distributed import LocalCluster, Status
from distributed.deploy.cluster import Cluster, _exponential_backoff
from distributed.utils_test import gen_test

Expand Down Expand Up @@ -38,6 +39,21 @@ async def test_logs_deprecated():
cluster.logs()


@gen_test()
async def test_cluster_wait_for_worker():
async with LocalCluster(n_workers=2, asynchronous=True) as cluster:
assert len(cluster.scheduler.workers) == 2
cluster.scale(4)
await cluster.wait_for_workers(4)
assert all(
[
worker["status"] == Status.running.name
for _, worker in cluster.scheduler_info["workers"].items()
]
)
assert len(cluster.scheduler.workers) == 4


@gen_test()
async def test_deprecated_loop_properties():
class ExampleCluster(Cluster):
Expand Down

0 comments on commit 76bbfaf

Please sign in to comment.