Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions doc/source/serve/advanced-guides/replica-ranks.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
This API is experimental and may change between Ray minor versions.
:::

Replica ranks provide a unique identifier for **each replica within a deployment**. Each replica receives a **rank (an integer from 0 to N-1)** and **a world size (the total number of replicas)**.
Replica ranks provide a unique identifier for **each replica within a deployment**. Each replica receives a **`ReplicaRank` object** containing rank information and **a world size (the total number of replicas)**. The rank object includes a global rank (an integer from 0 to N-1), a node rank, and a local rank on the node.

## Access replica ranks

Expand All @@ -28,9 +28,29 @@ The following example shows how to access replica rank information:

The [`ReplicaContext`](../api/doc/ray.serve.context.ReplicaContext.rst) provides two key fields:

- `rank`: An integer from 0 to N-1 representing this replica's unique identifier.
- `rank`: A [`ReplicaRank`](../api/doc/ray.serve.schema.ReplicaRank.rst) object containing rank information for this replica. Access the integer rank value with `.rank`.
- `world_size`: The target number of replicas for the deployment.

The `ReplicaRank` object contains three fields:
- `rank`: The global rank (an integer from 0 to N-1) representing this replica's unique identifier across all nodes.
- `node_rank`: The rank of the node this replica runs on (an integer from 0 to M-1 where M is the number of nodes).
- `local_rank`: The rank of this replica on its node (an integer from 0 to K-1 where K is the number of replicas on this node).

:::{note}
**Accessing rank values:**

To use the rank in your code, access the `.rank` attribute to get the integer value:

```python
context = serve.get_replica_context()
my_rank = context.rank.rank # Get the integer rank value
my_node_rank = context.rank.node_rank # Get the node rank
my_local_rank = context.rank.local_rank # Get the local rank on this node
```

Most use cases only need the global `rank` value. The `node_rank` and `local_rank` are useful for advanced scenarios such as coordinating replicas on the same node.
:::

## Handle rank changes with reconfigure

When a replica's rank changes (such as during downscaling), Ray Serve can automatically call the `reconfigure` method on your deployment class to notify it of the new rank. This allows you to update replica-specific state when ranks change.
Expand All @@ -54,15 +74,15 @@ The following example shows how to implement `reconfigure` to handle rank change
Ray Serve automatically calls your `reconfigure` method in the following situations:

1. **At replica startup:** When a replica starts, if your deployment has both a `reconfigure` method and a `user_config`, Ray Serve calls `reconfigure` after running `__init__`. This lets you initialize rank-aware state without duplicating code between `__init__` and `reconfigure`.
2. **When you update user_config:** When you redeploy with a new `user_config`, Ray Serve calls `reconfigure` on all running replicas. If your `reconfigure` method includes `rank` as a parameter, Ray Serve passes both the new `user_config` and the current rank.
3. **When a replica's rank changes:** During downscaling, ranks may be reassigned to maintain contiguity (0 to N-1). If your `reconfigure` method includes `rank` as a parameter and your deployment has a `user_config`, Ray Serve calls `reconfigure` with the existing `user_config` and the new rank.
2. **When you update user_config:** When you redeploy with a new `user_config`, Ray Serve calls `reconfigure` on all running replicas. If your `reconfigure` method includes `rank` as a parameter, Ray Serve passes both the new `user_config` and the current rank as a `ReplicaRank` object.
3. **When a replica's rank changes:** During downscaling, ranks may be reassigned to maintain contiguity (0 to N-1). If your `reconfigure` method includes `rank` as a parameter and your deployment has a `user_config`, Ray Serve calls `reconfigure` with the existing `user_config` and the new rank as a `ReplicaRank` object.

:::{note}
**Requirements to receive rank updates:**

