Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/ppo_trainer/naive_chat_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataP
print(f"[NaiveChatCompletionScheduler] generate_sequences sampling params: {kwargs}")

async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception):
assert exception is None, f"exception: {exception}"
conversation, batch_conversations, batch_index = (
info["conversation"],
info["batch_conversations"],
Expand Down Expand Up @@ -77,7 +78,7 @@ async def callback(completions: ChatCompletion, info: Dict[str, Any], exception:
"conversation": list(conversation),
},
model=self.model_name,
messages=conversation,
messages=conversation.tolist(),
**kwargs,
)
)
Expand Down
4 changes: 2 additions & 2 deletions tests/workers/rollout/test_vllm_tool_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from starlette.responses import JSONResponse

from examples.ppo_trainer.naive_chat_scheduler import NaiveChatCompletionScheduler
from tests.rollout.async_rollout_utils import init_async_rollout_manager
from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager
from verl.protocol import DataProto


Expand Down Expand Up @@ -257,7 +257,7 @@ def test_vllm_tool_calling():
config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml")
config.actor_rollout_ref.model.path = "Qwen/Qwen2-7B-Instruct"
config.actor_rollout_ref.rollout.mode = "async"
config.actor_rollout_ref.rollout.chat_scheduler = "tests.rollout.test_vllm_tool_calling.ToolChatCompletionScheduler"
config.actor_rollout_ref.rollout.chat_scheduler = "tests.workers.rollout.test_vllm_tool_calling.ToolChatCompletionScheduler"
config.actor_rollout_ref.rollout.prompt_length = 8192
config.actor_rollout_ref.rollout.response_length = 8192

Expand Down
14 changes: 5 additions & 9 deletions verl/workers/rollout/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,29 +184,25 @@ async def callback(completions: ChatCompletion, info: Dict[str, Any], exception:

completions, exception = None, None
try:
# TODO: OpenAI client uses httpx, seems to have performance issue in high concurrency requests.
completions = await self._chat_completions_openai(address, **chat_complete_request)
# NOTE: OpenAI client uses httpx, seems to have performance issue in high concurrency requests.
completions = await self._chat_completions_aiohttp(address, **chat_complete_request)
except Exception as e:
# Let user handle the exception
exception = e

await callback(completions, callback_additional_info, exception)

async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion:
client = AsyncOpenAI(
base_url=f"http://{address}/v1",
api_key="token-abc123",
timeout=None,
max_retries=0
)
client = AsyncOpenAI(base_url=f"http://{address}/v1", api_key="token-abc123", timeout=None, max_retries=0)
Copy link
Collaborator

@hongpeng-guo hongpeng-guo May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this line is the same as before, but the format changes. Just want to double check if the current one is lint with the pre-commit hook :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's auto format by pre-commit hook.

return await client.chat.completions.create(**chat_complete_request)

async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion:
try:
extra_headers = chat_complete_request.pop("extra_headers")
session = aiohttp.ClientSession()
async with session.post(
url=f"http://{address}/v1/chat/completions",
headers={"Authorization": "Bearer token-abc123"},
headers={"Authorization": "Bearer token-abc123", **extra_headers},
json=chat_complete_request,
) as resp:
data = await resp.json()
Expand Down
1 change: 1 addition & 0 deletions verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def collective_rpc(
sent_method = cloudpickle.dumps(method)
del method

# ~3ms overhead per schedule step due to SchedulerOutput/ModelRunnerOutput serialization/deserialization.
outputs = ray.get([worker.execute_method.remote(sent_method, *args, **(kwargs or {})) for worker in self.workers])
return outputs

Expand Down
Loading