diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index e6cf582e0df..ed7dc5800a1 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -66,6 +66,7 @@ jobs: - name: Download Model to Use run: | huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct + huggingface-cli download Qwen/Qwen2.5-1.5B-Instruct huggingface-cli download 'Qwen/Qwen2-7B-Instruct' huggingface-cli download 'deepseek-ai/deepseek-llm-7b-chat' export HF_HUB_OFFLINE=1 @@ -94,4 +95,4 @@ jobs: - name: Running multi-turn rollout tests on 8 L20 GPUs run: | pip3 install --upgrade vllm==0.8.3 tensordict==0.7.2 - python3 tests/workers/rollout/test_vllm_multi_turn.py + pytest -svvv tests/workers/rollout/test_vllm_chat_scheduler.py diff --git a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh index e355e2f9c7d..819fa9f0698 100644 --- a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh +++ b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh @@ -9,7 +9,6 @@ rollout_name="sglang" # sglang or vllm if [ "$rollout_mode" = "async" ]; then export VLLM_USE_V1=1 return_raw_chat="True" - chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler fi python3 -m verl.trainer.main_ppo \ @@ -38,7 +37,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=$rollout_name \ actor_rollout_ref.rollout.mode=$rollout_mode \ - actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \ + actor_rollout_ref.rollout.multi_turn.format=hermes \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ diff --git a/examples/ppo_trainer/naive_chat_scheduler.py b/examples/ppo_trainer/naive_chat_scheduler.py deleted file mode 100644 index 87592c76672..00000000000 --- a/examples/ppo_trainer/naive_chat_scheduler.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -from typing import Any, Dict, List - -import torch -from openai.types.chat.chat_completion import ChatCompletion -from tensordict import TensorDict - -from verl.protocol import DataProto -from verl.workers.rollout.async_server import ChatCompletionScheduler - - -class NaiveChatCompletionScheduler(ChatCompletionScheduler): - """ - A very naive implementation of ChatCompletionScheduler for demo purpose, - only do single-turn chat completion. - """ - - async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataProto: - kwargs = dict( - n=self.config.n, - max_completion_tokens=self.config.response_length, - temperature=self.config.temperature, - top_p=self.config.top_p, - ) - - do_sample = batch.meta_info.get("do_sample", True) - is_validate = batch.meta_info.get("validate", False) - if not do_sample or is_validate: - kwargs["n"] = 1 - kwargs["temperature"] = 0 - - kwargs.update(sampling_params) - print(f"[NaiveChatCompletionScheduler] generate_sequences sampling params: {kwargs}") - - async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): - assert exception is None, f"exception: {exception}" - conversation, batch_conversations, batch_index = ( - info["conversation"], - info["batch_conversations"], - info["batch_index"], - ) - - conversations = [] - for choice in completions.choices: - chat = conversation.copy() - chat.append({"role": choice.message.role, "content": choice.message.content}) - conversations.append(chat) - batch_conversations[batch_index] = conversations - - # NOTE: we can call tools and resubmit chat completions here. - # call_tools(completions, info) - # await self.submit_chat_completions(callback2, ...) - - # TODO: we may need to control max concurrent requests here, or it will harm prefix cache hit rate. - tasks, batch_conversations = [], [None] * len(batch) - for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"]): - # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] - tasks.append( - asyncio.create_task( - self.submit_chat_completions( - callback=callback, - callback_additional_info={ - "batch_conversations": batch_conversations, - "batch_index": batch_index, - "conversation": list(conversation), - }, - model=self.model_name, - messages=conversation.tolist(), - **kwargs, - ) - ) - ) - await asyncio.gather(*tasks) - print("[NaiveChatCompletionScheduler] generate_sequences done") - - return self._postprocess(batch, batch_conversations, kwargs["n"]) - - def _postprocess(self, batch: DataProto, batch_conversations: List[List[List[Dict[str, str]]]], n: int) -> DataProto: - # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py - # prompts: left pad - # responses: right pad - # input_ids: prompt + response - # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] - # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] - - # prompts: [prompt] from input dataset - prompts = [self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False) for prompt in batch.non_tensor_batch["raw_prompt"]] - - # flatten batch_conversations if n > 1 - assert len(batch_conversations) == len(prompts) - batch_conversations = [conversation for conversations in batch_conversations for conversation in conversations] - assert len(batch_conversations) == len(prompts) * n - - # sequences: [prompt + response] - sequences = [self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False) for conversation in batch_conversations] - - # responses: [response] - # TODO: mask out tools calling tokens? - responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)] - - prompts = self.tokenizer(prompts, return_tensors="pt", padding="longest", padding_side="left") - responses = self.tokenizer(responses, return_tensors="pt", padding="longest", padding_side="right") - if n > 1: - prompts["input_ids"] = prompts["input_ids"].repeat_interleave(n, dim=0) - prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(n, dim=0) - - input_ids = torch.cat([prompts["input_ids"], responses["input_ids"]], dim=1) - attention_mask = torch.cat([prompts["attention_mask"], responses["attention_mask"]], dim=1) - position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask - - batch = TensorDict( - { - "prompts": prompts["input_ids"], - "responses": responses["input_ids"], - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=len(input_ids), - ) - - return DataProto(batch=batch) diff --git a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh index abc490acdc7..72e5f9e847a 100644 --- a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh +++ b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh @@ -12,7 +12,6 @@ test_files="['$gsm8k_test_path', '$math_test_path']" rollout_mode="sync" if [ "$rollout_mode" = "async" ]; then return_raw_chat="True" - chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler fi python3 -m verl.trainer.main_ppo \ @@ -38,7 +37,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.mode=$rollout_mode \ - actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \ + actor_rollout_ref.rollout.multi_turn.format=hermes \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \ critic.optim.lr=1e-5 \ diff --git a/tests/workers/rollout/async_rollout_utils.py b/tests/workers/rollout/async_rollout_utils.py index bc6186553ef..5328ec3c57b 100644 --- a/tests/workers/rollout/async_rollout_utils.py +++ b/tests/workers/rollout/async_rollout_utils.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from typing import Any, Dict import ray from omegaconf import DictConfig @@ -24,12 +22,7 @@ from verl.workers.rollout.async_server import AsyncLLMServerManager -def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, Any] = None) -> AsyncLLMServerManager: - # make openai client happy - os.environ["no_proxy"] = "" - os.environ["http_proxy"] = "" - os.environ["https_proxy"] = "" - +def init_async_rollout_manager(config: DictConfig) -> AsyncLLMServerManager: # =========================== 1. Create hybrid ActorRollout workers =========================== role_worker_mapping = { Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker), @@ -61,9 +54,8 @@ def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, A # =========================== 2. Create AsyncLLMServerManager =========================== async_rollout_manager = AsyncLLMServerManager( - config=config.actor_rollout_ref, + config=config, worker_group=actor_rollout_wg, - scheduler_kwargs=scheduler_kwargs, ) return async_rollout_manager diff --git a/tests/workers/rollout/test_custom_completion_callback.py b/tests/workers/rollout/test_custom_completion_callback.py new file mode 100644 index 00000000000..4e808e45f65 --- /dev/null +++ b/tests/workers/rollout/test_custom_completion_callback.py @@ -0,0 +1,280 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import concurrent.futures +import os +import re +import socket +import sys +import tempfile +from contextlib import asynccontextmanager +from typing import Any, Dict, List + +import fastapi +import numpy as np +import ray +import uvicorn +from datasets import load_dataset +from omegaconf import DictConfig, OmegaConf +from openai.types.chat.chat_completion import ChatCompletion +from starlette.requests import Request +from starlette.responses import JSONResponse + +from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager +from verl.protocol import DataProto +from verl.utils import hf_tokenizer +from verl.utils.reward_score.sandbox_fusion.utils import _process_single_case +from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler, ToolCompletionCallback + + +def _get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + +@ray.remote(num_cpus=1) +class Sandbox: + """Sandbox to execute python code. + + WARNING: This class is for testing purpose only, do not use it in production. + Please use a sandbox with strong isolation and security restrictions instead. + """ + + def __init__(self): + self.address = ray._private.services.get_node_ip_address() + self.port = None + self.server_ready = asyncio.Event() + asyncio.create_task(self._start_fastapi_server()) + + async def code_execution(self, request: Request): + request_json = await request.json() + code = request_json["code"] + print(f"execute code:\n{code}") + + _, temp_file = tempfile.mkstemp(suffix=".py", prefix="temp_code", dir=None, text=True) + with open(temp_file, "w") as f: + f.write(code) + + try: + process = await asyncio.create_subprocess_exec(sys.executable, temp_file, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) + + stdout, stderr = await process.communicate() + + response = { + "status": "Success" if process.returncode == 0 else "Failed", + "run_result": { + "status": "Finished", + "stdout": stdout.decode(), + "stderr": stderr.decode(), + "return_code": process.returncode, + }, + } + return JSONResponse(content=response) + finally: + try: + os.unlink(temp_file) + except: # noqa: E722 + pass + + async def _start_fastapi_server(self): + @asynccontextmanager + async def lifespan(app: fastapi.FastAPI): + print("FastAPI startup") + self.server_ready.set() + yield + + print("FastAPI shutdown, maybe address already in use, exit process immediately.") + os._exit(-1) + + app = fastapi.FastAPI(lifespan=lifespan) + app.router.add_api_route("/run_code", self.code_execution, methods=["POST"]) + + self.port = _get_free_port() + config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning") + server = uvicorn.Server(config) + await server.serve() + + async def get_server_address(self) -> str: + """Get FastAPI server address.""" + await self.server_ready.wait() + return f"{self.address}:{self.port}" + + +class CustomCompletionCallback(ToolCompletionCallback): + def __init__(self, config: DictConfig, scheduler: ChatCompletionScheduler): + super().__init__(config, scheduler) + + self.max_turns = 16 + self.answer_pattern = re.compile(r"(.*?)", re.DOTALL) + self.code_pattern = re.compile(r"\s*```python(.*?)```\s*", re.DOTALL) + + self.sandbox_fusion_url = config.reward_model.sandbox_fusion.url + self.default_timeout = 10 + # TODO: support asyncio executor + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max(32, os.cpu_count() * 5)) + + async def sandbox_code_execution(self, code: str) -> Dict[str, Any]: + loop = asyncio.get_running_loop() + result_status, metadata = await loop.run_in_executor( + self.executor, + _process_single_case, + 0, # case_index, + None, # stdin_data, + None, # expected_output, + self.sandbox_fusion_url, # sandbox_fusion_url + code, # generation + self.default_timeout, # timeout + "python", # language + ) + + return metadata + + @property + def extra_body(self): + extra = { + "include_stop_str_in_output": True, + "stop": ["", ""], + } + return extra + + async def __call__(self, messages: List[Dict[str, str]], completions: ChatCompletion, info: Dict[str, Any]): + role, content, finish_reason = completions.choices[0].message.role, completions.choices[0].message.content, completions.choices[0].finish_reason + messages.append({"role": role, "content": content}) + turn = len(messages) + + # STEP 0: check if we reach max turns + if len(messages) >= self.max_turns: + print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Reach max turns, done!") + return + + # STEP 1: check if we reach max tokens + if finish_reason == "length": + print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Reach max tokens, done!") + return + + # STEP 2: check if we got answer + matches = self.answer_pattern.findall(content) + if matches: + print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Got answer: {matches[0]}, done!") + return + + # STEP 3: check if we got code block + matches = self.code_pattern.findall(content) + if not matches: + print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] No code block found, done!") + return + + # STEP 4: execute code block in sandbox + code = matches[0].strip() + metadata = await self.sandbox_code_execution(code) + if metadata["run_status"] != "Finished": + print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Code block execution failed: {metadata}, done!") + return + + stdout, stderr = metadata["stdout"], metadata["stderr"] + messages.append({"role": "tool", "content": f"{stdout}{stderr}"}) + print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Code block executed, continue...") + + # STEP 5: resubmit chat completions with code block output + self.scheduler.submit_chat_completions( + messages=messages, + request_id=completions.id, + info=info, + ) + + +user_prompt_template = """ +You are a helpful assistant. Let's solve math problem in following steps: +1. Write a python code first and return the code to user, the code must be in following format: + + +```python +import os + +print(...) +``` + + +The code must explictly print necessary output to stdout. Remember stop generation at immediately and return the code. +2. User will send the python code to a external sandbox to execute and get output from stdout. +3. User will send the output in format output to you, and you should use the output to answer the question. +The answer format must be: \\boxed{'The final answer goes here.'} + +*user question:* +{question} +""" + + +if __name__ == "__main__": + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + + # Load config + config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml") + model_path = "Qwen/Qwen2.5-1.5B-Instruct" + config.actor_rollout_ref.model.path = model_path + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.multi_turn.format = "hermes" + config.actor_rollout_ref.rollout.multi_turn.completion_callback = "tests.workers.rollout.test_custom_completion_callback.CustomCompletionCallback" + config.actor_rollout_ref.rollout.prompt_length = 4096 + config.actor_rollout_ref.rollout.response_length = 4096 + config.actor_rollout_ref.rollout.n = 4 + + # Init sandbox and async rollout manager + sandbox = Sandbox.options(num_cpus=1).remote() + sandbox_address = ray.get(sandbox.get_server_address.remote()) + sandbox_fusion_url = f"http://{sandbox_address}/run_code" + config.reward_model.sandbox_fusion.url = sandbox_fusion_url + async_rollout_manager = init_async_rollout_manager(config) + + # Build dataset + dataset = load_dataset("Maxwell-Jia/AIME_2024", split="train") + prompts = DataProto( + non_tensor_batch={ + "raw_prompt": np.array([[{"role": "user", "content": user_prompt_template.replace("{question}", problem)}] for problem in dataset["Problem"]]), + }, + ) + + result = async_rollout_manager.generate_sequences(prompts=prompts) + assert len(result) == len(dataset) * config.actor_rollout_ref.rollout.n + + # Check max turns that sandbox is called + num_turns = result.non_tensor_batch["__num_turns__"] + print(f"num_turns: {num_turns}") + assert np.max(num_turns) > 2, f"max turns: {np.max(num_turns)}" + + # Check response_mask + tokenizer = hf_tokenizer(config.actor_rollout_ref.model.path) + responses = result.batch["responses"] + response_mask = result.batch["response_mask"] + assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + + # Decode responses with response_mask + for i in range(len(responses)): + valid_tokens = responses[i][response_mask[i].bool()] + response_str = tokenizer.decode(valid_tokens) + assert "" not in response_str, f"found in response: {response_str}" + assert "" not in response_str, f"found in response: {response_str}" + print(f"response: {response_str}") + + print("Test passed!") diff --git a/tests/workers/rollout/test_vllm_chat_scheduler.py b/tests/workers/rollout/test_vllm_chat_scheduler.py new file mode 100644 index 00000000000..23550bff776 --- /dev/null +++ b/tests/workers/rollout/test_vllm_chat_scheduler.py @@ -0,0 +1,233 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from typing import Any, Tuple + +import numpy as np +import pytest +import ray +from omegaconf import DictConfig, OmegaConf +from transformers.utils import get_json_schema + +from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager +from verl.protocol import DataProto +from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema +from verl.utils import hf_tokenizer + + +@pytest.fixture +def init_config() -> DictConfig: + config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml") + model_path = "Qwen/Qwen2.5-1.5B-Instruct" + config.actor_rollout_ref.model.path = model_path + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.multi_turn.format = "hermes" + config.actor_rollout_ref.rollout.prompt_length = 4096 + config.actor_rollout_ref.rollout.response_length = 4096 + + # test sleep/wake_up with fsdp offload + config.actor_rollout_ref.actor.fsdp_config.param_offload = True + config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True + + return config + + +def test_vllm_async_rollout_without_tool_calls(init_config): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + + # =========================== 1. Init rollout manager =========================== + async_rollout_manager = init_async_rollout_manager(init_config) + + # test sleep and wake_up + async_rollout_manager.sleep() + async_rollout_manager.wake_up() + + # =========================== 2. Generate sequences =========================== + raw_prompts = [ + [ + { + "role": "user", + "content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.", + } + ], + [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}], + ] + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array(raw_prompts), + }, + ) + result = async_rollout_manager.generate_sequences(prompts=batch) + + # check result + seq_len = result.batch["prompts"].size(1) + result.batch["responses"].size(1) + assert len(result) == 2 + assert result.batch["input_ids"].size(1) == seq_len + assert result.batch["attention_mask"].size(1) == seq_len + assert result.batch["position_ids"].size(1) == seq_len + + # check turns + num_turns = result.non_tensor_batch["__num_turns__"] + assert np.all(num_turns == 2) + + print("Test passed!") + ray.shutdown() + + +class WeatherTool(BaseTool): + def get_current_temperature(self, location: str, unit: str = "celsius"): + """Get current temperature at a location. + + Args: + location: The location to get the temperature for, in the format "City, State, Country". + unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) + + Returns: + the temperature, the location, and the unit in a dict + """ + return { + "temperature": 26.1, + "location": location, + "unit": unit, + } + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + schema = get_json_schema(self.get_current_temperature) + return OpenAIFunctionToolSchema(**schema) + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]: + try: + result = self.get_current_temperature(**parameters) + return json.dumps(result), 0, {} + except Exception as e: + return str(e), 0, {} + + +class WeatherToolWithData(BaseTool): + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + schema = get_json_schema(self.get_temperature_date) + return OpenAIFunctionToolSchema(**schema) + + def get_temperature_date(self, location: str, date: str, unit: str = "celsius"): + """Get temperature at a location and date. + + Args: + location: The location to get the temperature for, in the format "City, State, Country". + date: The date to get the temperature for, in the format "Year-Month-Day". + unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) + + Returns: + the temperature, the location, the date and the unit in a dict + """ + return { + "temperature": 25.9, + "location": location, + "date": date, + "unit": unit, + } + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]: + try: + result = self.get_temperature_date(**parameters) + return json.dumps(result), 0, {} + except Exception as e: + return str(e), 0, {} + + +def test_vllm_async_rollout_with_tool_calls(init_config): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + + # =========================== 1. Init rollout manager =========================== + tool_config = { + "tools": [ + { + "class_name": "tests.workers.rollout.test_vllm_chat_scheduler.WeatherTool", + "config": {}, + }, + { + "class_name": "tests.workers.rollout.test_vllm_chat_scheduler.WeatherToolWithData", + "config": {}, + }, + ] + } + tool_config_path = "/tmp/tool_config.json" + with open(tool_config_path, "w") as f: + json.dump(tool_config, f) + + init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path + async_rollout_manager = init_async_rollout_manager(init_config) + + # =========================== 2. Generate sequences =========================== + raw_prompts = [ + [ + {"role": "user", "content": "How are you?"}, + ], + [ + {"role": "user", "content": "What's the temperature in Los Angeles now?"}, + ], + [ + {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\nCurrent Date: 2024-09-30"}, + {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"}, + ], + ] + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), + }, + ) + result = async_rollout_manager.generate_sequences(prompts=batch) + + # Check turns + num_turns = result.non_tensor_batch["__num_turns__"] + # [user, assistant] + assert num_turns[0] == 2 + # [user, assistant, tool, assistant] + assert num_turns[1] == 4 + # [system, user, assistant, tool, tool, assistant] + assert num_turns[2] == 6 + + # Check response_mask + tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) + responses = result.batch["responses"] + response_mask = result.batch["response_mask"] + assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + + # Decode responses with response_mask + for i in range(len(responses)): + valid_tokens = responses[i][response_mask[i].bool()] + response_str = tokenizer.decode(valid_tokens) + assert "" not in response_str, f"found in response: {response_str}" + assert "" not in response_str, f"found in response: {response_str}" + print(f"response: {response_str}") + + print("Test passed!") + ray.shutdown() diff --git a/tests/workers/rollout/test_vllm_multi_turn.py b/tests/workers/rollout/test_vllm_multi_turn.py deleted file mode 100644 index b705d86a9ca..00000000000 --- a/tests/workers/rollout/test_vllm_multi_turn.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import json -from typing import Any, Dict - -import numpy as np -import ray -from omegaconf import DictConfig, OmegaConf -from openai.types.chat.chat_completion import ChatCompletion -from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ErrorResponse - -from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager -from verl.protocol import DataProto - - -def init_config() -> DictConfig: - config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml") - model_path = "Qwen/Qwen2-7B-Instruct" - config.actor_rollout_ref.model.path = model_path - config.actor_rollout_ref.rollout.mode = "async" - config.actor_rollout_ref.rollout.chat_scheduler = "examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler" - config.actor_rollout_ref.rollout.prompt_length = 4096 - config.actor_rollout_ref.rollout.response_length = 4096 - - # test sleep/wake_up with fsdp offload - config.actor_rollout_ref.actor.fsdp_config.param_offload = True - config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True - - return config - - -def test_vllm_multi_turn(config): - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "WARN", - "VLLM_USE_V1": "1", - } - } - ) - - # =========================== 1. Init rollout manager =========================== - model_name = "/".join(config.actor_rollout_ref.model.path.split("/")[-2:]) - async_rollout_manager = init_async_rollout_manager(config) - - # test sleep and wake_up - async_rollout_manager.sleep() - async_rollout_manager.wake_up() - - async_chat_scheduler = async_rollout_manager.chat_scheduler - - # =========================== 2. Multi turn rollout =========================== - async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): - assert exception is None, f"exception: {exception}" - messages, round = info["messages"], info["round"] - message = completions.choices[0].message - messages.append({"role": message.role, "content": message.content}) - print(f"[round={round}] role: {message.role}, content: {message.content}") - - extra_headers = {"x-request-id": completions.id} - if round == 0: - messages.append({"role": "user", "content": "What is your name?"}) - await async_chat_scheduler.submit_chat_completions( - callback=callback, - callback_additional_info={"messages": messages, "round": 1}, - model=model_name, - messages=messages, - extra_headers=extra_headers, - ) - elif round == 1: - messages.append({"role": "user", "content": "What is your favorite color?"}) - await async_chat_scheduler.submit_chat_completions( - callback=callback, - callback_additional_info={"messages": messages, "round": 2}, - model=model_name, - messages=messages, - extra_headers=extra_headers, - ) - else: - print("Done!") - - messages = [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}] - async_rollout_manager.submit_chat_completions( - callback=callback, - callback_additional_info={"messages": messages, "round": 0}, - model=model_name, - messages=messages, - ) - assert len(messages) == 6 - for round, message in enumerate(messages): - if round % 2 == 0: - assert message["role"] == "user" - else: - assert message["role"] == "assistant" - - # =========================== 3. Generate sequences =========================== - raw_prompts = [ - [ - { - "role": "user", - "content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.", - } - ], - [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}], - ] - batch = DataProto( - non_tensor_batch={ - "raw_prompt": np.array(raw_prompts), - }, - ) - result = async_rollout_manager.generate_sequences(prompts=batch) - seq_len = result.batch["prompts"].size(1) + result.batch["responses"].size(1) - assert len(result) == 2 - assert result.batch["input_ids"].size(1) == seq_len - assert result.batch["attention_mask"].size(1) == seq_len - assert result.batch["position_ids"].size(1) == seq_len - - ray.shutdown() - - -async def test_vllm_streaming_response(config): - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "WARN", - "VLLM_USE_V1": "1", - } - } - ) - - model_name = "/".join(config.actor_rollout_ref.model.path.split("/")[-2:]) - async_rollout_manager = init_async_rollout_manager(config) - async_llm_server = async_rollout_manager.async_llm_servers[0] - - # non-streaming request - request = ChatCompletionRequest( - model=model_name, - messages=[{"role": "user", "content": "What is your name?"}], - stream=False, - ) - generator = async_llm_server.chat_completion_generator.remote(request) - async for ref in generator: - status_code, data = await ref - print(f">>>> status_code: {status_code}, {data}") - data = data[len("data: ") :].rstrip() - if status_code != 200: - response = ErrorResponse(**json.loads(data)) - else: - response = ChatCompletionResponse(**json.loads(data)) - assert response.choices[0].message.role == "assistant" - assert response.choices[0].message.content is not None - - # streaming request - request = ChatCompletionRequest( - model=model_name, - messages=[{"role": "user", "content": "How are you?"}], - stream=True, - ) - generator = async_llm_server.chat_completion_generator.remote(request) - async for ref in generator: - status_code, data = await ref - print(f">>>> status_code: {status_code}, {data}") - data = data[len("data: ") :].rstrip() - if status_code != 200: - response = ErrorResponse(**json.loads(data)) - elif data == "[DONE]": - break - else: - response = ChatCompletionStreamResponse(**json.loads(data)) - assert response.choices[0].delta.role is None or response.choices[0].delta.role == "assistant" - assert response.choices[0].delta.content is not None - - ray.shutdown() - - -if __name__ == "__main__": - config = init_config() - test_vllm_multi_turn(config) - asyncio.run(test_vllm_streaming_response(config)) diff --git a/tests/workers/rollout/test_vllm_tool_calling.py b/tests/workers/rollout/test_vllm_tool_calling.py deleted file mode 100644 index efc8fbf547b..00000000000 --- a/tests/workers/rollout/test_vllm_tool_calling.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import os -import re -import socket -import sys -import tempfile -from contextlib import asynccontextmanager -from typing import Any, Dict - -import aiohttp -import fastapi -import numpy as np -import ray -import uvicorn -from datasets import load_dataset -from omegaconf import OmegaConf -from openai.types.chat.chat_completion import ChatCompletion -from starlette.requests import Request -from starlette.responses import JSONResponse - -from examples.ppo_trainer.naive_chat_scheduler import NaiveChatCompletionScheduler -from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager -from verl.protocol import DataProto - - -def _get_free_port(): - with socket.socket() as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] - - -@ray.remote(num_cpus=1) -class Sandbox: - """Sandbox to execute python code. - - WARNING: This class is for testing purpose only, do not use it in production. - Please use a sandbox with strong isolation and security restrictions instead. - """ - - def __init__(self): - self.address = ray._private.services.get_node_ip_address() - self.port = None - self.server_ready = asyncio.Event() - asyncio.create_task(self._start_fastapi_server()) - - async def code_execution(self, request: Request): - request_json = await request.json() - code = request_json["code"] - print(f"execute code:\n{code}") - - _, temp_file = tempfile.mkstemp(suffix=".py", prefix="temp_code", dir=None, text=True) - with open(temp_file, "w") as f: - f.write(code) - - try: - process = await asyncio.create_subprocess_exec(sys.executable, temp_file, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) - - stdout, stderr = await process.communicate() - - return JSONResponse(content={"stdout": stdout.decode(), "stderr": stderr.decode(), "returncode": process.returncode}) - finally: - try: - os.unlink(temp_file) - except: # noqa: E722 - pass - - async def _start_fastapi_server(self): - @asynccontextmanager - async def lifespan(app: fastapi.FastAPI): - print("FastAPI startup") - self.server_ready.set() - yield - - print("FastAPI shutdown, maybe address already in use, exit process immediately.") - os._exit(-1) - - app = fastapi.FastAPI(lifespan=lifespan) - app.router.add_api_route("/code/execution", self.code_execution, methods=["POST"]) - - self.port = _get_free_port() - config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning") - server = uvicorn.Server(config) - await server.serve() - - async def get_server_address(self) -> str: - """Get FastAPI server address.""" - await self.server_ready.wait() - return f"{self.address}:{self.port}" - - -class ToolChatCompletionScheduler(NaiveChatCompletionScheduler): - """This is a demo chat completion scheduler that supports sandbox code execution - described in ReTool paper: https://arxiv.org/pdf/2504.11536 - """ - - def __init__(self, config, model_path, server_addresses, sandbox_address, system_prompt, **kwargs): - super().__init__(config, model_path, server_addresses, **kwargs) - self.sandbox_address = sandbox_address - self.system_prompt = system_prompt - - async def sandbox_code_execution(self, code: str) -> Dict[str, Any]: - """Execute python code in sandbox.""" - try: - session = aiohttp.ClientSession() - async with session.post( - url=f"http://{self.sandbox_address}/code/execution", - json={"code": code}, - ) as resp: - return await resp.json() - finally: - await session.close() - - async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataProto: - kwargs = dict( - n=self.config.n, - max_completion_tokens=self.config.response_length, - temperature=self.config.temperature, - top_p=self.config.top_p, - extra_body={ - "include_stop_str_in_output": True, - "stop": ["", ""], - }, - ) - - do_sample = batch.meta_info.get("do_sample", True) - is_validate = batch.meta_info.get("validate", False) - if not do_sample or is_validate: - kwargs["n"] = 1 - kwargs["temperature"] = 0 - - kwargs.update(sampling_params) - print(f"[ToolChatCompletionScheduler] generate_sequences sampling params: {kwargs}") - - max_turns = 3 - - async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): - batch_conversations, batch_index, turn = ( - info["batch_conversations"], - info["batch_index"], - info["turn"], - ) - role, content = completions.choices[0].message.role, completions.choices[0].message.content - batch_conversations[batch_index].append({"role": role, "content": content}) - - # STEP 0: check if we reach max turns - if turn == max_turns: - print(f"[id={completions.id},turn={turn}] Reach max turns {max_turns}, done!") - return - - # STEP 1: check if we got answer - matches = re.findall(r"(.*?)", content, re.DOTALL) - if matches: - print(f"[id={completions.id},turn={turn}] Got answer: {matches[0]}, done!") - return - - # STEP 2: check if we got code block - matches = re.findall(r"\s*```python(.*?)```\s*", content, re.DOTALL) - if not matches: - print(f"[id={completions.id},turn={turn}] No code block found, done!") - return - - # STEP 3: execute code block in sandbox - code = matches[0].strip() - result = await self.sandbox_code_execution(code) - stdout, stderr = result["stdout"], result["stderr"] - batch_conversations[batch_index].append({"role": "tool", "content": f"{stdout}{stderr}"}) - print(f"[id={completions.id},turn={turn}] Code block executed, continue...") - - # STEP 4: resubmit chat completions with code block output - extra_headers = {"x-request-id": completions.id} - await self.submit_chat_completions( - callback=callback, - callback_additional_info={ - "batch_conversations": batch_conversations, - "batch_index": batch_index, - "turn": turn + 1, - }, - model=self.model_name, - messages=batch_conversations[batch_index], - extra_headers=extra_headers, - **kwargs, - ) - - tasks, batch_conversations = [], [None] * len(batch) - for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"]): - # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] - batch_conversations[batch_index] = [{"role": "system", "content": self.system_prompt}] + list(conversation) - tasks.append( - asyncio.create_task( - self.submit_chat_completions( - callback=callback, - callback_additional_info={ - "batch_conversations": batch_conversations, - "batch_index": batch_index, - "turn": 1, - }, - model=self.model_name, - messages=batch_conversations[batch_index], - **kwargs, - ) - ) - ) - - await asyncio.gather(*tasks) - print("[NaiveChatCompletionScheduler] generate_sequences done") - - # _postprocess assumes n>=1 - batch_conversations = [[conversation] for conversation in batch_conversations] - return self._postprocess(batch, batch_conversations, kwargs["n"]) - - -system_prompt = """ -You are a helpful assistant. Let's solve math problem in following steps: -1. Write a python code first and return the code to user, the code must be in following format: - - -```python -import os - -print(...) -``` - - -The code must explictly print necessary output to stdout. Remember stop generation at immediately and return the code. -2. User will send the python code to a external sandbox to execute and get output from stdout. -3. User will send the output in format output to you, and you should use the output to answer the question. -The answer format must be: \\boxed{'The final answer goes here.'} -""" - - -def test_vllm_tool_calling(): - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "INFO", - "VLLM_USE_V1": "1", - } - } - ) - - # Load config - config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml") - config.actor_rollout_ref.model.path = "Qwen/Qwen2-7B-Instruct" - config.actor_rollout_ref.rollout.mode = "async" - config.actor_rollout_ref.rollout.chat_scheduler = "tests.workers.rollout.test_vllm_tool_calling.ToolChatCompletionScheduler" - config.actor_rollout_ref.rollout.prompt_length = 8192 - config.actor_rollout_ref.rollout.response_length = 8192 - - # Init sandbox and async rollout manager - sandbox = Sandbox.options(num_cpus=1).remote() - sandbox_address = ray.get(sandbox.get_server_address.remote()) - async_rollout_manager = init_async_rollout_manager(config, scheduler_kwargs={"sandbox_address": sandbox_address, "system_prompt": system_prompt}) - - # Build dataset - dataset = load_dataset("Maxwell-Jia/AIME_2024", split="train") - prompts = DataProto(non_tensor_batch={"raw_prompt": np.array([[{"role": "user", "content": problem}] for problem in dataset["Problem"]])}) - - result = async_rollout_manager.generate_sequences(prompts=prompts) - assert len(result) == len(dataset) - - -if __name__ == "__main__": - test_vllm_tool_calling() diff --git a/verl/tools/base_tool.py b/verl/tools/base_tool.py index e627b99f2bc..6642a5329d4 100644 --- a/verl/tools/base_tool.py +++ b/verl/tools/base_tool.py @@ -12,9 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Tuple +import importlib +import json +import sys +from typing import Any, List, Optional, Tuple from uuid import uuid4 +from omegaconf import OmegaConf + from .schemas import OpenAIFunctionToolSchema @@ -32,8 +37,10 @@ class BaseTool: def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): self.config = config - self.name = tool_schema.function.name - self.tool_schema = tool_schema + self.tool_schema = tool_schema or self.get_openai_tool_schema() + assert self.tool_schema is not None, "Tool schema is not set!" + self.name = self.tool_schema.function.name + print(json.dumps(self.tool_schema.model_dump(exclude_unset=True, exclude_none=True), indent=2)) def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: return self.tool_schema @@ -84,3 +91,41 @@ async def release(self, instance_id: str, **kwargs) -> None: instance_id: The instance id of the tool. """ pass + + +def initialize_tools_from_config(tools_config_file) -> List[BaseTool]: + """Initialize tools from config file. + + Args: + tools_config_file: The config file of the tools. + + Returns: + A list of tools. + """ + tools_config = OmegaConf.load(tools_config_file) + + tool_list = [] + for tool_config in tools_config.tools: + cls_name = tool_config.class_name + module_name, class_name = cls_name.rsplit(".", 1) + + if module_name not in sys.modules: + spec = importlib.util.find_spec(module_name) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + else: + module = sys.modules[module_name] + + tool_cls = getattr(module, class_name) + + if tool_config.get("tool_schema", None) is None: + tool_schema = None + else: + tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True) + tool_schema = OpenAIFunctionToolSchema.parse_obj(tool_schema_dict) + + tool = tool_cls(config=OmegaConf.to_container(tool_config.config, resolve=True), tool_schema=tool_schema) + tool_list.append(tool) + + return tool_list diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 929ffd43e67..e0814c35119 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -98,7 +98,6 @@ actor_rollout_ref: rollout: name: vllm mode: sync # sync: LLM, async: AsyncLLM - chat_scheduler: null # async chat scheduler, e.g examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler temperature: 1.0 top_k: -1 # 0 for hf rollout, -1 for vllm rollout top_p: 1 @@ -144,6 +143,7 @@ actor_rollout_ref: max_turns: null # null for no limit (default max_length // 3) tool_config_path: null # null for no tool format: chatml # chatml, more formats will be supported in the future + completion_callback: null # null for default callback critic: rollout_n: ${actor_rollout_ref.rollout.n} diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py index 6ff42d7e02a..f246f47c354 100644 --- a/verl/trainer/ppo/metric_utils.py +++ b/verl/trainer/ppo/metric_utils.py @@ -50,12 +50,12 @@ def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: def _compute_response_info(batch: DataProto) -> Dict[str, Any]: """ Computes information about prompts and responses from a batch. - + This is an internal helper function that extracts masks and lengths for prompts and responses. - + Args: batch: A DataProto object containing batch data with responses and attention masks. - + Returns: A dictionary containing: - response_mask: Attention mask for the response tokens @@ -172,9 +172,9 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]: """ Computes timing metrics for different processing stages in PPO training. - - This function calculates both raw timing metrics (in seconds) and per-token timing metrics - (in milliseconds) for various processing stages like generation, reference computation, + + This function calculates both raw timing metrics (in seconds) and per-token timing metrics + (in milliseconds) for various processing stages like generation, reference computation, value computation, advantage computation, and model updates. Args: @@ -211,23 +211,23 @@ def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Di def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]: """ Computes throughput metrics for PPO training. - + This function calculates performance metrics related to token processing speed, including the total number of tokens processed, time per step, and throughput (tokens per second per GPU). - + Args: batch: A DataProto object containing batch data with meta information about token counts. timing_raw: A dictionary mapping stage names to their execution times in seconds. Must contain a "step" key with the total step time. n_gpus: Number of GPUs used for training. - + Returns: A dictionary containing: - perf/total_num_tokens: Total number of tokens processed in the batch - perf/time_per_step: Time taken for the step in seconds - perf/throughput: Tokens processed per second per GPU - + Note: The throughput is calculated as total_tokens / (time * n_gpus) to normalize across different GPU counts. @@ -324,12 +324,12 @@ def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> flo def process_validation_metrics(data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42) -> dict[str, dict[str, dict[str, float]]]: """ Process validation metrics into a structured format with statistical analysis. - + This function organizes validation metrics by data source and prompt, then computes various statistical measures including means, standard deviations, best/worst values, and majority voting results. It also performs bootstrap sampling to estimate statistics for different sample sizes. - + Args: data_sources: List of data source identifiers for each sample. sample_inputs: List of input prompts corresponding to each sample. @@ -345,7 +345,7 @@ def process_validation_metrics(data_sources: list[str], sample_inputs: list[str] } } } - + Where metric_name includes: - "mean@N": Mean value across N samples - "std@N": Standard deviation across N samples @@ -355,7 +355,7 @@ def process_validation_metrics(data_sources: list[str], sample_inputs: list[str] - "worst@N/std": Standard deviation of the worst values in bootstrap samples - "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists) - "maj@N/std": Standard deviation of majority voting results (if "pred" exists) - + Example: >>> data_sources = ["source1", "source1", "source2"] >>> sample_inputs = ["prompt1", "prompt1", "prompt2"] diff --git a/verl/workers/rollout/async_server.py b/verl/workers/rollout/async_server.py index d693f39242b..441d59f0ad5 100644 --- a/verl/workers/rollout/async_server.py +++ b/verl/workers/rollout/async_server.py @@ -12,31 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio -import heapq -import importlib import logging import os import socket import threading from abc import ABC, abstractmethod from contextlib import asynccontextmanager -from typing import Any, Callable, Dict, List, Tuple, Type -from uuid import uuid4 +from typing import Any, Dict, List, Tuple, Type -import aiohttp import fastapi import ray import uvicorn -from cachetools import LRUCache from omegaconf import DictConfig -from openai import AsyncOpenAI -from openai.types.chat.chat_completion import ChatCompletion from starlette.requests import Request from verl.protocol import DataProto from verl.single_controller.ray.base import RayWorkerGroup -from verl.utils import hf_tokenizer -from verl.utils.fs import copy_to_local +from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler logger = logging.getLogger(__file__) @@ -59,7 +51,7 @@ def __init__(self): async def _start_fastapi_server(self): @asynccontextmanager async def lifespan(app: fastapi.FastAPI): - print("FastAPI startup") + print(f"FastAPI listen on {self.address}:{self.port}") self.server_ready.set() yield @@ -105,130 +97,19 @@ async def sleep(self): raise NotImplementedError -class ChatCompletionScheduler: - def __init__( - self, - config: DictConfig, - model_path: str, - server_addresses: List[str], - max_cache_size: int = 10000, - ): - """ - Args: - config: DictConfig, rollout config. - model_path: str, model path. - server_addresses: List[str], server addresses. - max_cache_size: int, max cache size of request_id to address mapping. - """ - self.config = config - self.model_name = "/".join(model_path.split("/")[-2:]) - local_path = copy_to_local(model_path) - self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True) - - # Least requests load balancing - self.weighted_addresses = [[0, address] for address in server_addresses] - heapq.heapify(self.weighted_addresses) - - # LRU cache to map request_id to address - self.request_id_to_address = LRUCache(maxsize=max_cache_size) - - async def submit_chat_completions( - self, - callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], - callback_additional_info: Dict[str, Any], - **chat_complete_request, - ): - """ - Submit a chat completion request to the server with the least number of requests. - - Args: - callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], async callback function - to handle the response. The callback function should have the following signature: - - ```python - async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): - ... - ``` - - completions: chat completion response from server. - - info: user provided `callback_additional_info`. - - exception: exception raise from OpenAI client if request failed, otherwise None. - - **CAUTION**: the callback function must be async and non-blocking, if you have any blocking operation, - please move to separate thread or process pool to avoid blocking the event loop. - - callback_additional_info: Dict[str, Any], additional info to pass to the callback function. - - **chat_complete_request: dict, request parameters same as OpenAI AsyncCompletions.create. - OpenAI API reference: https://platform.openai.com/docs/api-reference/chat/create - """ - if "extra_headers" not in chat_complete_request: - chat_complete_request["extra_headers"] = {} - - extra_headers = chat_complete_request["extra_headers"] - request_id = extra_headers.get("x-request-id", None) - if request_id: - if request_id.startswith("chatcmpl-"): - request_id = request_id[len("chatcmpl-") :] - extra_headers["x-request-id"] = request_id - - address = self.request_id_to_address.pop(request_id) - else: - address = self.weighted_addresses[0][1] - self.weighted_addresses[0][0] += 1 - heapq.heapreplace(self.weighted_addresses, self.weighted_addresses[0]) - - # use new request_id to avoid duplicate request_id problem - request_id = uuid4().hex - self.request_id_to_address[request_id] = address - chat_complete_request["extra_headers"]["x-request-id"] = request_id - - completions, exception = None, None - try: - # NOTE: OpenAI client uses httpx, seems to have performance issue in high concurrency requests. - completions = await self._chat_completions_aiohttp(address, **chat_complete_request) - except Exception as e: - # Let user handle the exception - exception = e - - await callback(completions, callback_additional_info, exception) - - async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion: - client = AsyncOpenAI(base_url=f"http://{address}/v1", api_key="token-abc123", timeout=None, max_retries=0) - return await client.chat.completions.create(**chat_complete_request) - - async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion: - try: - extra_headers = chat_complete_request.pop("extra_headers") - timeout = aiohttp.ClientTimeout(total=None) - session = aiohttp.ClientSession(timeout=timeout) - async with session.post( - url=f"http://{address}/v1/chat/completions", - headers={"Authorization": "Bearer token-abc123", **extra_headers}, - json=chat_complete_request, - ) as resp: - data = await resp.json() - return ChatCompletion(**data) - finally: - await session.close() - - async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto: - raise NotImplementedError - - class AsyncLLMServerManager: """AsyncLLMServerManager manage a group of vllm instances, i.e AsyncvLLMServer.""" - def __init__(self, config: DictConfig, worker_group: RayWorkerGroup, *, scheduler_kwargs: Dict[str, Any] = None): + def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): """Initialize AsyncLLMServerManager. Args: config: DictConfig, actor_rollout_ref config. worker_group: RayWorkerGroup, worker group of AsyncActorRolloutRefWorker. - scheduler_kwargs: Dict[str, Any], kwargs for chat scheduler. """ - self.config = config + self.full_config = config + self.config = config.actor_rollout_ref self.worker_group = worker_group - self.scheduler_kwargs = scheduler_kwargs if scheduler_kwargs else {} self.rollout_tp_size = self.config.rollout.tensor_model_parallel_size self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size @@ -284,14 +165,9 @@ def _init_chat_scheduler(self): self.chat_scheduler_loop = asyncio.new_event_loop() asyncio.set_event_loop(self.chat_scheduler_loop) - module_path, class_name = self.config.rollout.chat_scheduler.rsplit(".", 1) - module = importlib.import_module(module_path) - scheduler_cls = getattr(module, class_name) - self.chat_scheduler = scheduler_cls( - config=self.config.rollout, - model_path=self.config.model.path, + self.chat_scheduler = ChatCompletionScheduler( + config=self.full_config, server_addresses=self.server_addresses, - **self.scheduler_kwargs, ) self.chat_scheduler_ready.set() @@ -307,9 +183,8 @@ def sleep(self): def submit_chat_completions( self, - callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], - callback_additional_info: Dict[str, Any], - **chat_complete_request, + messages: List[Dict[str, str]], + sampling_params: Dict[str, Any], ): """Submit a chat completion request to chat scheduler and wait until it is done. To submit multiple requests in parallel, please use `generate_sequences` instead. @@ -318,10 +193,10 @@ def submit_chat_completions( """ assert self.chat_scheduler is not None, "chat scheduler is not initialized." future = asyncio.run_coroutine_threadsafe( - self.chat_scheduler.submit_chat_completions( - callback=callback, - callback_additional_info=callback_additional_info, - **chat_complete_request, + self.chat_scheduler._submit_chat_completions_semaphore( + messages=messages, + request_id=None, + sampling_params=sampling_params, ), self.chat_scheduler_loop, ) diff --git a/verl/workers/rollout/chat_scheduler.py b/verl/workers/rollout/chat_scheduler.py new file mode 100644 index 00000000000..1426f286fbe --- /dev/null +++ b/verl/workers/rollout/chat_scheduler.py @@ -0,0 +1,418 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import heapq +import importlib +import itertools +import json +import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, List +from uuid import uuid4 + +import aiohttp +import numpy as np +import torch +from cachetools import LRUCache +from omegaconf import DictConfig +from openai import AsyncOpenAI +from openai.types.chat.chat_completion import ChatCompletion +from tensordict import TensorDict + +from verl.protocol import DataProto +from verl.tools.base_tool import initialize_tools_from_config +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_to_local + +logger = logging.getLogger(__file__) + + +class CompletionCallback(ABC): + def __init__(self, config: DictConfig, scheduler: "ChatCompletionScheduler"): + self.config = config + self.scheduler = scheduler + + # Initialize tools from config file + self.max_turns = config.actor_rollout_ref.rollout.multi_turn.max_turns + tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path + tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else [] + self.tools = {tool.name: tool for tool in tool_list} + self._tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list] + print(f"Initialized tools: {self.tools}", flush=True) + + local_path = copy_to_local(config.actor_rollout_ref.model.path) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True) + + @property + def tool_schemas(self): + """OpenAI JSON tool schemas.""" + return self._tool_schemas + + @property + def extra_body(self) -> Dict[str, Any]: + """Extra body pass to OpenAI API.""" + return None + + @abstractmethod + async def __call__(self, messages: List[Dict[str, str]], completions: ChatCompletion, info: Dict[str, Any]): + """Call back function to process completions. + + Args: + messages: List of messages including raw prompt and assistant, tool response generated so far. + completions: Chat completions from OpenAI compatible server. + info: Any other auxiliary information pass across multi-turn. + """ + raise NotImplementedError + + @abstractmethod + def postprocess(self, batch: DataProto, batch_conversations: List[List[Dict[str, str]]], n: int) -> DataProto: + """Post process batch data. + + Args: + batch: Batch input messages from RLHFDataset. + batch_conversations: List of messages including raw prompt, assistant response, tool response. + Note that `len(batch_conversations) == len(batch) * n`, e.g n=2, + batch_conversations=[messages_0_0, messages_0_1, messages_1_0, messages_1_1, ...] + n: How many chat completion choices to generate for each input message. + + Returns: + Batch data, should include ["prompts", "responses", "response_mask", "input_ids", "attention_mask", "position_ids"]. + """ + raise NotImplementedError + + +class ToolCompletionCallback(CompletionCallback): + def __init__(self, config: DictConfig, scheduler: "ChatCompletionScheduler"): + super().__init__(config, scheduler) + + # TODO: add reward manager to calculate reward score once a sample finish + + async def __call__(self, messages: List[Dict[str, str]], completions: ChatCompletion, info: Dict[str, Any]): + message = completions.choices[0].message.model_dump(exclude_unset=True, exclude_none=True) + if "content" not in message: + message["content"] = "" + messages.append(message) + finish_reason = completions.choices[0].finish_reason + + # STEP 0: check if we reach max turns + if self.max_turns and len(messages) >= self.max_turns: + print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Reach max turns, done!") + return + + # STEP 1: check if the model called tools + if finish_reason != "tool_calls": + print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] No tool called, done!") + return + + # STEP 2: call tools + tool_calls = completions.choices[0].message.tool_calls + print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Call {len(tool_calls)} tools") + tasks = [] + for tool_call in tool_calls: + tasks.append(self._call_tool(tool_call)) + tool_responses = await asyncio.gather(*tasks) + if any(isinstance(item, Exception) for item in tool_responses): + print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Error when calling tools, done!") + return + messages.extend(tool_responses) + + # STEP 3: resubmit completion request with tool responses + self.scheduler.submit_chat_completions(messages=messages, request_id=completions.id, info=info) + + async def _call_tool(self, tool_call) -> Dict[str, str]: + """Call tool and return tool response.""" + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) + tool = self.tools[tool_name] + + instance_id = await tool.create() + try: + tool_response, tool_reward_score, tool_metrics = await tool.execute(instance_id, tool_args) + except Exception as e: + logger.exception(f"Error when executing tool: {e}") + return e + finally: + await tool.release(instance_id) + + return { + "role": "tool", + "content": tool_response, + "tool_call_id": tool_call.id, + } + + def postprocess(self, batch: DataProto, batch_conversations: List[List[Dict[str, str]]], n: int) -> DataProto: + # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py + # prompts: left pad + # responses: right pad + # input_ids: prompt + response + # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] + # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] + + # prompts: [prompt] from input dataset + prompts = [self.tokenizer.apply_chat_template(prompt, tools=self.tool_schemas, add_generation_prompt=True, tokenize=False) for prompt in batch.non_tensor_batch["raw_prompt"]] + assert len(batch_conversations) == len(prompts) * n + + # sequences: [prompt + response] + sequences = [self.tokenizer.apply_chat_template(conversation, tools=self.tool_schemas, add_generation_prompt=False, tokenize=False) for conversation in batch_conversations] + + # responses: [response] + responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)] + + prompts = self.tokenizer(prompts, return_tensors="pt", padding="longest", padding_side="left") + responses = self.tokenizer(responses, return_tensors="pt", padding="longest", padding_side="right") + if n > 1: + prompts["input_ids"] = prompts["input_ids"].repeat_interleave(n, dim=0) + prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(n, dim=0) + + # response_mask: response mask with tools calling masked out + response_mask = self._mask_out_tools_calling_tokens(batch.non_tensor_batch["raw_prompt"].repeat(n, axis=0), batch_conversations, responses["input_ids"], responses["attention_mask"]) + + input_ids = torch.cat([prompts["input_ids"], responses["input_ids"]], dim=1) + attention_mask = torch.cat([prompts["attention_mask"], responses["attention_mask"]], dim=1) + position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask + + batch = TensorDict( + { + "prompts": prompts["input_ids"], # [bsz, prompt_length] + "responses": responses["input_ids"], # [bsz, response_length] + "response_mask": response_mask, # [bsz, response_length] + "input_ids": input_ids, # [bsz, prompt_length + response_length] + "attention_mask": attention_mask, # [bsz, prompt_length + response_length] + "position_ids": position_ids, # [bsz, prompt_length + response_length] + }, + batch_size=len(input_ids), + ) + + num_turns = np.array([len(conversation) for conversation in batch_conversations], dtype=np.int32) + return DataProto(batch=batch, non_tensor_batch={"__num_turns__": num_turns}) + + def _mask_out_tools_calling_tokens( + self, + raw_prompts: List[List[Dict[str, str]]], + batch_conversations: List[List[Dict[str, str]]], + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + """Mask out tools calling tokens in the responses. + + Args: + raw_prompts: [prompt] from input dataset + batch_conversations: [prompt + response] + input_ids: responses tokens + attention_mask: responses attention mask + + Returns: + mask: (batch_size, response_length) + """ + batch_size = input_ids.size(0) + assert len(raw_prompts) == batch_size, f"{len(raw_prompts)} != {batch_size}" + assert len(batch_conversations) == batch_size, f"{len(batch_conversations)} != {batch_size}" + + # Deduplicate adjacent tool calls, since they're merged into one turn. + # [user, assistant, tool, tool, assistant] -> [user, assistant, tool, assistant] + # TODO: it's chat_template specific, find a more generic way to do this. + def deduplicate_adjacent_tool_calls(roles): + result = [] + for role, group in itertools.groupby(roles): + if role == "tool": + result.append(role) + else: + result.extend(group) + return result + + loss_mask = attention_mask.clone() + for i in range(batch_size): + responses = batch_conversations[i][len(raw_prompts[i]) :] + assert len(responses) > 0, f"responses is empty: {responses}" + + roles = deduplicate_adjacent_tool_calls([response["role"] for response in responses]) + # Each turn should be: [BOS]...[EOS] + eos_indices = input_ids[i].eq(self.tokenizer.eos_token_id).nonzero().squeeze(1)[: len(roles)] + for j in range(len(roles)): + if roles[j] == "tool": + bos = eos_indices[j - 1] + 1 if j > 0 else 0 + eos = eos_indices[j] + loss_mask[i, bos : eos + 1] = 0 + + return loss_mask + + +class ChatCompletionScheduler: + def __init__( + self, + config: DictConfig, + server_addresses: List[str], + max_cache_size: int = 10000, + ): + """ + Args: + config: DictConfig. + server_addresses: List[str], OpenAI compatible server addresses. + max_cache_size: int, max cache size of request_id to address mapping. + """ + self.config = config.actor_rollout_ref.rollout + model_path = config.actor_rollout_ref.model.path + self.model_name = "/".join(model_path.split("/")[-2:]) + + # Least requests load balancing + self.weighted_addresses = [[0, address] for address in server_addresses] + heapq.heapify(self.weighted_addresses) + + # LRU cache to map request_id to address + self.request_id_to_address = LRUCache(maxsize=max_cache_size) + + self.background_tasks = set() + if self.config.multi_turn.completion_callback is None: + self.completion_callback = ToolCompletionCallback(config, self) + logger.warning("completion_callback is None, use ToolCompletionCallback") + else: + module_path, class_name = self.config.multi_turn.completion_callback.rsplit(".", 1) + module = importlib.import_module(module_path) + self.completion_callback = getattr(module, class_name)(config, self) + + def submit_chat_completions(self, *, messages: List[Dict[str, str]], request_id: str, info: Dict[str, Any]): + """Submit chat completion request without wait, completion_callback will be called when the request is done. + + Args: + messages: List of messages. + request_id: Request id. + info: Any other auxiliary information pass across multi-turn. + """ + info["__depth__"] += 1 + task = asyncio.create_task(self._submit_chat_completions_and_callback(messages, request_id, info)) + + # “fire-and-forget” background tasks + self.background_tasks.add(task) + task.add_done_callback(self.background_tasks.discard) + + async def _submit_chat_completions_and_callback( + self, + messages: List[Dict[str, str]], + request_id: str, + info: Dict[str, Any], + ): + """Submit chat completion request, wait request finish and do callback.""" + if request_id: + request_id = request_id.removeprefix("chatcmpl-") + assert request_id in self.request_id_to_address + address = self.request_id_to_address.pop(request_id) + else: + address = self.weighted_addresses[0][1] + self.weighted_addresses[0][0] += 1 + heapq.heapreplace(self.weighted_addresses, self.weighted_addresses[0]) + + # use new request_id to avoid duplicate request_id problem + request_id = uuid4().hex + self.request_id_to_address[request_id] = address + + completions, exception = None, None + try: + # NOTE: OpenAI client uses httpx, seems to have performance issue in high concurrency requests. + completions = await self._chat_completions_aiohttp( + address, + messages=messages, + tools=self.completion_callback.tool_schemas, + extra_body=self.completion_callback.extra_body, + extra_headers={"x-request-id": request_id}, + **info["__sampling_params__"], + ) + except Exception as e: + # Let user handle the exception + exception = e + + info["__depth__"] -= 1 + + if exception is not None: + logger.exception(f"chat completion failed with exception: {exception}") + else: + try: + await self.completion_callback(messages, completions, info) + except Exception as e: + logger.exception(f"completion callback failed with exception: {e}") + + # No more ongoing completion requests + if info["__depth__"] == 0: + info["__done__"].set() + + async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion: + client = AsyncOpenAI(base_url=f"http://{address}/v1", api_key="token-abc123", timeout=None, max_retries=0) + return await client.chat.completions.create(**chat_complete_request) + + async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion: + try: + extra_body = chat_complete_request.pop("extra_body", {}) + chat_complete_request.update(extra_body or {}) + extra_headers = chat_complete_request.pop("extra_headers") + timeout = aiohttp.ClientTimeout(total=None) + session = aiohttp.ClientSession(timeout=timeout) + async with session.post( + url=f"http://{address}/v1/chat/completions", + headers={"Authorization": "Bearer token-abc123", **extra_headers}, + json=chat_complete_request, + ) as resp: + data = await resp.json() + return ChatCompletion(**data) + finally: + await session.close() + + async def generate_sequences(self, batch: DataProto) -> DataProto: + kwargs = dict( + model=self.model_name, + temperature=self.config.temperature, + top_p=self.config.top_p, + ) + + # override sampling params for validation + if batch.meta_info.get("validate", False): + kwargs["top_p"] = self.config.val_kwargs.top_p + kwargs["temperature"] = self.config.val_kwargs.temperature + + print(f"[ChatCompletionScheduler] generate_sequences sampling params: {kwargs}") + + # NOTE: For multi-turn rollout, repeat raw_prompt n times and process each prompt independently, + # validation dataset has already been repeated in `PPOTrainer._validate`. + n = 1 if batch.meta_info.get("validate", False) else self.config.n + tasks, batch_conversations = [], [None] * len(batch) * n + for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"].repeat(n, axis=0)): + # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] + batch_conversations[batch_index] = conversation.tolist() + + tasks.append( + asyncio.create_task( + self._submit_chat_completions_semaphore( + messages=batch_conversations[batch_index], + request_id=None, + sampling_params=kwargs, + ) + ) + ) + + await asyncio.gather(*tasks) + print("[ChatCompletionScheduler] generate_sequences done") + + return self.completion_callback.postprocess(batch, batch_conversations, n=n) + + async def _submit_chat_completions_semaphore(self, messages: List[Dict[str, str]], request_id: str, sampling_params: Dict[str, Any]): + done = asyncio.Event() + + info = { + "__done__": done, + "__depth__": 0, # indicate how many ongoing completion requests + "__sampling_params__": sampling_params, + } + + self.submit_chat_completions(messages=messages, request_id=request_id, info=info) + + # Wait until all completion requests are done + await done.wait() diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 607c9769540..16e57a89e61 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from collections.abc import AsyncGenerator from typing import Any, Callable, Dict, List, Optional, Tuple, Union import cloudpickle @@ -122,14 +121,14 @@ class AsyncvLLMServer(AsyncServerBase): def __init__(self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_prefix: str): """ Args: - config: DictConfig, actor_rollout_ref config. + config: DictConfig. vllm_dp_size: int, vllm data parallel size. vllm_dp_rank: int, vllm data parallel rank. wg_prefix: str, worker group prefix, used to lookup actors. """ super().__init__() - self.config = config + self.config = config.actor_rollout_ref self.vllm_dp_size = vllm_dp_size self.vllm_dp_rank = vllm_dp_rank self.wg_prefix = wg_prefix @@ -154,7 +153,7 @@ async def init_engine(self): kwargs = dict( n=1, logprobs=0, - max_tokens=config.response_length, + max_new_tokens=config.response_length, ) for k in config.keys(): if hasattr(SamplingParams(), str(k)): @@ -201,6 +200,8 @@ async def init_engine(self): request_logger=RequestLogger(max_log_len=4096), chat_template=None, chat_template_content_format="auto", + enable_auto_tools=True, + tool_parser=config.multi_turn.format, # hermes, llama3_json, ... ) async def chat_completion(self, raw_request: Request): @@ -220,28 +221,6 @@ async def chat_completion(self, raw_request: Request): assert isinstance(generator, ChatCompletionResponse) return JSONResponse(content=generator.model_dump()) - async def chat_completion_generator(self, request: ChatCompletionRequest) -> AsyncGenerator[Tuple[int, str]]: - """Direct chat completion without FastAPI. - - Args: - request: ChatCompletionRequest, request object. - - Returns: - AsyncGenerator[Tuple[int, str]]: async generator of (status_code, data) pairs. - """ - generator = await self.openai_serving_chat.create_chat_completion(request) - if isinstance(generator, ErrorResponse): - data = generator.model_dump_json(exclude_unset=True) - yield generator.code, f"data: {data}\n\n" - - if request.stream: - async for chunk in generator: - yield 200, chunk - else: - assert isinstance(generator, ChatCompletionResponse) - data = generator.model_dump_json(exclude_unset=True) - yield 200, f"data: {data}\n\n" - async def wake_up(self): await self.engine.wake_up()