From b1710ac879d22ebbf74145b7b91641e5653b2095 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E6=B5=B7=E6=B3=89?= Date: Fri, 23 May 2025 17:50:02 +0800 Subject: [PATCH 1/2] sglang_async_rollout: Implement sglang async rollout and multi-turn using AsyncServerBase --- verl/workers/fsdp_workers.py | 101 ++++++++++++------ verl/workers/rollout/async_server.py | 4 +- .../sglang_rollout/async_sglang_rollout.py | 86 ++++++++++++++- .../sglang_rollout/async_sglang_server.py | 71 ++++++++++++ 4 files changed, 227 insertions(+), 35 deletions(-) create mode 100644 verl/workers/rollout/sglang_rollout/async_sglang_server.py diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 48114746b67..2478db109ac 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -407,41 +407,68 @@ def _build_rollout(self, trust_remote_code=False): log_gpu_memory_usage("After building sharding manager", logger=logger) elif rollout_name == "sglang": - from verl.workers.rollout.sglang_rollout import SGLangRollout - - # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to - # SGLang's model_runner would check CUDA device capability. However, due to verl's setting, - # the main process of ray can not find any CUDA device, which would potentially lead to: - # "RuntimeError: No CUDA GPUs are available". - # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and - # we import it here use the abs path. - # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 - from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager - - log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger) - local_path = copy_to_local(self.config.model.path) - rollout = SGLangRollout( - actor_module=local_path, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - trust_remote_code=trust_remote_code, - ) - log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger) - - if torch.distributed.get_world_size() == 1: - self.config.rollout.load_format = "dummy_hf" - rollout_sharding_manager = FSDPSGLangShardingManager( - module=self.actor_module_fsdp, - inference_engine=rollout.inference_engine, - model_config=self.actor_model_config, - full_params="hf" in self.config.rollout.load_format, - device_mesh=rollout_device_mesh, - offload_param=self._is_offload_param, - ) - log_gpu_memory_usage("After building sharding manager", logger=logger) + if self.config.rollout.mode == "sync": + from verl.workers.rollout.sglang_rollout import SGLangRollout + + # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to + # SGLang's model_runner would check CUDA device capability. However, due to verl's setting, + # the main process of ray can not find any CUDA device, which would potentially lead to: + # "RuntimeError: No CUDA GPUs are available". + # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and + # we import it here use the abs path. + # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 + from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager + + log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger) + local_path = copy_to_local(self.config.model.path) + rollout = SGLangRollout( + actor_module=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + trust_remote_code=trust_remote_code, + ) + log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger) + + if torch.distributed.get_world_size() == 1: + self.config.rollout.load_format = "dummy_hf" + rollout_sharding_manager = FSDPSGLangShardingManager( + module=self.actor_module_fsdp, + inference_engine=rollout.inference_engine, + model_config=self.actor_model_config, + full_params="hf" in self.config.rollout.load_format, + device_mesh=rollout_device_mesh, + offload_param=self._is_offload_param, + ) + log_gpu_memory_usage("After building sharding manager", logger=logger) + elif self.config.rollout.mode == "async": + from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout + from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager + local_path = copy_to_local(self.config.model.path) + log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None) + rollout = AsyncSGLangRollout( + actor_module=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + trust_remote_code=trust_remote_code, + ) + log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None) + + if torch.distributed.get_world_size() == 1: + self.config.rollout.load_format = "dummy_hf" + rollout_sharding_manager = FSDPAsyncSGLangShardingManager( + module=self.actor_module_fsdp, + inference_engine=rollout._engine, + model_config=self.actor_model_config, + full_params="hf" in self.config.rollout.load_format, + device_mesh=rollout_device_mesh, + offload_param=self._is_offload_param, + ) + log_gpu_memory_usage("After building sharding manager", logger=None) elif rollout_name == "sglang_async": + # TODO replace by rollout.mode == "async" from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager @@ -1369,3 +1396,11 @@ def execute_method(self, method: Union[str, bytes], *args, **kwargs): if self.vllm_tp_rank == 0 and method != "execute_model": print(f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: {method if isinstance(method, str) else 'Callable'}") return self.rollout.execute_method(method, *args, **kwargs) + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + def resume(self): + return self.rollout.resume() + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + def offload(self): + return self.rollout.offload() diff --git a/verl/workers/rollout/async_server.py b/verl/workers/rollout/async_server.py index f3c5d5cc448..7055c7fe68e 100644 --- a/verl/workers/rollout/async_server.py +++ b/verl/workers/rollout/async_server.py @@ -348,6 +348,8 @@ def async_server_class(rollout_backend: str) -> Type[AsyncServerBase]: return AsyncvLLMServer elif rollout_backend == "sglang": - raise NotImplementedError + from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer + + return AsyncSglangServer else: raise NotImplementedError diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py index ada5eb3d2a5..6deeefd0a10 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py @@ -17,10 +17,11 @@ import asyncio import logging import os +import time from contextlib import contextmanager from copy import deepcopy from json import JSONDecodeError -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from uuid import uuid4 import numpy as np @@ -197,9 +198,11 @@ def __init__( else: self._engine = None + self.sharding_manager = None # offload if self._tp_rank == 0: self._engine.release_memory_occupation() + self.is_sleep = True kwargs = dict( n=1, @@ -768,3 +771,84 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in req_list.append(req) return req_list + + def execute_method(self, method: Union[str, bytes], *args, **kwargs): + if method == "chat_completion": + json_request = args[0] + + formatted_messages = [] + for msg in json_request["messages"]: + role = msg.get("role", "user") + content = msg.get("content", "") + formatted_messages.append(f"{role}: {content}") + prompt_str = "\n".join(formatted_messages) + + sampling_params_dict = { + "n": json_request.get("n", 1), + "max_new_tokens": json_request.get("max_completion_tokens", self.config.response_length), + "temperature": json_request.get("temperature", 1.0), + "top_p": json_request.get("top_p", 1.0), + } + output = None + if self._tp_rank == 0: + loop = asyncio.get_event_loop() + output = loop.run_until_complete( + self._engine.async_generate( + prompt=prompt_str, + sampling_params=sampling_params_dict, + return_logprob=True, + ) + ) + output = broadcast_pyobj( + data=[output], + rank=self._rank, + dist_group=self._device_mesh_cpu["tp"].get_group(), + src=self._device_mesh_cpu["tp"].mesh[0].item(), + force_cpu_device=False, + ) + + # only return value from master rank + if self._tp_rank != 0: + return None + # build openai chat completion format + choices = [] + id = None + for i, content in enumerate(output): + choices.append( + { + "index": i, + "message": { + "role": "assistant", + "content": content["text"], + }, + "finish_reason": content["meta_info"]["finish_reason"]["type"], + } + ) + id = content["meta_info"]["id"] + + return { + "id": "chatcmpl-" + id, + "object": "chat.completion", + "created": int(time.time()), + "model": json_request.get("model", "sglang_model"), + "choices": choices, + } + else: + raise ValueError(f"not supported method : {method}") + + # this function is left for uniform train-inference resharding + + def resume(self): + if not self.is_sleep: + return + self.sharding_manager.__enter__() # pylint: disable=C2801 + + self.is_sleep = False + + # this function is left for uniform train-inference resharding + def offload(self): + if self.is_sleep: + return + + self.sharding_manager.__exit__(None, None, None) + self.is_sleep = True diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py new file mode 100644 index 00000000000..c29494f794b --- /dev/null +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -0,0 +1,71 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging + +import ray +from omegaconf import DictConfig +from starlette.requests import Request +from starlette.responses import JSONResponse + +from verl.workers.rollout.async_server import AsyncServerBase + +logger = logging.getLogger(__file__) + + +@ray.remote(num_cpus=1) +class AsyncSglangServer(AsyncServerBase): + def __init__(self, config: DictConfig, dp_size: int, dp_rank: int, wg_prefix: str): + super().__init__() + self.config = config + rollout_config = config.get("rollout", {}) + self._tp_size = rollout_config.get("tensor_model_parallel_size", 1) + self._dp_size = dp_size + self._dp_rank = dp_rank + self.wg_prefix = wg_prefix + self.workers = [] + + async def init_engine(self): + all_actors = ray.util.list_named_actors(all_namespaces=True) + matched_actors = [actor for actor in all_actors if actor.get("name", None).startswith(self.wg_prefix + "WorkerDict_")] + + # TODO support multi node + for matched_actor in matched_actors: + current_rank = int(matched_actor["name"].split(":")[-1]) + + # send to all works in this tp group, because sglang is SPMD + if current_rank >= self._dp_rank * self._tp_size and current_rank < (self._dp_rank + 1) * self._tp_size: + self.workers.append(ray.get_actor(**matched_actor)) + + async def chat_completion(self, raw_request: Request): + request = await raw_request.json() + + output_dp_lst = [] + for worker in self.workers: + output_future = worker.execute_method.remote("chat_completion", request) + output_dp_lst.append(output_future) + outputs = await asyncio.gather(*output_dp_lst) + + for output in outputs: + if output is not None: + return JSONResponse(output) + raise RuntimeError("AsyncSglangServer No output from workers self._dp_rank: {self._dp_rank}, self._tp_size: {self._tp_size}, self.workers: {self.workers}") + + async def wake_up(self): + for worker in self.workers: + worker.resume.remote() + + async def sleep(self): + for worker in self.workers: + worker.offload.remote() From 5adada1608756d294ab9abaff64f32a8f705e9b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E6=B5=B7=E6=B3=89?= Date: Wed, 28 May 2025 10:29:27 +0800 Subject: [PATCH 2/2] sglang_async_rollout: add pytest for init engine --- .../rollout/test_async_sglang_server.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tests/workers/rollout/test_async_sglang_server.py diff --git a/tests/workers/rollout/test_async_sglang_server.py b/tests/workers/rollout/test_async_sglang_server.py new file mode 100644 index 00000000000..a9cc7ade01c --- /dev/null +++ b/tests/workers/rollout/test_async_sglang_server.py @@ -0,0 +1,62 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from omegaconf import DictConfig + + +@patch.dict( + "sys.modules", + { + "verl.workers.rollout.sglang_rollout.async_sglang_rollout": MagicMock(AsyncSGLangRollout=MagicMock()), + "verl.workers.rollout.sglang_rollout.sglang_rollout": MagicMock(SGLangRollout=MagicMock()), + }, +) +class TestAsyncSglangServer: + @pytest.fixture + def mock_ray_actor(self): + mock_actor = MagicMock() + mock_actor.execute_method.remote = AsyncMock(return_value={"content": "mocked response"}) + mock_actor.resume.remote = AsyncMock() + mock_actor.offload.remote = AsyncMock() + return mock_actor + + @pytest.fixture + def server_config(self): + return DictConfig({"rollout": {"tensor_model_parallel_size": 2}}) + + @pytest.mark.asyncio + @patch("verl.workers.rollout.sglang_rollout.async_sglang_server.ray.util.list_named_actors") + @patch("verl.workers.rollout.async_server.AsyncServerBase._start_fastapi_server", new_callable=AsyncMock) + @pytest.mark.filterwarnings("ignore:Ray state API is no longer experimental:DeprecationWarning") + async def test_init_engine(self, mock_start_fastapi_server, mock_list_actors, server_config, mock_ray_actor): + mock_list_actors.return_value = [ + {"name": "test_prefixWorkerDict_0:0", "namespace": "test"}, + {"name": "test_prefixWorkerDict_0:1", "namespace": "test"}, + {"name": "test_prefixWorkerDict_1:2", "namespace": "test"}, + ] + from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer + + ActualClassToInstantiate = AsyncSglangServer + if hasattr(AsyncSglangServer, "__ray_metadata__") and hasattr(AsyncSglangServer.__ray_metadata__, "modified_class"): + ActualClassToInstantiate = AsyncSglangServer.__ray_metadata__.modified_class + + with patch("verl.workers.rollout.sglang_rollout.async_sglang_server.ray.get_actor", return_value=mock_ray_actor): + instance = ActualClassToInstantiate(server_config, 2, 0, "test_prefix") + + await instance.init_engine() + + # Verify instance.workers is correctly populated + assert len(instance.workers) == 2