To get rank changes through `reconfigure`, your deployment needs:
- A class-based deployment (function deployments don't support `reconfigure`)
- A `reconfigure` method with `rank` as a parameter: `def reconfigure(self, user_config, rank: int)`
- A `reconfigure` method with `rank` as a parameter: `def reconfigure(self, user_config, rank: ReplicaRank)`
- A `user_config` in your deployment (even if it's just an empty dict: `user_config={}`)

Without a `user_config`, Ray Serve won't call `reconfigure` for rank changes.
Expand Down
12 changes: 7 additions & 5 deletions doc/source/serve/doc_code/replica_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
@serve.deployment(num_replicas=4)
class ModelShard:
def __call__(self):
context = serve.get_replica_context()
return {
"rank": serve.get_replica_context().rank,
"world_size": serve.get_replica_context().world_size,
"rank": context.rank.rank, # Access the integer rank value
"world_size": context.world_size,
}


Expand All @@ -17,20 +18,21 @@ def __call__(self):
# __reconfigure_rank_start__
from typing import Any
from ray import serve
from ray.serve.schema import ReplicaRank


@serve.deployment(num_replicas=4, user_config={"name": "model_v1"})
class RankAwareModel:
def __init__(self):
context = serve.get_replica_context()
self.rank = context.rank
self.rank = context.rank.rank # Extract integer rank value
self.world_size = context.world_size
self.model_name = None
print(f"Replica rank: {self.rank}/{self.world_size}")

async def reconfigure(self, user_config: Any, rank: int):
async def reconfigure(self, user_config: Any, rank: ReplicaRank):
"""Called when user_config or rank changes."""
self.rank = rank
self.rank = rank.rank # Extract integer rank value from ReplicaRank object
self.world_size = serve.get_replica_context().world_size
self.model_name = user_config.get("name")
print(f"Reconfigured: rank={self.rank}, model={self.model_name}")
Expand Down
37 changes: 16 additions & 21 deletions python/ray/serve/_private/deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def __init__(
self._docs_path: Optional[str] = None
self._route_patterns: Optional[List[str]] = None
# Rank assigned to the replica.
self._rank: Optional[int] = None
self._rank: Optional[ReplicaRank] = None
# Populated in `on_scheduled` or `recover`.
self._actor_handle: ActorHandle = None
self._placement_group: PlacementGroup = None
Expand Down Expand Up @@ -290,7 +290,7 @@ def deployment_name(self) -> str:
return self._deployment_id.name

@property
def rank(self) -> Optional[int]:
def rank(self) -> Optional[ReplicaRank]:
return self._rank

@property
Expand Down Expand Up @@ -442,7 +442,7 @@ def initialization_latency_s(self) -> Optional[float]:
return self._initialization_latency_s

def start(
self, deployment_info: DeploymentInfo, rank: int
self, deployment_info: DeploymentInfo, rank: ReplicaRank
) -> ReplicaSchedulingRequest:
"""Start the current DeploymentReplica instance.

Expand Down Expand Up @@ -609,11 +609,7 @@ def _format_user_config(self, user_config: Any):
temp = msgpack_deserialize(temp)
return temp

def reconfigure(
self,
version: DeploymentVersion,
rank: int,
) -> bool:
def reconfigure(self, version: DeploymentVersion, rank: ReplicaRank) -> bool:
"""
Update replica version. Also, updates the deployment config on the actor
behind this DeploymentReplica instance if necessary.
Expand Down Expand Up @@ -1170,7 +1166,7 @@ def initialization_latency_s(self) -> Optional[float]:
return self._actor.initialization_latency_s

def start(
self, deployment_info: DeploymentInfo, rank: int
self, deployment_info: DeploymentInfo, rank: ReplicaRank
) -> ReplicaSchedulingRequest:
"""
Start a new actor for current DeploymentReplica instance.
Expand All @@ -1184,7 +1180,7 @@ def start(
def reconfigure(
self,
version: DeploymentVersion,
rank: int,
rank: ReplicaRank,
) -> bool:
"""
Update replica version. Also, updates the deployment config on the actor
Expand All @@ -1211,7 +1207,7 @@ def recover(self) -> bool:
return True

@property
def rank(self) -> Optional[int]:
def rank(self) -> Optional[ReplicaRank]:
"""Get the rank assigned to the replica."""
return self._actor.rank

Expand Down Expand Up @@ -1695,9 +1691,11 @@ def _assign_rank_impl():
# Assign global rank
rank = self._replica_rank_manager.assign_rank(replica_id)

return ReplicaRank(rank=rank)
return ReplicaRank(rank=rank, node_rank=-1, local_rank=-1)

return self._execute_with_error_handling(_assign_rank_impl, ReplicaRank(rank=0))
return self._execute_with_error_handling(
_assign_rank_impl, ReplicaRank(rank=0, node_rank=-1, local_rank=-1)
)

def release_rank(self, replica_id: str) -> None:
"""Release rank for a replica.
Expand Down Expand Up @@ -1776,10 +1774,10 @@ def _get_replica_rank_impl():
raise RuntimeError(f"Rank for {replica_id} not assigned")

global_rank = self._replica_rank_manager.get_rank(replica_id)
return ReplicaRank(rank=global_rank)
return ReplicaRank(rank=global_rank, node_rank=-1, local_rank=-1)

return self._execute_with_error_handling(
_get_replica_rank_impl, ReplicaRank(rank=0)
_get_replica_rank_impl, ReplicaRank(rank=0, node_rank=-1, local_rank=-1)
)

def check_rank_consistency_and_reassign_minimally(
Expand Down Expand Up @@ -2547,7 +2545,7 @@ def scale_deployment_replicas(
self._target_state.version,
)
scheduling_request = new_deployment_replica.start(
self._target_state.info, rank=assigned_rank.rank
self._target_state.info, rank=assigned_rank
)

upscale.append(scheduling_request)
Expand Down Expand Up @@ -2665,10 +2663,7 @@ def _check_startup_replicas(
# data structure with RUNNING state.
# Recover rank from the replica actor during controller restart
replica_id = replica.replica_id.unique_id
recovered_rank = replica.rank
self._rank_manager.recover_rank(
replica_id, ReplicaRank(rank=recovered_rank)
)
self._rank_manager.recover_rank(replica_id, replica.rank)
# This replica should be now be added to handle's replica
# set.
self._replicas.add(ReplicaState.RUNNING, replica)
Expand Down Expand Up @@ -2951,7 +2946,7 @@ def _reconfigure_replicas_with_new_ranks(
# World size is calculated automatically from deployment config
_ = replica.reconfigure(
self._target_state.version,
rank=new_rank.rank,
rank=new_rank,
)
updated_count += 1

Expand Down
16 changes: 8 additions & 8 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
RayServeException,
)
from ray.serve.handle import DeploymentHandle
from ray.serve.schema import EncodingType, LoggingConfig
from ray.serve.schema import EncodingType, LoggingConfig, ReplicaRank

logger = logging.getLogger(SERVE_LOGGER_NAME)

Expand All @@ -129,7 +129,7 @@
Optional[str],
int,
int,
int, # rank
ReplicaRank, # rank
Optional[List[str]], # route_patterns
]

Expand Down Expand Up @@ -510,7 +510,7 @@ def __init__(
version: DeploymentVersion,
ingress: bool,
route_prefix: str,
rank: int,
rank: ReplicaRank,
):
self._version = version
self._replica_id = replica_id
Expand Down Expand Up @@ -610,7 +610,7 @@ def get_dynamically_created_handles(self) -> Set[DeploymentID]:
return self._dynamically_created_handles

def _set_internal_replica_context(
self, *, servable_object: Callable = None, rank: int = None
self, *, servable_object: Callable = None, rank: ReplicaRank = None
):
# Calculate world_size from deployment config instead of storing it
world_size = self._deployment_config.num_replicas
Expand Down Expand Up @@ -961,7 +961,7 @@ async def initialize(self, deployment_config: DeploymentConfig):
async def reconfigure(
self,
deployment_config: DeploymentConfig,
rank: int,
rank: ReplicaRank,
route_prefix: Optional[str] = None,
):
try:
Expand Down Expand Up @@ -1186,7 +1186,7 @@ async def __init__(
version: DeploymentVersion,
ingress: bool,
route_prefix: str,
rank: int,
rank: ReplicaRank,
):
deployment_config = DeploymentConfig.from_proto_bytes(
deployment_config_proto_bytes
Expand Down Expand Up @@ -1305,7 +1305,7 @@ async def record_routing_stats(self) -> Dict[str, Any]:
return await self._replica_impl.record_routing_stats()

async def reconfigure(
self, deployment_config, rank: int, route_prefix: Optional[str] = None
self, deployment_config, rank: ReplicaRank, route_prefix: Optional[str] = None
) -> ReplicaMetadata:
await self._replica_impl.reconfigure(deployment_config, rank, route_prefix)
return self._replica_impl.get_metadata()
Expand Down Expand Up @@ -1802,7 +1802,7 @@ async def _call_user_autoscaling_stats(self) -> Dict[str, Union[int, float]]:
return result

@_run_user_code
async def call_reconfigure(self, user_config: Optional[Any], rank: int):
async def call_reconfigure(self, user_config: Optional[Any], rank: ReplicaRank):
self._raise_if_not_initialized("call_reconfigure")

# NOTE(edoakes): there is the possibility of a race condition in user code if
Expand Down
5 changes: 3 additions & 2 deletions python/ray/serve/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ray.serve._private.replica_result import ReplicaResult
from ray.serve.exceptions import RayServeException
from ray.serve.grpc_util import RayServegRPCContext
from ray.serve.schema import ReplicaRank
from ray.util.annotations import DeveloperAPI

logger = logging.getLogger(SERVE_LOGGER_NAME)
Expand All @@ -48,7 +49,7 @@ class ReplicaContext:
replica_id: ReplicaID
servable_object: Callable
_deployment_config: DeploymentConfig
rank: int
rank: ReplicaRank
world_size: int
_handle_registration_callback: Optional[Callable[[DeploymentID], None]] = None

Expand Down Expand Up @@ -113,7 +114,7 @@ def _set_internal_replica_context(
replica_id: ReplicaID,
servable_object: Callable,
_deployment_config: DeploymentConfig,
rank: int,
rank: ReplicaRank,
world_size: int,
handle_registration_callback: Optional[Callable[[str, str], None]] = None,
):
Expand Down
6 changes: 6 additions & 0 deletions python/ray/serve/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,3 +1600,9 @@ class ReplicaRank(BaseModel):
rank: int = Field(
description="Global rank of the replica across all nodes scoped to the deployment."
)

node_rank: int = Field(description="Rank of the node in the deployment.")

local_rank: int = Field(
description="Rank of the replica on the node scoped to the deployment."
)
Loading