Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tests/rollout/async_rollout_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, Tuple
from typing import Any, Dict

import ray
from omegaconf import DictConfig
Expand All @@ -24,7 +24,7 @@
from verl.workers.rollout.async_server import AsyncLLMServerManager


def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, Any] = None) -> Tuple[Dict[str, RayWorkerGroup], AsyncLLMServerManager]:
def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, Any] = None) -> AsyncLLMServerManager:
# make openai client happy
os.environ["no_proxy"] = ""
os.environ["http_proxy"] = ""
Expand All @@ -51,13 +51,11 @@ def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, A
resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls

all_wg = {}
wg_dicts = []
for resource_pool, class_dict in resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)
wg_dicts.append(wg_dict)
actor_rollout_wg = all_wg["actor_rollout"]
actor_rollout_wg.init_model()

Expand All @@ -68,4 +66,4 @@ def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, A
scheduler_kwargs=scheduler_kwargs,
)

return all_wg, async_rollout_manager
return async_rollout_manager
4 changes: 2 additions & 2 deletions tests/rollout/test_vllm_multi_turn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_vllm_multi_turn(config):

# =========================== 1. Init rollout manager ===========================
model_name = "/".join(config.actor_rollout_ref.model.path.split("/")[-2:])
worker_groups, async_rollout_manager = init_async_rollout_manager(config)
async_rollout_manager = init_async_rollout_manager(config)

# test sleep and wake_up
async_rollout_manager.sleep()
Expand Down Expand Up @@ -145,7 +145,7 @@ async def test_vllm_streaming_response(config):
)

model_name = "/".join(config.actor_rollout_ref.model.path.split("/")[-2:])
worker_groups, async_rollout_manager = init_async_rollout_manager(config)
async_rollout_manager = init_async_rollout_manager(config)
async_llm_server = async_rollout_manager.async_llm_servers[0]

# non-streaming request
Expand Down
2 changes: 1 addition & 1 deletion tests/rollout/test_vllm_tool_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def test_vllm_tool_calling():
# Init sandbox and async rollout manager
sandbox = Sandbox.options(num_cpus=1).remote()
sandbox_address = ray.get(sandbox.get_server_address.remote())
worker_groups, async_rollout_manager = init_async_rollout_manager(config, scheduler_kwargs={"sandbox_address": sandbox_address, "system_prompt": system_prompt})
async_rollout_manager = init_async_rollout_manager(config, scheduler_kwargs={"sandbox_address": sandbox_address, "system_prompt": system_prompt})

# Build dataset
dataset = load_dataset("Maxwell-Jia/AIME_2024", split="train")
Expand Down
15 changes: 11 additions & 4 deletions verl/single_controller/ray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def __init__(
name_prefix: str = None,
detached=False,
worker_names=None,
worker_handles: List[ray.actor.ActorHandle] = None,
ray_wait_register_center_timeout: int = 300,
**kwargs,
) -> None:
Expand All @@ -206,7 +207,7 @@ def __init__(
self._worker_names = worker_names

if self._is_init_with_detached_workers:
self._init_with_detached_workers(worker_names=worker_names)
self._init_with_detached_workers(worker_names=worker_names, worker_handles=worker_handles)
else:
self._init_with_resource_pool(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, bin_pack=bin_pack, detached=detached)

Expand All @@ -220,8 +221,12 @@ def _is_worker_alive(self, worker: ray.actor.ActorHandle):
worker_state_dict = get_actor(worker._actor_id.hex())
return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False

def _init_with_detached_workers(self, worker_names):
workers = [ray.get_actor(name=name) for name in worker_names]
def _init_with_detached_workers(self, worker_names, worker_handles):
# ray.get_actor holds a weak reference to the actor, which causes actors garbage collected unexpectedly
# if we only hold spawn RayWorkerGroup. By passing actor handle explicitly, spawn RayWorkerGroup have
# strong reference to these actors.
# https://github.com/ray-project/ray/pull/45699
workers = worker_handles if worker_handles else [ray.get_actor(name=name) for name in worker_names]
self._workers = workers
self._world_size = len(worker_names)

Expand Down Expand Up @@ -319,9 +324,10 @@ def from_detached(
cls,
name_prefix,
worker_names=None,
worker_handles=None,
ray_cls_with_init=None,
):
worker_group = cls(resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=name_prefix, worker_names=worker_names)
worker_group = cls(resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=name_prefix, worker_names=worker_names, worker_handles=worker_handles)
return worker_group

def spawn(self, prefix_set):
Expand Down Expand Up @@ -349,6 +355,7 @@ def _rebind_actor_methods(worker_group, actor_name):
new_worker_group = self.from_detached(
name_prefix=self.name_prefix,
worker_names=self._worker_names,
worker_handles=self._workers,
ray_cls_with_init=self.ray_cls_with_init,
)

Expand Down
3 changes: 0 additions & 3 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,6 @@ def init_workers(self):
# Instead, directly pass different resource pool to different worker groups.
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
all_wg = {}
self.wg_dicts = []
wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
Expand All @@ -718,8 +717,6 @@ def init_workers(self):
wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, **wg_kwargs)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)
# keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699
self.wg_dicts.append(wg_dict)

if self.use_critic:
self.critic_wg = all_wg["critic"]
Expand Down