diff --git a/python/ray/_private/client_mode_hook.py b/python/ray/_private/client_mode_hook.py index 74682f1cfa9d..7d34d96d96a3 100644 --- a/python/ray/_private/client_mode_hook.py +++ b/python/ray/_private/client_mode_hook.py @@ -54,6 +54,26 @@ def client_mode_should_convert(): return client_mode_enabled and _client_hook_enabled +def client_mode_wrap(func): + """Decorator to wrap a function call in a task. + + This is useful for functions where the goal isn't to delegate + module calls to the ray client equivalent, but to instead implement + ray client features that can be executed by tasks on the server side. + """ + from ray.util.client import ray + + @wraps(func) + def wrapper(*args, **kwargs): + if client_mode_should_convert(): + f = ray.remote(func) + ref = f.remote(*args, **kwargs) + return ray.get(ref) + return func(*args, **kwargs) + + return wrapper + + def client_mode_convert_function(func_cls, in_args, in_kwargs, **kwargs): """Runs a preregistered ray RemoteFunction through the ray client. diff --git a/python/ray/util/placement_group.py b/python/ray/util/placement_group.py index c723f77d3ecc..ffd6aae6ba28 100644 --- a/python/ray/util/placement_group.py +++ b/python/ray/util/placement_group.py @@ -5,8 +5,10 @@ import ray from ray._raylet import PlacementGroupID, ObjectRef from ray.utils import hex_to_binary +from ray._private.client_mode_hook import client_mode_should_convert +from ray._private.client_mode_hook import client_mode_wrap -bundle_reservation_check = None +_bundle_reservation_check = None # We need to import this method to use for ready API. @@ -14,16 +16,16 @@ # if we define this method inside ready method, this function is # exported whenever ready is called, which can impact performance, # https://github.com/ray-project/ray/issues/6240. -def _export_bundle_reservation_check_method_if_needed(): - global bundle_reservation_check - if bundle_reservation_check: - return +def _ensure_bundle_reservation_check(): + global _bundle_reservation_check + if _bundle_reservation_check is None: - @ray.remote(num_cpus=0, max_calls=0) - def bundle_reservation_check_func(placement_group): - return placement_group + @ray.remote(num_cpus=0, max_calls=0) + def bundle_reservation_check_func(placement_group): + return placement_group - bundle_reservation_check = bundle_reservation_check_func + _bundle_reservation_check = bundle_reservation_check_func + return _bundle_reservation_check class PlacementGroup: @@ -53,7 +55,7 @@ def ready(self) -> ObjectRef: """ self._fill_bundle_cache_if_needed() - _export_bundle_reservation_check_method_if_needed() + reservation_check_fn = _ensure_bundle_reservation_check() assert len(self.bundle_cache) != 0, ( "ready() cannot be called on placement group object with a " @@ -77,7 +79,7 @@ def ready(self) -> ObjectRef: else: resources[resource_name] = value - return bundle_reservation_check.options( + return reservation_check_fn.options( num_cpus=num_cpus, num_gpus=num_gpus, placement_group=self, @@ -91,11 +93,7 @@ def wait(self, timeout_seconds: Union[float, int]) -> bool: Return: True if the placement group is created. False otherwise. """ - worker = ray.worker.global_worker - worker.check_connected() - - return worker.core_worker.wait_placement_group_ready( - self.id, timeout_seconds) + return _call_placement_group_ready(self.id, timeout_seconds) @property def bundle_specs(self) -> List[Dict]: @@ -121,29 +119,42 @@ def _get_none_zero_resource(self, bundle: List[Dict]): def _fill_bundle_cache_if_needed(self): if not self.bundle_cache: - # Since creating placement group is async, it is - # possible table is not ready yet. To avoid the - # problem, we should keep trying with timeout. - TIMEOUT_SECOND = 30 - WAIT_INTERVAL = 0.05 - timeout_cnt = 0 - worker = ray.worker.global_worker - worker.check_connected() - - while timeout_cnt < int(TIMEOUT_SECOND / WAIT_INTERVAL): - pg_info = ray.state.state.placement_group_table(self.id) - if pg_info: - self.bundle_cache = list(pg_info["bundles"].values()) - return - time.sleep(WAIT_INTERVAL) - timeout_cnt += 1 - - raise RuntimeError( - "Couldn't get the bundle information of placement group id " - f"{self.id} in {TIMEOUT_SECOND} seconds. It is likely " - "because GCS server is too busy.") + self.bundle_cache = _get_bundle_cache(self.id) + +@client_mode_wrap +def _call_placement_group_ready(id, timeout_seconds): + worker = ray.worker.global_worker + worker.check_connected() + return worker.core_worker.wait_placement_group_ready(id, timeout_seconds) + + +@client_mode_wrap +def _get_bundle_cache(id): + # Since creating placement group is async, it is + # possible table is not ready yet. To avoid the + # problem, we should keep trying with timeout. + TIMEOUT_SECOND = 30 + WAIT_INTERVAL = 0.05 + timeout_cnt = 0 + worker = ray.worker.global_worker + worker.check_connected() + + while timeout_cnt < int(TIMEOUT_SECOND / WAIT_INTERVAL): + pg_info = ray.state.state.placement_group_table(id) + if pg_info: + return list(pg_info["bundles"].values()) + time.sleep(WAIT_INTERVAL) + timeout_cnt += 1 + + raise RuntimeError( + "Couldn't get the bundle information of placement group id " + f"{id} in {TIMEOUT_SECOND} seconds. It is likely " + "because GCS server is too busy.") + + +@client_mode_wrap def placement_group(bundles: List[Dict[str, float]], strategy: str = "PACK", name: str = "", @@ -199,6 +210,7 @@ def placement_group(bundles: List[Dict[str, float]], return PlacementGroup(placement_group_id) +@client_mode_wrap def remove_placement_group(placement_group: PlacementGroup): """Asynchronously remove placement group. @@ -212,6 +224,7 @@ def remove_placement_group(placement_group: PlacementGroup): worker.core_worker.remove_placement_group(placement_group.id) +@client_mode_wrap def get_placement_group(placement_group_name: str): """Get a placement group object with a global name. @@ -235,6 +248,7 @@ def get_placement_group(placement_group_name: str): hex_to_binary(placement_group_info["placement_group_id"]))) +@client_mode_wrap def placement_group_table(placement_group: PlacementGroup = None) -> list: """Get the state of the placement group from GCS. @@ -277,6 +291,9 @@ def get_current_placement_group() -> Optional[PlacementGroup]: None if the current task or actor wasn't created with any placement group. """ + if client_mode_should_convert(): + # Client mode is only a driver. + return None worker = ray.worker.global_worker worker.check_connected() pg_id = worker.placement_group_id