From 0247390494fd5919014dcdf65514c1d978db71a4 Mon Sep 17 00:00:00 2001 From: wuxibin Date: Thu, 8 May 2025 21:32:18 +0800 Subject: [PATCH] [ray] fix: make spawn worker group hold strong reference to actors --- tests/rollout/async_rollout_utils.py | 8 +++----- tests/rollout/test_vllm_multi_turn.py | 4 ++-- tests/rollout/test_vllm_tool_calling.py | 2 +- verl/single_controller/ray/base.py | 15 +++++++++++---- verl/trainer/ppo/ray_trainer.py | 3 --- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/tests/rollout/async_rollout_utils.py b/tests/rollout/async_rollout_utils.py index e42d5580ac5..bc6186553ef 100644 --- a/tests/rollout/async_rollout_utils.py +++ b/tests/rollout/async_rollout_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Dict, Tuple +from typing import Any, Dict import ray from omegaconf import DictConfig @@ -24,7 +24,7 @@ from verl.workers.rollout.async_server import AsyncLLMServerManager -def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, Any] = None) -> Tuple[Dict[str, RayWorkerGroup], AsyncLLMServerManager]: +def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, Any] = None) -> AsyncLLMServerManager: # make openai client happy os.environ["no_proxy"] = "" os.environ["http_proxy"] = "" @@ -51,13 +51,11 @@ def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, A resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls all_wg = {} - wg_dicts = [] for resource_pool, class_dict in resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) - wg_dicts.append(wg_dict) actor_rollout_wg = all_wg["actor_rollout"] actor_rollout_wg.init_model() @@ -68,4 +66,4 @@ def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, A scheduler_kwargs=scheduler_kwargs, ) - return all_wg, async_rollout_manager + return async_rollout_manager diff --git a/tests/rollout/test_vllm_multi_turn.py b/tests/rollout/test_vllm_multi_turn.py index ea683b83024..27077224a13 100644 --- a/tests/rollout/test_vllm_multi_turn.py +++ b/tests/rollout/test_vllm_multi_turn.py @@ -55,7 +55,7 @@ def test_vllm_multi_turn(config): # =========================== 1. Init rollout manager =========================== model_name = "/".join(config.actor_rollout_ref.model.path.split("/")[-2:]) - worker_groups, async_rollout_manager = init_async_rollout_manager(config) + async_rollout_manager = init_async_rollout_manager(config) # test sleep and wake_up async_rollout_manager.sleep() @@ -145,7 +145,7 @@ async def test_vllm_streaming_response(config): ) model_name = "/".join(config.actor_rollout_ref.model.path.split("/")[-2:]) - worker_groups, async_rollout_manager = init_async_rollout_manager(config) + async_rollout_manager = init_async_rollout_manager(config) async_llm_server = async_rollout_manager.async_llm_servers[0] # non-streaming request diff --git a/tests/rollout/test_vllm_tool_calling.py b/tests/rollout/test_vllm_tool_calling.py index 6d007184961..9003a6f8d75 100644 --- a/tests/rollout/test_vllm_tool_calling.py +++ b/tests/rollout/test_vllm_tool_calling.py @@ -264,7 +264,7 @@ def test_vllm_tool_calling(): # Init sandbox and async rollout manager sandbox = Sandbox.options(num_cpus=1).remote() sandbox_address = ray.get(sandbox.get_server_address.remote()) - worker_groups, async_rollout_manager = init_async_rollout_manager(config, scheduler_kwargs={"sandbox_address": sandbox_address, "system_prompt": system_prompt}) + async_rollout_manager = init_async_rollout_manager(config, scheduler_kwargs={"sandbox_address": sandbox_address, "system_prompt": system_prompt}) # Build dataset dataset = load_dataset("Maxwell-Jia/AIME_2024", split="train") diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index 72be01588eb..6970e5fb809 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -189,6 +189,7 @@ def __init__( name_prefix: str = None, detached=False, worker_names=None, + worker_handles: List[ray.actor.ActorHandle] = None, ray_wait_register_center_timeout: int = 300, **kwargs, ) -> None: @@ -206,7 +207,7 @@ def __init__( self._worker_names = worker_names if self._is_init_with_detached_workers: - self._init_with_detached_workers(worker_names=worker_names) + self._init_with_detached_workers(worker_names=worker_names, worker_handles=worker_handles) else: self._init_with_resource_pool(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, bin_pack=bin_pack, detached=detached) @@ -220,8 +221,12 @@ def _is_worker_alive(self, worker: ray.actor.ActorHandle): worker_state_dict = get_actor(worker._actor_id.hex()) return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False - def _init_with_detached_workers(self, worker_names): - workers = [ray.get_actor(name=name) for name in worker_names] + def _init_with_detached_workers(self, worker_names, worker_handles): + # ray.get_actor holds a weak reference to the actor, which causes actors garbage collected unexpectedly + # if we only hold spawn RayWorkerGroup. By passing actor handle explicitly, spawn RayWorkerGroup have + # strong reference to these actors. + # https://github.com/ray-project/ray/pull/45699 + workers = worker_handles if worker_handles else [ray.get_actor(name=name) for name in worker_names] self._workers = workers self._world_size = len(worker_names) @@ -319,9 +324,10 @@ def from_detached( cls, name_prefix, worker_names=None, + worker_handles=None, ray_cls_with_init=None, ): - worker_group = cls(resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=name_prefix, worker_names=worker_names) + worker_group = cls(resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=name_prefix, worker_names=worker_names, worker_handles=worker_handles) return worker_group def spawn(self, prefix_set): @@ -349,6 +355,7 @@ def _rebind_actor_methods(worker_group, actor_name): new_worker_group = self.from_detached( name_prefix=self.name_prefix, worker_names=self._worker_names, + worker_handles=self._workers, ray_cls_with_init=self.ray_cls_with_init, ) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index d31524444f1..ddadafb6453 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -708,7 +708,6 @@ def init_workers(self): # Instead, directly pass different resource pool to different worker groups. # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. all_wg = {} - self.wg_dicts = [] wg_kwargs = {} # Setting up kwargs for RayWorkerGroup if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout @@ -718,8 +717,6 @@ def init_workers(self): wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, **wg_kwargs) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) - # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 - self.wg_dicts.append(wg_dict) if self.use_critic: self.critic_wg = all_wg["critic"]