Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tests/workers/rollout/test_async_sglang_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from omegaconf import DictConfig


@patch.dict(
"sys.modules",
{
"verl.workers.rollout.sglang_rollout.async_sglang_rollout": MagicMock(AsyncSGLangRollout=MagicMock()),
"verl.workers.rollout.sglang_rollout.sglang_rollout": MagicMock(SGLangRollout=MagicMock()),
},
)
class TestAsyncSglangServer:
@pytest.fixture
def mock_ray_actor(self):
mock_actor = MagicMock()
mock_actor.execute_method.remote = AsyncMock(return_value={"content": "mocked response"})
mock_actor.resume.remote = AsyncMock()
mock_actor.offload.remote = AsyncMock()
return mock_actor

@pytest.fixture
def server_config(self):
return DictConfig({"rollout": {"tensor_model_parallel_size": 2}})

@pytest.mark.asyncio
@patch("verl.workers.rollout.sglang_rollout.async_sglang_server.ray.util.list_named_actors")
@patch("verl.workers.rollout.async_server.AsyncServerBase._start_fastapi_server", new_callable=AsyncMock)
@pytest.mark.filterwarnings("ignore:Ray state API is no longer experimental:DeprecationWarning")
async def test_init_engine(self, mock_start_fastapi_server, mock_list_actors, server_config, mock_ray_actor):
mock_list_actors.return_value = [
{"name": "test_prefixWorkerDict_0:0", "namespace": "test"},
{"name": "test_prefixWorkerDict_0:1", "namespace": "test"},
{"name": "test_prefixWorkerDict_1:2", "namespace": "test"},
]
from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer

ActualClassToInstantiate = AsyncSglangServer
if hasattr(AsyncSglangServer, "__ray_metadata__") and hasattr(AsyncSglangServer.__ray_metadata__, "modified_class"):
ActualClassToInstantiate = AsyncSglangServer.__ray_metadata__.modified_class

with patch("verl.workers.rollout.sglang_rollout.async_sglang_server.ray.get_actor", return_value=mock_ray_actor):
instance = ActualClassToInstantiate(server_config, 2, 0, "test_prefix")

await instance.init_engine()

# Verify instance.workers is correctly populated
assert len(instance.workers) == 2
101 changes: 68 additions & 33 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,41 +407,68 @@ def _build_rollout(self, trust_remote_code=False):
log_gpu_memory_usage("After building sharding manager", logger=logger)

elif rollout_name == "sglang":
from verl.workers.rollout.sglang_rollout import SGLangRollout

# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to
# SGLang's model_runner would check CUDA device capability. However, due to verl's setting,
# the main process of ray can not find any CUDA device, which would potentially lead to:
# "RuntimeError: No CUDA GPUs are available".
# For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and
# we import it here use the abs path.
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager

log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
local_path = copy_to_local(self.config.model.path)
rollout = SGLangRollout(
actor_module=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
trust_remote_code=trust_remote_code,
)
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)

if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = "dummy_hf"
rollout_sharding_manager = FSDPSGLangShardingManager(
module=self.actor_module_fsdp,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
full_params="hf" in self.config.rollout.load_format,
device_mesh=rollout_device_mesh,
offload_param=self._is_offload_param,
)
log_gpu_memory_usage("After building sharding manager", logger=logger)
if self.config.rollout.mode == "sync":
from verl.workers.rollout.sglang_rollout import SGLangRollout

# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to
# SGLang's model_runner would check CUDA device capability. However, due to verl's setting,
# the main process of ray can not find any CUDA device, which would potentially lead to:
# "RuntimeError: No CUDA GPUs are available".
# For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and
# we import it here use the abs path.
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager

log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
local_path = copy_to_local(self.config.model.path)
rollout = SGLangRollout(
actor_module=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
trust_remote_code=trust_remote_code,
)
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)

