diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index 9460147e791a..7f9578668bb4 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -281,6 +281,9 @@ def __init__( self._last_record_routing_stats_time: float = 0.0 self._ingress: bool = False + # Outbound deployments polling state + self._outbound_deployments: Optional[List[DeploymentID]] = None + @property def replica_id(self) -> str: return self._replica_id @@ -775,6 +778,7 @@ def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[str]]: self._grpc_port, self._rank, self._route_patterns, + self._outbound_deployments, ) = ray.get(self._ready_obj_ref) except RayTaskError as e: logger.exception( @@ -1047,6 +1051,9 @@ def force_stop(self, log_shutdown_message: bool = False): except ValueError: pass + def get_outbound_deployments(self) -> Optional[List[DeploymentID]]: + return self._outbound_deployments + class DeploymentReplica: """Manages state transitions for deployment replicas. @@ -1327,6 +1334,9 @@ def resource_requirements(self) -> Tuple[str, str]: # https://github.com/ray-project/ray/issues/26210 for the issue. return json.dumps(required), json.dumps(available) + def get_outbound_deployments(self) -> Optional[List[DeploymentID]]: + return self._actor.get_outbound_deployments() + class ReplicaStateContainer: """Container for mapping ReplicaStates to lists of DeploymentReplicas.""" @@ -3093,6 +3103,27 @@ def _stop_one_running_replica_for_testing(self): def is_ingress(self) -> bool: return self._target_state.info.ingress + def get_outbound_deployments(self) -> Optional[List[DeploymentID]]: + """Get the outbound deployments. + + Returns: + Sorted list of deployment IDs that this deployment calls. None if + outbound deployments are not yet polled. + """ + result: Set[DeploymentID] = set() + has_outbound_deployments = False + for replica in self._replicas.get([ReplicaState.RUNNING]): + if replica.version != self._target_state.version: + # Only consider replicas of the target version + continue + outbound_deployments = replica.get_outbound_deployments() + if outbound_deployments is not None: + result.update(outbound_deployments) + has_outbound_deployments = True + if not has_outbound_deployments: + return None + return sorted(result, key=lambda d: (d.name)) + class DeploymentStateManager: """Manages all state for deployments in the system. @@ -3701,3 +3732,21 @@ def _get_replica_ranks_mapping(self, deployment_id: DeploymentID) -> Dict[str, i return {} return deployment_state._get_replica_ranks_mapping() + + def get_deployment_outbound_deployments( + self, deployment_id: DeploymentID + ) -> Optional[List[DeploymentID]]: + """Get the cached outbound deployments for a specific deployment. + + Args: + deployment_id: The deployment ID to get outbound deployments for. + + Returns: + List of deployment IDs that this deployment calls, or None if + the deployment doesn't exist or hasn't been polled yet. + """ + deployment_state = self._deployment_states.get(deployment_id) + if deployment_state is None: + return None + + return deployment_state.get_outbound_deployments() diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index c468fb55f1ff..3912043ac7ab 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -131,6 +131,7 @@ int, int, # rank Optional[List[str]], # route_patterns + Optional[List[DeploymentID]], # outbound_deployments ] @@ -604,11 +605,51 @@ def get_metadata(self) -> ReplicaMetadata: self._grpc_port, current_rank, route_patterns, + self.list_outbound_deployments(), ) def get_dynamically_created_handles(self) -> Set[DeploymentID]: return self._dynamically_created_handles + def list_outbound_deployments(self) -> List[DeploymentID]: + """List all outbound deployment IDs this replica calls into. + + This includes: + - Handles created via get_deployment_handle() + - Handles passed as init args/kwargs to the deployment constructor + + This is used to determine which deployments are reachable from this replica. + The list of DeploymentIDs can change over time as new handles can be created at runtime. + Also its not guaranteed that the list of DeploymentIDs are identical across replicas + because it depends on user code. + + Returns: + A list of DeploymentIDs that this replica calls into. + """ + seen_deployment_ids: Set[DeploymentID] = set() + + # First, collect dynamically created handles + for deployment_id in self.get_dynamically_created_handles(): + seen_deployment_ids.add(deployment_id) + + # Get the init args/kwargs + init_args = self._user_callable_wrapper._init_args + init_kwargs = self._user_callable_wrapper._init_kwargs + + # Use _PyObjScanner to find all DeploymentHandle objects in: + # The init_args and init_kwargs (handles might be passed as init args) + scanner = _PyObjScanner(source_type=DeploymentHandle) + try: + handles = scanner.find_nodes((init_args, init_kwargs)) + + for handle in handles: + deployment_id = handle.deployment_id + seen_deployment_ids.add(deployment_id) + finally: + scanner.clear() + + return list(seen_deployment_ids) + def _set_internal_replica_context( self, *, servable_object: Callable = None, rank: int = None ): @@ -1219,45 +1260,6 @@ def get_num_ongoing_requests(self) -> int: """ return self._replica_impl.get_num_ongoing_requests() - def list_outbound_deployments(self) -> List[DeploymentID]: - """List all outbound deployment IDs this replica calls into. - - This includes: - - Handles created via get_deployment_handle() - - Handles passed as init args/kwargs to the deployment constructor - - This is used to determine which deployments are reachable from this replica. - The list of DeploymentIDs can change over time as new handles can be created at runtime. - Also its not guaranteed that the list of DeploymentIDs are identical across replicas - because it depends on user code. - - Returns: - A list of DeploymentIDs that this replica calls into. - """ - seen_deployment_ids: Set[DeploymentID] = set() - - # First, collect dynamically created handles - for deployment_id in self._replica_impl.get_dynamically_created_handles(): - seen_deployment_ids.add(deployment_id) - - # Get the init args/kwargs - init_args = self._replica_impl._user_callable_wrapper._init_args - init_kwargs = self._replica_impl._user_callable_wrapper._init_kwargs - - # Use _PyObjScanner to find all DeploymentHandle objects in: - # The init_args and init_kwargs (handles might be passed as init args) - scanner = _PyObjScanner(source_type=DeploymentHandle) - try: - handles = scanner.find_nodes((init_args, init_kwargs)) - - for handle in handles: - deployment_id = handle.deployment_id - seen_deployment_ids.add(deployment_id) - finally: - scanner.clear() - - return list(seen_deployment_ids) - async def is_allocated(self) -> str: """poke the replica to check whether it's alive. @@ -1281,6 +1283,9 @@ async def is_allocated(self) -> str: get_component_logger_file_path(), ) + def list_outbound_deployments(self) -> Optional[List[DeploymentID]]: + return self._replica_impl.list_outbound_deployments() + async def initialize_and_get_metadata( self, deployment_config: DeploymentConfig = None, _after: Optional[Any] = None ) -> ReplicaMetadata: diff --git a/python/ray/serve/tests/test_controller_recovery.py b/python/ray/serve/tests/test_controller_recovery.py index b1f056280e3a..d3c716f8e36e 100644 --- a/python/ray/serve/tests/test_controller_recovery.py +++ b/python/ray/serve/tests/test_controller_recovery.py @@ -65,7 +65,7 @@ def __call__(self, *args): replica_version_hash = None for replica in deployment_dict[id]: ref = replica.get_actor_handle().initialize_and_get_metadata.remote() - _, version, _, _, _, _, _, _, _ = ray.get(ref) + _, version, _, _, _, _, _, _, _, _ = ray.get(ref) if replica_version_hash is None: replica_version_hash = hash(version) assert replica_version_hash == hash(version), ( @@ -118,7 +118,7 @@ def __call__(self, *args): for replica_name in recovered_replica_names: actor_handle = ray.get_actor(replica_name, namespace=SERVE_NAMESPACE) ref = actor_handle.initialize_and_get_metadata.remote() - _, version, _, _, _, _, _, _, _ = ray.get(ref) + _, version, _, _, _, _, _, _, _, _ = ray.get(ref) assert replica_version_hash == hash( version ), "Replica version hash should be the same after recover from actor names" diff --git a/python/ray/serve/tests/unit/test_deployment_state.py b/python/ray/serve/tests/unit/test_deployment_state.py index df145bf0de31..b83c9e3b9eea 100644 --- a/python/ray/serve/tests/unit/test_deployment_state.py +++ b/python/ray/serve/tests/unit/test_deployment_state.py @@ -310,6 +310,9 @@ def check_health(self): def get_routing_stats(self) -> Dict[str, Any]: return {} + def get_outbound_deployments(self) -> Optional[List[DeploymentID]]: + return getattr(self, "_outbound_deployments", None) + @property def route_patterns(self) -> Optional[List[str]]: return None @@ -5600,5 +5603,143 @@ def test_rank_assignment_with_replica_failures(self, mock_deployment_state_manag }, f"Expected ranks [0, 1, 2], got {ranks_mapping.values()}" +class TestGetOutboundDeployments: + def test_basic_outbound_deployments(self, mock_deployment_state_manager): + """Test that outbound deployments are returned.""" + create_dsm, _, _, _ = mock_deployment_state_manager + dsm: DeploymentStateManager = create_dsm() + + deployment_id = DeploymentID(name="test_deployment", app_name="test_app") + b_info_1, _ = deployment_info(num_replicas=1) + dsm.deploy(deployment_id, b_info_1) + + # Create a RUNNING replica + ds = dsm._deployment_states[deployment_id] + dsm.update() # Transitions to STARTING + for replica in ds._replicas.get([ReplicaState.STARTING]): + replica._actor.set_ready() + dsm.update() # Transitions to RUNNING + + # Set outbound deployments on the mock replica + running_replicas = ds._replicas.get([ReplicaState.RUNNING]) + assert len(running_replicas) == 1 + d1 = DeploymentID(name="dep1", app_name="test_app") + d2 = DeploymentID(name="dep2", app_name="test_app") + running_replicas[0]._actor._outbound_deployments = [d1, d2] + + outbound_deployments = ds.get_outbound_deployments() + assert outbound_deployments == [d1, d2] + + # Verify it's accessible through DeploymentStateManager + assert dsm.get_deployment_outbound_deployments(deployment_id) == [ + d1, + d2, + ] + + def test_deployment_state_manager_returns_none_for_nonexistent_deployment( + self, mock_deployment_state_manager + ): + """Test that DeploymentStateManager returns None for nonexistent deployments.""" + ( + create_dsm, + timer, + cluster_node_info_cache, + autoscaling_state_manager, + ) = mock_deployment_state_manager + dsm = create_dsm() + + deployment_id = DeploymentID(name="nonexistent", app_name="test_app") + assert dsm.get_deployment_outbound_deployments(deployment_id) is None + + def test_returns_none_if_replicas_are_not_running( + self, mock_deployment_state_manager + ): + """Test that DeploymentStateManager returns None if replicas are not running.""" + create_dsm, _, _, _ = mock_deployment_state_manager + dsm: DeploymentStateManager = create_dsm() + + deployment_id = DeploymentID(name="test_deployment", app_name="test_app") + b_info_1, _ = deployment_info(num_replicas=2) + dsm.deploy(deployment_id, b_info_1) + ds = dsm._deployment_states[deployment_id] + dsm.update() + replicas = ds._replicas.get([ReplicaState.STARTING]) + assert len(replicas) == 2 + d1 = DeploymentID(name="dep1", app_name="test_app") + d2 = DeploymentID(name="dep2", app_name="test_app") + d3 = DeploymentID(name="dep3", app_name="test_app") + d4 = DeploymentID(name="dep4", app_name="test_app") + replicas[0]._actor._outbound_deployments = [d1, d2] + replicas[1]._actor._outbound_deployments = [d3, d4] + dsm.update() + + outbound_deployments = ds.get_outbound_deployments() + assert outbound_deployments is None + + # Set replicas ready + replicas[0]._actor.set_ready() + dsm.update() + outbound_deployments = ds.get_outbound_deployments() + assert outbound_deployments == [d1, d2] + + def test_only_considers_replicas_matching_target_version( + self, mock_deployment_state_manager + ): + """Test that only replicas with target version are considered. + + When a new version is deployed, old version replicas that are still + running should not be included in the outbound deployments result. + """ + create_dsm, _, _, _ = mock_deployment_state_manager + dsm: DeploymentStateManager = create_dsm() + + # Deploy version 1 + b_info_1, v1 = deployment_info(version="1") + dsm.deploy(TEST_DEPLOYMENT_ID, b_info_1) + ds = dsm._deployment_states[TEST_DEPLOYMENT_ID] + dsm.update() + + # Get v1 replica to RUNNING state + ds._replicas.get()[0]._actor.set_ready() + dsm.update() + + # Set outbound deployments for v1 replica + d1 = DeploymentID(name="dep1", app_name="test_app") + d2 = DeploymentID(name="dep2", app_name="test_app") + ds._replicas.get()[0]._actor._outbound_deployments = [d1, d2] + + # Verify v1 outbound deployments are returned + assert ds.get_outbound_deployments() == [d1, d2] + + # Deploy version 2 - this triggers rolling update + b_info_2, v2 = deployment_info(version="2") + dsm.deploy(TEST_DEPLOYMENT_ID, b_info_2) + dsm.update() + + # Now we have v1 stopping and v2 starting + check_counts( + ds, + total=2, + by_state=[(ReplicaState.STOPPING, 1, v1), (ReplicaState.STARTING, 1, v2)], + ) + + # Key test: Even though v1 replica exists (stopping), it should not be + # included because target version is v2. Since v2 is not RUNNING yet, + # should return None. + assert ds.get_outbound_deployments() is None + + # Set outbound deployments for v2 replica and mark it ready + d3 = DeploymentID(name="dep3", app_name="test_app") + ds._replicas.get(states=[ReplicaState.STARTING])[ + 0 + ]._actor._outbound_deployments = [d3] + ds._replicas.get(states=[ReplicaState.STARTING])[0]._actor.set_ready() + dsm.update() + + # Now v2 is running. Should only return v2's outbound deployments (d3), + # not v1's outbound deployments (d1, d2). + assert ds.get_outbound_deployments() == [d3] + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__]))