diff --git a/docs/source/vllm_integration.md b/docs/source/vllm_integration.md
index 559a3e74588..63e6ebacfbd 100644
--- a/docs/source/vllm_integration.md
+++ b/docs/source/vllm_integration.md
@@ -3,7 +3,7 @@
This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood.
> [!WARNING]
-> TRL currently only supports vLLM versions `0.10.2`, `0.11.0`, `0.11.1`, and `0.11.2`. Please ensure you have one of these versions installed to avoid compatibility issues.
+> TRL currently only supports vLLM versions `0.10.2`, `0.11.0`, `0.11.1`, `0.11.2` and `0.12.0`. Please ensure you have one of these versions installed to avoid compatibility issues.
> [!TIP]
> The following trainers currently support generation with vLLM:
diff --git a/pyproject.toml b/pyproject.toml
index 25993a5b109..366185f42bb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -78,7 +78,7 @@ test = [
"pytest"
]
vllm = [
- "vllm>=0.10.2,<0.12.0",
+ "vllm>=0.10.2,<0.13.0",
"fastapi",
"pydantic",
"requests",
diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py
index 2b09bd2152d..ba509cd85db 100644
--- a/tests/test_grpo_trainer.py
+++ b/tests/test_grpo_trainer.py
@@ -918,8 +918,8 @@ def test_training_vllm_and_peft(self):
@require_vllm
@pytest.mark.skip(reason="We should add a mock for the vLLM server.")
- def test_training_vllm_guided_decoding(self):
- """Test that training works with vLLM for generation with guided decoding."""
+ def test_training_vllm_structured_outputs(self):
+ """Test that training works with vLLM for generation with structured outputs."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(
@@ -930,7 +930,7 @@ def test_training_vllm_guided_decoding(self):
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
use_vllm=True,
- vllm_guided_decoding_regex=r"\n.*\n\n\n.*\n",
+ vllm_structured_outputs_regex=r"\n.*\n\n\n.*\n",
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM
@@ -953,7 +953,7 @@ def test_training_vllm_guided_decoding(self):
@require_vllm
@pytest.mark.skip(reason="We should add a mock for the vLLM server.")
def test_training_vllm_importance_sampling_correction(self):
- """Test that training works with vLLM for generation with guided decoding."""
+ """Test that training works with vLLM for generation with structured outputs."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(
diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py
index 741028a6a8f..24d1006c67a 100644
--- a/tests/test_rloo_trainer.py
+++ b/tests/test_rloo_trainer.py
@@ -693,8 +693,8 @@ def test_training_vllm_and_peft(self):
@require_vllm
@pytest.mark.skip(reason="We should add a mock for the vLLM server.")
- def test_training_vllm_guided_decoding(self):
- """Test that training works with vLLM for generation with guided decoding."""
+ def test_training_vllm_structured_outputs(self):
+ """Test that training works with vLLM for generation with structured outputs."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = RLOOConfig(
@@ -705,7 +705,7 @@ def test_training_vllm_guided_decoding(self):
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
use_vllm=True,
- vllm_guided_decoding_regex=r"\n.*\n\n\n.*\n",
+ vllm_structured_outputs_regex=r"\n.*\n\n\n.*\n",
)
trainer = RLOOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM
diff --git a/trl/__init__.py b/trl/__init__.py
index 696019c8df5..9f39b367c9c 100644
--- a/trl/__init__.py
+++ b/trl/__init__.py
@@ -194,7 +194,7 @@
# Fix DisableTqdm
# Bug introduced in https://github.com/vllm-project/vllm/pull/52
# Fixed in https://github.com/vllm-project/vllm/pull/28471 (released in v0.11.1)
- # Since TRL currently only supports vLLM v0.10.2-0.11.2, we patch it here. This can be removed when TRL requires
+ # Since TRL currently only supports vLLM v0.10.2-0.12.0, we patch it here. This can be removed when TRL requires
# vLLM >=0.11.1
import vllm.model_executor.model_loader.weight_utils
from tqdm import tqdm
diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py
index d9307352e2e..046cf71e5f5 100644
--- a/trl/experimental/gold/gold_config.py
+++ b/trl/experimental/gold/gold_config.py
@@ -84,8 +84,8 @@ class GOLDConfig(SFTConfig):
to set this to a low value if the student and teacher models share the same GPU.
vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
Tensor parallel size for the colocated student vLLM engine (if `vllm_mode="colocate"`).
- vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
- Regex for vLLM guided decoding for the student model.
+ vllm_structured_outputs_regex (`str` or `None`, *optional*, defaults to `None`):
+ Regex for vLLM structured outputs for the student model.
vllm_sync_frequency (`int`, *optional*, defaults to `1`):
Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after
every step.
@@ -303,9 +303,9 @@ class GOLDConfig(SFTConfig):
default=1,
metadata={"help": 'Tensor parallel size for the colocated vLLM engine when `vllm_mode="colocate"`.'},
)
- vllm_guided_decoding_regex: str | None = field(
+ vllm_structured_outputs_regex: str | None = field(
default=None,
- metadata={"help": "Regex pattern used for vLLM guided decoding (optional)."},
+ metadata={"help": "Regex pattern used for vLLM structured outputs (optional)."},
)
vllm_sync_frequency: int = field(
default=1,
diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py
index fc19398c140..f3a9cf9105b 100644
--- a/trl/experimental/gold/gold_trainer.py
+++ b/trl/experimental/gold/gold_trainer.py
@@ -72,7 +72,7 @@
if is_vllm_available():
from vllm import LLM, SamplingParams
- from vllm.sampling_params import GuidedDecodingParams
+ from vllm.sampling_params import StructuredOutputsParams
if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss
@@ -978,7 +978,7 @@ def __init__(
self.accelerator.wait_for_everyone()
else:
raise ValueError(f"Unknown vllm_mode: {self.vllm_mode}")
- self.vllm_guided_decoding_regex = args.vllm_guided_decoding_regex
+ self.vllm_structured_outputs_regex = args.vllm_structured_outputs_regex
self.vllm_sync_frequency = args.vllm_sync_frequency
self._last_vllm_sync_step = -1
@@ -1675,7 +1675,7 @@ def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_
top_k=top_k,
min_p=min_p,
max_tokens=max_completion_length,
- guided_decoding_regex=self.vllm_guided_decoding_regex,
+ structured_outputs_regex=self.vllm_structured_outputs_regex,
)["completion_ids"]
else:
completion_ids = [None] * len(all_prompts_text)
@@ -1686,10 +1686,12 @@ def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_
)
completion_ids = completion_ids[process_slice]
elif self.vllm_mode == "colocate":
- if self.vllm_guided_decoding_regex:
- guided_decoding = GuidedDecodingParams(backend="outlines", regex=self.vllm_guided_decoding_regex)
+ if self.vllm_structured_outputs_regex:
+ structured_outputs = StructuredOutputsParams(
+ backend="outlines", regex=self.vllm_structured_outputs_regex
+ )
else:
- guided_decoding = None
+ structured_outputs = None
sampling_params = SamplingParams(
n=1,
repetition_penalty=repetition_penalty,
@@ -1698,7 +1700,7 @@ def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_
top_k=top_k,
min_p=min_p,
max_tokens=max_completion_length,
- guided_decoding=guided_decoding,
+ structured_outputs=structured_outputs,
)
if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1:
diff --git a/trl/experimental/online_dpo/online_dpo_config.py b/trl/experimental/online_dpo/online_dpo_config.py
index 6d373875028..ec6d6fac2e0 100644
--- a/trl/experimental/online_dpo/online_dpo_config.py
+++ b/trl/experimental/online_dpo/online_dpo_config.py
@@ -106,8 +106,8 @@ class may differ from those in [`~transformers.TrainingArguments`].
server is running (start with `trl vllm-serve`).
- `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a
separate server but may cause resource contention with training.
- vllm_guided_decoding_regex (`str`, *optional*):
- Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
+ vllm_structured_outputs_regex (`str`, *optional*):
+ Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled.
> Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
@@ -306,9 +306,9 @@ class may differ from those in [`~transformers.TrainingArguments`].
"model implementation."
},
)
- vllm_guided_decoding_regex: str | None = field(
+ vllm_structured_outputs_regex: str | None = field(
default=None,
- metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
+ metadata={"help": "Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled."},
)
vllm_gpu_memory_utilization: float | None = field(
default=0.55,
diff --git a/trl/experimental/online_dpo/online_dpo_trainer.py b/trl/experimental/online_dpo/online_dpo_trainer.py
index f16c98db156..08e81c1cdd7 100644
--- a/trl/experimental/online_dpo/online_dpo_trainer.py
+++ b/trl/experimental/online_dpo/online_dpo_trainer.py
@@ -78,7 +78,7 @@
if is_vllm_available():
from vllm import LLM, SamplingParams
- from vllm.sampling_params import GuidedDecodingParams
+ from vllm.sampling_params import StructuredOutputsParams
if is_bitsandbytes_available():
import bitsandbytes as bnb
@@ -491,7 +491,7 @@ def __init__(
else:
raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")
# vLLM specific sampling arguments
- self.guided_decoding_regex = args.vllm_guided_decoding_regex
+ self.structured_outputs_regex = args.vllm_structured_outputs_regex
self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation
# Set up vLLM generation config
@@ -507,8 +507,8 @@ def __init__(
}
if args.generation_kwargs is not None:
generation_params.update(args.generation_kwargs)
- if self.guided_decoding_regex:
- generation_params["guided_decoding"] = GuidedDecodingParams(regex=self.guided_decoding_regex)
+ if self.structured_outputs_regex:
+ generation_params["structured_outputs"] = StructuredOutputsParams(regex=self.structured_outputs_regex)
self.generation_config = SamplingParams(**generation_params)
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
@@ -748,7 +748,9 @@ def _generate_vllm_server(self, prompts, images=None):
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.generation_config.max_tokens,
- guided_decoding_regex=self.guided_decoding_regex if hasattr(self, "guided_decoding_regex") else None,
+ structured_outputs_regex=self.structured_outputs_regex
+ if hasattr(self, "structured_outputs_regex")
+ else None,
generation_kwargs=self.args.generation_kwargs,
)["completion_ids"]
# Flatten: each prompt generates 2 completions
diff --git a/trl/experimental/openenv/utils.py b/trl/experimental/openenv/utils.py
index 08ead93e6e8..3d40dcc390c 100644
--- a/trl/experimental/openenv/utils.py
+++ b/trl/experimental/openenv/utils.py
@@ -23,7 +23,7 @@
if is_vllm_available():
from vllm import SamplingParams
- from vllm.sampling_params import GuidedDecodingParams
+ from vllm.sampling_params import StructuredOutputsParams
def _build_colocate_sampling_params(
@@ -32,10 +32,10 @@ def _build_colocate_sampling_params(
*,
logprobs: bool = True,
) -> SamplingParams:
- if trainer.guided_decoding_regex:
- guided_decoding = GuidedDecodingParams(regex=trainer.guided_decoding_regex)
+ if trainer.structured_outputs_regex:
+ structured_outputs = StructuredOutputsParams(regex=trainer.structured_outputs_regex)
else:
- guided_decoding = None
+ structured_outputs = None
generation_kwargs: dict[str, Any] = {
"n": 1,
@@ -43,7 +43,7 @@ def _build_colocate_sampling_params(
"top_k": trainer.top_k,
"min_p": 0.0 if trainer.min_p is None else trainer.min_p,
"max_tokens": trainer.max_completion_length,
- "guided_decoding": guided_decoding,
+ "structured_outputs": structured_outputs,
}
if trainer.repetition_penalty is not None:
generation_kwargs["repetition_penalty"] = trainer.repetition_penalty
diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py
index e21df6d837e..22f1877ce51 100644
--- a/trl/extras/vllm_client.py
+++ b/trl/extras/vllm_client.py
@@ -190,7 +190,7 @@ def generate(
min_p: float = 0.0,
max_tokens: int = 16,
truncate_prompt_tokens: int | None = None,
- guided_decoding_regex: str | None = None,
+ structured_outputs_regex: str | None = None,
generation_kwargs: dict | None = None,
) -> dict[str, list[list[int]]]:
"""
@@ -219,7 +219,7 @@ def generate(
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
disabled.
- guided_decoding_regex (`str`, *optional*):
+ structured_outputs_regex (`str`, *optional*):
Regular expression to guide the decoding process.
generation_kwargs (`dict`, *optional*):
Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like
@@ -253,7 +253,7 @@ def generate(
"min_p": min_p,
"max_tokens": max_tokens,
"truncate_prompt_tokens": truncate_prompt_tokens,
- "guided_decoding_regex": guided_decoding_regex,
+ "structured_outputs_regex": structured_outputs_regex,
"generation_kwargs": generation_kwargs or {},
},
)
@@ -278,7 +278,7 @@ def chat(
min_p: float = 0.0,
max_tokens: int = 16,
truncate_prompt_tokens: int | None = None,
- guided_decoding_regex: str | None = None,
+ structured_outputs_regex: str | None = None,
generation_kwargs: dict | None = None,
chat_template_kwargs: dict | None = None,
tools: list | None = None,
@@ -309,7 +309,7 @@ def chat(
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
disabled.
- guided_decoding_regex (`str`, *optional*):
+ structured_outputs_regex (`str`, *optional*):
Regular expression to guide the decoding process.
generation_kwargs (`dict`, *optional*):
Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like
@@ -360,7 +360,7 @@ def chat(
"min_p": min_p,
"max_tokens": max_tokens,
"truncate_prompt_tokens": truncate_prompt_tokens,
- "guided_decoding_regex": guided_decoding_regex,
+ "structured_outputs_regex": structured_outputs_regex,
"generation_kwargs": generation_kwargs or {},
"chat_template_kwargs": chat_template_kwargs or {},
},
diff --git a/trl/import_utils.py b/trl/import_utils.py
index 2cce6df9633..5106b782e66 100644
--- a/trl/import_utils.py
+++ b/trl/import_utils.py
@@ -79,9 +79,9 @@ def is_uvicorn_available() -> bool:
def is_vllm_available() -> bool:
_vllm_available, _vllm_version = _is_package_available("vllm", return_version=True)
if _vllm_available:
- if not (version.parse("0.10.2") <= version.parse(_vllm_version) <= version.parse("0.11.2")):
+ if not (version.parse("0.10.2") <= version.parse(_vllm_version) <= version.parse("0.12.0")):
warnings.warn(
- "TRL currently supports vLLM versions: 0.10.2, 0.11.0, 0.11.1, 0.11.2. You have version "
+ "TRL currently supports vLLM versions: 0.10.2, 0.11.0, 0.11.1, 0.11.2, 0.12.0. You have version "
f"{_vllm_version} installed. We recommend installing a supported version to avoid compatibility "
"issues.",
stacklevel=2,
diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py
index d5884b8290a..19b37d98c46 100644
--- a/trl/scripts/vllm_serve.py
+++ b/trl/scripts/vllm_serve.py
@@ -26,6 +26,7 @@
import torch
import torch.distributed.distributed_c10d as c10d
+from packaging import version
from transformers import is_torch_xpu_available, is_vision_available
from trl import TrlParser
@@ -55,13 +56,18 @@
if is_vllm_available():
+ import vllm
from vllm import LLM, SamplingParams
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.parallel_state import get_world_group
from vllm.distributed.utils import StatelessProcessGroup
- from vllm.sampling_params import GuidedDecodingParams
from vllm.utils import get_open_port
+ if version.parse(vllm.__version__) <= version.parse("0.10.2"):
+ from vllm.sampling_params import GuidedDecodingParams
+ else:
+ from vllm.sampling_params import StructuredOutputsParams
+
if is_vllm_ascend_available():
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator
@@ -495,7 +501,7 @@ class GenerateRequest(BaseModel):
min_p: float = 0.0
max_tokens: int = 16
truncate_prompt_tokens: int | None = None
- guided_decoding_regex: str | None = None
+ structured_outputs_regex: str | None = None
generation_kwargs: dict = field(default_factory=dict)
class GenerateResponse(BaseModel):
@@ -528,8 +534,8 @@ async def generate(request: GenerateRequest):
- `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported
by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left
truncation). If set to `None`, truncation is disabled.
- - `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the
- model will only generate tokens that match this regex pattern.
+ - `structured_outputs_regex` (`str`, *optional*): A regex pattern for structured outputs. If provided,
+ the model will only generate tokens that match this regex pattern.
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
`SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains
keys that conflict with the other parameters, they will override them.
@@ -564,11 +570,19 @@ async def generate(request: GenerateRequest):
row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))}
prompts.append(row)
- # Guided decoding, if enabled
- if request.guided_decoding_regex is not None:
- guided_decoding = GuidedDecodingParams(regex=request.guided_decoding_regex)
+ # Structured outputs, if enabled
+ if version.parse(vllm.__version__) <= version.parse("0.10.2"):
+ structured_outputs_key = "guided_decoding"
+ if request.structured_outputs_regex is not None:
+ structured_outputs = GuidedDecodingParams(regex=request.structured_outputs_regex)
+ else:
+ structured_outputs = None
else:
- guided_decoding = None
+ structured_outputs_key = "structured_outputs"
+ if request.structured_outputs_regex is not None:
+ structured_outputs = StructuredOutputsParams(regex=request.structured_outputs_regex)
+ else:
+ structured_outputs = None
generation_kwargs = {
"n": request.n,
@@ -579,9 +593,9 @@ async def generate(request: GenerateRequest):
"min_p": request.min_p,
"max_tokens": request.max_tokens,
"truncate_prompt_tokens": request.truncate_prompt_tokens,
- "guided_decoding": guided_decoding,
"logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only
}
+ generation_kwargs[structured_outputs_key] = structured_outputs
generation_kwargs.update(request.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)
@@ -625,7 +639,7 @@ class ChatRequest(BaseModel):
min_p: float = 0.0
max_tokens: int = 16
truncate_prompt_tokens: int | None = None
- guided_decoding_regex: str | None = None
+ structured_outputs_regex: str | None = None
generation_kwargs: dict = field(default_factory=dict)
chat_template_kwargs: dict = field(default_factory=dict)
@@ -658,8 +672,8 @@ async def chat(request: ChatRequest):
- `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported
by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left
truncation). If set to `None`, truncation is disabled.
- - `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the
- model will only generate tokens that match this regex pattern.
+ - `structured_outputs_regex` (`str`, *optional*): A regex pattern for structured outputs. If provided,
+ the model will only generate tokens that match this regex pattern.
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
`SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains
keys that conflict with the other parameters, they will override them.
@@ -697,11 +711,19 @@ async def chat(request: ChatRequest):
if part["type"] == "image_pil":
part["image_pil"] = Image.open(BytesIO(base64.b64decode(part["image_pil"])))
- # Guided decoding, if enabled
- if request.guided_decoding_regex is not None:
- guided_decoding = GuidedDecodingParams(regex=request.guided_decoding_regex)
+ # Structured outputs, if enabled
+ if version.parse(vllm.__version__) <= version.parse("0.10.2"):
+ structured_outputs_key = "guided_decoding"
+ if request.structured_outputs_regex is not None:
+ structured_outputs = GuidedDecodingParams(regex=request.structured_outputs_regex)
+ else:
+ structured_outputs = None
else:
- guided_decoding = None
+ structured_outputs_key = "structured_outputs"
+ if request.structured_outputs_regex is not None:
+ structured_outputs = StructuredOutputsParams(regex=request.structured_outputs_regex)
+ else:
+ structured_outputs = None
generation_kwargs = {
"n": request.n,
@@ -712,9 +734,9 @@ async def chat(request: ChatRequest):
"min_p": request.min_p,
"max_tokens": request.max_tokens,
"truncate_prompt_tokens": request.truncate_prompt_tokens,
- "guided_decoding": guided_decoding,
"logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only
}
+ generation_kwargs[structured_outputs_key] = structured_outputs
generation_kwargs.update(request.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)
diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py
index 09370c73d57..a52af5cac31 100644
--- a/trl/trainer/grpo_config.py
+++ b/trl/trainer/grpo_config.py
@@ -122,8 +122,8 @@ class GRPOConfig(TrainingArguments):
Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use
the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model
implementation.
- vllm_guided_decoding_regex (`str`, *optional*):
- Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
+ vllm_structured_outputs_regex (`str`, *optional*):
+ Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled.
> Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
@@ -301,6 +301,15 @@ class GRPOConfig(TrainingArguments):
Parameter `max_prompt_length` is deprecated and will be removed in version 0.28.0. You should instead
filter your dataset before training to ensure that prompts do not exceed your desired length.
+
+
+ vllm_guided_decoding_regex:
+
+
+
+ Parameter `vllm_guided_decoding_regex` is deprecated and will be removed in version 0.28.0. You should
+ instead use `vllm_structured_outputs_regex`.
+
"""
@@ -517,9 +526,9 @@ class GRPOConfig(TrainingArguments):
"usage low, but waking the engine adds host–device transfer latency."
},
)
- vllm_guided_decoding_regex: str | None = field(
+ vllm_structured_outputs_regex: str | None = field(
default=None,
- metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
+ metadata={"help": "Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled."},
)
# Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
@@ -797,6 +806,10 @@ class GRPOConfig(TrainingArguments):
"desired length."
},
)
+ vllm_guided_decoding_regex: str | None = field(
+ default=None,
+ metadata={"help": "Deprecated, use `vllm_structured_outputs_regex` instead."},
+ )
def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
@@ -880,3 +893,11 @@ def __post_init__(self):
FutureWarning,
stacklevel=2,
)
+ if self.vllm_guided_decoding_regex is not None:
+ warnings.warn(
+ "The `vllm_guided_decoding_regex` argument is deprecated and will be removed in version 0.28.0. You "
+ "should instead use `vllm_structured_outputs_regex`.",
+ FutureWarning,
+ stacklevel=2,
+ )
+ self.vllm_structured_outputs_regex = self.vllm_guided_decoding_regex
diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py
index 1c9a2fab1cc..4749df33933 100644
--- a/trl/trainer/grpo_trainer.py
+++ b/trl/trainer/grpo_trainer.py
@@ -36,6 +36,7 @@
from accelerate.logging import get_logger
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from datasets import Dataset, IterableDataset
+from packaging import version
from packaging.version import Version
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
@@ -101,8 +102,13 @@
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
if is_vllm_available():
+ import vllm
from vllm import LLM, SamplingParams
- from vllm.sampling_params import GuidedDecodingParams
+
+ if version.parse(vllm.__version__) <= version.parse("0.10.2"):
+ from vllm.sampling_params import GuidedDecodingParams
+ else:
+ from vllm.sampling_params import StructuredOutputsParams
if is_wandb_available():
import wandb
@@ -701,7 +707,7 @@ def cast_outputs_to_original_dtype(module, args, output):
raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")
# vLLM specific sampling arguments
- self.guided_decoding_regex = args.vllm_guided_decoding_regex
+ self.structured_outputs_regex = args.vllm_structured_outputs_regex
self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation
@@ -1315,7 +1321,7 @@ def _generate_single_turn(self, prompts: list):
"top_k": self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
- "guided_decoding_regex": self.guided_decoding_regex,
+ "structured_outputs_regex": self.structured_outputs_regex,
"generation_kwargs": self.args.generation_kwargs,
}
with profiling_context(self, "vLLM.generate"):
@@ -1389,10 +1395,18 @@ def _generate_single_turn(self, prompts: list):
completion_ids = output["completion_ids"]
logprobs = output["logprobs"]
else:
- if self.guided_decoding_regex:
- guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)
+ if version.parse(vllm.__version__) <= version.parse("0.10.2"):
+ structured_outputs_key = "guided_decoding"
+ if self.structured_outputs_regex:
+ structured_outputs = GuidedDecodingParams(regex=self.structured_outputs_regex)
+ else:
+ structured_outputs = None
else:
- guided_decoding = None
+ structured_outputs_key = "structured_outputs"
+ if self.structured_outputs_regex:
+ structured_outputs = StructuredOutputsParams(regex=self.structured_outputs_regex)
+ else:
+ structured_outputs = None
generation_kwargs = {
"n": 1, # vLLM on each GPU generates only 1 in colocate mode
@@ -1402,9 +1416,9 @@ def _generate_single_turn(self, prompts: list):
"top_k": self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
- "guided_decoding": guided_decoding,
"logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only
}
+ generation_kwargs[structured_outputs_key] = structured_outputs
if self.args.generation_kwargs is not None:
generation_kwargs.update(self.args.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)
diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py
index c5f16a93220..285600ae665 100644
--- a/trl/trainer/rloo_config.py
+++ b/trl/trainer/rloo_config.py
@@ -117,8 +117,8 @@ class RLOOConfig(TrainingArguments):
Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use
the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model
implementation.
- vllm_guided_decoding_regex (`str`, *optional*):
- Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
+ vllm_structured_outputs_regex (`str`, *optional*):
+ Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled.
> Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
@@ -212,6 +212,15 @@ class RLOOConfig(TrainingArguments):
Parameter `max_prompt_length` is deprecated and will be removed in version 0.29.0. You should instead
filter your dataset before training to ensure that prompts do not exceed your desired length.
+
+
+ vllm_guided_decoding_regex:
+
+
+
+ Parameter `vllm_guided_decoding_regex` is deprecated and will be removed in version 0.28.0. You should
+ instead use `vllm_structured_outputs_regex`.
+
"""
@@ -419,9 +428,9 @@ class RLOOConfig(TrainingArguments):
"usage low, but waking the engine adds host–device transfer latency."
},
)
- vllm_guided_decoding_regex: str | None = field(
+ vllm_structured_outputs_regex: str | None = field(
default=None,
- metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
+ metadata={"help": "Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled."},
)
# Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
@@ -580,6 +589,10 @@ class RLOOConfig(TrainingArguments):
"desired length."
},
)
+ vllm_guided_decoding_regex: str | None = field(
+ default=None,
+ metadata={"help": "Deprecated, use `vllm_structured_outputs_regex` instead."},
+ )
def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
@@ -649,3 +662,11 @@ def __post_init__(self):
FutureWarning,
stacklevel=2,
)
+ if self.vllm_guided_decoding_regex is not None:
+ warnings.warn(
+ "The `vllm_guided_decoding_regex` argument is deprecated and will be removed in version 0.28.0. You "
+ "should instead use `vllm_structured_outputs_regex`.",
+ FutureWarning,
+ stacklevel=2,
+ )
+ self.vllm_structured_outputs_regex = self.vllm_guided_decoding_regex
diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py
index 8fa3e814479..06fe04b1a7c 100644
--- a/trl/trainer/rloo_trainer.py
+++ b/trl/trainer/rloo_trainer.py
@@ -33,6 +33,7 @@
from accelerate.logging import get_logger
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from datasets import Dataset, IterableDataset
+from packaging import version
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader, Sampler
@@ -93,8 +94,13 @@
from peft import PeftConfig, PeftModel, get_peft_model
if is_vllm_available():
+ import vllm
from vllm import LLM, SamplingParams
- from vllm.sampling_params import GuidedDecodingParams
+
+ if version.parse(vllm.__version__) <= version.parse("0.10.2"):
+ from vllm.sampling_params import GuidedDecodingParams
+ else:
+ from vllm.sampling_params import StructuredOutputsParams
if is_wandb_available():
import wandb
@@ -562,7 +568,7 @@ def __init__(
raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")
# vLLM specific sampling arguments
- self.guided_decoding_regex = args.vllm_guided_decoding_regex
+ self.structured_outputs_regex = args.vllm_structured_outputs_regex
self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation
@@ -1083,7 +1089,7 @@ def _generate_single_turn(self, prompts: list):
"top_k": self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
- "guided_decoding_regex": self.guided_decoding_regex,
+ "structured_outputs_regex": self.structured_outputs_regex,
"generation_kwargs": self.args.generation_kwargs,
}
with profiling_context(self, "vLLM.generate"):
@@ -1116,10 +1122,18 @@ def _generate_single_turn(self, prompts: list):
# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
elif self.vllm_mode == "colocate":
- if self.guided_decoding_regex:
- guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)
+ if version.parse(vllm.__version__) <= version.parse("0.10.2"):
+ structured_outputs_key = "guided_decoding"
+ if self.structured_outputs_regex:
+ structured_outputs = GuidedDecodingParams(regex=self.structured_outputs_regex)
+ else:
+ structured_outputs = None
else:
- guided_decoding = None
+ structured_outputs_key = "structured_outputs"
+ if self.structured_outputs_regex:
+ structured_outputs = StructuredOutputsParams(regex=self.structured_outputs_regex)
+ else:
+ structured_outputs = None
generation_kwargs = {
"n": 1, # vLLM on each GPU generates only 1 in colocate mode
@@ -1129,8 +1143,8 @@ def _generate_single_turn(self, prompts: list):
"top_k": self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
- "guided_decoding": guided_decoding,
}
+ generation_kwargs[structured_outputs_key] = structured_outputs
if self.args.generation_kwargs is not None:
generation_kwargs.update(self.args.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)