if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = "dummy_hf"
rollout_sharding_manager = FSDPSGLangShardingManager(
module=self.actor_module_fsdp,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
full_params="hf" in self.config.rollout.load_format,
device_mesh=rollout_device_mesh,
offload_param=self._is_offload_param,
)
log_gpu_memory_usage("After building sharding manager", logger=logger)
elif self.config.rollout.mode == "async":
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager

local_path = copy_to_local(self.config.model.path)
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None)
rollout = AsyncSGLangRollout(
actor_module=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
trust_remote_code=trust_remote_code,
)
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None)

if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = "dummy_hf"
rollout_sharding_manager = FSDPAsyncSGLangShardingManager(
module=self.actor_module_fsdp,
inference_engine=rollout._engine,
model_config=self.actor_model_config,
full_params="hf" in self.config.rollout.load_format,
device_mesh=rollout_device_mesh,
offload_param=self._is_offload_param,
)
log_gpu_memory_usage("After building sharding manager", logger=None)
elif rollout_name == "sglang_async":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unavailable after PR 1717 is merged.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unavailable after PR 1717 is merged.

let's merge this PR to avoid further conflict

# TODO replace by rollout.mode == "async"
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager

Expand Down Expand Up @@ -1369,3 +1396,11 @@ def execute_method(self, method: Union[str, bytes], *args, **kwargs):
if self.vllm_tp_rank == 0 and method != "execute_model":
print(f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: {method if isinstance(method, str) else 'Callable'}")
return self.rollout.execute_method(method, *args, **kwargs)

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
def resume(self):
return self.rollout.resume()

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
def offload(self):
return self.rollout.offload()
4 changes: 3 additions & 1 deletion verl/workers/rollout/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ def async_server_class(rollout_backend: str) -> Type[AsyncServerBase]:

return AsyncvLLMServer
elif rollout_backend == "sglang":
raise NotImplementedError
from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer

return AsyncSglangServer
else:
raise NotImplementedError
86 changes: 85 additions & 1 deletion verl/workers/rollout/sglang_rollout/async_sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import asyncio
import logging
import os
import time
from contextlib import contextmanager
from copy import deepcopy
from json import JSONDecodeError
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union
from uuid import uuid4

import numpy as np
Expand Down Expand Up @@ -197,9 +198,11 @@ def __init__(
else:
self._engine = None

self.sharding_manager = None
# offload
if self._tp_rank == 0:
self._engine.release_memory_occupation()
self.is_sleep = True

kwargs = dict(
n=1,
Expand Down Expand Up @@ -768,3 +771,84 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in
req_list.append(req)

return req_list

def execute_method(self, method: Union[str, bytes], *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to implement this as an async function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to implement this as an async function?

sure. It will be done in #1721
It have to be sync method at this moment.

if method == "chat_completion":
json_request = args[0]

formatted_messages = []
for msg in json_request["messages"]:
role = msg.get("role", "user")
content = msg.get("content", "")
formatted_messages.append(f"{role}: {content}")
prompt_str = "\n".join(formatted_messages)

sampling_params_dict = {
"n": json_request.get("n", 1),
"max_new_tokens": json_request.get("max_completion_tokens", self.config.response_length),
"temperature": json_request.get("temperature", 1.0),
"top_p": json_request.get("top_p", 1.0),
}
output = None
if self._tp_rank == 0:
loop = asyncio.get_event_loop()
output = loop.run_until_complete(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to use _async_rollout_a_request?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to use _async_rollout_a_request?

It will be replaced soon in #1721

self._engine.async_generate(
prompt=prompt_str,
sampling_params=sampling_params_dict,
return_logprob=True,
)
)
output = broadcast_pyobj(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be removed if we only need TP0 to return all results from the current TP group and assemble them in the chat scheduler.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be removed if we only need TP0 to return all results from the current TP group and assemble them in the chat scheduler.

yes

data=[output],
rank=self._rank,
dist_group=self._device_mesh_cpu["tp"].get_group(),
src=self._device_mesh_cpu["tp"].mesh[0].item(),
force_cpu_device=False,
)

# only return value from master rank
if self._tp_rank != 0:
return None
# build openai chat completion format
choices = []
id = None
for i, content in enumerate(output):
choices.append(
{
"index": i,
"message": {
"role": "assistant",
"content": content["text"],
},
"finish_reason": content["meta_info"]["finish_reason"]["type"],
}
)
id = content["meta_info"]["id"]

return {
"id": "chatcmpl-" + id,
"object": "chat.completion",
"created": int(time.time()),
"model": json_request.get("model", "sglang_model"),
"choices": choices,
}
else:
raise ValueError(f"not supported method : {method}")

# this function is left for uniform train-inference resharding

def resume(self):
if not self.is_sleep:
return
self.sharding_manager.__enter__() # pylint: disable=C2801
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any better implementation way ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any better implementation way ?

vllm part also has code like this. I will create a new issue for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any better implementation way ?

#1734


self.is_sleep = False

# this function is left for uniform train-inference resharding
def offload(self):
if self.is_sleep:
return

self.sharding_manager.__exit__(None, None, None)
self.is_sleep = True
71 changes: 71 additions & 0 deletions verl/workers/rollout/sglang_rollout/async_sglang_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging

import ray
from omegaconf import DictConfig
from starlette.requests import Request
from starlette.responses import JSONResponse

from verl.workers.rollout.async_server import AsyncServerBase

logger = logging.getLogger(__file__)


@ray.remote(num_cpus=1)
class AsyncSglangServer(AsyncServerBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test for AsyncSglangServer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加了测试
tests/workers/rollout/test_async_sglang_server.py

def __init__(self, config: DictConfig, dp_size: int, dp_rank: int, wg_prefix: str):
super().__init__()
self.config = config
rollout_config = config.get("rollout", {})
self._tp_size = rollout_config.get("tensor_model_parallel_size", 1)
self._dp_size = dp_size
self._dp_rank = dp_rank
self.wg_prefix = wg_prefix
self.workers = []

async def init_engine(self):
all_actors = ray.util.list_named_actors(all_namespaces=True)
matched_actors = [actor for actor in all_actors if actor.get("name", None).startswith(self.wg_prefix + "WorkerDict_")]

# TODO support multi node
for matched_actor in matched_actors:
current_rank = int(matched_actor["name"].split(":")[-1])

# send to all works in this tp group, because sglang is SPMD
if current_rank >= self._dp_rank * self._tp_size and current_rank < (self._dp_rank + 1) * self._tp_size:
self.workers.append(ray.get_actor(**matched_actor))

async def chat_completion(self, raw_request: Request):
request = await raw_request.json()

output_dp_lst = []
for worker in self.workers:
output_future = worker.execute_method.remote("chat_completion", request)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can refer to the implementation in the following link: https://github.com/SwordFaith/verl/blob/refactor/merge_sgl_rollouts_and_bump_to_0.4.6.post4/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L213. We could only retrieve results from TP0, meaning non-TP0 workers would return a special response and be skipped by default.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can refer to the implementation in the following link: https://github.com/SwordFaith/verl/blob/refactor/merge_sgl_rollouts_and_bump_to_0.4.6.post4/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L213. We could only retrieve results from TP0, meaning non-TP0 workers would return a special response and be skipped by default.

yes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing broadcast_pyobj, is it just to call TP0 worker.execute_method?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing broadcast_pyobj, is it just to call TP0 worker.execute_method?

If we have a proper registration mechanism, it would be possible.

output_dp_lst.append(output_future)
outputs = await asyncio.gather(*output_dp_lst)

for output in outputs:
if output is not None:
return JSONResponse(output)
raise RuntimeError("AsyncSglangServer No output from workers self._dp_rank: {self._dp_rank}, self._tp_size: {self._tp_size}, self.workers: {self.workers}")

async def wake_up(self):
for worker in self.workers:
worker.resume.remote()

async def sleep(self):
for worker in self.workers:
worker.offload.remote()
Loading