diff --git a/docs/source/vllm_integration.md b/docs/source/vllm_integration.md index 559a3e74588..63e6ebacfbd 100644 --- a/docs/source/vllm_integration.md +++ b/docs/source/vllm_integration.md @@ -3,7 +3,7 @@ This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. > [!WARNING] -> TRL currently only supports vLLM versions `0.10.2`, `0.11.0`, `0.11.1`, and `0.11.2`. Please ensure you have one of these versions installed to avoid compatibility issues. +> TRL currently only supports vLLM versions `0.10.2`, `0.11.0`, `0.11.1`, `0.11.2` and `0.12.0`. Please ensure you have one of these versions installed to avoid compatibility issues. > [!TIP] > The following trainers currently support generation with vLLM: diff --git a/pyproject.toml b/pyproject.toml index 25993a5b109..366185f42bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ test = [ "pytest" ] vllm = [ - "vllm>=0.10.2,<0.12.0", + "vllm>=0.10.2,<0.13.0", "fastapi", "pydantic", "requests", diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 2b09bd2152d..ba509cd85db 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -918,8 +918,8 @@ def test_training_vllm_and_peft(self): @require_vllm @pytest.mark.skip(reason="We should add a mock for the vLLM server.") - def test_training_vllm_guided_decoding(self): - """Test that training works with vLLM for generation with guided decoding.""" + def test_training_vllm_structured_outputs(self): + """Test that training works with vLLM for generation with structured outputs.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") training_args = GRPOConfig( @@ -930,7 +930,7 @@ def test_training_vllm_guided_decoding(self): max_completion_length=8, # reduce the completion length to reduce memory usage report_to="none", use_vllm=True, - vllm_guided_decoding_regex=r"\n.*\n\n\n.*\n", + vllm_structured_outputs_regex=r"\n.*\n\n\n.*\n", ) trainer = GRPOTrainer( model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM @@ -953,7 +953,7 @@ def test_training_vllm_guided_decoding(self): @require_vllm @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_importance_sampling_correction(self): - """Test that training works with vLLM for generation with guided decoding.""" + """Test that training works with vLLM for generation with structured outputs.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") training_args = GRPOConfig( diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 741028a6a8f..24d1006c67a 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -693,8 +693,8 @@ def test_training_vllm_and_peft(self): @require_vllm @pytest.mark.skip(reason="We should add a mock for the vLLM server.") - def test_training_vllm_guided_decoding(self): - """Test that training works with vLLM for generation with guided decoding.""" + def test_training_vllm_structured_outputs(self): + """Test that training works with vLLM for generation with structured outputs.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") training_args = RLOOConfig( @@ -705,7 +705,7 @@ def test_training_vllm_guided_decoding(self): max_completion_length=8, # reduce the completion length to reduce memory usage report_to="none", use_vllm=True, - vllm_guided_decoding_regex=r"\n.*\n\n\n.*\n", + vllm_structured_outputs_regex=r"\n.*\n\n\n.*\n", ) trainer = RLOOTrainer( model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM diff --git a/trl/__init__.py b/trl/__init__.py index 696019c8df5..9f39b367c9c 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -194,7 +194,7 @@ # Fix DisableTqdm # Bug introduced in https://github.com/vllm-project/vllm/pull/52 # Fixed in https://github.com/vllm-project/vllm/pull/28471 (released in v0.11.1) - # Since TRL currently only supports vLLM v0.10.2-0.11.2, we patch it here. This can be removed when TRL requires + # Since TRL currently only supports vLLM v0.10.2-0.12.0, we patch it here. This can be removed when TRL requires # vLLM >=0.11.1 import vllm.model_executor.model_loader.weight_utils from tqdm import tqdm diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index d9307352e2e..046cf71e5f5 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -84,8 +84,8 @@ class GOLDConfig(SFTConfig): to set this to a low value if the student and teacher models share the same GPU. vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): Tensor parallel size for the colocated student vLLM engine (if `vllm_mode="colocate"`). - vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): - Regex for vLLM guided decoding for the student model. + vllm_structured_outputs_regex (`str` or `None`, *optional*, defaults to `None`): + Regex for vLLM structured outputs for the student model. vllm_sync_frequency (`int`, *optional*, defaults to `1`): Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after every step. @@ -303,9 +303,9 @@ class GOLDConfig(SFTConfig): default=1, metadata={"help": 'Tensor parallel size for the colocated vLLM engine when `vllm_mode="colocate"`.'}, ) - vllm_guided_decoding_regex: str | None = field( + vllm_structured_outputs_regex: str | None = field( default=None, - metadata={"help": "Regex pattern used for vLLM guided decoding (optional)."}, + metadata={"help": "Regex pattern used for vLLM structured outputs (optional)."}, ) vllm_sync_frequency: int = field( default=1, diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index fc19398c140..f3a9cf9105b 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -72,7 +72,7 @@ if is_vllm_available(): from vllm import LLM, SamplingParams - from vllm.sampling_params import GuidedDecodingParams + from vllm.sampling_params import StructuredOutputsParams if is_liger_kernel_available(): from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss @@ -978,7 +978,7 @@ def __init__( self.accelerator.wait_for_everyone() else: raise ValueError(f"Unknown vllm_mode: {self.vllm_mode}") - self.vllm_guided_decoding_regex = args.vllm_guided_decoding_regex + self.vllm_structured_outputs_regex = args.vllm_structured_outputs_regex self.vllm_sync_frequency = args.vllm_sync_frequency self._last_vllm_sync_step = -1 @@ -1675,7 +1675,7 @@ def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_ top_k=top_k, min_p=min_p, max_tokens=max_completion_length, - guided_decoding_regex=self.vllm_guided_decoding_regex, + structured_outputs_regex=self.vllm_structured_outputs_regex, )["completion_ids"] else: completion_ids = [None] * len(all_prompts_text) @@ -1686,10 +1686,12 @@ def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_ ) completion_ids = completion_ids[process_slice] elif self.vllm_mode == "colocate": - if self.vllm_guided_decoding_regex: - guided_decoding = GuidedDecodingParams(backend="outlines", regex=self.vllm_guided_decoding_regex) + if self.vllm_structured_outputs_regex: + structured_outputs = StructuredOutputsParams( + backend="outlines", regex=self.vllm_structured_outputs_regex + ) else: - guided_decoding = None + structured_outputs = None sampling_params = SamplingParams( n=1, repetition_penalty=repetition_penalty, @@ -1698,7 +1700,7 @@ def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_ top_k=top_k, min_p=min_p, max_tokens=max_completion_length, - guided_decoding=guided_decoding, + structured_outputs=structured_outputs, ) if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: diff --git a/trl/experimental/online_dpo/online_dpo_config.py b/trl/experimental/online_dpo/online_dpo_config.py index 6d373875028..ec6d6fac2e0 100644 --- a/trl/experimental/online_dpo/online_dpo_config.py +++ b/trl/experimental/online_dpo/online_dpo_config.py @@ -106,8 +106,8 @@ class may differ from those in [`~transformers.TrainingArguments`]. server is running (start with `trl vllm-serve`). - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a separate server but may cause resource contention with training. - vllm_guided_decoding_regex (`str`, *optional*): - Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. + vllm_structured_outputs_regex (`str`, *optional*): + Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled. > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) @@ -306,9 +306,9 @@ class may differ from those in [`~transformers.TrainingArguments`]. "model implementation." }, ) - vllm_guided_decoding_regex: str | None = field( + vllm_structured_outputs_regex: str | None = field( default=None, - metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, + metadata={"help": "Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled."}, ) vllm_gpu_memory_utilization: float | None = field( default=0.55, diff --git a/trl/experimental/online_dpo/online_dpo_trainer.py b/trl/experimental/online_dpo/online_dpo_trainer.py index f16c98db156..08e81c1cdd7 100644 --- a/trl/experimental/online_dpo/online_dpo_trainer.py +++ b/trl/experimental/online_dpo/online_dpo_trainer.py @@ -78,7 +78,7 @@ if is_vllm_available(): from vllm import LLM, SamplingParams - from vllm.sampling_params import GuidedDecodingParams + from vllm.sampling_params import StructuredOutputsParams if is_bitsandbytes_available(): import bitsandbytes as bnb @@ -491,7 +491,7 @@ def __init__( else: raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.") # vLLM specific sampling arguments - self.guided_decoding_regex = args.vllm_guided_decoding_regex + self.structured_outputs_regex = args.vllm_structured_outputs_regex self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation # Set up vLLM generation config @@ -507,8 +507,8 @@ def __init__( } if args.generation_kwargs is not None: generation_params.update(args.generation_kwargs) - if self.guided_decoding_regex: - generation_params["guided_decoding"] = GuidedDecodingParams(regex=self.guided_decoding_regex) + if self.structured_outputs_regex: + generation_params["structured_outputs"] = StructuredOutputsParams(regex=self.structured_outputs_regex) self.generation_config = SamplingParams(**generation_params) # When using vLLM, the main process is responsible for loading the model weights. This can cause process @@ -748,7 +748,9 @@ def _generate_vllm_server(self, prompts, images=None): top_k=-1 if self.top_k is None else self.top_k, min_p=0.0 if self.min_p is None else self.min_p, max_tokens=self.generation_config.max_tokens, - guided_decoding_regex=self.guided_decoding_regex if hasattr(self, "guided_decoding_regex") else None, + structured_outputs_regex=self.structured_outputs_regex + if hasattr(self, "structured_outputs_regex") + else None, generation_kwargs=self.args.generation_kwargs, )["completion_ids"] # Flatten: each prompt generates 2 completions diff --git a/trl/experimental/openenv/utils.py b/trl/experimental/openenv/utils.py index 08ead93e6e8..3d40dcc390c 100644 --- a/trl/experimental/openenv/utils.py +++ b/trl/experimental/openenv/utils.py @@ -23,7 +23,7 @@ if is_vllm_available(): from vllm import SamplingParams - from vllm.sampling_params import GuidedDecodingParams + from vllm.sampling_params import StructuredOutputsParams def _build_colocate_sampling_params( @@ -32,10 +32,10 @@ def _build_colocate_sampling_params( *, logprobs: bool = True, ) -> SamplingParams: - if trainer.guided_decoding_regex: - guided_decoding = GuidedDecodingParams(regex=trainer.guided_decoding_regex) + if trainer.structured_outputs_regex: + structured_outputs = StructuredOutputsParams(regex=trainer.structured_outputs_regex) else: - guided_decoding = None + structured_outputs = None generation_kwargs: dict[str, Any] = { "n": 1, @@ -43,7 +43,7 @@ def _build_colocate_sampling_params( "top_k": trainer.top_k, "min_p": 0.0 if trainer.min_p is None else trainer.min_p, "max_tokens": trainer.max_completion_length, - "guided_decoding": guided_decoding, + "structured_outputs": structured_outputs, } if trainer.repetition_penalty is not None: generation_kwargs["repetition_penalty"] = trainer.repetition_penalty diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index e21df6d837e..22f1877ce51 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -190,7 +190,7 @@ def generate( min_p: float = 0.0, max_tokens: int = 16, truncate_prompt_tokens: int | None = None, - guided_decoding_regex: str | None = None, + structured_outputs_regex: str | None = None, generation_kwargs: dict | None = None, ) -> dict[str, list[list[int]]]: """ @@ -219,7 +219,7 @@ def generate( If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is disabled. - guided_decoding_regex (`str`, *optional*): + structured_outputs_regex (`str`, *optional*): Regular expression to guide the decoding process. generation_kwargs (`dict`, *optional*): Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like @@ -253,7 +253,7 @@ def generate( "min_p": min_p, "max_tokens": max_tokens, "truncate_prompt_tokens": truncate_prompt_tokens, - "guided_decoding_regex": guided_decoding_regex, + "structured_outputs_regex": structured_outputs_regex, "generation_kwargs": generation_kwargs or {}, }, ) @@ -278,7 +278,7 @@ def chat( min_p: float = 0.0, max_tokens: int = 16, truncate_prompt_tokens: int | None = None, - guided_decoding_regex: str | None = None, + structured_outputs_regex: str | None = None, generation_kwargs: dict | None = None, chat_template_kwargs: dict | None = None, tools: list | None = None, @@ -309,7 +309,7 @@ def chat( If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is disabled. - guided_decoding_regex (`str`, *optional*): + structured_outputs_regex (`str`, *optional*): Regular expression to guide the decoding process. generation_kwargs (`dict`, *optional*): Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like @@ -360,7 +360,7 @@ def chat( "min_p": min_p, "max_tokens": max_tokens, "truncate_prompt_tokens": truncate_prompt_tokens, - "guided_decoding_regex": guided_decoding_regex, + "structured_outputs_regex": structured_outputs_regex, "generation_kwargs": generation_kwargs or {}, "chat_template_kwargs": chat_template_kwargs or {}, }, diff --git a/trl/import_utils.py b/trl/import_utils.py index 2cce6df9633..5106b782e66 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -79,9 +79,9 @@ def is_uvicorn_available() -> bool: def is_vllm_available() -> bool: _vllm_available, _vllm_version = _is_package_available("vllm", return_version=True) if _vllm_available: - if not (version.parse("0.10.2") <= version.parse(_vllm_version) <= version.parse("0.11.2")): + if not (version.parse("0.10.2") <= version.parse(_vllm_version) <= version.parse("0.12.0")): warnings.warn( - "TRL currently supports vLLM versions: 0.10.2, 0.11.0, 0.11.1, 0.11.2. You have version " + "TRL currently supports vLLM versions: 0.10.2, 0.11.0, 0.11.1, 0.11.2, 0.12.0. You have version " f"{_vllm_version} installed. We recommend installing a supported version to avoid compatibility " "issues.", stacklevel=2, diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index d5884b8290a..19b37d98c46 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -26,6 +26,7 @@ import torch import torch.distributed.distributed_c10d as c10d +from packaging import version from transformers import is_torch_xpu_available, is_vision_available from trl import TrlParser @@ -55,13 +56,18 @@ if is_vllm_available(): + import vllm from vllm import LLM, SamplingParams from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.parallel_state import get_world_group from vllm.distributed.utils import StatelessProcessGroup - from vllm.sampling_params import GuidedDecodingParams from vllm.utils import get_open_port + if version.parse(vllm.__version__) <= version.parse("0.10.2"): + from vllm.sampling_params import GuidedDecodingParams + else: + from vllm.sampling_params import StructuredOutputsParams + if is_vllm_ascend_available(): from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator @@ -495,7 +501,7 @@ class GenerateRequest(BaseModel): min_p: float = 0.0 max_tokens: int = 16 truncate_prompt_tokens: int | None = None - guided_decoding_regex: str | None = None + structured_outputs_regex: str | None = None generation_kwargs: dict = field(default_factory=dict) class GenerateResponse(BaseModel): @@ -528,8 +534,8 @@ async def generate(request: GenerateRequest): - `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is disabled. - - `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the - model will only generate tokens that match this regex pattern. + - `structured_outputs_regex` (`str`, *optional*): A regex pattern for structured outputs. If provided, + the model will only generate tokens that match this regex pattern. - `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains keys that conflict with the other parameters, they will override them. @@ -564,11 +570,19 @@ async def generate(request: GenerateRequest): row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))} prompts.append(row) - # Guided decoding, if enabled - if request.guided_decoding_regex is not None: - guided_decoding = GuidedDecodingParams(regex=request.guided_decoding_regex) + # Structured outputs, if enabled + if version.parse(vllm.__version__) <= version.parse("0.10.2"): + structured_outputs_key = "guided_decoding" + if request.structured_outputs_regex is not None: + structured_outputs = GuidedDecodingParams(regex=request.structured_outputs_regex) + else: + structured_outputs = None else: - guided_decoding = None + structured_outputs_key = "structured_outputs" + if request.structured_outputs_regex is not None: + structured_outputs = StructuredOutputsParams(regex=request.structured_outputs_regex) + else: + structured_outputs = None generation_kwargs = { "n": request.n, @@ -579,9 +593,9 @@ async def generate(request: GenerateRequest): "min_p": request.min_p, "max_tokens": request.max_tokens, "truncate_prompt_tokens": request.truncate_prompt_tokens, - "guided_decoding": guided_decoding, "logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only } + generation_kwargs[structured_outputs_key] = structured_outputs generation_kwargs.update(request.generation_kwargs) sampling_params = SamplingParams(**generation_kwargs) @@ -625,7 +639,7 @@ class ChatRequest(BaseModel): min_p: float = 0.0 max_tokens: int = 16 truncate_prompt_tokens: int | None = None - guided_decoding_regex: str | None = None + structured_outputs_regex: str | None = None generation_kwargs: dict = field(default_factory=dict) chat_template_kwargs: dict = field(default_factory=dict) @@ -658,8 +672,8 @@ async def chat(request: ChatRequest): - `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is disabled. - - `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the - model will only generate tokens that match this regex pattern. + - `structured_outputs_regex` (`str`, *optional*): A regex pattern for structured outputs. If provided, + the model will only generate tokens that match this regex pattern. - `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains keys that conflict with the other parameters, they will override them. @@ -697,11 +711,19 @@ async def chat(request: ChatRequest): if part["type"] == "image_pil": part["image_pil"] = Image.open(BytesIO(base64.b64decode(part["image_pil"]))) - # Guided decoding, if enabled - if request.guided_decoding_regex is not None: - guided_decoding = GuidedDecodingParams(regex=request.guided_decoding_regex) + # Structured outputs, if enabled + if version.parse(vllm.__version__) <= version.parse("0.10.2"): + structured_outputs_key = "guided_decoding" + if request.structured_outputs_regex is not None: + structured_outputs = GuidedDecodingParams(regex=request.structured_outputs_regex) + else: + structured_outputs = None else: - guided_decoding = None + structured_outputs_key = "structured_outputs" + if request.structured_outputs_regex is not None: + structured_outputs = StructuredOutputsParams(regex=request.structured_outputs_regex) + else: + structured_outputs = None generation_kwargs = { "n": request.n, @@ -712,9 +734,9 @@ async def chat(request: ChatRequest): "min_p": request.min_p, "max_tokens": request.max_tokens, "truncate_prompt_tokens": request.truncate_prompt_tokens, - "guided_decoding": guided_decoding, "logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only } + generation_kwargs[structured_outputs_key] = structured_outputs generation_kwargs.update(request.generation_kwargs) sampling_params = SamplingParams(**generation_kwargs) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 09370c73d57..a52af5cac31 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -122,8 +122,8 @@ class GRPOConfig(TrainingArguments): Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model implementation. - vllm_guided_decoding_regex (`str`, *optional*): - Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. + vllm_structured_outputs_regex (`str`, *optional*): + Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled. > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) @@ -301,6 +301,15 @@ class GRPOConfig(TrainingArguments): Parameter `max_prompt_length` is deprecated and will be removed in version 0.28.0. You should instead filter your dataset before training to ensure that prompts do not exceed your desired length. + + + vllm_guided_decoding_regex: + + + + Parameter `vllm_guided_decoding_regex` is deprecated and will be removed in version 0.28.0. You should + instead use `vllm_structured_outputs_regex`. + """ @@ -517,9 +526,9 @@ class GRPOConfig(TrainingArguments): "usage low, but waking the engine adds host–device transfer latency." }, ) - vllm_guided_decoding_regex: str | None = field( + vllm_structured_outputs_regex: str | None = field( default=None, - metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, + metadata={"help": "Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled."}, ) # Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) @@ -797,6 +806,10 @@ class GRPOConfig(TrainingArguments): "desired length." }, ) + vllm_guided_decoding_regex: str | None = field( + default=None, + metadata={"help": "Deprecated, use `vllm_structured_outputs_regex` instead."}, + ) def __post_init__(self): self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 @@ -880,3 +893,11 @@ def __post_init__(self): FutureWarning, stacklevel=2, ) + if self.vllm_guided_decoding_regex is not None: + warnings.warn( + "The `vllm_guided_decoding_regex` argument is deprecated and will be removed in version 0.28.0. You " + "should instead use `vllm_structured_outputs_regex`.", + FutureWarning, + stacklevel=2, + ) + self.vllm_structured_outputs_regex = self.vllm_guided_decoding_regex diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 1c9a2fab1cc..4749df33933 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -36,6 +36,7 @@ from accelerate.logging import get_logger from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from datasets import Dataset, IterableDataset +from packaging import version from packaging.version import Version from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -101,8 +102,13 @@ from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss if is_vllm_available(): + import vllm from vllm import LLM, SamplingParams - from vllm.sampling_params import GuidedDecodingParams + + if version.parse(vllm.__version__) <= version.parse("0.10.2"): + from vllm.sampling_params import GuidedDecodingParams + else: + from vllm.sampling_params import StructuredOutputsParams if is_wandb_available(): import wandb @@ -701,7 +707,7 @@ def cast_outputs_to_original_dtype(module, args, output): raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.") # vLLM specific sampling arguments - self.guided_decoding_regex = args.vllm_guided_decoding_regex + self.structured_outputs_regex = args.vllm_structured_outputs_regex self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation @@ -1315,7 +1321,7 @@ def _generate_single_turn(self, prompts: list): "top_k": self.top_k, "min_p": 0.0 if self.min_p is None else self.min_p, "max_tokens": self.max_completion_length, - "guided_decoding_regex": self.guided_decoding_regex, + "structured_outputs_regex": self.structured_outputs_regex, "generation_kwargs": self.args.generation_kwargs, } with profiling_context(self, "vLLM.generate"): @@ -1389,10 +1395,18 @@ def _generate_single_turn(self, prompts: list): completion_ids = output["completion_ids"] logprobs = output["logprobs"] else: - if self.guided_decoding_regex: - guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex) + if version.parse(vllm.__version__) <= version.parse("0.10.2"): + structured_outputs_key = "guided_decoding" + if self.structured_outputs_regex: + structured_outputs = GuidedDecodingParams(regex=self.structured_outputs_regex) + else: + structured_outputs = None else: - guided_decoding = None + structured_outputs_key = "structured_outputs" + if self.structured_outputs_regex: + structured_outputs = StructuredOutputsParams(regex=self.structured_outputs_regex) + else: + structured_outputs = None generation_kwargs = { "n": 1, # vLLM on each GPU generates only 1 in colocate mode @@ -1402,9 +1416,9 @@ def _generate_single_turn(self, prompts: list): "top_k": self.top_k, "min_p": 0.0 if self.min_p is None else self.min_p, "max_tokens": self.max_completion_length, - "guided_decoding": guided_decoding, "logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only } + generation_kwargs[structured_outputs_key] = structured_outputs if self.args.generation_kwargs is not None: generation_kwargs.update(self.args.generation_kwargs) sampling_params = SamplingParams(**generation_kwargs) diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index c5f16a93220..285600ae665 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -117,8 +117,8 @@ class RLOOConfig(TrainingArguments): Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model implementation. - vllm_guided_decoding_regex (`str`, *optional*): - Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. + vllm_structured_outputs_regex (`str`, *optional*): + Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled. > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) @@ -212,6 +212,15 @@ class RLOOConfig(TrainingArguments): Parameter `max_prompt_length` is deprecated and will be removed in version 0.29.0. You should instead filter your dataset before training to ensure that prompts do not exceed your desired length. + + + vllm_guided_decoding_regex: + + + + Parameter `vllm_guided_decoding_regex` is deprecated and will be removed in version 0.28.0. You should + instead use `vllm_structured_outputs_regex`. + """ @@ -419,9 +428,9 @@ class RLOOConfig(TrainingArguments): "usage low, but waking the engine adds host–device transfer latency." }, ) - vllm_guided_decoding_regex: str | None = field( + vllm_structured_outputs_regex: str | None = field( default=None, - metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, + metadata={"help": "Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled."}, ) # Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) @@ -580,6 +589,10 @@ class RLOOConfig(TrainingArguments): "desired length." }, ) + vllm_guided_decoding_regex: str | None = field( + default=None, + metadata={"help": "Deprecated, use `vllm_structured_outputs_regex` instead."}, + ) def __post_init__(self): self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 @@ -649,3 +662,11 @@ def __post_init__(self): FutureWarning, stacklevel=2, ) + if self.vllm_guided_decoding_regex is not None: + warnings.warn( + "The `vllm_guided_decoding_regex` argument is deprecated and will be removed in version 0.28.0. You " + "should instead use `vllm_structured_outputs_regex`.", + FutureWarning, + stacklevel=2, + ) + self.vllm_structured_outputs_regex = self.vllm_guided_decoding_regex diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 8fa3e814479..06fe04b1a7c 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -33,6 +33,7 @@ from accelerate.logging import get_logger from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from datasets import Dataset, IterableDataset +from packaging import version from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader, Sampler @@ -93,8 +94,13 @@ from peft import PeftConfig, PeftModel, get_peft_model if is_vllm_available(): + import vllm from vllm import LLM, SamplingParams - from vllm.sampling_params import GuidedDecodingParams + + if version.parse(vllm.__version__) <= version.parse("0.10.2"): + from vllm.sampling_params import GuidedDecodingParams + else: + from vllm.sampling_params import StructuredOutputsParams if is_wandb_available(): import wandb @@ -562,7 +568,7 @@ def __init__( raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.") # vLLM specific sampling arguments - self.guided_decoding_regex = args.vllm_guided_decoding_regex + self.structured_outputs_regex = args.vllm_structured_outputs_regex self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation @@ -1083,7 +1089,7 @@ def _generate_single_turn(self, prompts: list): "top_k": self.top_k, "min_p": 0.0 if self.min_p is None else self.min_p, "max_tokens": self.max_completion_length, - "guided_decoding_regex": self.guided_decoding_regex, + "structured_outputs_regex": self.structured_outputs_regex, "generation_kwargs": self.args.generation_kwargs, } with profiling_context(self, "vLLM.generate"): @@ -1116,10 +1122,18 @@ def _generate_single_turn(self, prompts: list): # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts elif self.vllm_mode == "colocate": - if self.guided_decoding_regex: - guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex) + if version.parse(vllm.__version__) <= version.parse("0.10.2"): + structured_outputs_key = "guided_decoding" + if self.structured_outputs_regex: + structured_outputs = GuidedDecodingParams(regex=self.structured_outputs_regex) + else: + structured_outputs = None else: - guided_decoding = None + structured_outputs_key = "structured_outputs" + if self.structured_outputs_regex: + structured_outputs = StructuredOutputsParams(regex=self.structured_outputs_regex) + else: + structured_outputs = None generation_kwargs = { "n": 1, # vLLM on each GPU generates only 1 in colocate mode @@ -1129,8 +1143,8 @@ def _generate_single_turn(self, prompts: list): "top_k": self.top_k, "min_p": 0.0 if self.min_p is None else self.min_p, "max_tokens": self.max_completion_length, - "guided_decoding": guided_decoding, } + generation_kwargs[structured_outputs_key] = structured_outputs if self.args.generation_kwargs is not None: generation_kwargs.update(self.args.generation_kwargs) sampling_params = SamplingParams(**generation_kwargs)