Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/ray/_private/client_mode_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,26 @@ def client_mode_should_convert():
return client_mode_enabled and _client_hook_enabled


def client_mode_wrap(func):
"""Decorator to wrap a function call in a task.

This is useful for functions where the goal isn't to delegate
module calls to the ray client equivalent, but to instead implement
ray client features that can be executed by tasks on the server side.
"""
from ray.util.client import ray

@wraps(func)
def wrapper(*args, **kwargs):
if client_mode_should_convert():
f = ray.remote(func)
ref = f.remote(*args, **kwargs)
return ray.get(ref)
return func(*args, **kwargs)

return wrapper


def client_mode_convert_function(func_cls, in_args, in_kwargs, **kwargs):
"""Runs a preregistered ray RemoteFunction through the ray client.

Expand Down
91 changes: 54 additions & 37 deletions python/ray/util/placement_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,27 @@
import ray
from ray._raylet import PlacementGroupID, ObjectRef
from ray.utils import hex_to_binary
from ray._private.client_mode_hook import client_mode_should_convert
from ray._private.client_mode_hook import client_mode_wrap

bundle_reservation_check = None
_bundle_reservation_check = None


# We need to import this method to use for ready API.
# But ray.remote is only available in runtime, and
# if we define this method inside ready method, this function is
# exported whenever ready is called, which can impact performance,
# https://github.com/ray-project/ray/issues/6240.
def _export_bundle_reservation_check_method_if_needed():
global bundle_reservation_check
if bundle_reservation_check:
return
def _ensure_bundle_reservation_check():
global _bundle_reservation_check
if _bundle_reservation_check is None:

@ray.remote(num_cpus=0, max_calls=0)
def bundle_reservation_check_func(placement_group):
return placement_group
@ray.remote(num_cpus=0, max_calls=0)
def bundle_reservation_check_func(placement_group):
return placement_group

bundle_reservation_check = bundle_reservation_check_func
_bundle_reservation_check = bundle_reservation_check_func
return _bundle_reservation_check


class PlacementGroup:
Expand Down Expand Up @@ -53,7 +55,7 @@ def ready(self) -> ObjectRef:
"""
self._fill_bundle_cache_if_needed()

_export_bundle_reservation_check_method_if_needed()
reservation_check_fn = _ensure_bundle_reservation_check()

assert len(self.bundle_cache) != 0, (
"ready() cannot be called on placement group object with a "
Expand All @@ -77,7 +79,7 @@ def ready(self) -> ObjectRef:
else:
resources[resource_name] = value

return bundle_reservation_check.options(
return reservation_check_fn.options(
num_cpus=num_cpus,
num_gpus=num_gpus,
placement_group=self,
Expand All @@ -91,11 +93,7 @@ def wait(self, timeout_seconds: Union[float, int]) -> bool:
Return:
True if the placement group is created. False otherwise.
"""
worker = ray.worker.global_worker
worker.check_connected()

return worker.core_worker.wait_placement_group_ready(
self.id, timeout_seconds)
return _call_placement_group_ready(self.id, timeout_seconds)

@property
def bundle_specs(self) -> List[Dict]:
Expand All @@ -121,29 +119,42 @@ def _get_none_zero_resource(self, bundle: List[Dict]):

def _fill_bundle_cache_if_needed(self):
if not self.bundle_cache:
# Since creating placement group is async, it is
# possible table is not ready yet. To avoid the
# problem, we should keep trying with timeout.
TIMEOUT_SECOND = 30
WAIT_INTERVAL = 0.05
timeout_cnt = 0
worker = ray.worker.global_worker
worker.check_connected()

while timeout_cnt < int(TIMEOUT_SECOND / WAIT_INTERVAL):
pg_info = ray.state.state.placement_group_table(self.id)
if pg_info:
self.bundle_cache = list(pg_info["bundles"].values())
return
time.sleep(WAIT_INTERVAL)
timeout_cnt += 1

raise RuntimeError(
"Couldn't get the bundle information of placement group id "
f"{self.id} in {TIMEOUT_SECOND} seconds. It is likely "
"because GCS server is too busy.")
self.bundle_cache = _get_bundle_cache(self.id)


@client_mode_wrap
def _call_placement_group_ready(id, timeout_seconds):
worker = ray.worker.global_worker
worker.check_connected()

return worker.core_worker.wait_placement_group_ready(id, timeout_seconds)


@client_mode_wrap
def _get_bundle_cache(id):
# Since creating placement group is async, it is
# possible table is not ready yet. To avoid the
# problem, we should keep trying with timeout.
TIMEOUT_SECOND = 30
WAIT_INTERVAL = 0.05
timeout_cnt = 0
worker = ray.worker.global_worker
worker.check_connected()

while timeout_cnt < int(TIMEOUT_SECOND / WAIT_INTERVAL):
pg_info = ray.state.state.placement_group_table(id)
if pg_info:
return list(pg_info["bundles"].values())
time.sleep(WAIT_INTERVAL)
timeout_cnt += 1

raise RuntimeError(
"Couldn't get the bundle information of placement group id "
f"{id} in {TIMEOUT_SECOND} seconds. It is likely "
"because GCS server is too busy.")


@client_mode_wrap
def placement_group(bundles: List[Dict[str, float]],
strategy: str = "PACK",
name: str = "",
Expand Down Expand Up @@ -199,6 +210,7 @@ def placement_group(bundles: List[Dict[str, float]],
return PlacementGroup(placement_group_id)


@client_mode_wrap
def remove_placement_group(placement_group: PlacementGroup):
"""Asynchronously remove placement group.

Expand All @@ -212,6 +224,7 @@ def remove_placement_group(placement_group: PlacementGroup):
worker.core_worker.remove_placement_group(placement_group.id)


@client_mode_wrap
def get_placement_group(placement_group_name: str):
"""Get a placement group object with a global name.

Expand All @@ -235,6 +248,7 @@ def get_placement_group(placement_group_name: str):
hex_to_binary(placement_group_info["placement_group_id"])))


@client_mode_wrap
def placement_group_table(placement_group: PlacementGroup = None) -> list:
"""Get the state of the placement group from GCS.

Expand Down Expand Up @@ -277,6 +291,9 @@ def get_current_placement_group() -> Optional[PlacementGroup]:
None if the current task or actor wasn't
created with any placement group.
"""
if client_mode_should_convert():
# Client mode is only a driver.
return None
worker = ray.worker.global_worker
worker.check_connected()
pg_id = worker.placement_group_id
Expand Down