-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[rollout] feat: Implement sglang async rollout and multi-turn using AsyncServerBase #1698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # Copyright 2025 Bytedance Ltd. and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from unittest.mock import AsyncMock, MagicMock, patch | ||
|
|
||
| import pytest | ||
| from omegaconf import DictConfig | ||
|
|
||
|
|
||
| @patch.dict( | ||
| "sys.modules", | ||
| { | ||
| "verl.workers.rollout.sglang_rollout.async_sglang_rollout": MagicMock(AsyncSGLangRollout=MagicMock()), | ||
| "verl.workers.rollout.sglang_rollout.sglang_rollout": MagicMock(SGLangRollout=MagicMock()), | ||
| }, | ||
| ) | ||
| class TestAsyncSglangServer: | ||
| @pytest.fixture | ||
| def mock_ray_actor(self): | ||
| mock_actor = MagicMock() | ||
| mock_actor.execute_method.remote = AsyncMock(return_value={"content": "mocked response"}) | ||
| mock_actor.resume.remote = AsyncMock() | ||
| mock_actor.offload.remote = AsyncMock() | ||
| return mock_actor | ||
|
|
||
| @pytest.fixture | ||
| def server_config(self): | ||
| return DictConfig({"rollout": {"tensor_model_parallel_size": 2}}) | ||
|
|
||
| @pytest.mark.asyncio | ||
| @patch("verl.workers.rollout.sglang_rollout.async_sglang_server.ray.util.list_named_actors") | ||
| @patch("verl.workers.rollout.async_server.AsyncServerBase._start_fastapi_server", new_callable=AsyncMock) | ||
| @pytest.mark.filterwarnings("ignore:Ray state API is no longer experimental:DeprecationWarning") | ||
| async def test_init_engine(self, mock_start_fastapi_server, mock_list_actors, server_config, mock_ray_actor): | ||
| mock_list_actors.return_value = [ | ||
| {"name": "test_prefixWorkerDict_0:0", "namespace": "test"}, | ||
| {"name": "test_prefixWorkerDict_0:1", "namespace": "test"}, | ||
| {"name": "test_prefixWorkerDict_1:2", "namespace": "test"}, | ||
| ] | ||
| from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer | ||
|
|
||
| ActualClassToInstantiate = AsyncSglangServer | ||
| if hasattr(AsyncSglangServer, "__ray_metadata__") and hasattr(AsyncSglangServer.__ray_metadata__, "modified_class"): | ||
| ActualClassToInstantiate = AsyncSglangServer.__ray_metadata__.modified_class | ||
|
|
||
| with patch("verl.workers.rollout.sglang_rollout.async_sglang_server.ray.get_actor", return_value=mock_ray_actor): | ||
| instance = ActualClassToInstantiate(server_config, 2, 0, "test_prefix") | ||
|
|
||
| await instance.init_engine() | ||
|
|
||
| # Verify instance.workers is correctly populated | ||
| assert len(instance.workers) == 2 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,10 +17,11 @@ | |
| import asyncio | ||
| import logging | ||
| import os | ||
| import time | ||
| from contextlib import contextmanager | ||
| from copy import deepcopy | ||
| from json import JSONDecodeError | ||
| from typing import TYPE_CHECKING | ||
| from typing import TYPE_CHECKING, Union | ||
| from uuid import uuid4 | ||
|
|
||
| import numpy as np | ||
|
|
@@ -197,9 +198,11 @@ def __init__( | |
| else: | ||
| self._engine = None | ||
|
|
||
| self.sharding_manager = None | ||
| # offload | ||
| if self._tp_rank == 0: | ||
| self._engine.release_memory_occupation() | ||
| self.is_sleep = True | ||
|
|
||
| kwargs = dict( | ||
| n=1, | ||
|
|
@@ -768,3 +771,84 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in | |
| req_list.append(req) | ||
|
|
||
| return req_list | ||
|
|
||
| def execute_method(self, method: Union[str, bytes], *args, **kwargs): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to implement this as an async function?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
sure. It will be done in #1721 |
||
| if method == "chat_completion": | ||
| json_request = args[0] | ||
|
|
||
| formatted_messages = [] | ||
| for msg in json_request["messages"]: | ||
| role = msg.get("role", "user") | ||
| content = msg.get("content", "") | ||
| formatted_messages.append(f"{role}: {content}") | ||
| prompt_str = "\n".join(formatted_messages) | ||
|
|
||
| sampling_params_dict = { | ||
| "n": json_request.get("n", 1), | ||
| "max_new_tokens": json_request.get("max_completion_tokens", self.config.response_length), | ||
| "temperature": json_request.get("temperature", 1.0), | ||
| "top_p": json_request.get("top_p", 1.0), | ||
| } | ||
| output = None | ||
| if self._tp_rank == 0: | ||
| loop = asyncio.get_event_loop() | ||
| output = loop.run_until_complete( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be better to use _async_rollout_a_request?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It will be replaced soon in #1721 |
||
| self._engine.async_generate( | ||
| prompt=prompt_str, | ||
| sampling_params=sampling_params_dict, | ||
| return_logprob=True, | ||
| ) | ||
| ) | ||
| output = broadcast_pyobj( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can be removed if we only need TP0 to return all results from the current TP group and assemble them in the chat scheduler.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yes |
||
| data=[output], | ||
| rank=self._rank, | ||
| dist_group=self._device_mesh_cpu["tp"].get_group(), | ||
| src=self._device_mesh_cpu["tp"].mesh[0].item(), | ||
| force_cpu_device=False, | ||
| ) | ||
|
|
||
| # only return value from master rank | ||
| if self._tp_rank != 0: | ||
| return None | ||
| # build openai chat completion format | ||
| choices = [] | ||
| id = None | ||
| for i, content in enumerate(output): | ||
| choices.append( | ||
| { | ||
| "index": i, | ||
| "message": { | ||
| "role": "assistant", | ||
| "content": content["text"], | ||
| }, | ||
| "finish_reason": content["meta_info"]["finish_reason"]["type"], | ||
| } | ||
| ) | ||
| id = content["meta_info"]["id"] | ||
|
|
||
| return { | ||
| "id": "chatcmpl-" + id, | ||
| "object": "chat.completion", | ||
| "created": int(time.time()), | ||
| "model": json_request.get("model", "sglang_model"), | ||
| "choices": choices, | ||
| } | ||
| else: | ||
| raise ValueError(f"not supported method : {method}") | ||
|
|
||
| # this function is left for uniform train-inference resharding | ||
|
|
||
| def resume(self): | ||
| if not self.is_sleep: | ||
| return | ||
| self.sharding_manager.__enter__() # pylint: disable=C2801 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any better implementation way ?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
vllm part also has code like this. I will create a new issue for this.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| self.is_sleep = False | ||
|
|
||
| # this function is left for uniform train-inference resharding | ||
| def offload(self): | ||
| if self.is_sleep: | ||
| return | ||
|
|
||
| self.sharding_manager.__exit__(None, None, None) | ||
| self.is_sleep = True | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # Copyright 2025 Bytedance Ltd. and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import asyncio | ||
| import logging | ||
|
|
||
| import ray | ||
| from omegaconf import DictConfig | ||
| from starlette.requests import Request | ||
| from starlette.responses import JSONResponse | ||
|
|
||
| from verl.workers.rollout.async_server import AsyncServerBase | ||
|
|
||
| logger = logging.getLogger(__file__) | ||
|
|
||
|
|
||
| @ray.remote(num_cpus=1) | ||
| class AsyncSglangServer(AsyncServerBase): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a test for AsyncSglangServer?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 加了测试 |
||
| def __init__(self, config: DictConfig, dp_size: int, dp_rank: int, wg_prefix: str): | ||
| super().__init__() | ||
| self.config = config | ||
| rollout_config = config.get("rollout", {}) | ||
| self._tp_size = rollout_config.get("tensor_model_parallel_size", 1) | ||
| self._dp_size = dp_size | ||
| self._dp_rank = dp_rank | ||
| self.wg_prefix = wg_prefix | ||
| self.workers = [] | ||
|
|
||
| async def init_engine(self): | ||
| all_actors = ray.util.list_named_actors(all_namespaces=True) | ||
| matched_actors = [actor for actor in all_actors if actor.get("name", None).startswith(self.wg_prefix + "WorkerDict_")] | ||
|
|
||
| # TODO support multi node | ||
| for matched_actor in matched_actors: | ||
| current_rank = int(matched_actor["name"].split(":")[-1]) | ||
|
|
||
| # send to all works in this tp group, because sglang is SPMD | ||
| if current_rank >= self._dp_rank * self._tp_size and current_rank < (self._dp_rank + 1) * self._tp_size: | ||
| self.workers.append(ray.get_actor(**matched_actor)) | ||
|
|
||
| async def chat_completion(self, raw_request: Request): | ||
| request = await raw_request.json() | ||
|
|
||
| output_dp_lst = [] | ||
| for worker in self.workers: | ||
| output_future = worker.execute_method.remote("chat_completion", request) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can refer to the implementation in the following link: https://github.com/SwordFaith/verl/blob/refactor/merge_sgl_rollouts_and_bump_to_0.4.6.post4/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L213. We could only retrieve results from TP0, meaning non-TP0 workers would return a special response and be skipped by default.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yes
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After removing broadcast_pyobj, is it just to call TP0 worker.execute_method?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If we have a proper registration mechanism, it would be possible. |
||
| output_dp_lst.append(output_future) | ||
| outputs = await asyncio.gather(*output_dp_lst) | ||
|
|
||
| for output in outputs: | ||
| if output is not None: | ||
| return JSONResponse(output) | ||
| raise RuntimeError("AsyncSglangServer No output from workers self._dp_rank: {self._dp_rank}, self._tp_size: {self._tp_size}, self.workers: {self.workers}") | ||
|
|
||
| async def wake_up(self): | ||
| for worker in self.workers: | ||
| worker.resume.remote() | ||
|
|
||
| async def sleep(self): | ||
| for worker in self.workers: | ||
| worker.offload.remote() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unavailable after PR 1717 is merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's merge this PR to avoid further conflict