diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml
index e6cf582e0df..ed7dc5800a1 100644
--- a/.github/workflows/vllm.yml
+++ b/.github/workflows/vllm.yml
@@ -66,6 +66,7 @@ jobs:
- name: Download Model to Use
run: |
huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct
+ huggingface-cli download Qwen/Qwen2.5-1.5B-Instruct
huggingface-cli download 'Qwen/Qwen2-7B-Instruct'
huggingface-cli download 'deepseek-ai/deepseek-llm-7b-chat'
export HF_HUB_OFFLINE=1
@@ -94,4 +95,4 @@ jobs:
- name: Running multi-turn rollout tests on 8 L20 GPUs
run: |
pip3 install --upgrade vllm==0.8.3 tensordict==0.7.2
- python3 tests/workers/rollout/test_vllm_multi_turn.py
+ pytest -svvv tests/workers/rollout/test_vllm_chat_scheduler.py
diff --git a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh
index e355e2f9c7d..819fa9f0698 100644
--- a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh
+++ b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh
@@ -9,7 +9,6 @@ rollout_name="sglang" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
- chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler
fi
python3 -m verl.trainer.main_ppo \
@@ -38,7 +37,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=$rollout_name \
actor_rollout_ref.rollout.mode=$rollout_mode \
- actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \
+ actor_rollout_ref.rollout.multi_turn.format=hermes \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
diff --git a/examples/ppo_trainer/naive_chat_scheduler.py b/examples/ppo_trainer/naive_chat_scheduler.py
deleted file mode 100644
index 87592c76672..00000000000
--- a/examples/ppo_trainer/naive_chat_scheduler.py
+++ /dev/null
@@ -1,135 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import asyncio
-from typing import Any, Dict, List
-
-import torch
-from openai.types.chat.chat_completion import ChatCompletion
-from tensordict import TensorDict
-
-from verl.protocol import DataProto
-from verl.workers.rollout.async_server import ChatCompletionScheduler
-
-
-class NaiveChatCompletionScheduler(ChatCompletionScheduler):
- """
- A very naive implementation of ChatCompletionScheduler for demo purpose,
- only do single-turn chat completion.
- """
-
- async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataProto:
- kwargs = dict(
- n=self.config.n,
- max_completion_tokens=self.config.response_length,
- temperature=self.config.temperature,
- top_p=self.config.top_p,
- )
-
- do_sample = batch.meta_info.get("do_sample", True)
- is_validate = batch.meta_info.get("validate", False)
- if not do_sample or is_validate:
- kwargs["n"] = 1
- kwargs["temperature"] = 0
-
- kwargs.update(sampling_params)
- print(f"[NaiveChatCompletionScheduler] generate_sequences sampling params: {kwargs}")
-
- async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception):
- assert exception is None, f"exception: {exception}"
- conversation, batch_conversations, batch_index = (
- info["conversation"],
- info["batch_conversations"],
- info["batch_index"],
- )
-
- conversations = []
- for choice in completions.choices:
- chat = conversation.copy()
- chat.append({"role": choice.message.role, "content": choice.message.content})
- conversations.append(chat)
- batch_conversations[batch_index] = conversations
-
- # NOTE: we can call tools and resubmit chat completions here.
- # call_tools(completions, info)
- # await self.submit_chat_completions(callback2, ...)
-
- # TODO: we may need to control max concurrent requests here, or it will harm prefix cache hit rate.
- tasks, batch_conversations = [], [None] * len(batch)
- for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"]):
- # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...]
- tasks.append(
- asyncio.create_task(
- self.submit_chat_completions(
- callback=callback,
- callback_additional_info={
- "batch_conversations": batch_conversations,
- "batch_index": batch_index,
- "conversation": list(conversation),
- },
- model=self.model_name,
- messages=conversation.tolist(),
- **kwargs,
- )
- )
- )
- await asyncio.gather(*tasks)
- print("[NaiveChatCompletionScheduler] generate_sequences done")
-
- return self._postprocess(batch, batch_conversations, kwargs["n"])
-
- def _postprocess(self, batch: DataProto, batch_conversations: List[List[List[Dict[str, str]]]], n: int) -> DataProto:
- # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
- # prompts: left pad
- # responses: right pad
- # input_ids: prompt + response
- # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
- # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
-
- # prompts: [prompt] from input dataset
- prompts = [self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False) for prompt in batch.non_tensor_batch["raw_prompt"]]
-
- # flatten batch_conversations if n > 1
- assert len(batch_conversations) == len(prompts)
- batch_conversations = [conversation for conversations in batch_conversations for conversation in conversations]
- assert len(batch_conversations) == len(prompts) * n
-
- # sequences: [prompt + response]
- sequences = [self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False) for conversation in batch_conversations]
-
- # responses: [response]
- # TODO: mask out tools calling tokens?
- responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)]
-
- prompts = self.tokenizer(prompts, return_tensors="pt", padding="longest", padding_side="left")
- responses = self.tokenizer(responses, return_tensors="pt", padding="longest", padding_side="right")
- if n > 1:
- prompts["input_ids"] = prompts["input_ids"].repeat_interleave(n, dim=0)
- prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(n, dim=0)
-
- input_ids = torch.cat([prompts["input_ids"], responses["input_ids"]], dim=1)
- attention_mask = torch.cat([prompts["attention_mask"], responses["attention_mask"]], dim=1)
- position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask
-
- batch = TensorDict(
- {
- "prompts": prompts["input_ids"],
- "responses": responses["input_ids"],
- "input_ids": input_ids,
- "attention_mask": attention_mask,
- "position_ids": position_ids,
- },
- batch_size=len(input_ids),
- )
-
- return DataProto(batch=batch)
diff --git a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh
index abc490acdc7..72e5f9e847a 100644
--- a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh
+++ b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh
@@ -12,7 +12,6 @@ test_files="['$gsm8k_test_path', '$math_test_path']"
rollout_mode="sync"
if [ "$rollout_mode" = "async" ]; then
return_raw_chat="True"
- chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler
fi
python3 -m verl.trainer.main_ppo \
@@ -38,7 +37,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=$rollout_mode \
- actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \
+ actor_rollout_ref.rollout.multi_turn.format=hermes \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \
critic.optim.lr=1e-5 \
diff --git a/tests/workers/rollout/async_rollout_utils.py b/tests/workers/rollout/async_rollout_utils.py
index bc6186553ef..5328ec3c57b 100644
--- a/tests/workers/rollout/async_rollout_utils.py
+++ b/tests/workers/rollout/async_rollout_utils.py
@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
-from typing import Any, Dict
import ray
from omegaconf import DictConfig
@@ -24,12 +22,7 @@
from verl.workers.rollout.async_server import AsyncLLMServerManager
-def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, Any] = None) -> AsyncLLMServerManager:
- # make openai client happy
- os.environ["no_proxy"] = ""
- os.environ["http_proxy"] = ""
- os.environ["https_proxy"] = ""
-
+def init_async_rollout_manager(config: DictConfig) -> AsyncLLMServerManager:
# =========================== 1. Create hybrid ActorRollout workers ===========================
role_worker_mapping = {
Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker),
@@ -61,9 +54,8 @@ def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, A
# =========================== 2. Create AsyncLLMServerManager ===========================
async_rollout_manager = AsyncLLMServerManager(
- config=config.actor_rollout_ref,
+ config=config,
worker_group=actor_rollout_wg,
- scheduler_kwargs=scheduler_kwargs,
)
return async_rollout_manager
diff --git a/tests/workers/rollout/test_custom_completion_callback.py b/tests/workers/rollout/test_custom_completion_callback.py
new file mode 100644
index 00000000000..4e808e45f65
--- /dev/null
+++ b/tests/workers/rollout/test_custom_completion_callback.py
@@ -0,0 +1,280 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import asyncio
+import concurrent.futures
+import os
+import re
+import socket
+import sys
+import tempfile
+from contextlib import asynccontextmanager
+from typing import Any, Dict, List
+
+import fastapi
+import numpy as np
+import ray
+import uvicorn
+from datasets import load_dataset
+from omegaconf import DictConfig, OmegaConf
+from openai.types.chat.chat_completion import ChatCompletion
+from starlette.requests import Request
+from starlette.responses import JSONResponse
+
+from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager
+from verl.protocol import DataProto
+from verl.utils import hf_tokenizer
+from verl.utils.reward_score.sandbox_fusion.utils import _process_single_case
+from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler, ToolCompletionCallback
+
+
+def _get_free_port():
+ with socket.socket() as sock:
+ sock.bind(("", 0))
+ return sock.getsockname()[1]
+
+
+@ray.remote(num_cpus=1)
+class Sandbox:
+ """Sandbox to execute python code.
+
+ WARNING: This class is for testing purpose only, do not use it in production.
+ Please use a sandbox with strong isolation and security restrictions instead.
+ """
+
+ def __init__(self):
+ self.address = ray._private.services.get_node_ip_address()
+ self.port = None
+ self.server_ready = asyncio.Event()
+ asyncio.create_task(self._start_fastapi_server())
+
+ async def code_execution(self, request: Request):
+ request_json = await request.json()
+ code = request_json["code"]
+ print(f"execute code:\n{code}")
+
+ _, temp_file = tempfile.mkstemp(suffix=".py", prefix="temp_code", dir=None, text=True)
+ with open(temp_file, "w") as f:
+ f.write(code)
+
+ try:
+ process = await asyncio.create_subprocess_exec(sys.executable, temp_file, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
+
+ stdout, stderr = await process.communicate()
+
+ response = {
+ "status": "Success" if process.returncode == 0 else "Failed",
+ "run_result": {
+ "status": "Finished",
+ "stdout": stdout.decode(),
+ "stderr": stderr.decode(),
+ "return_code": process.returncode,
+ },
+ }
+ return JSONResponse(content=response)
+ finally:
+ try:
+ os.unlink(temp_file)
+ except: # noqa: E722
+ pass
+
+ async def _start_fastapi_server(self):
+ @asynccontextmanager
+ async def lifespan(app: fastapi.FastAPI):
+ print("FastAPI startup")
+ self.server_ready.set()
+ yield
+
+ print("FastAPI shutdown, maybe address already in use, exit process immediately.")
+ os._exit(-1)
+
+ app = fastapi.FastAPI(lifespan=lifespan)
+ app.router.add_api_route("/run_code", self.code_execution, methods=["POST"])
+
+ self.port = _get_free_port()
+ config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning")
+ server = uvicorn.Server(config)
+ await server.serve()
+
+ async def get_server_address(self) -> str:
+ """Get FastAPI server address."""
+ await self.server_ready.wait()
+ return f"{self.address}:{self.port}"
+
+
+class CustomCompletionCallback(ToolCompletionCallback):
+ def __init__(self, config: DictConfig, scheduler: ChatCompletionScheduler):
+ super().__init__(config, scheduler)
+
+ self.max_turns = 16
+ self.answer_pattern = re.compile(r"(.*?)", re.DOTALL)
+ self.code_pattern = re.compile(r"\s*```python(.*?)```\s*", re.DOTALL)
+
+ self.sandbox_fusion_url = config.reward_model.sandbox_fusion.url
+ self.default_timeout = 10
+ # TODO: support asyncio executor
+ self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max(32, os.cpu_count() * 5))
+
+ async def sandbox_code_execution(self, code: str) -> Dict[str, Any]:
+ loop = asyncio.get_running_loop()
+ result_status, metadata = await loop.run_in_executor(
+ self.executor,
+ _process_single_case,
+ 0, # case_index,
+ None, # stdin_data,
+ None, # expected_output,
+ self.sandbox_fusion_url, # sandbox_fusion_url
+ code, # generation
+ self.default_timeout, # timeout
+ "python", # language
+ )
+
+ return metadata
+
+ @property
+ def extra_body(self):
+ extra = {
+ "include_stop_str_in_output": True,
+ "stop": ["", ""],
+ }
+ return extra
+
+ async def __call__(self, messages: List[Dict[str, str]], completions: ChatCompletion, info: Dict[str, Any]):
+ role, content, finish_reason = completions.choices[0].message.role, completions.choices[0].message.content, completions.choices[0].finish_reason
+ messages.append({"role": role, "content": content})
+ turn = len(messages)
+
+ # STEP 0: check if we reach max turns
+ if len(messages) >= self.max_turns:
+ print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Reach max turns, done!")
+ return
+
+ # STEP 1: check if we reach max tokens
+ if finish_reason == "length":
+ print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Reach max tokens, done!")
+ return
+
+ # STEP 2: check if we got answer
+ matches = self.answer_pattern.findall(content)
+ if matches:
+ print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Got answer: {matches[0]}, done!")
+ return
+
+ # STEP 3: check if we got code block
+ matches = self.code_pattern.findall(content)
+ if not matches:
+ print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] No code block found, done!")
+ return
+
+ # STEP 4: execute code block in sandbox
+ code = matches[0].strip()
+ metadata = await self.sandbox_code_execution(code)
+ if metadata["run_status"] != "Finished":
+ print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Code block execution failed: {metadata}, done!")
+ return
+
+ stdout, stderr = metadata["stdout"], metadata["stderr"]
+ messages.append({"role": "tool", "content": f"{stdout}{stderr}"})
+ print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Code block executed, continue...")
+
+ # STEP 5: resubmit chat completions with code block output
+ self.scheduler.submit_chat_completions(
+ messages=messages,
+ request_id=completions.id,
+ info=info,
+ )
+
+
+user_prompt_template = """
+You are a helpful assistant. Let's solve math problem in following steps:
+1. Write a python code first and return the code to user, the code must be in following format:
+
+
+```python
+import os
+
+print(...)
+```
+
+
+The code must explictly print necessary output to stdout. Remember stop generation at immediately and return the code.
+2. User will send the python code to a external sandbox to execute and get output from stdout.
+3. User will send the output in format output to you, and you should use the output to answer the question.
+The answer format must be: \\boxed{'The final answer goes here.'}
+
+*user question:*
+{question}
+"""
+
+
+if __name__ == "__main__":
+ ray.init(
+ runtime_env={
+ "env_vars": {
+ "TOKENIZERS_PARALLELISM": "true",
+ "NCCL_DEBUG": "WARN",
+ "VLLM_LOGGING_LEVEL": "INFO",
+ "VLLM_USE_V1": "1",
+ }
+ }
+ )
+
+ # Load config
+ config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml")
+ model_path = "Qwen/Qwen2.5-1.5B-Instruct"
+ config.actor_rollout_ref.model.path = model_path
+ config.actor_rollout_ref.rollout.mode = "async"
+ config.actor_rollout_ref.rollout.multi_turn.format = "hermes"
+ config.actor_rollout_ref.rollout.multi_turn.completion_callback = "tests.workers.rollout.test_custom_completion_callback.CustomCompletionCallback"
+ config.actor_rollout_ref.rollout.prompt_length = 4096
+ config.actor_rollout_ref.rollout.response_length = 4096
+ config.actor_rollout_ref.rollout.n = 4
+
+ # Init sandbox and async rollout manager
+ sandbox = Sandbox.options(num_cpus=1).remote()
+ sandbox_address = ray.get(sandbox.get_server_address.remote())
+ sandbox_fusion_url = f"http://{sandbox_address}/run_code"
+ config.reward_model.sandbox_fusion.url = sandbox_fusion_url
+ async_rollout_manager = init_async_rollout_manager(config)
+
+ # Build dataset
+ dataset = load_dataset("Maxwell-Jia/AIME_2024", split="train")
+ prompts = DataProto(
+ non_tensor_batch={
+ "raw_prompt": np.array([[{"role": "user", "content": user_prompt_template.replace("{question}", problem)}] for problem in dataset["Problem"]]),
+ },
+ )
+
+ result = async_rollout_manager.generate_sequences(prompts=prompts)
+ assert len(result) == len(dataset) * config.actor_rollout_ref.rollout.n
+
+ # Check max turns that sandbox is called
+ num_turns = result.non_tensor_batch["__num_turns__"]
+ print(f"num_turns: {num_turns}")
+ assert np.max(num_turns) > 2, f"max turns: {np.max(num_turns)}"
+
+ # Check response_mask
+ tokenizer = hf_tokenizer(config.actor_rollout_ref.model.path)
+ responses = result.batch["responses"]
+ response_mask = result.batch["response_mask"]
+ assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
+
+ # Decode responses with response_mask
+ for i in range(len(responses)):
+ valid_tokens = responses[i][response_mask[i].bool()]
+ response_str = tokenizer.decode(valid_tokens)
+ assert "" not in response_str, f"found in response: {response_str}"
+ assert "" not in response_str, f"found in response: {response_str}"
+ print(f"response: {response_str}")
+
+ print("Test passed!")
diff --git a/tests/workers/rollout/test_vllm_chat_scheduler.py b/tests/workers/rollout/test_vllm_chat_scheduler.py
new file mode 100644
index 00000000000..23550bff776
--- /dev/null
+++ b/tests/workers/rollout/test_vllm_chat_scheduler.py
@@ -0,0 +1,233 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+from typing import Any, Tuple
+
+import numpy as np
+import pytest
+import ray
+from omegaconf import DictConfig, OmegaConf
+from transformers.utils import get_json_schema
+
+from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager
+from verl.protocol import DataProto
+from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema
+from verl.utils import hf_tokenizer
+
+
+@pytest.fixture
+def init_config() -> DictConfig:
+ config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml")
+ model_path = "Qwen/Qwen2.5-1.5B-Instruct"
+ config.actor_rollout_ref.model.path = model_path
+ config.actor_rollout_ref.rollout.mode = "async"
+ config.actor_rollout_ref.rollout.multi_turn.format = "hermes"
+ config.actor_rollout_ref.rollout.prompt_length = 4096
+ config.actor_rollout_ref.rollout.response_length = 4096
+
+ # test sleep/wake_up with fsdp offload
+ config.actor_rollout_ref.actor.fsdp_config.param_offload = True
+ config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
+
+ return config
+
+
+def test_vllm_async_rollout_without_tool_calls(init_config):
+ ray.init(
+ runtime_env={
+ "env_vars": {
+ "TOKENIZERS_PARALLELISM": "true",
+ "NCCL_DEBUG": "WARN",
+ "VLLM_LOGGING_LEVEL": "INFO",
+ "VLLM_USE_V1": "1",
+ }
+ }
+ )
+
+ # =========================== 1. Init rollout manager ===========================
+ async_rollout_manager = init_async_rollout_manager(init_config)
+
+ # test sleep and wake_up
+ async_rollout_manager.sleep()
+ async_rollout_manager.wake_up()
+
+ # =========================== 2. Generate sequences ===========================
+ raw_prompts = [
+ [
+ {
+ "role": "user",
+ "content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.",
+ }
+ ],
+ [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}],
+ ]
+ batch = DataProto(
+ non_tensor_batch={
+ "raw_prompt": np.array(raw_prompts),
+ },
+ )
+ result = async_rollout_manager.generate_sequences(prompts=batch)
+
+ # check result
+ seq_len = result.batch["prompts"].size(1) + result.batch["responses"].size(1)
+ assert len(result) == 2
+ assert result.batch["input_ids"].size(1) == seq_len
+ assert result.batch["attention_mask"].size(1) == seq_len
+ assert result.batch["position_ids"].size(1) == seq_len
+
+ # check turns
+ num_turns = result.non_tensor_batch["__num_turns__"]
+ assert np.all(num_turns == 2)
+
+ print("Test passed!")
+ ray.shutdown()
+
+
+class WeatherTool(BaseTool):
+ def get_current_temperature(self, location: str, unit: str = "celsius"):
+ """Get current temperature at a location.
+
+ Args:
+ location: The location to get the temperature for, in the format "City, State, Country".
+ unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
+
+ Returns:
+ the temperature, the location, and the unit in a dict
+ """
+ return {
+ "temperature": 26.1,
+ "location": location,
+ "unit": unit,
+ }
+
+ def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
+ schema = get_json_schema(self.get_current_temperature)
+ return OpenAIFunctionToolSchema(**schema)
+
+ async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
+ try:
+ result = self.get_current_temperature(**parameters)
+ return json.dumps(result), 0, {}
+ except Exception as e:
+ return str(e), 0, {}
+
+
+class WeatherToolWithData(BaseTool):
+ def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
+ schema = get_json_schema(self.get_temperature_date)
+ return OpenAIFunctionToolSchema(**schema)
+
+ def get_temperature_date(self, location: str, date: str, unit: str = "celsius"):
+ """Get temperature at a location and date.
+
+ Args:
+ location: The location to get the temperature for, in the format "City, State, Country".
+ date: The date to get the temperature for, in the format "Year-Month-Day".
+ unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
+
+ Returns:
+ the temperature, the location, the date and the unit in a dict
+ """
+ return {
+ "temperature": 25.9,
+ "location": location,
+ "date": date,
+ "unit": unit,
+ }
+
+ async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
+ try:
+ result = self.get_temperature_date(**parameters)
+ return json.dumps(result), 0, {}
+ except Exception as e:
+ return str(e), 0, {}
+
+
+def test_vllm_async_rollout_with_tool_calls(init_config):
+ ray.init(
+ runtime_env={
+ "env_vars": {
+ "TOKENIZERS_PARALLELISM": "true",
+ "NCCL_DEBUG": "WARN",
+ "VLLM_LOGGING_LEVEL": "INFO",
+ "VLLM_USE_V1": "1",
+ }
+ }
+ )
+
+ # =========================== 1. Init rollout manager ===========================
+ tool_config = {
+ "tools": [
+ {
+ "class_name": "tests.workers.rollout.test_vllm_chat_scheduler.WeatherTool",
+ "config": {},
+ },
+ {
+ "class_name": "tests.workers.rollout.test_vllm_chat_scheduler.WeatherToolWithData",
+ "config": {},
+ },
+ ]
+ }
+ tool_config_path = "/tmp/tool_config.json"
+ with open(tool_config_path, "w") as f:
+ json.dump(tool_config, f)
+
+ init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path
+ async_rollout_manager = init_async_rollout_manager(init_config)
+
+ # =========================== 2. Generate sequences ===========================
+ raw_prompts = [
+ [
+ {"role": "user", "content": "How are you?"},
+ ],
+ [
+ {"role": "user", "content": "What's the temperature in Los Angeles now?"},
+ ],
+ [
+ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\nCurrent Date: 2024-09-30"},
+ {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"},
+ ],
+ ]
+ batch = DataProto(
+ non_tensor_batch={
+ "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
+ },
+ )
+ result = async_rollout_manager.generate_sequences(prompts=batch)
+
+ # Check turns
+ num_turns = result.non_tensor_batch["__num_turns__"]
+ # [user, assistant]
+ assert num_turns[0] == 2
+ # [user, assistant, tool, assistant]
+ assert num_turns[1] == 4
+ # [system, user, assistant, tool, tool, assistant]
+ assert num_turns[2] == 6
+
+ # Check response_mask
+ tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
+ responses = result.batch["responses"]
+ response_mask = result.batch["response_mask"]
+ assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
+
+ # Decode responses with response_mask
+ for i in range(len(responses)):
+ valid_tokens = responses[i][response_mask[i].bool()]
+ response_str = tokenizer.decode(valid_tokens)
+ assert "" not in response_str, f"found in response: {response_str}"
+ assert "" not in response_str, f"found in response: {response_str}"
+ print(f"response: {response_str}")
+
+ print("Test passed!")
+ ray.shutdown()
diff --git a/tests/workers/rollout/test_vllm_multi_turn.py b/tests/workers/rollout/test_vllm_multi_turn.py
deleted file mode 100644
index b705d86a9ca..00000000000
--- a/tests/workers/rollout/test_vllm_multi_turn.py
+++ /dev/null
@@ -1,195 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import asyncio
-import json
-from typing import Any, Dict
-
-import numpy as np
-import ray
-from omegaconf import DictConfig, OmegaConf
-from openai.types.chat.chat_completion import ChatCompletion
-from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ErrorResponse
-
-from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager
-from verl.protocol import DataProto
-
-
-def init_config() -> DictConfig:
- config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml")
- model_path = "Qwen/Qwen2-7B-Instruct"
- config.actor_rollout_ref.model.path = model_path
- config.actor_rollout_ref.rollout.mode = "async"
- config.actor_rollout_ref.rollout.chat_scheduler = "examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler"
- config.actor_rollout_ref.rollout.prompt_length = 4096
- config.actor_rollout_ref.rollout.response_length = 4096
-
- # test sleep/wake_up with fsdp offload
- config.actor_rollout_ref.actor.fsdp_config.param_offload = True
- config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
-
- return config
-
-
-def test_vllm_multi_turn(config):
- ray.init(
- runtime_env={
- "env_vars": {
- "TOKENIZERS_PARALLELISM": "true",
- "NCCL_DEBUG": "WARN",
- "VLLM_LOGGING_LEVEL": "WARN",
- "VLLM_USE_V1": "1",
- }
- }
- )
-
- # =========================== 1. Init rollout manager ===========================
- model_name = "/".join(config.actor_rollout_ref.model.path.split("/")[-2:])
- async_rollout_manager = init_async_rollout_manager(config)
-
- # test sleep and wake_up
- async_rollout_manager.sleep()
- async_rollout_manager.wake_up()
-
- async_chat_scheduler = async_rollout_manager.chat_scheduler
-
- # =========================== 2. Multi turn rollout ===========================
- async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception):
- assert exception is None, f"exception: {exception}"
- messages, round = info["messages"], info["round"]
- message = completions.choices[0].message
- messages.append({"role": message.role, "content": message.content})
- print(f"[round={round}] role: {message.role}, content: {message.content}")
-
- extra_headers = {"x-request-id": completions.id}
- if round == 0:
- messages.append({"role": "user", "content": "What is your name?"})
- await async_chat_scheduler.submit_chat_completions(
- callback=callback,
- callback_additional_info={"messages": messages, "round": 1},
- model=model_name,
- messages=messages,
- extra_headers=extra_headers,
- )
- elif round == 1:
- messages.append({"role": "user", "content": "What is your favorite color?"})
- await async_chat_scheduler.submit_chat_completions(
- callback=callback,
- callback_additional_info={"messages": messages, "round": 2},
- model=model_name,
- messages=messages,
- extra_headers=extra_headers,
- )
- else:
- print("Done!")
-
- messages = [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}]
- async_rollout_manager.submit_chat_completions(
- callback=callback,
- callback_additional_info={"messages": messages, "round": 0},
- model=model_name,
- messages=messages,
- )
- assert len(messages) == 6
- for round, message in enumerate(messages):
- if round % 2 == 0:
- assert message["role"] == "user"
- else:
- assert message["role"] == "assistant"
-
- # =========================== 3. Generate sequences ===========================
- raw_prompts = [
- [
- {
- "role": "user",
- "content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.",
- }
- ],
- [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}],
- ]
- batch = DataProto(
- non_tensor_batch={
- "raw_prompt": np.array(raw_prompts),
- },
- )
- result = async_rollout_manager.generate_sequences(prompts=batch)
- seq_len = result.batch["prompts"].size(1) + result.batch["responses"].size(1)
- assert len(result) == 2
- assert result.batch["input_ids"].size(1) == seq_len
- assert result.batch["attention_mask"].size(1) == seq_len
- assert result.batch["position_ids"].size(1) == seq_len
-
- ray.shutdown()
-
-
-async def test_vllm_streaming_response(config):
- ray.init(
- runtime_env={
- "env_vars": {
- "TOKENIZERS_PARALLELISM": "true",
- "NCCL_DEBUG": "WARN",
- "VLLM_LOGGING_LEVEL": "WARN",
- "VLLM_USE_V1": "1",
- }
- }
- )
-
- model_name = "/".join(config.actor_rollout_ref.model.path.split("/")[-2:])
- async_rollout_manager = init_async_rollout_manager(config)
- async_llm_server = async_rollout_manager.async_llm_servers[0]
-
- # non-streaming request
- request = ChatCompletionRequest(
- model=model_name,
- messages=[{"role": "user", "content": "What is your name?"}],
- stream=False,
- )
- generator = async_llm_server.chat_completion_generator.remote(request)
- async for ref in generator:
- status_code, data = await ref
- print(f">>>> status_code: {status_code}, {data}")
- data = data[len("data: ") :].rstrip()
- if status_code != 200:
- response = ErrorResponse(**json.loads(data))
- else:
- response = ChatCompletionResponse(**json.loads(data))
- assert response.choices[0].message.role == "assistant"
- assert response.choices[0].message.content is not None
-
- # streaming request
- request = ChatCompletionRequest(
- model=model_name,
- messages=[{"role": "user", "content": "How are you?"}],
- stream=True,
- )
- generator = async_llm_server.chat_completion_generator.remote(request)
- async for ref in generator:
- status_code, data = await ref
- print(f">>>> status_code: {status_code}, {data}")
- data = data[len("data: ") :].rstrip()
- if status_code != 200:
- response = ErrorResponse(**json.loads(data))
- elif data == "[DONE]":
- break
- else:
- response = ChatCompletionStreamResponse(**json.loads(data))
- assert response.choices[0].delta.role is None or response.choices[0].delta.role == "assistant"
- assert response.choices[0].delta.content is not None
-
- ray.shutdown()
-
-
-if __name__ == "__main__":
- config = init_config()
- test_vllm_multi_turn(config)
- asyncio.run(test_vllm_streaming_response(config))
diff --git a/tests/workers/rollout/test_vllm_tool_calling.py b/tests/workers/rollout/test_vllm_tool_calling.py
deleted file mode 100644
index efc8fbf547b..00000000000
--- a/tests/workers/rollout/test_vllm_tool_calling.py
+++ /dev/null
@@ -1,278 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import asyncio
-import os
-import re
-import socket
-import sys
-import tempfile
-from contextlib import asynccontextmanager
-from typing import Any, Dict
-
-import aiohttp
-import fastapi
-import numpy as np
-import ray
-import uvicorn
-from datasets import load_dataset
-from omegaconf import OmegaConf
-from openai.types.chat.chat_completion import ChatCompletion
-from starlette.requests import Request
-from starlette.responses import JSONResponse
-
-from examples.ppo_trainer.naive_chat_scheduler import NaiveChatCompletionScheduler
-from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager
-from verl.protocol import DataProto
-
-
-def _get_free_port():
- with socket.socket() as sock:
- sock.bind(("", 0))
- return sock.getsockname()[1]
-
-
-@ray.remote(num_cpus=1)
-class Sandbox:
- """Sandbox to execute python code.
-
- WARNING: This class is for testing purpose only, do not use it in production.
- Please use a sandbox with strong isolation and security restrictions instead.
- """
-
- def __init__(self):
- self.address = ray._private.services.get_node_ip_address()
- self.port = None
- self.server_ready = asyncio.Event()
- asyncio.create_task(self._start_fastapi_server())
-
- async def code_execution(self, request: Request):
- request_json = await request.json()
- code = request_json["code"]
- print(f"execute code:\n{code}")
-
- _, temp_file = tempfile.mkstemp(suffix=".py", prefix="temp_code", dir=None, text=True)
- with open(temp_file, "w") as f:
- f.write(code)
-
- try:
- process = await asyncio.create_subprocess_exec(sys.executable, temp_file, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
-
- stdout, stderr = await process.communicate()
-
- return JSONResponse(content={"stdout": stdout.decode(), "stderr": stderr.decode(), "returncode": process.returncode})
- finally:
- try:
- os.unlink(temp_file)
- except: # noqa: E722
- pass
-
- async def _start_fastapi_server(self):
- @asynccontextmanager
- async def lifespan(app: fastapi.FastAPI):
- print("FastAPI startup")
- self.server_ready.set()
- yield
-
- print("FastAPI shutdown, maybe address already in use, exit process immediately.")
- os._exit(-1)
-
- app = fastapi.FastAPI(lifespan=lifespan)
- app.router.add_api_route("/code/execution", self.code_execution, methods=["POST"])
-
- self.port = _get_free_port()
- config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning")
- server = uvicorn.Server(config)
- await server.serve()
-
- async def get_server_address(self) -> str:
- """Get FastAPI server address."""
- await self.server_ready.wait()
- return f"{self.address}:{self.port}"
-
-
-class ToolChatCompletionScheduler(NaiveChatCompletionScheduler):
- """This is a demo chat completion scheduler that supports sandbox code execution
- described in ReTool paper: https://arxiv.org/pdf/2504.11536
- """
-
- def __init__(self, config, model_path, server_addresses, sandbox_address, system_prompt, **kwargs):
- super().__init__(config, model_path, server_addresses, **kwargs)
- self.sandbox_address = sandbox_address
- self.system_prompt = system_prompt
-
- async def sandbox_code_execution(self, code: str) -> Dict[str, Any]:
- """Execute python code in sandbox."""
- try:
- session = aiohttp.ClientSession()
- async with session.post(
- url=f"http://{self.sandbox_address}/code/execution",
- json={"code": code},
- ) as resp:
- return await resp.json()
- finally:
- await session.close()
-
- async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataProto:
- kwargs = dict(
- n=self.config.n,
- max_completion_tokens=self.config.response_length,
- temperature=self.config.temperature,
- top_p=self.config.top_p,
- extra_body={
- "include_stop_str_in_output": True,
- "stop": ["", ""],
- },
- )
-
- do_sample = batch.meta_info.get("do_sample", True)
- is_validate = batch.meta_info.get("validate", False)
- if not do_sample or is_validate:
- kwargs["n"] = 1
- kwargs["temperature"] = 0
-
- kwargs.update(sampling_params)
- print(f"[ToolChatCompletionScheduler] generate_sequences sampling params: {kwargs}")
-
- max_turns = 3
-
- async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception):
- batch_conversations, batch_index, turn = (
- info["batch_conversations"],
- info["batch_index"],
- info["turn"],
- )
- role, content = completions.choices[0].message.role, completions.choices[0].message.content
- batch_conversations[batch_index].append({"role": role, "content": content})
-
- # STEP 0: check if we reach max turns
- if turn == max_turns:
- print(f"[id={completions.id},turn={turn}] Reach max turns {max_turns}, done!")
- return
-
- # STEP 1: check if we got answer
- matches = re.findall(r"(.*?)", content, re.DOTALL)
- if matches:
- print(f"[id={completions.id},turn={turn}] Got answer: {matches[0]}, done!")
- return
-
- # STEP 2: check if we got code block
- matches = re.findall(r"\s*```python(.*?)```\s*", content, re.DOTALL)
- if not matches:
- print(f"[id={completions.id},turn={turn}] No code block found, done!")
- return
-
- # STEP 3: execute code block in sandbox
- code = matches[0].strip()
- result = await self.sandbox_code_execution(code)
- stdout, stderr = result["stdout"], result["stderr"]
- batch_conversations[batch_index].append({"role": "tool", "content": f"{stdout}{stderr}"})
- print(f"[id={completions.id},turn={turn}] Code block executed, continue...")
-
- # STEP 4: resubmit chat completions with code block output
- extra_headers = {"x-request-id": completions.id}
- await self.submit_chat_completions(
- callback=callback,
- callback_additional_info={
- "batch_conversations": batch_conversations,
- "batch_index": batch_index,
- "turn": turn + 1,
- },
- model=self.model_name,
- messages=batch_conversations[batch_index],
- extra_headers=extra_headers,
- **kwargs,
- )
-
- tasks, batch_conversations = [], [None] * len(batch)
- for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"]):
- # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...]
- batch_conversations[batch_index] = [{"role": "system", "content": self.system_prompt}] + list(conversation)
- tasks.append(
- asyncio.create_task(
- self.submit_chat_completions(
- callback=callback,
- callback_additional_info={
- "batch_conversations": batch_conversations,
- "batch_index": batch_index,
- "turn": 1,
- },
- model=self.model_name,
- messages=batch_conversations[batch_index],
- **kwargs,
- )
- )
- )
-
- await asyncio.gather(*tasks)
- print("[NaiveChatCompletionScheduler] generate_sequences done")
-
- # _postprocess assumes n>=1
- batch_conversations = [[conversation] for conversation in batch_conversations]
- return self._postprocess(batch, batch_conversations, kwargs["n"])
-
-
-system_prompt = """
-You are a helpful assistant. Let's solve math problem in following steps:
-1. Write a python code first and return the code to user, the code must be in following format:
-
-
-```python
-import os
-
-print(...)
-```
-
-
-The code must explictly print necessary output to stdout. Remember stop generation at immediately and return the code.
-2. User will send the python code to a external sandbox to execute and get output from stdout.
-3. User will send the output in format output to you, and you should use the output to answer the question.
-The answer format must be: \\boxed{'The final answer goes here.'}
-"""
-
-
-def test_vllm_tool_calling():
- ray.init(
- runtime_env={
- "env_vars": {
- "TOKENIZERS_PARALLELISM": "true",
- "NCCL_DEBUG": "WARN",
- "VLLM_LOGGING_LEVEL": "INFO",
- "VLLM_USE_V1": "1",
- }
- }
- )
-
- # Load config
- config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml")
- config.actor_rollout_ref.model.path = "Qwen/Qwen2-7B-Instruct"
- config.actor_rollout_ref.rollout.mode = "async"
- config.actor_rollout_ref.rollout.chat_scheduler = "tests.workers.rollout.test_vllm_tool_calling.ToolChatCompletionScheduler"
- config.actor_rollout_ref.rollout.prompt_length = 8192
- config.actor_rollout_ref.rollout.response_length = 8192
-
- # Init sandbox and async rollout manager
- sandbox = Sandbox.options(num_cpus=1).remote()
- sandbox_address = ray.get(sandbox.get_server_address.remote())
- async_rollout_manager = init_async_rollout_manager(config, scheduler_kwargs={"sandbox_address": sandbox_address, "system_prompt": system_prompt})
-
- # Build dataset
- dataset = load_dataset("Maxwell-Jia/AIME_2024", split="train")
- prompts = DataProto(non_tensor_batch={"raw_prompt": np.array([[{"role": "user", "content": problem}] for problem in dataset["Problem"]])})
-
- result = async_rollout_manager.generate_sequences(prompts=prompts)
- assert len(result) == len(dataset)
-
-
-if __name__ == "__main__":
- test_vllm_tool_calling()
diff --git a/verl/tools/base_tool.py b/verl/tools/base_tool.py
index e627b99f2bc..6642a5329d4 100644
--- a/verl/tools/base_tool.py
+++ b/verl/tools/base_tool.py
@@ -12,9 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Optional, Tuple
+import importlib
+import json
+import sys
+from typing import Any, List, Optional, Tuple
from uuid import uuid4
+from omegaconf import OmegaConf
+
from .schemas import OpenAIFunctionToolSchema
@@ -32,8 +37,10 @@ class BaseTool:
def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
self.config = config
- self.name = tool_schema.function.name
- self.tool_schema = tool_schema
+ self.tool_schema = tool_schema or self.get_openai_tool_schema()
+ assert self.tool_schema is not None, "Tool schema is not set!"
+ self.name = self.tool_schema.function.name
+ print(json.dumps(self.tool_schema.model_dump(exclude_unset=True, exclude_none=True), indent=2))
def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
return self.tool_schema
@@ -84,3 +91,41 @@ async def release(self, instance_id: str, **kwargs) -> None:
instance_id: The instance id of the tool.
"""
pass
+
+
+def initialize_tools_from_config(tools_config_file) -> List[BaseTool]:
+ """Initialize tools from config file.
+
+ Args:
+ tools_config_file: The config file of the tools.
+
+ Returns:
+ A list of tools.
+ """
+ tools_config = OmegaConf.load(tools_config_file)
+
+ tool_list = []
+ for tool_config in tools_config.tools:
+ cls_name = tool_config.class_name
+ module_name, class_name = cls_name.rsplit(".", 1)
+
+ if module_name not in sys.modules:
+ spec = importlib.util.find_spec(module_name)
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module)
+ else:
+ module = sys.modules[module_name]
+
+ tool_cls = getattr(module, class_name)
+
+ if tool_config.get("tool_schema", None) is None:
+ tool_schema = None
+ else:
+ tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True)
+ tool_schema = OpenAIFunctionToolSchema.parse_obj(tool_schema_dict)
+
+ tool = tool_cls(config=OmegaConf.to_container(tool_config.config, resolve=True), tool_schema=tool_schema)
+ tool_list.append(tool)
+
+ return tool_list
diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml
index 929ffd43e67..e0814c35119 100644
--- a/verl/trainer/config/ppo_trainer.yaml
+++ b/verl/trainer/config/ppo_trainer.yaml
@@ -98,7 +98,6 @@ actor_rollout_ref:
rollout:
name: vllm
mode: sync # sync: LLM, async: AsyncLLM
- chat_scheduler: null # async chat scheduler, e.g examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1
@@ -144,6 +143,7 @@ actor_rollout_ref:
max_turns: null # null for no limit (default max_length // 3)
tool_config_path: null # null for no tool
format: chatml # chatml, more formats will be supported in the future
+ completion_callback: null # null for default callback
critic:
rollout_n: ${actor_rollout_ref.rollout.n}
diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py
index 6ff42d7e02a..f246f47c354 100644
--- a/verl/trainer/ppo/metric_utils.py
+++ b/verl/trainer/ppo/metric_utils.py
@@ -50,12 +50,12 @@ def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
"""
Computes information about prompts and responses from a batch.
-
+
This is an internal helper function that extracts masks and lengths for prompts and responses.
-
+
Args:
batch: A DataProto object containing batch data with responses and attention masks.
-
+
Returns:
A dictionary containing:
- response_mask: Attention mask for the response tokens
@@ -172,9 +172,9 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str,
def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]:
"""
Computes timing metrics for different processing stages in PPO training.
-
- This function calculates both raw timing metrics (in seconds) and per-token timing metrics
- (in milliseconds) for various processing stages like generation, reference computation,
+
+ This function calculates both raw timing metrics (in seconds) and per-token timing metrics
+ (in milliseconds) for various processing stages like generation, reference computation,
value computation, advantage computation, and model updates.
Args:
@@ -211,23 +211,23 @@ def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Di
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:
"""
Computes throughput metrics for PPO training.
-
+
This function calculates performance metrics related to token processing speed,
including the total number of tokens processed, time per step, and throughput
(tokens per second per GPU).
-
+
Args:
batch: A DataProto object containing batch data with meta information about token counts.
timing_raw: A dictionary mapping stage names to their execution times in seconds.
Must contain a "step" key with the total step time.
n_gpus: Number of GPUs used for training.
-
+
Returns:
A dictionary containing:
- perf/total_num_tokens: Total number of tokens processed in the batch
- perf/time_per_step: Time taken for the step in seconds
- perf/throughput: Tokens processed per second per GPU
-
+
Note:
The throughput is calculated as total_tokens / (time * n_gpus) to normalize
across different GPU counts.
@@ -324,12 +324,12 @@ def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> flo
def process_validation_metrics(data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42) -> dict[str, dict[str, dict[str, float]]]:
"""
Process validation metrics into a structured format with statistical analysis.
-
+
This function organizes validation metrics by data source and prompt, then computes
various statistical measures including means, standard deviations, best/worst values,
and majority voting results. It also performs bootstrap sampling to estimate statistics
for different sample sizes.
-
+
Args:
data_sources: List of data source identifiers for each sample.
sample_inputs: List of input prompts corresponding to each sample.
@@ -345,7 +345,7 @@ def process_validation_metrics(data_sources: list[str], sample_inputs: list[str]
}
}
}
-
+
Where metric_name includes:
- "mean@N": Mean value across N samples
- "std@N": Standard deviation across N samples
@@ -355,7 +355,7 @@ def process_validation_metrics(data_sources: list[str], sample_inputs: list[str]
- "worst@N/std": Standard deviation of the worst values in bootstrap samples
- "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists)
- "maj@N/std": Standard deviation of majority voting results (if "pred" exists)
-
+
Example:
>>> data_sources = ["source1", "source1", "source2"]
>>> sample_inputs = ["prompt1", "prompt1", "prompt2"]
diff --git a/verl/workers/rollout/async_server.py b/verl/workers/rollout/async_server.py
index d693f39242b..441d59f0ad5 100644
--- a/verl/workers/rollout/async_server.py
+++ b/verl/workers/rollout/async_server.py
@@ -12,31 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
-import heapq
-import importlib
import logging
import os
import socket
import threading
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
-from typing import Any, Callable, Dict, List, Tuple, Type
-from uuid import uuid4
+from typing import Any, Dict, List, Tuple, Type
-import aiohttp
import fastapi
import ray
import uvicorn
-from cachetools import LRUCache
from omegaconf import DictConfig
-from openai import AsyncOpenAI
-from openai.types.chat.chat_completion import ChatCompletion
from starlette.requests import Request
from verl.protocol import DataProto
from verl.single_controller.ray.base import RayWorkerGroup
-from verl.utils import hf_tokenizer
-from verl.utils.fs import copy_to_local
+from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler
logger = logging.getLogger(__file__)
@@ -59,7 +51,7 @@ def __init__(self):
async def _start_fastapi_server(self):
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
- print("FastAPI startup")
+ print(f"FastAPI listen on {self.address}:{self.port}")
self.server_ready.set()
yield
@@ -105,130 +97,19 @@ async def sleep(self):
raise NotImplementedError
-class ChatCompletionScheduler:
- def __init__(
- self,
- config: DictConfig,
- model_path: str,
- server_addresses: List[str],
- max_cache_size: int = 10000,
- ):
- """
- Args:
- config: DictConfig, rollout config.
- model_path: str, model path.
- server_addresses: List[str], server addresses.
- max_cache_size: int, max cache size of request_id to address mapping.
- """
- self.config = config
- self.model_name = "/".join(model_path.split("/")[-2:])
- local_path = copy_to_local(model_path)
- self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True)
-
- # Least requests load balancing
- self.weighted_addresses = [[0, address] for address in server_addresses]
- heapq.heapify(self.weighted_addresses)
-
- # LRU cache to map request_id to address
- self.request_id_to_address = LRUCache(maxsize=max_cache_size)
-
- async def submit_chat_completions(
- self,
- callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None],
- callback_additional_info: Dict[str, Any],
- **chat_complete_request,
- ):
- """
- Submit a chat completion request to the server with the least number of requests.
-
- Args:
- callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], async callback function
- to handle the response. The callback function should have the following signature:
-
- ```python
- async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception):
- ...
- ```
- - completions: chat completion response from server.
- - info: user provided `callback_additional_info`.
- - exception: exception raise from OpenAI client if request failed, otherwise None.
-
- **CAUTION**: the callback function must be async and non-blocking, if you have any blocking operation,
- please move to separate thread or process pool to avoid blocking the event loop.
-
- callback_additional_info: Dict[str, Any], additional info to pass to the callback function.
-
- **chat_complete_request: dict, request parameters same as OpenAI AsyncCompletions.create.
- OpenAI API reference: https://platform.openai.com/docs/api-reference/chat/create
- """
- if "extra_headers" not in chat_complete_request:
- chat_complete_request["extra_headers"] = {}
-
- extra_headers = chat_complete_request["extra_headers"]
- request_id = extra_headers.get("x-request-id", None)
- if request_id:
- if request_id.startswith("chatcmpl-"):
- request_id = request_id[len("chatcmpl-") :]
- extra_headers["x-request-id"] = request_id
-
- address = self.request_id_to_address.pop(request_id)
- else:
- address = self.weighted_addresses[0][1]
- self.weighted_addresses[0][0] += 1
- heapq.heapreplace(self.weighted_addresses, self.weighted_addresses[0])
-
- # use new request_id to avoid duplicate request_id problem
- request_id = uuid4().hex
- self.request_id_to_address[request_id] = address
- chat_complete_request["extra_headers"]["x-request-id"] = request_id
-
- completions, exception = None, None
- try:
- # NOTE: OpenAI client uses httpx, seems to have performance issue in high concurrency requests.
- completions = await self._chat_completions_aiohttp(address, **chat_complete_request)
- except Exception as e:
- # Let user handle the exception
- exception = e
-
- await callback(completions, callback_additional_info, exception)
-
- async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion:
- client = AsyncOpenAI(base_url=f"http://{address}/v1", api_key="token-abc123", timeout=None, max_retries=0)
- return await client.chat.completions.create(**chat_complete_request)
-
- async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion:
- try:
- extra_headers = chat_complete_request.pop("extra_headers")
- timeout = aiohttp.ClientTimeout(total=None)
- session = aiohttp.ClientSession(timeout=timeout)
- async with session.post(
- url=f"http://{address}/v1/chat/completions",
- headers={"Authorization": "Bearer token-abc123", **extra_headers},
- json=chat_complete_request,
- ) as resp:
- data = await resp.json()
- return ChatCompletion(**data)
- finally:
- await session.close()
-
- async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto:
- raise NotImplementedError
-
-
class AsyncLLMServerManager:
"""AsyncLLMServerManager manage a group of vllm instances, i.e AsyncvLLMServer."""
- def __init__(self, config: DictConfig, worker_group: RayWorkerGroup, *, scheduler_kwargs: Dict[str, Any] = None):
+ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup):
"""Initialize AsyncLLMServerManager.
Args:
config: DictConfig, actor_rollout_ref config.
worker_group: RayWorkerGroup, worker group of AsyncActorRolloutRefWorker.
- scheduler_kwargs: Dict[str, Any], kwargs for chat scheduler.
"""
- self.config = config
+ self.full_config = config
+ self.config = config.actor_rollout_ref
self.worker_group = worker_group
- self.scheduler_kwargs = scheduler_kwargs if scheduler_kwargs else {}
self.rollout_tp_size = self.config.rollout.tensor_model_parallel_size
self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size
@@ -284,14 +165,9 @@ def _init_chat_scheduler(self):
self.chat_scheduler_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.chat_scheduler_loop)
- module_path, class_name = self.config.rollout.chat_scheduler.rsplit(".", 1)
- module = importlib.import_module(module_path)
- scheduler_cls = getattr(module, class_name)
- self.chat_scheduler = scheduler_cls(
- config=self.config.rollout,
- model_path=self.config.model.path,
+ self.chat_scheduler = ChatCompletionScheduler(
+ config=self.full_config,
server_addresses=self.server_addresses,
- **self.scheduler_kwargs,
)
self.chat_scheduler_ready.set()
@@ -307,9 +183,8 @@ def sleep(self):
def submit_chat_completions(
self,
- callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None],
- callback_additional_info: Dict[str, Any],
- **chat_complete_request,
+ messages: List[Dict[str, str]],
+ sampling_params: Dict[str, Any],
):
"""Submit a chat completion request to chat scheduler and wait until it is done.
To submit multiple requests in parallel, please use `generate_sequences` instead.
@@ -318,10 +193,10 @@ def submit_chat_completions(
"""
assert self.chat_scheduler is not None, "chat scheduler is not initialized."
future = asyncio.run_coroutine_threadsafe(
- self.chat_scheduler.submit_chat_completions(
- callback=callback,
- callback_additional_info=callback_additional_info,
- **chat_complete_request,
+ self.chat_scheduler._submit_chat_completions_semaphore(
+ messages=messages,
+ request_id=None,
+ sampling_params=sampling_params,
),
self.chat_scheduler_loop,
)
diff --git a/verl/workers/rollout/chat_scheduler.py b/verl/workers/rollout/chat_scheduler.py
new file mode 100644
index 00000000000..1426f286fbe
--- /dev/null
+++ b/verl/workers/rollout/chat_scheduler.py
@@ -0,0 +1,418 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import asyncio
+import heapq
+import importlib
+import itertools
+import json
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List
+from uuid import uuid4
+
+import aiohttp
+import numpy as np
+import torch
+from cachetools import LRUCache
+from omegaconf import DictConfig
+from openai import AsyncOpenAI
+from openai.types.chat.chat_completion import ChatCompletion
+from tensordict import TensorDict
+
+from verl.protocol import DataProto
+from verl.tools.base_tool import initialize_tools_from_config
+from verl.utils import hf_tokenizer
+from verl.utils.fs import copy_to_local
+
+logger = logging.getLogger(__file__)
+
+
+class CompletionCallback(ABC):
+ def __init__(self, config: DictConfig, scheduler: "ChatCompletionScheduler"):
+ self.config = config
+ self.scheduler = scheduler
+
+ # Initialize tools from config file
+ self.max_turns = config.actor_rollout_ref.rollout.multi_turn.max_turns
+ tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path
+ tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []
+ self.tools = {tool.name: tool for tool in tool_list}
+ self._tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]
+ print(f"Initialized tools: {self.tools}", flush=True)
+
+ local_path = copy_to_local(config.actor_rollout_ref.model.path)
+ self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True)
+
+ @property
+ def tool_schemas(self):
+ """OpenAI JSON tool schemas."""
+ return self._tool_schemas
+
+ @property
+ def extra_body(self) -> Dict[str, Any]:
+ """Extra body pass to OpenAI API."""
+ return None
+
+ @abstractmethod
+ async def __call__(self, messages: List[Dict[str, str]], completions: ChatCompletion, info: Dict[str, Any]):
+ """Call back function to process completions.
+
+ Args:
+ messages: List of messages including raw prompt and assistant, tool response generated so far.
+ completions: Chat completions from OpenAI compatible server.
+ info: Any other auxiliary information pass across multi-turn.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def postprocess(self, batch: DataProto, batch_conversations: List[List[Dict[str, str]]], n: int) -> DataProto:
+ """Post process batch data.
+
+ Args:
+ batch: Batch input messages from RLHFDataset.
+ batch_conversations: List of messages including raw prompt, assistant response, tool response.
+ Note that `len(batch_conversations) == len(batch) * n`, e.g n=2,
+ batch_conversations=[messages_0_0, messages_0_1, messages_1_0, messages_1_1, ...]
+ n: How many chat completion choices to generate for each input message.
+
+ Returns:
+ Batch data, should include ["prompts", "responses", "response_mask", "input_ids", "attention_mask", "position_ids"].
+ """
+ raise NotImplementedError
+
+
+class ToolCompletionCallback(CompletionCallback):
+ def __init__(self, config: DictConfig, scheduler: "ChatCompletionScheduler"):
+ super().__init__(config, scheduler)
+
+ # TODO: add reward manager to calculate reward score once a sample finish
+
+ async def __call__(self, messages: List[Dict[str, str]], completions: ChatCompletion, info: Dict[str, Any]):
+ message = completions.choices[0].message.model_dump(exclude_unset=True, exclude_none=True)
+ if "content" not in message:
+ message["content"] = ""
+ messages.append(message)
+ finish_reason = completions.choices[0].finish_reason
+
+ # STEP 0: check if we reach max turns
+ if self.max_turns and len(messages) >= self.max_turns:
+ print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Reach max turns, done!")
+ return
+
+ # STEP 1: check if the model called tools
+ if finish_reason != "tool_calls":
+ print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] No tool called, done!")
+ return
+
+ # STEP 2: call tools
+ tool_calls = completions.choices[0].message.tool_calls
+ print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Call {len(tool_calls)} tools")
+ tasks = []
+ for tool_call in tool_calls:
+ tasks.append(self._call_tool(tool_call))
+ tool_responses = await asyncio.gather(*tasks)
+ if any(isinstance(item, Exception) for item in tool_responses):
+ print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Error when calling tools, done!")
+ return
+ messages.extend(tool_responses)
+
+ # STEP 3: resubmit completion request with tool responses
+ self.scheduler.submit_chat_completions(messages=messages, request_id=completions.id, info=info)
+
+ async def _call_tool(self, tool_call) -> Dict[str, str]:
+ """Call tool and return tool response."""
+ tool_name = tool_call.function.name
+ tool_args = json.loads(tool_call.function.arguments)
+ tool = self.tools[tool_name]
+
+ instance_id = await tool.create()
+ try:
+ tool_response, tool_reward_score, tool_metrics = await tool.execute(instance_id, tool_args)
+ except Exception as e:
+ logger.exception(f"Error when executing tool: {e}")
+ return e
+ finally:
+ await tool.release(instance_id)
+
+ return {
+ "role": "tool",
+ "content": tool_response,
+ "tool_call_id": tool_call.id,
+ }
+
+ def postprocess(self, batch: DataProto, batch_conversations: List[List[Dict[str, str]]], n: int) -> DataProto:
+ # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
+ # prompts: left pad
+ # responses: right pad
+ # input_ids: prompt + response
+ # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
+ # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
+
+ # prompts: [prompt] from input dataset
+ prompts = [self.tokenizer.apply_chat_template(prompt, tools=self.tool_schemas, add_generation_prompt=True, tokenize=False) for prompt in batch.non_tensor_batch["raw_prompt"]]
+ assert len(batch_conversations) == len(prompts) * n
+
+ # sequences: [prompt + response]
+ sequences = [self.tokenizer.apply_chat_template(conversation, tools=self.tool_schemas, add_generation_prompt=False, tokenize=False) for conversation in batch_conversations]
+
+ # responses: [response]
+ responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)]
+
+ prompts = self.tokenizer(prompts, return_tensors="pt", padding="longest", padding_side="left")
+ responses = self.tokenizer(responses, return_tensors="pt", padding="longest", padding_side="right")
+ if n > 1:
+ prompts["input_ids"] = prompts["input_ids"].repeat_interleave(n, dim=0)
+ prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(n, dim=0)
+
+ # response_mask: response mask with tools calling masked out
+ response_mask = self._mask_out_tools_calling_tokens(batch.non_tensor_batch["raw_prompt"].repeat(n, axis=0), batch_conversations, responses["input_ids"], responses["attention_mask"])
+
+ input_ids = torch.cat([prompts["input_ids"], responses["input_ids"]], dim=1)
+ attention_mask = torch.cat([prompts["attention_mask"], responses["attention_mask"]], dim=1)
+ position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask
+
+ batch = TensorDict(
+ {
+ "prompts": prompts["input_ids"], # [bsz, prompt_length]
+ "responses": responses["input_ids"], # [bsz, response_length]
+ "response_mask": response_mask, # [bsz, response_length]
+ "input_ids": input_ids, # [bsz, prompt_length + response_length]
+ "attention_mask": attention_mask, # [bsz, prompt_length + response_length]
+ "position_ids": position_ids, # [bsz, prompt_length + response_length]
+ },
+ batch_size=len(input_ids),
+ )
+
+ num_turns = np.array([len(conversation) for conversation in batch_conversations], dtype=np.int32)
+ return DataProto(batch=batch, non_tensor_batch={"__num_turns__": num_turns})
+
+ def _mask_out_tools_calling_tokens(
+ self,
+ raw_prompts: List[List[Dict[str, str]]],
+ batch_conversations: List[List[Dict[str, str]]],
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ """Mask out tools calling tokens in the responses.
+
+ Args:
+ raw_prompts: [prompt] from input dataset
+ batch_conversations: [prompt + response]
+ input_ids: responses tokens
+ attention_mask: responses attention mask
+
+ Returns:
+ mask: (batch_size, response_length)
+ """
+ batch_size = input_ids.size(0)
+ assert len(raw_prompts) == batch_size, f"{len(raw_prompts)} != {batch_size}"
+ assert len(batch_conversations) == batch_size, f"{len(batch_conversations)} != {batch_size}"
+
+ # Deduplicate adjacent tool calls, since they're merged into one turn.
+ # [user, assistant, tool, tool, assistant] -> [user, assistant, tool, assistant]
+ # TODO: it's chat_template specific, find a more generic way to do this.
+ def deduplicate_adjacent_tool_calls(roles):
+ result = []
+ for role, group in itertools.groupby(roles):
+ if role == "tool":
+ result.append(role)
+ else:
+ result.extend(group)
+ return result
+
+ loss_mask = attention_mask.clone()
+ for i in range(batch_size):
+ responses = batch_conversations[i][len(raw_prompts[i]) :]
+ assert len(responses) > 0, f"responses is empty: {responses}"
+
+ roles = deduplicate_adjacent_tool_calls([response["role"] for response in responses])
+ # Each turn should be: [BOS]...[EOS]
+ eos_indices = input_ids[i].eq(self.tokenizer.eos_token_id).nonzero().squeeze(1)[: len(roles)]
+ for j in range(len(roles)):
+ if roles[j] == "tool":
+ bos = eos_indices[j - 1] + 1 if j > 0 else 0
+ eos = eos_indices[j]
+ loss_mask[i, bos : eos + 1] = 0
+
+ return loss_mask
+
+
+class ChatCompletionScheduler:
+ def __init__(
+ self,
+ config: DictConfig,
+ server_addresses: List[str],
+ max_cache_size: int = 10000,
+ ):
+ """
+ Args:
+ config: DictConfig.
+ server_addresses: List[str], OpenAI compatible server addresses.
+ max_cache_size: int, max cache size of request_id to address mapping.
+ """
+ self.config = config.actor_rollout_ref.rollout
+ model_path = config.actor_rollout_ref.model.path
+ self.model_name = "/".join(model_path.split("/")[-2:])
+
+ # Least requests load balancing
+ self.weighted_addresses = [[0, address] for address in server_addresses]
+ heapq.heapify(self.weighted_addresses)
+
+ # LRU cache to map request_id to address
+ self.request_id_to_address = LRUCache(maxsize=max_cache_size)
+
+ self.background_tasks = set()
+ if self.config.multi_turn.completion_callback is None:
+ self.completion_callback = ToolCompletionCallback(config, self)
+ logger.warning("completion_callback is None, use ToolCompletionCallback")
+ else:
+ module_path, class_name = self.config.multi_turn.completion_callback.rsplit(".", 1)
+ module = importlib.import_module(module_path)
+ self.completion_callback = getattr(module, class_name)(config, self)
+
+ def submit_chat_completions(self, *, messages: List[Dict[str, str]], request_id: str, info: Dict[str, Any]):
+ """Submit chat completion request without wait, completion_callback will be called when the request is done.
+
+ Args:
+ messages: List of messages.
+ request_id: Request id.
+ info: Any other auxiliary information pass across multi-turn.
+ """
+ info["__depth__"] += 1
+ task = asyncio.create_task(self._submit_chat_completions_and_callback(messages, request_id, info))
+
+ # “fire-and-forget” background tasks
+ self.background_tasks.add(task)
+ task.add_done_callback(self.background_tasks.discard)
+
+ async def _submit_chat_completions_and_callback(
+ self,
+ messages: List[Dict[str, str]],
+ request_id: str,
+ info: Dict[str, Any],
+ ):
+ """Submit chat completion request, wait request finish and do callback."""
+ if request_id:
+ request_id = request_id.removeprefix("chatcmpl-")
+ assert request_id in self.request_id_to_address
+ address = self.request_id_to_address.pop(request_id)
+ else:
+ address = self.weighted_addresses[0][1]
+ self.weighted_addresses[0][0] += 1
+ heapq.heapreplace(self.weighted_addresses, self.weighted_addresses[0])
+
+ # use new request_id to avoid duplicate request_id problem
+ request_id = uuid4().hex
+ self.request_id_to_address[request_id] = address
+
+ completions, exception = None, None
+ try:
+ # NOTE: OpenAI client uses httpx, seems to have performance issue in high concurrency requests.
+ completions = await self._chat_completions_aiohttp(
+ address,
+ messages=messages,
+ tools=self.completion_callback.tool_schemas,
+ extra_body=self.completion_callback.extra_body,
+ extra_headers={"x-request-id": request_id},
+ **info["__sampling_params__"],
+ )
+ except Exception as e:
+ # Let user handle the exception
+ exception = e
+
+ info["__depth__"] -= 1
+
+ if exception is not None:
+ logger.exception(f"chat completion failed with exception: {exception}")
+ else:
+ try:
+ await self.completion_callback(messages, completions, info)
+ except Exception as e:
+ logger.exception(f"completion callback failed with exception: {e}")
+
+ # No more ongoing completion requests
+ if info["__depth__"] == 0:
+ info["__done__"].set()
+
+ async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion:
+ client = AsyncOpenAI(base_url=f"http://{address}/v1", api_key="token-abc123", timeout=None, max_retries=0)
+ return await client.chat.completions.create(**chat_complete_request)
+
+ async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion:
+ try:
+ extra_body = chat_complete_request.pop("extra_body", {})
+ chat_complete_request.update(extra_body or {})
+ extra_headers = chat_complete_request.pop("extra_headers")
+ timeout = aiohttp.ClientTimeout(total=None)
+ session = aiohttp.ClientSession(timeout=timeout)
+ async with session.post(
+ url=f"http://{address}/v1/chat/completions",
+ headers={"Authorization": "Bearer token-abc123", **extra_headers},
+ json=chat_complete_request,
+ ) as resp:
+ data = await resp.json()
+ return ChatCompletion(**data)
+ finally:
+ await session.close()
+
+ async def generate_sequences(self, batch: DataProto) -> DataProto:
+ kwargs = dict(
+ model=self.model_name,
+ temperature=self.config.temperature,
+ top_p=self.config.top_p,
+ )
+
+ # override sampling params for validation
+ if batch.meta_info.get("validate", False):
+ kwargs["top_p"] = self.config.val_kwargs.top_p
+ kwargs["temperature"] = self.config.val_kwargs.temperature
+
+ print(f"[ChatCompletionScheduler] generate_sequences sampling params: {kwargs}")
+
+ # NOTE: For multi-turn rollout, repeat raw_prompt n times and process each prompt independently,
+ # validation dataset has already been repeated in `PPOTrainer._validate`.
+ n = 1 if batch.meta_info.get("validate", False) else self.config.n
+ tasks, batch_conversations = [], [None] * len(batch) * n
+ for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"].repeat(n, axis=0)):
+ # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...]
+ batch_conversations[batch_index] = conversation.tolist()
+
+ tasks.append(
+ asyncio.create_task(
+ self._submit_chat_completions_semaphore(
+ messages=batch_conversations[batch_index],
+ request_id=None,
+ sampling_params=kwargs,
+ )
+ )
+ )
+
+ await asyncio.gather(*tasks)
+ print("[ChatCompletionScheduler] generate_sequences done")
+
+ return self.completion_callback.postprocess(batch, batch_conversations, n=n)
+
+ async def _submit_chat_completions_semaphore(self, messages: List[Dict[str, str]], request_id: str, sampling_params: Dict[str, Any]):
+ done = asyncio.Event()
+
+ info = {
+ "__done__": done,
+ "__depth__": 0, # indicate how many ongoing completion requests
+ "__sampling_params__": sampling_params,
+ }
+
+ self.submit_chat_completions(messages=messages, request_id=request_id, info=info)
+
+ # Wait until all completion requests are done
+ await done.wait()
diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py
index 607c9769540..16e57a89e61 100644
--- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py
+++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from collections.abc import AsyncGenerator
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import cloudpickle
@@ -122,14 +121,14 @@ class AsyncvLLMServer(AsyncServerBase):
def __init__(self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_prefix: str):
"""
Args:
- config: DictConfig, actor_rollout_ref config.
+ config: DictConfig.
vllm_dp_size: int, vllm data parallel size.
vllm_dp_rank: int, vllm data parallel rank.
wg_prefix: str, worker group prefix, used to lookup actors.
"""
super().__init__()
- self.config = config
+ self.config = config.actor_rollout_ref
self.vllm_dp_size = vllm_dp_size
self.vllm_dp_rank = vllm_dp_rank
self.wg_prefix = wg_prefix
@@ -154,7 +153,7 @@ async def init_engine(self):
kwargs = dict(
n=1,
logprobs=0,
- max_tokens=config.response_length,
+ max_new_tokens=config.response_length,
)
for k in config.keys():
if hasattr(SamplingParams(), str(k)):
@@ -201,6 +200,8 @@ async def init_engine(self):
request_logger=RequestLogger(max_log_len=4096),
chat_template=None,
chat_template_content_format="auto",
+ enable_auto_tools=True,
+ tool_parser=config.multi_turn.format, # hermes, llama3_json, ...
)
async def chat_completion(self, raw_request: Request):
@@ -220,28 +221,6 @@ async def chat_completion(self, raw_request: Request):
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump())
- async def chat_completion_generator(self, request: ChatCompletionRequest) -> AsyncGenerator[Tuple[int, str]]:
- """Direct chat completion without FastAPI.
-
- Args:
- request: ChatCompletionRequest, request object.
-
- Returns:
- AsyncGenerator[Tuple[int, str]]: async generator of (status_code, data) pairs.
- """
- generator = await self.openai_serving_chat.create_chat_completion(request)
- if isinstance(generator, ErrorResponse):
- data = generator.model_dump_json(exclude_unset=True)
- yield generator.code, f"data: {data}\n\n"
-
- if request.stream:
- async for chunk in generator:
- yield 200, chunk
- else:
- assert isinstance(generator, ChatCompletionResponse)
- data = generator.model_dump_json(exclude_unset=True)
- yield 200, f"data: {data}\n\n"
-
async def wake_up(self):
await self.engine.wake_up()