diff --git a/distributed/client.py b/distributed/client.py index 264fe232a93..8c897e965ad 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1458,7 +1458,11 @@ def wait_for_workers( raise ValueError( f"`n_workers` must be a positive integer. Instead got {n_workers}." ) - return self.sync(self._wait_for_workers, n_workers, timeout=timeout) + + if self.cluster is None: + return self.sync(self._wait_for_workers, n_workers, timeout=timeout) + + return self.cluster.wait_for_workers(n_workers, timeout) def _heartbeat(self): # Don't send heartbeat if scheduler comm or cluster are already closed diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index bbba149d6f0..c9170c1cfa4 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -19,6 +19,7 @@ from distributed.compatibility import PeriodicCallback from distributed.core import Status from distributed.deploy.adaptive import Adaptive +from distributed.metrics import time from distributed.objects import SchedulerInfo from distributed.utils import ( Log, @@ -33,6 +34,9 @@ logger = logging.getLogger(__name__) +no_default = "__no_default__" + + class Cluster(SyncMethodMixin): """Superclass for cluster objects @@ -582,6 +586,57 @@ def __eq__(self, other): def __hash__(self): return id(self) + async def _wait_for_workers(self, n_workers=0, timeout=None): + self.scheduler_info = SchedulerInfo(await self.scheduler_comm.identity()) + if timeout: + deadline = time() + parse_timedelta(timeout) + else: + deadline = None + + def running_workers(info): + return len( + [ + ws + for ws in info["workers"].values() + if ws["status"] == Status.running.name + ] + ) + + while n_workers and running_workers(self.scheduler_info) < n_workers: + if deadline and time() > deadline: + raise TimeoutError( + "Only %d/%d workers arrived after %s" + % (running_workers(self.scheduler_info), n_workers, timeout) + ) + await asyncio.sleep(0.1) + + self.scheduler_info = SchedulerInfo(await self.scheduler_comm.identity()) + + def wait_for_workers( + self, n_workers: int | str = no_default, timeout: float | None = None + ) -> None: + """Blocking call to wait for n workers before continuing + + Parameters + ---------- + n_workers : int + The number of workers + timeout : number, optional + Time in seconds after which to raise a + ``dask.distributed.TimeoutError`` + """ + if n_workers is no_default: + warnings.warn( + "Please specify the `n_workers` argument when using `Client.wait_for_workers`. Not specifying `n_workers` will no longer be supported in future versions.", + FutureWarning, + ) + n_workers = 0 + elif not isinstance(n_workers, int) or n_workers < 1: + raise ValueError( + f"`n_workers` must be a positive integer. Instead got {n_workers}." + ) + return self.sync(self._wait_for_workers, n_workers, timeout=timeout) + def _exponential_backoff( attempt: int, multiplier: float, exponential_base: float, max_interval: float diff --git a/distributed/deploy/tests/test_cluster.py b/distributed/deploy/tests/test_cluster.py index cb0b189ea2a..5533916072e 100644 --- a/distributed/deploy/tests/test_cluster.py +++ b/distributed/deploy/tests/test_cluster.py @@ -3,6 +3,7 @@ import pytest from tornado.ioloop import IOLoop +from distributed import LocalCluster, Status from distributed.deploy.cluster import Cluster, _exponential_backoff from distributed.utils_test import gen_test @@ -38,6 +39,21 @@ async def test_logs_deprecated(): cluster.logs() +@gen_test() +async def test_cluster_wait_for_worker(): + async with LocalCluster(n_workers=2, asynchronous=True) as cluster: + assert len(cluster.scheduler.workers) == 2 + cluster.scale(4) + await cluster.wait_for_workers(4) + assert all( + [ + worker["status"] == Status.running.name + for _, worker in cluster.scheduler_info["workers"].items() + ] + ) + assert len(cluster.scheduler.workers) == 4 + + @gen_test() async def test_deprecated_loop_properties(): class ExampleCluster(Cluster):