diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 817ebcf92f8..b848d3e52c7 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -252,3 +252,209 @@ def tearDownClass(cls): child.send_signal(signal.SIGTERM) cls.server_process.terminate() cls.server_process.wait() + + +@pytest.mark.slow +@require_torch_multi_gpu +class TestVLLMClientServerAsync(unittest.TestCase): + model_id = "Qwen/Qwen2.5-1.5B" + + @classmethod + def setUpClass(cls): + # We want the server to run on GPU 1, so we set CUDA_VISIBLE_DEVICES to "1" + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = "1" # Restrict to GPU 1 + + # Start the server process + cls.server_process = subprocess.Popen( + ["trl", "vllm-serve-async", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env + ) + + #Initialize the client + cls.client = VLLMClient(connection_timeout=240) + cls.client.init_communicator() + + def test_generate(self): + prompt = "Hello, AI! Tell me a joke." + response = self.client.session.post( + url="http://localhost:8000/v1/completions", + json={ + "model": self.model_id, + "prompt": prompt, + "max_tokens": 50 + } + ) + response.raise_for_status() + response_json = response.json() + + # Check basic response structure + self.assertIn("choices", response_json) + self.assertGreater(len(response_json["choices"]), 0) + + # Check that we got a non-empty text response + first_choice = response_json["choices"][0] + self.assertIn("text", first_choice) + self.assertGreater(len(first_choice["text"]), 0) + + def test_update_model_params(self): + model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda") + self.client.update_model_params(model) + + def test_reset_prefix_cache(self): + # Test resetting the prefix cache + self.client.reset_prefix_cache() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + # Close the client + cls.client.close_communicator() + + # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to + # kill the server process and its children explicitly. + parent = psutil.Process(cls.server_process.pid) + children = parent.children(recursive=True) + for child in children: + child.send_signal(signal.SIGTERM) + cls.server_process.terminate() + cls.server_process.wait() + + +@pytest.mark.slow +@require_3_gpus +class TestVLLMClientAsyncServerTP(unittest.TestCase): + model_id = "Qwen/Qwen2.5-1.5B" + + @classmethod + def setUpClass(cls): + # We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2" + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2 + + # Start the server process + cls.server_process = subprocess.Popen( + ["trl", "vllm-serve-async", "--model", cls.model_id, "--tensor_parallel_size", "2"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) + + # Initialize the client + cls.client = VLLMClient(connection_timeout=240) + cls.client.init_communicator() + + def test_generate(self): + prompt = "Hello, AI! Tell me a joke." + response = self.client.session.post( + url="http://localhost:8000/v1/completions", + json={ + "model": self.model_id, + "prompt": prompt, + "max_tokens": 50 + } + ) + response.raise_for_status() + response_json = response.json() + + # Check basic response structure + self.assertIn("choices", response_json) + self.assertGreater(len(response_json["choices"]), 0) + + # Check that we got a non-empty text response + first_choice = response_json["choices"][0] + self.assertIn("text", first_choice) + self.assertGreater(len(first_choice["text"]), 0) + + def test_update_model_params(self): + model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda") + self.client.update_model_params(model) + + def test_reset_prefix_cache(self): + # Test resetting the prefix cache + self.client.reset_prefix_cache() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + # Close the client + cls.client.close_communicator() + + # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to + # kill the server process and its children explicitly. + parent = psutil.Process(cls.server_process.pid) + children = parent.children(recursive=True) + for child in children: + child.send_signal(signal.SIGTERM) + cls.server_process.terminate() + cls.server_process.wait() + + +@pytest.mark.slow +@require_3_gpus +class TestVLLMClientAsyncServerDP(unittest.TestCase): + model_id = "Qwen/Qwen2.5-1.5B" + + @classmethod + def setUpClass(cls): + # We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2" + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2 + + # Start the server process + cls.server_process = subprocess.Popen( + ["trl", "vllm-serve-async", "--model", cls.model_id, "--data_parallel_size", "2"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) + + # Initialize the client + cls.client = VLLMClient(connection_timeout=240) + + def test_generate(self): + prompt = "Hello, AI! Tell me a joke." + response = self.client.session.post( + url="http://localhost:8000/v1/completions", + json={ + "model": self.model_id, + "prompt": prompt, + "max_tokens": 50 + } + ) + response.raise_for_status() + response_json = response.json() + + # Check basic response structure + self.assertIn("choices", response_json) + self.assertGreater(len(response_json["choices"]), 0) + + # Check that we got a non-empty text response + first_choice = response_json["choices"][0] + self.assertIn("text", first_choice) + self.assertGreater(len(first_choice["text"]), 0) + + def test_update_model_params(self): + model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda") + self.client.update_model_params(model) + + def test_reset_prefix_cache(self): + # Test resetting the prefix cache + self.client.reset_prefix_cache() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + # Close the client + cls.client.close_communicator() + + # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to + # kill the server process and its children explicitly. + parent = psutil.Process(cls.server_process.pid) + children = parent.children(recursive=True) + for child in children: + child.send_signal(signal.SIGTERM) + cls.server_process.terminate() + cls.server_process.wait() diff --git a/trl/cli.py b/trl/cli.py index 7d85d570e68..3efc5337a9d 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -29,6 +29,8 @@ from .scripts.utils import TrlParser from .scripts.vllm_serve import main as vllm_serve_main from .scripts.vllm_serve import make_parser as make_vllm_serve_parser +from .scripts.vllm_serve_async import main as vllm_serve_async_main +from .scripts.vllm_serve_async import make_parser as make_vllm_serve_async_parser def main(): @@ -45,6 +47,7 @@ def main(): make_kto_parser(subparsers) make_sft_parser(subparsers) make_vllm_serve_parser(subparsers) + make_vllm_serve_async_parser(subparsers) # Parse the arguments; the remaining ones (`launch_args`) are passed to the 'accelerate launch' subparser. # Duplicates may occur if the same argument is provided in both the config file and CLI. @@ -138,7 +141,13 @@ def main(): ) vllm_serve_main(script_args) + + elif args.command == "vllm-serve-async": + # Here we defer to vllm's argument parser, so that we don't have to reimplement all of its logic + sys.argv = ["trl/scripts/vllm_serve_async.py"] + launch_args + vllm_serve_async_main() if __name__ == "__main__": main() + \ No newline at end of file diff --git a/trl/import_utils.py b/trl/import_utils.py index dd594e36ef6..7bd150f6421 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -35,6 +35,7 @@ _requests_available = _is_package_available("requests") _unsloth_available = _is_package_available("unsloth") _uvicorn_available = _is_package_available("uvicorn") +_uvloop_available = _is_package_available("uvloop") _vllm_available = _is_package_available("vllm") _vllm_ascend_available = _is_package_available("vllm_ascend") _joblib_available = _is_package_available("joblib") @@ -80,6 +81,10 @@ def is_uvicorn_available() -> bool: return _uvicorn_available +def is_uvloop_available() -> bool: + return _uvloop_available + + def is_vllm_available() -> bool: return _vllm_available diff --git a/trl/scripts/vllm_serve_async.py b/trl/scripts/vllm_serve_async.py new file mode 100644 index 00000000000..22dd0234e05 --- /dev/null +++ b/trl/scripts/vllm_serve_async.py @@ -0,0 +1,332 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +TRL Script for Running a vLLM OpenAI-Compatible API Server with Weight Synchronization. + +This script provides a RESTful API server that is compatible with the OpenAI API +standard, built upon the vLLM library. It largely mirrors the functionality of +vLLM's own `vllm.entrypoints.openai.api_server` but extends it with crucial +endpoints specifically designed for TRL's online training or reinforcement +learning workflows. + +Core Functionality Leveraged from vLLM OpenAI Server: +- Uses `vllm.entrypoints.openai.api_server.build_app` to create the base FastAPI app. +- Uses `vllm.entrypoints.openai.api_server.init_app_state` to initialize server state. +- Uses `vllm.entrypoints.openai.api_server.build_async_engine_client` for engine setup. +- Uses `vllm.entrypoints.openai.api_server.serve_http` to run the Uvicorn server. +- The overall `run_server` function structure is adapted from the original script. + +TRL-Specific Additions & Differences from `trl/scripts/vllm_serve.py`: +- This script adds endpoints for TRL weight synchronization (`/init_communicator/`, + `/update_named_param/`, etc.), enabling dynamic model updates during training or RL. + The underlying mechanism (`collective_rpc`) closely mirrors the implementation + in the original `trl/scripts/vllm_serve.py`. +- However, unlike the original `trl/scripts/vllm_serve.py`, this script adopts + the standard vLLM OpenAI API server structure (`build_app`, `init_app_state`). +- As a result, it leverages the standard OpenAI-compatible endpoints + (e.g., `/v1/chat/completions`, `/v1/completions`) for inference and does *not* + include the custom synchronous `/generate` endpoint previously found in + `trl/scripts/vllm_serve.py`. +""" + +import os +import signal +import argparse +import warnings + +import torch + +from trl import TrlParser +from trl.import_utils import ( + is_fastapi_available, + is_pydantic_available, + is_uvicorn_available, + is_uvloop_available, + is_vllm_available, +) + + +if is_fastapi_available(): + from fastapi import BackgroundTasks, FastAPI + + +if is_pydantic_available(): + from pydantic import BaseModel + +if is_uvloop_available(): + import uvloop + + +if is_vllm_available(): + from vllm.logger import init_logger + from vllm.utils import FlexibleArgumentParser, set_ulimit, is_valid_ipv6_address + from vllm.v1.engine.async_llm import AsyncLLM + from vllm.reasoning import ReasoningParserManager + from vllm.entrypoints.openai.tool_parsers import ToolParserManager + from vllm.entrypoints.openai.api_server import ( + build_app, + build_async_engine_client, + serve_http, + init_app_state, + create_server_socket, + cli_env_setup, + make_arg_parser, + ) + from vllm.entrypoints.openai.cli_args import ( + make_arg_parser, + validate_parsed_serve_args, + ) + from vllm.version import __version__ as VLLM_VERSION + +TIMEOUT_KEEP_ALIVE = 5 # seconds + + +logger = init_logger(__name__) + +# We use CUDA with multiprocessing, so we must use the 'spawn' start method. Otherwise, we will get the following +# error: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use +# the 'spawn' start method +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +def add_vllm_client_endpoints(app: FastAPI, llm: AsyncLLM): + @app.get("/get_world_size/") + async def get_world_size(): + """ + Retrieves the tensor parallel size from the LLM engine. + + Returns: + `dict`: + A dictionary containing the world size. + + Example response: + ```json + {"world_size": 8} + ``` + """ + tp = llm.vllm_config.parallel_config.tensor_parallel_size + dp = llm.vllm_config.parallel_config.data_parallel_size + return {"world_size": tp * dp} + + class InitCommunicatorRequest(BaseModel): + host: str + port: int + world_size: int + + @app.post("/init_communicator/") + async def init_communicator(request: InitCommunicatorRequest, background_tasks: BackgroundTasks): + """ + Initializes the communicator for synchronizing model weights between a client and multiple server + workers. + + Args: + request (`InitCommunicatorRequest`): + - `host` (`str`): Hostname or IP address of the master node. + - `port` (`int`): Port number to be used for communication. + - `world_size` (`int`): Total number of participating processes in the group. + """ + tp = llm.vllm_config.parallel_config.tensor_parallel_size + dp = llm.vllm_config.parallel_config.data_parallel_size + background_tasks.add_task( + llm.engine_core.collective_rpc_async, + "init_communicator", + args=(request.host, request.port, tp * dp + 1) + ) + return {"message": "Request received, initializing communicator"} + + class UpdateWeightsRequest(BaseModel): + name: str + dtype: str + shape: list[int] + + @app.post("/update_named_param/") + async def update_named_param(request: UpdateWeightsRequest, background_tasks: BackgroundTasks): + """ + Updates the model weights with the provided tensor. + + Once this endpoint is called, the client process should broadcast the updated weights to all server workers. + + Args: + request (`UpdateWeightsRequest`): + - `name` (`str`): Name of the weight tensor being updated. + - `dtype` (`str`): Data type of the weight tensor (e.g., `"torch.float32"`). + - `shape` (list of `int`): Shape of the weight + + """ + # The function is called this way: update_named_param(name="name", dtype=torch.float32, shape=(10, 10)) + # So with collect_rpc we need to call it this way: + # llm.collective_rpc("update_named_param", args=("name", torch.float32, (10, 10))) + # And with background_tasks.add_task we need to call it this way: + # background_tasks.add_task(llm.collective_rpc, "update_named_param", args=("name", torch.float32, (10, 10))) + dtype = torch.__getattribute__(request.dtype.split(".")[-1]) + background_tasks.add_task( + llm.engine_core.collective_rpc_async, + "update_named_param", + args=(request.name, dtype, request.shape) + ) + return {"message": "Request received, updating named parameter"} + + @app.post("/reset_prefix_cache/") + async def reset_prefix_cache(background_tasks: BackgroundTasks): + """ + Resets the prefix cache for the model. + """ + background_tasks.add_task( + llm.engine_core.reset_prefix_cache_async + ) + return {"message": "Request received, resetting prefix cache"} + + @app.post("/close_communicator/") + async def close_communicator(background_tasks: BackgroundTasks): + """ + Closes the weight update group and cleans up associated resources. + """ + background_tasks.add_task( + llm.engine_core.collective_rpc_async, + "close_communicator" + ) + return {"message": "Request received, closing communicator"} + + +async def run_server(args, **uvicorn_kwargs): + if not is_fastapi_available(): + raise ImportError( + "FastAPI is required to run the vLLM serve script. Please install it using `pip install fastapi`." + ) + + if not is_pydantic_available(): + raise ImportError( + "Pydantic is required to run the vLLM serve script. Please install it using `pip install pydantic`." + ) + + if not is_uvicorn_available(): + raise ImportError( + "Uvicorn is required to run the vLLM serve script. Please install it using `pip install uvicorn`." + ) + + if not is_uvloop_available(): + raise ImportError( + "Uvloop is required to run the vLLM serve script. Please install it using `pip install uvloop`." + ) + + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + valid_tool_parses = ToolParserManager.tool_parsers.keys() + if args.enable_auto_tool_choice \ + and args.tool_call_parser not in valid_tool_parses: + raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valid_tool_parses)} }})") + + valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys() + if args.enable_reasoning \ + and args.reasoning_parser not in valid_reasoning_parses: + raise KeyError( + f"invalid reasoning parser: {args.reasoning_parser} " + f"(chose from {{ {','.join(valid_reasoning_parses)} }})") + + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + async with build_async_engine_client(args) as engine_client: + app = build_app(args) + + add_vllm_client_endpoints(app, engine_client) # This is essentially the only difference from vllm original code + + vllm_config = await engine_client.get_vllm_config() + + await init_app_state(engine_client, vllm_config, app.state, args) + + def _listen_addr(a: str) -> str: + if is_valid_ipv6_address(a): + return '[' + a + ']' + return a or "0.0.0.0" + + is_ssl = args.ssl_keyfile and args.ssl_certfile + logger.info("Starting vLLM API server on http%s://%s:%d", + "s" if is_ssl else "", _listen_addr(sock_addr[0]), + sock_addr[1]) + + shutdown_task = await serve_http( + app, + sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + # NOTE: When the 'disable_uvicorn_access_log' value is True, + # no access log will be output. + access_log=not args.disable_uvicorn_access_log, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + try: + await shutdown_task + finally: + sock.close() + +def make_parser(subparsers: argparse._SubParsersAction = None): + if subparsers is not None: + parser = subparsers.add_parser("vllm-serve-async", add_help=False, help="Runs vLLM's OpenAI-compatible server with weight syncing.") + else: + parser = TrlParser() + return parser + +def main(): + cli_env_setup() + + parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible RESTful API server with weight syncing.") + parser = make_arg_parser(parser) + args = parser.parse_args() + validate_parsed_serve_args(args) + + # We can use the same worker extension class + args.worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension" + + # We encounter the same problem as in vllm_serve.py + if args.tensor_parallel_size == 1 and args.data_parallel_size > 1: + warnings.warn( + "Detected configuration: tensor_parallel_size=1 and data_parallel_size>1. This setup is known to " + "cause a crash when using the `trl vllm-serve` CLI entry point. As a workaround, please run the " + "server using the module path instead: `python -m trl.scripts.vllm_serve`", + RuntimeWarning, + ) + + uvloop.run(run_server(args)) + +if __name__ == "__main__": + main() + diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index fafdc30cb4b..f7dd009cb2e 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -86,6 +86,8 @@ class GRPOConfig(TrainingArguments): tokens. cache_implementation (`str` or `None`, *optional*, defaults to `None`): Implementation of the cache method for faster generation when use_vllm is set to False. + multi_turn (`bool`, *optional*, defaults to `False`): + (async WIP) Whether generations are multiturn. Controls how padding is masked in the generation. > Parameters that control generation acceleration powered by vLLM @@ -93,11 +95,14 @@ class GRPOConfig(TrainingArguments): Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation instead of the default model.generate(). Requires `vllm` to be installed. vllm_mode (`str`, *optional*, defaults to `"server"`): - Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or + Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"`, `"async_server"`, or `"colocate"`. - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM server is running (start with `trl vllm-serve`). + - (async WIP) `"async_server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM + server is running (start with `trl vllm-serve-async`). + The users writes their own .generate() methods by utilizing the openai-compatible API. - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a separate server but may cause resource contention with training. vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): @@ -309,6 +314,10 @@ class GRPOConfig(TrainingArguments): default=None, metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."}, ) + multi_turn: bool = field( + default=False, + metadata={"help": "(WIP) Whether client does multi-turn generation."}, + ) # Parameters that control generation acceleration powered by vLLM use_vllm: bool = field( @@ -321,9 +330,10 @@ class GRPOConfig(TrainingArguments): vllm_mode: str = field( default="server", metadata={ - "help": "Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `server` or " - "`'colocate'`. `'server'`: The trainer will send generation requests to a separate vLLM server. Make sure a " - "TRL vLLM server is running (start with `trl vllm-serve`). `'colocate'`: vLLM will run in the same " + "help": "Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `server`, " + "`async_server`, or `colocate`. `'server'`: The trainer will send generation requests to a separate vLLM " + "server. Make sure a TRL vLLM server is running (start with `trl vllm-serve`). `'async_server'`: The " + "TRL vLLM server is running (start with `trl vllm-serve-async`). `'colocate'`: vLLM will run in the same " "process and share the training GPUs. This avoids the need for a separate server but may cause resource " "contention with training." }, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 513b8bfac1f..35852b546ae 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -18,7 +18,7 @@ from collections import defaultdict, deque from collections.abc import Sized from contextlib import nullcontext -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Type, Union import datasets import torch @@ -80,6 +80,13 @@ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] +# (WIP) Allow users to have custom rollouts by making API calls to the trl vllm-serve server. This could e.g. be +# ready-made AI products. +# What we call a rollout function is a callable that takes in the data and sampling parameters and returns a list of +# generation results. Those results must include "prompt" and "completion" field and can include an optional "tools" +# field. Any extra fields are forwarded to the reward functions. +RolloutFunc = Callable[[dict[str, Any], dict[str, Any]], list[dict[str, Any]]] + class RepeatSampler(Sampler): """ @@ -333,6 +340,10 @@ def reward_func(completions, **kwargs): [Using a custom reward function](#using-a-custom-reward-function). - A list of reward functions, where each item can independently be any of the above types. Mixing different types within the list (e.g., a string model ID and a custom reward function) is allowed. + rollout_func (`RolloutFunc`, *optional*, defaults to `None`): + Function to use for generating completions. It must take in the data and sampling parameters and return a list of + generation results. Those results must include "prompt" and "completion" field and can include optional "tools" field + and any other fields that are forwarded to the reward functions. args ([`GRPOConfig`], *optional*, defaults to `None`): Configuration for this trainer. If `None`, a default configuration is used. train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): @@ -378,6 +389,7 @@ def __init__( model: Union[str, PreTrainedModel], reward_funcs: Union[RewardFunc, list[RewardFunc]], args: Optional[GRPOConfig] = None, + rollout_func: Optional[RolloutFunc] = None, train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, processing_class: Optional[PreTrainedTokenizerBase] = None, @@ -498,6 +510,7 @@ def data_collator(features): # No data collation is needed in GRPO self.repetition_penalty = args.repetition_penalty self.use_vllm = args.use_vllm self.vllm_mode = args.vllm_mode + self.multi_turn = args.multi_turn self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode self.use_liger_loss = args.use_liger_loss @@ -615,13 +628,21 @@ def data_collator(features): # No data collation is needed in GRPO "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " "`pip install vllm` to use it." ) + + # TODO: Temp hack? + if self.vllm_mode == "async_server": + if rollout_func is None: + raise ValueError("rollout_func must be provided when vllm_mode is 'async_server'") + self.rollout_func = rollout_func + else: + if rollout_func is not None: + raise ValueError("rollout_func must be None when vllm_mode is not 'async_server'") - if self.vllm_mode == "server" and self.accelerator.is_main_process: + if self.vllm_mode in ["server", "async_server"] and self.accelerator.is_main_process: self.vllm_client = VLLMClient( args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout ) self.vllm_client.init_communicator() - elif self.vllm_mode == "colocate": # Make sure vllm_tensor_parallel_size group size evenly divides the world size - each group should have # the same number of ranks @@ -917,7 +938,7 @@ def _move_model_to_vllm(self): continue name = name.replace("modules_to_save.default.", "") - if self.vllm_mode == "server" and self.accelerator.is_main_process: + if self.vllm_mode in ["server", "async_server"] and self.accelerator.is_main_process: self.vllm_client.update_named_param(name, param.data) elif self.vllm_mode == "colocate": llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model @@ -932,14 +953,14 @@ def _move_model_to_vllm(self): else: for name, param in self.model.named_parameters(): with gather_if_zero3([param]): - if self.vllm_mode == "server" and self.accelerator.is_main_process: + if self.vllm_mode in ["server", "async_server"] and self.accelerator.is_main_process: self.vllm_client.update_named_param(name, param.data) elif self.vllm_mode == "colocate": llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model llm_model.load_weights([(name, param.data)]) # Reset cache on vLLM - if self.vllm_mode == "server" and self.accelerator.is_main_process: + if self.vllm_mode in ["server", "async_server"] and self.accelerator.is_main_process: self.vllm_client.reset_prefix_cache() elif self.vllm_mode == "colocate": self.llm.reset_prefix_cache() @@ -983,6 +1004,14 @@ def _generate_and_score_completions( device = self.accelerator.device mode = "train" if self.model.training else "eval" + # NOTE (WIP): `extra_reward_kwargs` is useful when reward functions need access to + # information beyond the raw prompt and completion strings. This is particularly + # relevant when using custom generation logic or external generation clients (e.g., async_server mode). + # Such structured data might include: `generated_diff` or `test_coverage` for coding agents, `program_runtime` for code optimization, `win_rate` for game-playing agents, ... + # Feedback from the environment is often not present in the completion, but it is a valid reward signal. + extra_reward_kwargs = {} + + # TODO: Find a non-breaking workaround, the async-server use case rarely has a simple "prompt" key, so currently I create a dummy "prompt" which is overwritten by the async_server prompts = [x["prompt"] for x in inputs] prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompt_inputs = self.processing_class( @@ -1032,6 +1061,66 @@ def _generate_and_score_completions( (self.accelerator.process_index + 1) * len(prompts), ) completion_ids = completion_ids[process_slice] + elif self.vllm_mode == "async_server": + if self.accelerator.is_main_process: + with profiling_context(self, "Async vLLM.generate"): + outputs = self.rollout_func( # more lightweight than having to extend the VLLMClient as previously suggested + data=gather_object(inputs), + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + guided_decoding_regex=self.guided_decoding_regex, + ) + else: + outputs = [None] * len(inputs) + + # Broadcast outputs to all processes + outputs = broadcast_object_list(outputs, from_process=0) + + # Each process selects its own slice of the outputs + process_start = self.accelerator.process_index * len(inputs) + process_end = (self.accelerator.process_index + 1) * len(inputs) + outputs = outputs[process_start:process_end] + + # Extract prompts, completions, and tools from outputs + prompts = [o["prompt"] for o in outputs] + completions = [o["completion"] for o in outputs] + tools = [o.get("tools", None) for o in outputs] + + # Prepare prompt input tensors + prompts_text = [ + self.processing_class.apply_chat_template(p, tools=t, tokenize=False) + for p, t in zip(prompts, tools) + ] + prompt_inputs = self.processing_class( + text=prompts_text, + return_tensors="pt", + padding=True, + padding_side="left", + add_special_tokens=False, + ) + prompt_inputs = super()._prepare_inputs(prompt_inputs) # Needed? Or is .to(device) better? + prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + + # Prepare completion input tensors + completions_text = [ + self.processing_class.apply_chat_template(c, tokenize=False) + for c in completions + ] + completion_ids = self.processing_class(completions_text)["input_ids"] + completion_lens = [len(ids) for ids in completion_ids] # For masking right padding tokens + + # Gather any extra keys for reward functions + extra_reward_kwargs = {} + if outputs: + extra_reward_kwargs = { + k: [o[k] for o in outputs] + for k in outputs[0].keys() + if k not in ["prompt", "completion", "tools"] + } # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts elif self.vllm_mode == "colocate": @@ -1093,12 +1182,19 @@ def _generate_and_score_completions( prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] - # Mask everything after the first EOS token - is_eos = completion_ids == self.processing_class.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + if self.multi_turn and self.vllm_mode == "async_server": # TODO: Fix + # Create an integer mask based on completion lengths to mask right padding tokens + max_completion_len = completion_ids.size(1) + indices = torch.arange(max_completion_len, device=device).expand(completion_ids.size(0), -1) + completion_lens_tensor = torch.tensor(completion_lens, device=device).unsqueeze(1) + completion_mask = (indices < completion_lens_tensor).int() # all except last eos which is supposed to be there + else: + # Mask everything after the first EOS token + is_eos = completion_ids == self.processing_class.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need # to re-tokenize completions if the reward is computed from tokens. @@ -1163,7 +1259,7 @@ def _generate_and_score_completions( keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] reward_kwargs = {key: [example[key] for example in inputs] for key in keys} output_reward_func = reward_func( - prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs + prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs, **extra_reward_kwargs ) # Convert None values to NaN output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]