diff --git a/trl/generation/__init__.py b/trl/generation/__init__.py new file mode 100644 index 00000000000..22e7cf6d884 --- /dev/null +++ b/trl/generation/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generation backends for TRL trainers.""" + +from ..import_utils import is_vllm_available + + +__all__ = [] + +if is_vllm_available(): + from .vllm_generation import VLLMGeneration + + __all__.append("VLLMGeneration") diff --git a/trl/generation/vllm_generation.py b/trl/generation/vllm_generation.py new file mode 100644 index 00000000000..0eaa734a717 --- /dev/null +++ b/trl/generation/vllm_generation.py @@ -0,0 +1,688 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""vLLM-based generation backend for TRL trainers.""" + +import json +import os +from collections.abc import Callable +from contextlib import nullcontext +from typing import TYPE_CHECKING + +import torch +from accelerate.utils import broadcast_object_list, gather_object, is_peft_model +from packaging.version import Version +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, is_bitsandbytes_available + +from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages_vllm +from ..extras.profiling import ProfilingContext, profiling_decorator +from ..extras.vllm_client import VLLMClient +from ..import_utils import is_vllm_available +from ..trainer.utils import ensure_master_addr_port + + +if TYPE_CHECKING: + from accelerate import Accelerator + from peft import PeftModel + + +if is_vllm_available(): + import vllm + from vllm import LLM, SamplingParams + + if Version(vllm.__version__) <= Version("0.10.2"): + from vllm.sampling_params import GuidedDecodingParams + else: + from vllm.sampling_params import StructuredOutputsParams + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + + +class VLLMGeneration: + """Handles vLLM-based generation for trainers. + + Extracts all vLLM-specific logic (initialization, generation, weight sync) from trainers into a separate, testable + class. + + Args: + model ([`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): + Model to use for generation. + accelerator ([`~accelerate.Accelerator`]): + Accelerator for distributed training. + is_fsdp_enabled (`bool`): + Whether FSDP is enabled. + processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`]): + Tokenizer or processor for the model. + + > Parameters for vLLM: + + mode (`str`, *optional*, defaults to `"server"`): vLLM mode. Must be one of `"server"` or + `"colocate"`. + + - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM + server is running (start with `trl vllm-serve`). + - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a + separate server but may cause resource contention with training. + structured_outputs_regex (`str`, *optional*): + Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled. + + > Parameters for "server" vLLM mode: + + server_base_url (`str`, *optional*): + Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `server_host` and + `server_port` are ignored. + server_host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host of the vLLM server to connect to. Ignored if `server_base_url` is provided. + server_port (`int`, *optional*, defaults to `8000`): + Port of the vLLM server to connect to. Ignored if `server_base_url` is provided. + server_timeout (`float`, *optional*, defaults to `240.0`): + Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the + timeout, a `ConnectionError` is raised. + group_port (`int`, *optional*, defaults to `51216`): + Port number for the weight update group. This is used to communicate with the vLLM server. Unless the port + is occupied, there is no need to change it. + + > Parameters for "colocate" vLLM mode: + + tensor_parallel_size (`int`, *optional*, defaults to `1`): + The number of GPUs to use for distributed execution with tensor parallelism. This setting only applies when + `mode` is set to `"colocate"`. If you are using `mode="server"`, this parameter must be passed separately + when launching the vLLM server via the `--vllm_tensor_parallel_size` flag. + gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): + Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. Higher + values will increase the KV cache size and thus improve the model's throughput. However, if the value is + too high, it may cause out-of- memory (OOM) errors. This setting only applies when `mode` is set to + `"colocate"`. If you are using `mode="server"`, this parameter must be passed separately when launching the + vLLM server via the `--vllm_gpu_memory_utilization` flag. + max_model_length (`int`, *optional*): + Model context length (prompt and completion). Set it to at least the maximum prompt length in the dataset + plus `max_completion_length`; if omitted, it is inferred from the model config. + max_num_seqs (`int`, *optional*): + Maximum number of sequences to process in parallel, effectively capping the batch size. + enable_sleep_mode (`bool`, *optional*, defaults to `False`): + Whether to enable sleep mode for the engine to offload weights/cache during the optimizer step. Keeps GPU + memory usage low, but waking the engine adds host–device transfer latency. + model_impl (`str`, *optional*, defaults to `"auto"`): + Model implementation to use for vLLM. + - "auto" will try to use the vLLM implementation, if it exists, and fall back to the Transformers + implementation if no vLLM implementation is available. + - "vllm" will use the vLLM model implementation. + - "transformers" will use the Transformers model implementation. + - "terratorch" will use the TerraTorch model implementation. + + > Parameters for generation: + + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Parameter for repetition penalty. It penalizes new tokens based on whether they appear in the prompt and + the generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the + model to repeat tokens. Default `1.0` means no penalty. + temperature(`float`, *optional*, defaults to `1.0`): + Sampling temperature. It controls the randomness of the sampling. Lower values make the model more + deterministic, while higher values make the model more random and increase diversity. + top_p: (`float`, *optional*, defaults to `1.0`): + Top-p sampling parameter. It controls the cumulative probability of the top tokens to consider. Defaults to + `1.0` to consider all tokens. + top_k (`int`, *optional*, defaults to `0`): + Top-k sampling parameter. It controls the number of top tokens to consider. Defaults to `0` to consider all + tokens. + min_p (`float`, *optional*, defaults to `0.0`): + Min-p sampling parameter. It represents the minimum probability for a token to be considered, relative to + the probability of the most likely token. Default `0.0` means min-p is disabled. + max_completion_length (`int`, *optional*, defaults to `16`): + Maximum number of tokens to generate for each prompt. + generation_kwargs (`dict`, *optional*): + Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like + `seed`, `frequency_penalty`, etc. If it contains keys that conflict with the other parameters, they will + override them. + + > Parameters for chat/tools: + + chat_template (`str`, *optional*): + Template to use for structuring the chat. If not provided, the model's default chat template will be used. + chat_template_kwargs (`dict`, *optional*): + Additional keyword arguments to customize the chat template used by the model. + tools (`list`, *optional*): + Tools available for tool calling during chat generation. + rollout_func (`Callable`, *optional*): Optional custom rollout function that accepts prompts and returns + a dict with 'prompt_ids', 'completion_ids', 'logprobs', and optional extra fields. Should be a + single-argument callable: rollout_func(prompts) -> dict. To pass additional context (e.g., trainer), use a + closure or functools.partial: + rollout_func = lambda prompts: my_custom_rollout(prompts, trainer) + The closure will hold a reference to trainer and see its state updates. + """ + + def __init__( + self, + model: "PreTrainedModel | PeftModel", + accelerator: "Accelerator", + is_fsdp_enabled: bool, + processing_class: PreTrainedTokenizerBase | ProcessorMixin, + # vLLM configuration + mode: str = "server", + structured_outputs_regex: str | None = None, + # Server mode configuration + server_base_url: str | None = None, + server_host: str = "0.0.0.0", + server_port: int = 8000, + server_timeout: float = 240.0, + group_port: int = 51216, + # Colocate mode configuration + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.9, + max_model_length: int | None = None, + max_num_seqs: int | None = None, + enable_sleep_mode: bool = False, + model_impl: str = "auto", + # Generation configuration + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 0, + min_p: float = 0.0, + max_completion_length: int = 16, + generation_kwargs: dict | None = None, + # Chat/tool configuration + chat_template: str | None = None, + chat_template_kwargs: dict | None = None, + tools: list | None = None, + rollout_func: Callable | None = None, + ): + self.model = model + self.accelerator = accelerator + self.is_fsdp_enabled = is_fsdp_enabled + self.processing_class = processing_class + + # vLLM configuration + self.mode = mode + self.structured_outputs_regex = structured_outputs_regex + + # Server mode configuration + self.server_base_url = server_base_url + self.server_host = server_host + self.server_port = server_port + self.group_port = group_port + self.server_timeout = server_timeout + + # Colocate mode configuration + self.tensor_parallel_size = tensor_parallel_size + self.gpu_memory_utilization = gpu_memory_utilization + self.max_model_length = max_model_length + self.max_num_seqs = max_num_seqs + self.enable_sleep_mode = enable_sleep_mode + self.model_impl = model_impl + + # Generation configuration + self.repetition_penalty = repetition_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.min_p = min_p + self.max_completion_length = max_completion_length + self.generation_kwargs = generation_kwargs or {} + + # Chat/tool configuration + self.chat_template = chat_template + self.chat_template_kwargs = chat_template_kwargs or {} + self.tools = tools + self.rollout_func = rollout_func + + self._init_vllm() + + def _init_vllm(self): + """Initialize vLLM in server or colocate mode.""" + model = self.model + accelerator = self.accelerator + + if not is_vllm_available(): + raise ImportError( + "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " + "`pip install trl[vllm]` to use it." + ) + + if self.mode == "server": + if accelerator.is_main_process: + if self.server_base_url is not None: + base_url = self.server_base_url + else: + base_url = f"http://{self.server_host}:{self.server_port}" + self.vllm_client = VLLMClient( + base_url=base_url, group_port=self.group_port, connection_timeout=self.server_timeout + ) + self.vllm_client.init_communicator(device=torch.cuda.current_device()) + + elif self.mode == "colocate": + # Make sure tensor_parallel_size group size evenly divides the world size - each group should have + # the same number of ranks + if not accelerator.num_processes % self.tensor_parallel_size == 0: + raise ValueError( + f"tensor_parallel_size ({self.tensor_parallel_size}) must divide world size " + f"({accelerator.num_processes}) evenly." + ) + + if self.tensor_parallel_size > 1: + # Create subgroups of ranks for TP, each group with `tensor_parallel_size` ranks. + # For example, if world_size=8 and tensor_parallel_size=2 → groups: [0,1], [2,3], [4,5], [6,7] + self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration( + [ + list(range(i * self.tensor_parallel_size, (i + 1) * self.tensor_parallel_size)) + for i in range(accelerator.num_processes // self.tensor_parallel_size) + ] + ) + + # vLLM requires the environment variables to be set for distributed training. + os.environ["RANK"] = str(accelerator.process_index) + os.environ["LOCAL_RANK"] = str(accelerator.local_process_index) + os.environ["WORLD_SIZE"] = str(accelerator.num_processes) + # Ensure distributed rendezvous variables are set without colliding across concurrent runs + ensure_master_addr_port() + + quantization = None + if is_bitsandbytes_available(): + for _, module in model.named_modules(): + if isinstance(module, bnb.nn.Linear4bit): + quantization = "bitsandbytes" + break + elif isinstance(module, bnb.nn.Linear8bitLt): + raise ValueError("vLLM does not support in-flight 8-bit quantization.") + + # Build LLM initialization kwargs + self.llm = LLM( + model=model.name_or_path, + tensor_parallel_size=self.tensor_parallel_size, + gpu_memory_utilization=self.gpu_memory_utilization, + max_model_len=self.max_model_length, + max_num_seqs=self.max_num_seqs, + enable_sleep_mode=self.enable_sleep_mode, + model_impl=self.model_impl, + distributed_executor_backend="external_launcher", + # Feed identical seed for tp groups to ensure sampling results are the same across workers + seed=accelerator.process_index // self.tensor_parallel_size, + # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory + max_num_batched_tokens=4096, + # Important so temperature scaling/logit tweaking affects the TIS log probs + logprobs_mode="processed_logprobs", + quantization=quantization, + ) + if self.enable_sleep_mode: + self.llm.sleep(level=2) + else: + raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.mode}'.") + + # When using vLLM, the main process is responsible for loading the model weights. This can cause process + # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we + # synchronize all processes after vLLM has been fully initialized. + accelerator.wait_for_everyone() + + def _fix_param_name_to_vllm(self, name: str, extra_prefixes: list[str] | None = None) -> str: + """Fix parameter name for vLLM compatibility.""" + extra_prefixes = extra_prefixes or [] + prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes + for prefix in prefixes: + name = name.replace(prefix, "") + return name + + def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited: set[str] | None = None): + """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM.""" + # For FSDP1, we need to recurse into children and also use summon_full_params + accelerator = self.accelerator + + if visited is None: + visited = set() + for child_name, child_module in module.named_children(): + child_prefix = f"{prefix}.{child_name}" if prefix else child_name + self._sync_fsdp1_params_to_vllm( + child_module, prefix=child_prefix, visited=visited + ) # recurse into the child + + if isinstance(module, FSDP): + with FSDP.summon_full_params(module, recurse=False, writeback=False): + for param_name, param in module.named_parameters(): + full_name = f"{prefix}.{param_name}" if prefix else param_name + full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."]) + + if full_name in visited: + continue # skip FSDP subtrees already traversed + visited.add(full_name) + + if self.mode == "server" and accelerator.is_main_process: + self.vllm_client.update_named_param(full_name, param.data) + elif self.mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(full_name, param.data)]) + + def _sync_fsdp2_params_to_vllm(self, module: nn.Module): + """FSDP2-specific parameter synchronization.""" + accelerator = self.accelerator + + # For FSDP2, module.state_dict() already covers all parameters, so no need for recursion + for name, param in module.state_dict().items(): + # When using PEFT, we need to recover the original parameter name + name = name.removeprefix("base_model.model.").replace(".base_layer", "") + # Skip PEFT layers: they don't exist in vLLM, and they are merged already. + if is_peft_model(module) and module.prefix in name: + continue + # When module to save, remove its prefix and discard the original module + if "original_module" in name: + continue + name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) + + if param.is_cpu: + param = param.to(torch.device("cuda")) + param = param.full_tensor() + + if self.mode == "server" and accelerator.is_main_process: + self.vllm_client.update_named_param(name, param) + elif self.mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param)]) + + @profiling_decorator + def sync_weights(self): + """Synchronize model weights to vLLM. + + Handles FSDP, DeepSpeed, PEFT weight synchronization. + """ + model = self.model + accelerator = self.accelerator + is_fsdp_enabled = self.is_fsdp_enabled + + # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations + deepspeed_plugin = accelerator.state.deepspeed_plugin + zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 + if zero_stage_3: + import deepspeed + + gather_if_zero3 = deepspeed.zero.GatheredParameters + else: + gather_if_zero3 = nullcontext + + if is_peft_model(model): + # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as + # merging adapters in a sharded manner is not supported. + # TODO: does this work with FSDP? + with gather_if_zero3(list(model.parameters())): + model.merge_adapter() + + # Update vLLM weights while parameters are gathered + if is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext + # Update vLLM weights while parameters are gathered + # For PEFT with FSDP we need to use the memory efficient post-order traversal + fsdp_plugin = getattr(accelerator.state, "fsdp_plugin", None) + fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 + if fsdp_version == 1: + self._sync_fsdp1_params_to_vllm(model) # use memory-efficient post-order traversal for FSDP + elif fsdp_version == 2: + self._sync_fsdp2_params_to_vllm(model) + else: + # DeepSpeed ZeRO-3 with PEFT + for name, param in model.named_parameters(): + # When using PEFT, we need to recover the original parameter name + name = name.removeprefix("base_model.model.").replace(".base_layer", "") + # Skip PEFT layers: they don't exist in vLLM, and they are merged already. + if model.prefix in name: + continue + # When module to save, remove its prefix and discard the original module + if "original_module" in name: + continue + name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) + + if self.mode == "server" and accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + # Unmerge adapters while parameters are still gathered + model.unmerge_adapter() + # Parameters will automatically be repartitioned when exiting the context + else: + # For non-PEFT models, simply gather (if needed) and update each parameter individually. + if is_fsdp_enabled: + fsdp_plugin = getattr(accelerator.state, "fsdp_plugin", None) + fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 + if fsdp_version == 1: + self._sync_fsdp1_params_to_vllm(model) # use memory-efficient post-order traversal for FSDP + elif fsdp_version == 2: + self._sync_fsdp2_params_to_vllm(model) + else: + for name, param in model.named_parameters(): + name = self._fix_param_name_to_vllm(name) + with gather_if_zero3([param]): + if self.mode == "server" and accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + + # Reset cache on vLLM + if self.mode == "server" and accelerator.is_main_process: + self.vllm_client.reset_prefix_cache() + elif self.mode == "colocate": + self.llm.reset_prefix_cache() + + def generate(self, prompts: list, num_generations: int, profiler: ProfilingContext | None = None) -> tuple: + """Generate completions using vLLM. + + Args: + prompts: List of prompts (strings or chat conversations) + num_generations: Number of generations per prompt + profiler: Optional profiler for performance tracking + + Returns: + Tuple of (prompt_ids, completion_ids, logprobs, extra_fields) + """ + profiler = profiler or nullcontext() + accelerator = self.accelerator + rollout_func = self.rollout_func + temperature = self.temperature + top_p = self.top_p + top_k = self.top_k + min_p = self.min_p + repetition_penalty = self.repetition_penalty + max_completion_length = self.max_completion_length + processing_class = self.processing_class + chat_template_kwargs = self.chat_template_kwargs + tools = self.tools + chat_template = self.chat_template + + # Wake up colocated vLLM instances if needed + if self.mode == "colocate" and self.enable_sleep_mode: + torch.cuda.empty_cache() # required to avoid OOM in some cases + self.llm.wake_up(tags=["weights"]) + # Work around for https://github.com/vllm-project/vllm/issues/29341 + self.llm.collective_rpc("reload_weights") + + if is_conversational({"prompt": prompts[0]}): + prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts] + + # In vLLM, tool call arguments must be JSON strings. See https://github.com/vllm-project/vllm/pull/28820 + for prompt in prompts: # iterate over each conversation + if is_conversational({"prompt": prompt}): + for message in prompt: # iterate over each message + if "tool_calls" in message: # check if message has tool calls + for call in message["tool_calls"]: + args_value = call["function"]["arguments"] + if isinstance(args_value, dict): # only convert dict → JSON string + call["function"]["arguments"] = json.dumps(args_value) + + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + if self.mode == "server": + all_prompts = gather_object(prompts) + + if accelerator.is_main_process: + # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # prompt individually. + ordered_set_of_prompts = all_prompts[::num_generations] + + sampling_params = { + "n": num_generations, + "repetition_penalty": repetition_penalty, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": 0.0 if min_p is None else min_p, + "max_tokens": max_completion_length, + "structured_outputs_regex": self.structured_outputs_regex, + "generation_kwargs": self.generation_kwargs, + } + with profiler: # TODO: profiling_context(trainer, "vLLM.generate"): + if rollout_func is not None: + rollout_prompts = ordered_set_of_prompts + if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}): + rollout_prompts = [ + apply_chat_template({"prompt": p}, processing_class, **chat_template_kwargs)["prompt"] + for p in rollout_prompts + ] + output = rollout_func(rollout_prompts) + else: + if is_conversational({"prompt": ordered_set_of_prompts[0]}): + output = self.vllm_client.chat( + messages=ordered_set_of_prompts, + **sampling_params, + chat_template_kwargs=chat_template_kwargs, + tools=tools, + chat_template=chat_template, + ) + else: + output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) + # Extract required fields and collect any extra fields for reward functions + required_keys = {"prompt_ids", "completion_ids", "logprobs"} + extra_fields = {k: v for k, v in output.items() if k not in required_keys} + payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields) + else: + payload = None + + # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. + obj_list = [payload] + broadcast_object_list(obj_list, from_process=0) + all_prompt_ids, all_completion_ids, all_logprobs, all_extra_fields = obj_list[0] + + # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times + all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_generations)] + + process_slice = slice( + accelerator.process_index * len(prompts), + (accelerator.process_index + 1) * len(prompts), + ) + prompt_ids = all_prompt_ids[process_slice] + completion_ids = all_completion_ids[process_slice] + logprobs = all_logprobs[process_slice] + + # Slice extra fields dict-of-lists per process (extra fields are per-completion, like completion_ids) + extra_fields = {} + for key, values in all_extra_fields.items(): + if isinstance(values, list): + extra_fields[key] = values[process_slice] + else: + extra_fields[key] = values + + # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts + elif self.mode == "colocate": + if rollout_func is not None: + rollout_prompts = prompts + if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}): + rollout_prompts = [ + apply_chat_template({"prompt": prompt}, processing_class, **chat_template_kwargs)["prompt"] + for prompt in rollout_prompts + ] + output = rollout_func(rollout_prompts) + required_keys = {"prompt_ids", "completion_ids", "logprobs"} + extra_fields = {k: v for k, v in output.items() if k not in required_keys} + prompt_ids = output["prompt_ids"] + completion_ids = output["completion_ids"] + logprobs = output["logprobs"] + else: + if Version(vllm.__version__) <= Version("0.10.2"): + structured_outputs_key = "guided_decoding" + if self.structured_outputs_regex: + structured_outputs = GuidedDecodingParams(regex=self.structured_outputs_regex) + else: + structured_outputs = None + else: + structured_outputs_key = "structured_outputs" + if self.structured_outputs_regex: + structured_outputs = StructuredOutputsParams(regex=self.structured_outputs_regex) + else: + structured_outputs = None + + generation_kwargs = { + "n": 1, # vLLM on each GPU generates only 1 in colocate mode + "repetition_penalty": repetition_penalty, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": 0.0 if min_p is None else min_p, + "max_tokens": max_completion_length, + "logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only + } + generation_kwargs[structured_outputs_key] = structured_outputs + generation_kwargs.update(self.generation_kwargs) + sampling_params = SamplingParams(**generation_kwargs) + + if self.tensor_parallel_size > 1: + # Gather prompts from all ranks in the TP group and flatten. + # Each rank starts with its own prompts; after gathering, all ranks see the full group set. + orig_size = len(prompts) + gathered_prompts = [None for _ in range(self.tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) + all_prompts = [p for sublist in gathered_prompts for p in sublist] + else: + all_prompts = prompts + + if self.enable_sleep_mode: + self.llm.wake_up(tags=["kv_cache"]) + + with profiler: # TODO: profiling_context(trainer, "vLLM.generate"): + if is_conversational({"prompt": prompts[0]}): + all_outputs = self.llm.chat( + all_prompts, + sampling_params=sampling_params, + use_tqdm=False, + chat_template_kwargs=chat_template_kwargs, + tools=tools, + chat_template=chat_template, + ) + else: + all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False) + + all_prompt_ids = [output.prompt_token_ids for output in all_outputs] + all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + all_logprobs = [ + [next(iter(lp.values())).logprob for lp in output.logprobs] + for outputs in all_outputs + for output in outputs.outputs + ] + + if self.tensor_parallel_size > 1: + # Slice completions for this rank within its TP group. + # Each rank generates all outputs — we keep only our share. + local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + prompt_ids = all_prompt_ids[tp_slice] + completion_ids = all_completion_ids[tp_slice] + logprobs = all_logprobs[tp_slice] + else: + prompt_ids = all_prompt_ids + completion_ids = all_completion_ids + logprobs = all_logprobs + + extra_fields = {} # No extra fields for colocate mode + + if self.enable_sleep_mode: + self.llm.sleep(level=2) + + return prompt_ids, completion_ids, logprobs, extra_fields diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 6164768c842..fd73d77136a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -16,7 +16,6 @@ import atexit import copy import inspect -import json import os import sys import textwrap @@ -35,7 +34,7 @@ import torch.utils.data import transformers from accelerate.logging import get_logger -from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed +from accelerate.utils import gather, gather_object, is_peft_model, set_seed from datasets import Dataset, IterableDataset from packaging.version import Version from torch import nn @@ -50,7 +49,6 @@ PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, - is_bitsandbytes_available, is_trackio_available, is_wandb_available, ) @@ -62,11 +60,10 @@ apply_chat_template, is_conversational, prepare_multimodal_messages, - prepare_multimodal_messages_vllm, ) from ..extras.profiling import profiling_context, profiling_decorator -from ..extras.vllm_client import VLLMClient -from ..import_utils import is_jmespath_available, is_liger_kernel_available, is_vllm_available +from ..generation.vllm_generation import VLLMGeneration +from ..import_utils import is_jmespath_available, is_liger_kernel_available from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation from ..models.utils import _ForwardRedirection, disable_gradient_checkpointing from .base_trainer import BaseTrainer @@ -76,7 +73,6 @@ RepeatSampler, create_model_from_path, disable_dropout_in_model, - ensure_master_addr_port, entropy_from_logits, get_config_model_id, identity, @@ -102,14 +98,6 @@ if is_liger_kernel_available(): from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss -if is_vllm_available(): - import vllm - from vllm import LLM, SamplingParams - - if Version(vllm.__version__) <= Version("0.10.2"): - from vllm.sampling_params import GuidedDecodingParams - else: - from vllm.sampling_params import StructuredOutputsParams if is_wandb_available(): import wandb @@ -117,9 +105,6 @@ if is_trackio_available(): import trackio -if is_bitsandbytes_available(): - import bitsandbytes as bnb - logger = get_logger(__name__) # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of @@ -657,90 +642,52 @@ def cast_outputs_to_original_dtype(module, args, output): set_seed(args.seed, device_specific=True) if self.use_vllm: - if not is_vllm_available(): - raise ImportError( - "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install trl[vllm]` to use it." - ) - - if self.vllm_mode == "server": - if self.accelerator.is_main_process: - if args.vllm_server_base_url is not None: - base_url = args.vllm_server_base_url - else: - base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" - self.vllm_client = VLLMClient( - base_url=base_url, group_port=args.vllm_group_port, connection_timeout=args.vllm_server_timeout - ) - self.vllm_client.init_communicator(device=torch.cuda.current_device()) - - elif self.vllm_mode == "colocate": - # Make sure vllm_tensor_parallel_size group size evenly divides the world size - each group should have - # the same number of ranks - if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: - raise ValueError( - f"vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size " - f"({self.accelerator.num_processes}) evenly." - ) - - if self.vllm_tensor_parallel_size > 1: - # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks. - # For example, if world_size=8 and vllm_tensor_parallel_size=2 → groups: [0,1], [2,3], [4,5], [6,7] - self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration( - [ - list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) - for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) - ] - ) - - # vLLM requires the environment variables to be set for distributed training. - os.environ["RANK"] = str(self.accelerator.process_index) - os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index) - os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) - # Ensure distributed rendezvous variables are set without colliding across concurrent runs - ensure_master_addr_port() - - vllm_quantization = None - if is_bitsandbytes_available(): - for _, module in model.named_modules(): - if isinstance(module, bnb.nn.Linear4bit): - vllm_quantization = "bitsandbytes" - break - elif isinstance(module, bnb.nn.Linear8bitLt): - raise ValueError("vLLM does not support in-flight 8-bit quantization.") - self.llm = LLM( - model=model.name_or_path, - tensor_parallel_size=args.vllm_tensor_parallel_size, - gpu_memory_utilization=self.vllm_gpu_memory_utilization, - max_num_seqs=self.args.per_device_train_batch_size - * self.vllm_tensor_parallel_size - * self.args.steps_per_generation, - max_model_len=self.args.vllm_max_model_length, - distributed_executor_backend="external_launcher", - # Feed identical seed for tp groups to ensure sampling results are the same across workers - seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, - # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory - max_num_batched_tokens=4096, - model_impl=self.args.vllm_model_impl, - enable_sleep_mode=self.args.vllm_enable_sleep_mode, - # Important so temperature scaling/logit tweaking affects the TIS log probs - logprobs_mode="processed_logprobs", - quantization=vllm_quantization, - ) - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=2) - else: - raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.") - - # vLLM specific sampling arguments - self.structured_outputs_regex = args.vllm_structured_outputs_regex - + # Initialize vLLM generation backend + # Wrap rollout_func to capture trainer context if provided + rollout_func = None + if self.rollout_func is not None: + + def rollout_func(prompts): + return self.rollout_func(prompts, self) + + self.vllm_generation = VLLMGeneration( + model=self.model, + accelerator=self.accelerator, + is_fsdp_enabled=self.is_fsdp_enabled, + processing_class=self.processing_class, + # vLLM configuration + mode=args.vllm_mode, + structured_outputs_regex=args.vllm_structured_outputs_regex, + # Server mode configuration + server_base_url=args.vllm_server_base_url, + server_host=args.vllm_server_host, + server_port=args.vllm_server_port, + group_port=args.vllm_group_port, + server_timeout=args.vllm_server_timeout, + # Colocate mode configuration + tensor_parallel_size=args.vllm_tensor_parallel_size, + gpu_memory_utilization=args.vllm_gpu_memory_utilization, + max_model_length=args.vllm_max_model_length, + max_num_seqs=args.per_device_train_batch_size + * args.vllm_tensor_parallel_size + * args.steps_per_generation, + enable_sleep_mode=args.vllm_enable_sleep_mode, + model_impl=args.vllm_model_impl, + # Generation configuration + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + max_completion_length=self.max_completion_length, + generation_kwargs=args.generation_kwargs, + # Chat/tool configuration + chat_template=self.chat_template, + chat_template_kwargs=self.chat_template_kwargs, + tools=self.tools, + rollout_func=rollout_func, + ) self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation - - # When using vLLM, the main process is responsible for loading the model weights. This can cause process - # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we - # synchronize all processes after vLLM has been fully initialized. - self.accelerator.wait_for_everyone() else: generation_kwargs = { "max_new_tokens": self.max_completion_length, @@ -1038,140 +985,6 @@ def _get_per_token_logps_and_entropies( entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None return logps, entropies - def _fix_param_name_to_vllm(self, name, extra_prefixes: list[str] | None = None): - extra_prefixes = extra_prefixes or [] - prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes - for prefix in prefixes: - name = name.replace(prefix, "") - return name - - def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): - """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM.""" - # For FSDP1, we need to recurse into children and also use summon_full_params - if visited is None: - visited = set() - for child_name, child_module in module.named_children(): - child_prefix = f"{prefix}.{child_name}" if prefix else child_name - self._sync_fsdp1_params_to_vllm( - child_module, prefix=child_prefix, visited=visited - ) # recurse into the child - - if isinstance(module, FSDP): - with FSDP.summon_full_params(module, recurse=False, writeback=False): - for param_name, param in module.named_parameters(): - full_name = f"{prefix}.{param_name}" if prefix else param_name - full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."]) - - if full_name in visited: - continue # skip FSDP subtrees already traversed - visited.add(full_name) - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(full_name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(full_name, param.data)]) - - def _sync_fsdp2_params_to_vllm(self, module: nn.Module): - # For FSDP2, module.state_dict() already covers all parameters, so no need for recursion - for name, param in module.state_dict().items(): - # When using PEFT, we need to recover the original parameter name - name = name.removeprefix("base_model.model.").replace(".base_layer", "") - # Skip PEFT layers: they don’t exist in vLLM, and they are merged already. - if is_peft_model(module) and module.prefix in name: - continue - # When module to save, remove its prefix and discard the original module - if "original_module" in name: - continue - name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) - - if param.is_cpu: - param = param.to(torch.device("cuda")) - param = param.full_tensor() - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param) - elif self.vllm_mode == "colocate": - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, param)]) - - @profiling_decorator - def _move_model_to_vllm(self): - # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - if zero_stage_3: - import deepspeed - - gather_if_zero3 = deepspeed.zero.GatheredParameters - else: - gather_if_zero3 = nullcontext - - if is_peft_model(self.model): - # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as - # merging adapters in a sharded manner is not supported. - # TODO: does this work with FSDP? - with gather_if_zero3(list(self.model.parameters())): - self.model.merge_adapter() - - # Update vLLM weights while parameters are gathered - if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext - # Update vLLM weights while parameters are gathered - # For PEFT with FSDP we need to use the memory efficient post-order traversal - fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None) - fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 - if fsdp_version == 1: - self._sync_fsdp1_params_to_vllm( - self.model - ) # use memory-efficient post-order traversal for FSDP - elif fsdp_version == 2: - self._sync_fsdp2_params_to_vllm(self.model) - else: - # DeepSpeed ZeRO-3 with PEFT - for name, param in self.model.named_parameters(): - # When using PEFT, we need to recover the original parameter name - name = name.removeprefix("base_model.model.").replace(".base_layer", "") - # Skip PEFT layers: they don’t exist in vLLM, and they are merged already. - if self.model.prefix in name: - continue - # When module to save, remove its prefix and discard the original module - if "original_module" in name: - continue - name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, param.data)]) - # Unmerge adapters while parameters are still gathered - self.model.unmerge_adapter() - # Parameters will automatically be repartitioned when exiting the context - else: - # For non-PEFT models, simply gather (if needed) and update each parameter individually. - if self.is_fsdp_enabled: - fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None) - fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 - if fsdp_version == 1: - self._sync_fsdp1_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP - elif fsdp_version == 2: - self._sync_fsdp2_params_to_vllm(self.model) - else: - for name, param in self.model.named_parameters(): - name = self._fix_param_name_to_vllm(name) - with gather_if_zero3([param]): - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, param.data)]) - - # Reset cache on vLLM - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.reset_prefix_cache() - elif self.vllm_mode == "colocate": - self.llm.reset_prefix_cache() - def training_step(self, model, inputs, num_items_in_batch): time_before = time.perf_counter() output = super().training_step(model, inputs, num_items_in_batch) @@ -1303,205 +1116,16 @@ def _generate_single_turn(self, prompts: list): # Generate completions using either vLLM or regular generation if self.use_vllm: - if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: - # wake up colocated vLLM instances if needed - torch.cuda.empty_cache() # required to avoid OOM in some cases - self.llm.wake_up(tags=["weights"]) - # Work around for https://github.com/vllm-project/vllm/issues/29341 - self.llm.collective_rpc("reload_weights") - - # First, update the vLLM weights if needed + # Sync weights if training step changed if self.state.global_step != self._last_loaded_step: - self._move_model_to_vllm() + self.vllm_generation.sync_weights() self._last_loaded_step = self.state.global_step - if is_conversational({"prompt": prompts[0]}): - prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts] - - # In vLLM, tool call arguments must be JSON strings. See https://github.com/vllm-project/vllm/pull/28820 - for prompt in prompts: # iterate over each conversation - if is_conversational({"prompt": prompt}): - for message in prompt: # iterate over each message - if "tool_calls" in message: # check if message has tool calls - for call in message["tool_calls"]: - args = call["function"]["arguments"] - if isinstance(args, dict): # only convert dict → JSON string - call["function"]["arguments"] = json.dumps(args) - - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - if self.vllm_mode == "server": - all_prompts = gather_object(prompts) - num_generations = self.num_generations if mode == "train" else self.num_generations_eval - - if self.accelerator.is_main_process: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - ordered_set_of_prompts = all_prompts[::num_generations] - - sampling_params = { - "n": num_generations, - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "min_p": 0.0 if self.min_p is None else self.min_p, - "max_tokens": self.max_completion_length, - "structured_outputs_regex": self.structured_outputs_regex, - "generation_kwargs": self.args.generation_kwargs, - } - with profiling_context(self, "vLLM.generate"): - if self.rollout_func is not None: - rollout_prompts = ordered_set_of_prompts - if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}): - rollout_prompts = [ - apply_chat_template( - {"prompt": p}, self.processing_class, **self.chat_template_kwargs - )["prompt"] - for p in rollout_prompts - ] - output = self.rollout_func(rollout_prompts, self) - else: - if is_conversational({"prompt": ordered_set_of_prompts[0]}): - output = self.vllm_client.chat( - messages=ordered_set_of_prompts, - **sampling_params, - chat_template_kwargs=self.chat_template_kwargs, - tools=self.tools, - chat_template=self.chat_template, - ) - else: - output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) - # Extract required fields and collect any extra fields for reward functions - required_keys = {"prompt_ids", "completion_ids", "logprobs"} - extra_fields = {k: v for k, v in output.items() if k not in required_keys} - payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields) - else: - payload = None - - # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. - obj_list = [payload] - broadcast_object_list(obj_list, from_process=0) - all_prompt_ids, all_completion_ids, all_logprobs, all_extra_fields = obj_list[0] - - # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times - all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_generations)] - - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - prompt_ids = all_prompt_ids[process_slice] - completion_ids = all_completion_ids[process_slice] - logprobs = all_logprobs[process_slice] - - # Slice extra fields dict-of-lists per process (extra fields are per-completion, like completion_ids) - extra_fields = {} - for key, values in all_extra_fields.items(): - if isinstance(values, list): - extra_fields[key] = values[process_slice] - else: - extra_fields[key] = values - - # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts - elif self.vllm_mode == "colocate": - if self.rollout_func is not None: - rollout_prompts = prompts - if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}): - rollout_prompts = [ - apply_chat_template( - {"prompt": prompt}, self.processing_class, **self.chat_template_kwargs - )["prompt"] - for prompt in rollout_prompts - ] - output = self.rollout_func(rollout_prompts, self) - required_keys = {"prompt_ids", "completion_ids", "logprobs"} - extra_fields = {k: v for k, v in output.items() if k not in required_keys} - prompt_ids = output["prompt_ids"] - completion_ids = output["completion_ids"] - logprobs = output["logprobs"] - else: - if Version(vllm.__version__) <= Version("0.10.2"): - structured_outputs_key = "guided_decoding" - if self.structured_outputs_regex: - structured_outputs = GuidedDecodingParams(regex=self.structured_outputs_regex) - else: - structured_outputs = None - else: - structured_outputs_key = "structured_outputs" - if self.structured_outputs_regex: - structured_outputs = StructuredOutputsParams(regex=self.structured_outputs_regex) - else: - structured_outputs = None - - generation_kwargs = { - "n": 1, # vLLM on each GPU generates only 1 in colocate mode - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "min_p": 0.0 if self.min_p is None else self.min_p, - "max_tokens": self.max_completion_length, - "logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only - } - generation_kwargs[structured_outputs_key] = structured_outputs - if self.args.generation_kwargs is not None: - generation_kwargs.update(self.args.generation_kwargs) - sampling_params = SamplingParams(**generation_kwargs) - - if self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) - all_prompts = [p for sublist in gathered_prompts for p in sublist] - else: - all_prompts = prompts - - if self.args.vllm_enable_sleep_mode: - self.llm.wake_up(tags=["kv_cache"]) - - with profiling_context(self, "vLLM.generate"): - if is_conversational({"prompt": prompts[0]}): - all_outputs = self.llm.chat( - all_prompts, - sampling_params=sampling_params, - use_tqdm=False, - chat_template_kwargs=self.chat_template_kwargs, - tools=self.tools, - chat_template=self.chat_template, - ) - else: - all_outputs = self.llm.generate( - all_prompts, sampling_params=sampling_params, use_tqdm=False - ) - - all_prompt_ids = [output.prompt_token_ids for output in all_outputs] - all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - all_logprobs = [ - [next(iter(lp.values())).logprob for lp in output.logprobs] - for outputs in all_outputs - for output in outputs.outputs - ] - - if self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - prompt_ids = all_prompt_ids[tp_slice] - completion_ids = all_completion_ids[tp_slice] - logprobs = all_logprobs[tp_slice] - else: - prompt_ids = all_prompt_ids - completion_ids = all_completion_ids - logprobs = all_logprobs - - extra_fields = {} # No extra fields for colocate mode - - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=2) + # Generate using vLLM + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + prompt_ids, completion_ids, logprobs, extra_fields = self.vllm_generation.generate( + prompts=prompts, num_generations=num_generations, profiler=profiling_context(self, "vLLM.generate") + ) elif self.use_transformers_paged: if is_conversational({"prompt": prompts[0]}): diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index e2689c2e668..0d9ad62634f 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -16,7 +16,6 @@ import atexit import copy import inspect -import os import textwrap import time from collections import defaultdict, deque @@ -31,9 +30,8 @@ import torch import torch.utils.data from accelerate.logging import get_logger -from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed +from accelerate.utils import gather, gather_object, is_peft_model, set_seed from datasets import Dataset, IterableDataset -from packaging.version import Version from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader, Sampler @@ -46,7 +44,6 @@ PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, - is_bitsandbytes_available, is_trackio_available, is_wandb_available, ) @@ -57,11 +54,9 @@ apply_chat_template, is_conversational, prepare_multimodal_messages, - prepare_multimodal_messages_vllm, ) from ..extras.profiling import profiling_context, profiling_decorator -from ..extras.vllm_client import VLLMClient -from ..import_utils import is_vllm_available +from ..generation.vllm_generation import VLLMGeneration from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation from ..models.utils import disable_gradient_checkpointing from .base_trainer import BaseTrainer @@ -71,7 +66,6 @@ RepeatSampler, create_model_from_path, disable_dropout_in_model, - ensure_master_addr_port, entropy_from_logits, get_config_model_id, identity, @@ -94,14 +88,6 @@ if is_peft_available(): from peft import PeftConfig, PeftModel, get_peft_model -if is_vllm_available(): - import vllm - from vllm import LLM, SamplingParams - - if Version(vllm.__version__) <= Version("0.10.2"): - from vllm.sampling_params import GuidedDecodingParams - else: - from vllm.sampling_params import StructuredOutputsParams if is_wandb_available(): import wandb @@ -109,8 +95,6 @@ if is_trackio_available(): import trackio -if is_bitsandbytes_available(): - import bitsandbytes as bnb logger = get_logger(__name__) @@ -505,88 +489,42 @@ def __init__( set_seed(args.seed, device_specific=True) if self.use_vllm: - if not is_vllm_available(): - raise ImportError( - "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install trl[vllm]` to use it." - ) - - if self.vllm_mode == "server": - if self.accelerator.is_main_process: - if args.vllm_server_base_url is not None: - base_url = args.vllm_server_base_url - else: - base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" - self.vllm_client = VLLMClient( - base_url=base_url, group_port=args.vllm_group_port, connection_timeout=args.vllm_server_timeout - ) - self.vllm_client.init_communicator(device=torch.cuda.current_device()) - - elif self.vllm_mode == "colocate": - # Make sure vllm_tensor_parallel_size group size evenly divides the world size - each group should have - # the same number of ranks - if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: - raise ValueError( - f"vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size " - f"({self.accelerator.num_processes}) evenly." - ) - - if self.vllm_tensor_parallel_size > 1: - # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks. - # For example, if world_size=8 and vllm_tensor_parallel_size=2 → groups: [0,1], [2,3], [4,5], [6,7] - self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration( - [ - list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) - for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) - ] - ) - - # vLLM requires the environment variables to be set for distributed training. - os.environ["RANK"] = str(self.accelerator.process_index) - os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index) - os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) - # Ensure distributed rendezvous variables are set without colliding across concurrent runs - ensure_master_addr_port() - - vllm_quantization = None - if is_bitsandbytes_available(): - for _, module in model.named_modules(): - if isinstance(module, bnb.nn.Linear4bit): - vllm_quantization = "bitsandbytes" - break - elif isinstance(module, bnb.nn.Linear8bitLt): - raise ValueError("vLLM does not support in-flight 8-bit quantization.") - self.llm = LLM( - model=model.name_or_path, - tensor_parallel_size=args.vllm_tensor_parallel_size, - gpu_memory_utilization=self.vllm_gpu_memory_utilization, - max_num_seqs=self.args.per_device_train_batch_size - * self.vllm_tensor_parallel_size - * self.args.steps_per_generation, - max_model_len=self.args.vllm_max_model_length, - distributed_executor_backend="external_launcher", - # Feed identical seed for tp groups to ensure sampling results are the same across workers - seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, - # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory - max_num_batched_tokens=4096, - model_impl=self.args.vllm_model_impl, - enable_sleep_mode=self.args.vllm_enable_sleep_mode, - quantization=vllm_quantization, - ) - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=2) - else: - raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.") - - # vLLM specific sampling arguments - self.structured_outputs_regex = args.vllm_structured_outputs_regex - + # Initialize vLLM generation backend + self.vllm_generation = VLLMGeneration( + model=self.model, + accelerator=self.accelerator, + is_fsdp_enabled=self.is_fsdp_enabled, + processing_class=self.processing_class, + # vLLM configuration + mode=args.vllm_mode, + structured_outputs_regex=args.vllm_structured_outputs_regex, + # Server mode configuration + server_base_url=args.vllm_server_base_url, + server_host=args.vllm_server_host, + server_port=args.vllm_server_port, + group_port=args.vllm_group_port, + server_timeout=args.vllm_server_timeout, + # Colocate mode configuration + tensor_parallel_size=args.vllm_tensor_parallel_size, + gpu_memory_utilization=args.vllm_gpu_memory_utilization, + max_model_length=args.vllm_max_model_length, + max_num_seqs=args.per_device_train_batch_size + * args.vllm_tensor_parallel_size + * args.steps_per_generation, + enable_sleep_mode=args.vllm_enable_sleep_mode, + model_impl=args.vllm_model_impl, + # Generation configuration + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + max_completion_length=self.max_completion_length, + generation_kwargs=args.generation_kwargs, + # Chat/tool configuration + chat_template_kwargs=self.chat_template_kwargs, + ) self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation - - # When using vLLM, the main process is responsible for loading the model weights. This can cause process - # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we - # synchronize all processes after vLLM has been fully initialized. - self.accelerator.wait_for_everyone() else: generation_kwargs = { "max_new_tokens": self.max_completion_length, @@ -801,140 +739,6 @@ def _get_per_token_logps_and_entropies( entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None return logps, entropies - def _fix_param_name_to_vllm(self, name, extra_prefixes: list[str] | None = None): - extra_prefixes = extra_prefixes or [] - prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes - for prefix in prefixes: - name = name.replace(prefix, "") - return name - - def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): - """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM.""" - # For FSDP1, we need to recurse into children and also use summon_full_params - if visited is None: - visited = set() - for child_name, child_module in module.named_children(): - child_prefix = f"{prefix}.{child_name}" if prefix else child_name - self._sync_fsdp1_params_to_vllm( - child_module, prefix=child_prefix, visited=visited - ) # recurse into the child - - if isinstance(module, FSDP): - with FSDP.summon_full_params(module, recurse=False, writeback=False): - for param_name, param in module.named_parameters(): - full_name = f"{prefix}.{param_name}" if prefix else param_name - full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."]) - - if full_name in visited: - continue # skip FSDP subtrees already traversed - visited.add(full_name) - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(full_name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(full_name, param.data)]) - - def _sync_fsdp2_params_to_vllm(self, module: nn.Module): - # For FSDP2, module.state_dict() already covers all parameters, so no need for recursion - for name, param in module.state_dict().items(): - # When using PEFT, we need to recover the original parameter name - name = name.removeprefix("base_model.model.").replace(".base_layer", "") - # Skip PEFT layers: they don’t exist in vLLM, and they are merged already. - if is_peft_model(module) and module.prefix in name: - continue - # When module to save, remove its prefix and discard the original module - if "original_module" in name: - continue - name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) - - if param.is_cpu: - param = param.to(torch.device("cuda")) - param = param.full_tensor() - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param) - elif self.vllm_mode == "colocate": - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, param)]) - - @profiling_decorator - def _move_model_to_vllm(self): - # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - if zero_stage_3: - import deepspeed - - gather_if_zero3 = deepspeed.zero.GatheredParameters - else: - gather_if_zero3 = nullcontext - - if is_peft_model(self.model): - # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as - # merging adapters in a sharded manner is not supported. - # TODO: does this work with FSDP? - with gather_if_zero3(list(self.model.parameters())): - self.model.merge_adapter() - - # Update vLLM weights while parameters are gathered - if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext - # Update vLLM weights while parameters are gathered - # For PEFT with FSDP we need to use the memory efficient post-order traversal - fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None) - fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 - if fsdp_version == 1: - self._sync_fsdp1_params_to_vllm( - self.model - ) # use memory-efficient post-order traversal for FSDP - elif fsdp_version == 2: - self._sync_fsdp2_params_to_vllm(self.model) - else: - # DeepSpeed ZeRO-3 with PEFT - for name, param in self.model.named_parameters(): - # When using PEFT, we need to recover the original parameter name - name = name.removeprefix("base_model.model.").replace(".base_layer", "") - # Skip PEFT layers: they don’t exist in vLLM, and they are merged already. - if self.model.prefix in name: - continue - # When module to save, remove its prefix and discard the original module - if "original_module" in name: - continue - name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, param.data)]) - # Unmerge adapters while parameters are still gathered - self.model.unmerge_adapter() - # Parameters will automatically be repartitioned when exiting the context - else: - # For non-PEFT models, simply gather (if needed) and update each parameter individually. - if self.is_fsdp_enabled: - fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None) - fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 - if fsdp_version == 1: - self._sync_fsdp1_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP - elif fsdp_version == 2: - self._sync_fsdp2_params_to_vllm(self.model) - else: - for name, param in self.model.named_parameters(): - name = self._fix_param_name_to_vllm(name) - with gather_if_zero3([param]): - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, param.data)]) - - # Reset cache on vLLM - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.reset_prefix_cache() - elif self.vllm_mode == "colocate": - self.llm.reset_prefix_cache() - def training_step(self, model, inputs, num_items_in_batch): time_before = time.perf_counter() output = super().training_step(model, inputs, num_items_in_batch) @@ -1066,140 +870,16 @@ def _generate_single_turn(self, prompts: list): # Generate completions using either vLLM or regular generation if self.use_vllm: - if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: - # wake up colocated vLLM instances if needed - torch.cuda.empty_cache() # required to avoid OOM in some cases - self.llm.wake_up(tags=["weights"]) - # Work around for https://github.com/vllm-project/vllm/issues/29341 - self.llm.collective_rpc("reload_weights") - - # First, update the vLLM weights if needed + # Sync weights if training step changed if self.state.global_step != self._last_loaded_step: - self._move_model_to_vllm() + self.vllm_generation.sync_weights() self._last_loaded_step = self.state.global_step - if is_conversational({"prompt": prompts[0]}): - prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts] - - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - if self.vllm_mode == "server": - all_prompts = gather_object(prompts) - num_generations = self.num_generations if mode == "train" else self.num_generations_eval - - if self.accelerator.is_main_process: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - ordered_set_of_prompts = all_prompts[::num_generations] - - sampling_params = { - "n": num_generations, - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "min_p": 0.0 if self.min_p is None else self.min_p, - "max_tokens": self.max_completion_length, - "structured_outputs_regex": self.structured_outputs_regex, - "generation_kwargs": self.args.generation_kwargs, - } - with profiling_context(self, "vLLM.generate"): - if is_conversational({"prompt": ordered_set_of_prompts[0]}): - output = self.vllm_client.chat( - messages=ordered_set_of_prompts, - **sampling_params, - chat_template_kwargs=self.chat_template_kwargs, - ) - else: - output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) - payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) - else: - payload = None - - # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. - obj_list = [payload] - broadcast_object_list(obj_list, from_process=0) - all_prompt_ids, all_completion_ids, _ = obj_list[0] - - # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times - all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_generations)] - - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - prompt_ids = all_prompt_ids[process_slice] - completion_ids = all_completion_ids[process_slice] - - # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts - elif self.vllm_mode == "colocate": - if Version(vllm.__version__) <= Version("0.10.2"): - structured_outputs_key = "guided_decoding" - if self.structured_outputs_regex: - structured_outputs = GuidedDecodingParams(regex=self.structured_outputs_regex) - else: - structured_outputs = None - else: - structured_outputs_key = "structured_outputs" - if self.structured_outputs_regex: - structured_outputs = StructuredOutputsParams(regex=self.structured_outputs_regex) - else: - structured_outputs = None - - generation_kwargs = { - "n": 1, # vLLM on each GPU generates only 1 in colocate mode - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "min_p": 0.0 if self.min_p is None else self.min_p, - "max_tokens": self.max_completion_length, - } - generation_kwargs[structured_outputs_key] = structured_outputs - if self.args.generation_kwargs is not None: - generation_kwargs.update(self.args.generation_kwargs) - sampling_params = SamplingParams(**generation_kwargs) - - if self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) - all_prompts = [p for sublist in gathered_prompts for p in sublist] - else: - all_prompts = prompts - - if self.args.vllm_enable_sleep_mode: - self.llm.wake_up(tags=["kv_cache"]) - - with profiling_context(self, "vLLM.generate"): - if is_conversational({"prompt": prompts[0]}): - all_outputs = self.llm.chat( - all_prompts, - sampling_params=sampling_params, - use_tqdm=False, - chat_template_kwargs=self.chat_template_kwargs, - ) - else: - all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False) - - all_prompt_ids = [output.prompt_token_ids for output in all_outputs] - all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - - if self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - prompt_ids = all_prompt_ids[tp_slice] - completion_ids = all_completion_ids[tp_slice] - else: - prompt_ids = all_prompt_ids - completion_ids = all_completion_ids - - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=2) + # Generate using vLLM (note: RLOO doesn't use logprobs from generation, so we ignore them) + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + prompt_ids, completion_ids, _, _ = self.vllm_generation.generate( + prompts=prompts, num_generations=num_generations, profiler=profiling_context(self, "vLLM.generate") + ) elif self.use_transformers_paged: if is_conversational({"prompt": prompts[0]}):