Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/vllm_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood.

> [!WARNING]
> TRL currently only supports vLLM versions `0.10.2`, `0.11.0`, `0.11.1`, and `0.11.2`. Please ensure you have one of these versions installed to avoid compatibility issues.
> TRL currently only supports vLLM versions `0.10.2`, `0.11.0`, `0.11.1`, `0.11.2` and `0.12.0`. Please ensure you have one of these versions installed to avoid compatibility issues.

> [!TIP]
> The following trainers currently support generation with vLLM:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ test = [
"pytest"
]
vllm = [
"vllm>=0.10.2,<0.12.0",
"vllm>=0.10.2,<0.13.0",
"fastapi",
"pydantic",
"requests",
Expand Down
8 changes: 4 additions & 4 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,8 +918,8 @@ def test_training_vllm_and_peft(self):

@require_vllm
@pytest.mark.skip(reason="We should add a mock for the vLLM server.")
def test_training_vllm_guided_decoding(self):
"""Test that training works with vLLM for generation with guided decoding."""
def test_training_vllm_structured_outputs(self):
"""Test that training works with vLLM for generation with structured outputs."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = GRPOConfig(
Expand All @@ -930,7 +930,7 @@ def test_training_vllm_guided_decoding(self):
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
use_vllm=True,
vllm_guided_decoding_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
vllm_structured_outputs_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM
Expand All @@ -953,7 +953,7 @@ def test_training_vllm_guided_decoding(self):
@require_vllm
@pytest.mark.skip(reason="We should add a mock for the vLLM server.")
def test_training_vllm_importance_sampling_correction(self):
"""Test that training works with vLLM for generation with guided decoding."""
"""Test that training works with vLLM for generation with structured outputs."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = GRPOConfig(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,8 @@ def test_training_vllm_and_peft(self):

@require_vllm
@pytest.mark.skip(reason="We should add a mock for the vLLM server.")
def test_training_vllm_guided_decoding(self):
"""Test that training works with vLLM for generation with guided decoding."""
def test_training_vllm_structured_outputs(self):
"""Test that training works with vLLM for generation with structured outputs."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = RLOOConfig(
Expand All @@ -705,7 +705,7 @@ def test_training_vllm_guided_decoding(self):
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
use_vllm=True,
vllm_guided_decoding_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
vllm_structured_outputs_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
)
trainer = RLOOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM
Expand Down
2 changes: 1 addition & 1 deletion trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@
# Fix DisableTqdm
# Bug introduced in https://github.com/vllm-project/vllm/pull/52
# Fixed in https://github.com/vllm-project/vllm/pull/28471 (released in v0.11.1)
# Since TRL currently only supports vLLM v0.10.2-0.11.2, we patch it here. This can be removed when TRL requires
# Since TRL currently only supports vLLM v0.10.2-0.12.0, we patch it here. This can be removed when TRL requires
# vLLM >=0.11.1
import vllm.model_executor.model_loader.weight_utils
from tqdm import tqdm
Expand Down
8 changes: 4 additions & 4 deletions trl/experimental/gold/gold_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ class GOLDConfig(SFTConfig):
to set this to a low value if the student and teacher models share the same GPU.
vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
Tensor parallel size for the colocated student vLLM engine (if `vllm_mode="colocate"`).
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regex for vLLM guided decoding for the student model.
vllm_structured_outputs_regex (`str` or `None`, *optional*, defaults to `None`):
Regex for vLLM structured outputs for the student model.
vllm_sync_frequency (`int`, *optional*, defaults to `1`):
Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after
every step.
Expand Down Expand Up @@ -303,9 +303,9 @@ class GOLDConfig(SFTConfig):
default=1,
metadata={"help": 'Tensor parallel size for the colocated vLLM engine when `vllm_mode="colocate"`.'},
)
vllm_guided_decoding_regex: str | None = field(
vllm_structured_outputs_regex: str | None = field(
default=None,
metadata={"help": "Regex pattern used for vLLM guided decoding (optional)."},
metadata={"help": "Regex pattern used for vLLM structured outputs (optional)."},
)
vllm_sync_frequency: int = field(
default=1,
Expand Down
16 changes: 9 additions & 7 deletions trl/experimental/gold/gold_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@

if is_vllm_available():
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
from vllm.sampling_params import StructuredOutputsParams

if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss
Expand Down Expand Up @@ -978,7 +978,7 @@ def __init__(
self.accelerator.wait_for_everyone()
else:
raise ValueError(f"Unknown vllm_mode: {self.vllm_mode}")
self.vllm_guided_decoding_regex = args.vllm_guided_decoding_regex
self.vllm_structured_outputs_regex = args.vllm_structured_outputs_regex
self.vllm_sync_frequency = args.vllm_sync_frequency
self._last_vllm_sync_step = -1

Expand Down Expand Up @@ -1675,7 +1675,7 @@ def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_
top_k=top_k,
min_p=min_p,
max_tokens=max_completion_length,
guided_decoding_regex=self.vllm_guided_decoding_regex,
structured_outputs_regex=self.vllm_structured_outputs_regex,
)["completion_ids"]
else:
completion_ids = [None] * len(all_prompts_text)
Expand All @@ -1686,10 +1686,12 @@ def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_
)
completion_ids = completion_ids[process_slice]
elif self.vllm_mode == "colocate":
if self.vllm_guided_decoding_regex:
guided_decoding = GuidedDecodingParams(backend="outlines", regex=self.vllm_guided_decoding_regex)
if self.vllm_structured_outputs_regex:
structured_outputs = StructuredOutputsParams(
backend="outlines", regex=self.vllm_structured_outputs_regex
)
else:
guided_decoding = None
structured_outputs = None
sampling_params = SamplingParams(
n=1,
repetition_penalty=repetition_penalty,
Expand All @@ -1698,7 +1700,7 @@ def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_
top_k=top_k,
min_p=min_p,
max_tokens=max_completion_length,
guided_decoding=guided_decoding,
structured_outputs=structured_outputs,
)

if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1:
Expand Down
8 changes: 4 additions & 4 deletions trl/experimental/online_dpo/online_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ class may differ from those in [`~transformers.TrainingArguments`].
server is running (start with `trl vllm-serve`).
- `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a
separate server but may cause resource contention with training.
vllm_guided_decoding_regex (`str`, *optional*):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
vllm_structured_outputs_regex (`str`, *optional*):
Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled.

> Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)

Expand Down Expand Up @@ -306,9 +306,9 @@ class may differ from those in [`~transformers.TrainingArguments`].
"model implementation."
},
)
vllm_guided_decoding_regex: str | None = field(
vllm_structured_outputs_regex: str | None = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
metadata={"help": "Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled."},
)
vllm_gpu_memory_utilization: float | None = field(
default=0.55,
Expand Down
12 changes: 7 additions & 5 deletions trl/experimental/online_dpo/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@

if is_vllm_available():
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
from vllm.sampling_params import StructuredOutputsParams

if is_bitsandbytes_available():
import bitsandbytes as bnb
Expand Down Expand Up @@ -491,7 +491,7 @@ def __init__(
else:
raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")
# vLLM specific sampling arguments
self.guided_decoding_regex = args.vllm_guided_decoding_regex
self.structured_outputs_regex = args.vllm_structured_outputs_regex
self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation

# Set up vLLM generation config
Expand All @@ -507,8 +507,8 @@ def __init__(
}
if args.generation_kwargs is not None:
generation_params.update(args.generation_kwargs)
if self.guided_decoding_regex:
generation_params["guided_decoding"] = GuidedDecodingParams(regex=self.guided_decoding_regex)
if self.structured_outputs_regex:
generation_params["structured_outputs"] = StructuredOutputsParams(regex=self.structured_outputs_regex)
self.generation_config = SamplingParams(**generation_params)

# When using vLLM, the main process is responsible for loading the model weights. This can cause process
Expand Down Expand Up @@ -748,7 +748,9 @@ def _generate_vllm_server(self, prompts, images=None):
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.generation_config.max_tokens,
guided_decoding_regex=self.guided_decoding_regex if hasattr(self, "guided_decoding_regex") else None,
structured_outputs_regex=self.structured_outputs_regex
if hasattr(self, "structured_outputs_regex")
else None,
generation_kwargs=self.args.generation_kwargs,
)["completion_ids"]
# Flatten: each prompt generates 2 completions
Expand Down
10 changes: 5 additions & 5 deletions trl/experimental/openenv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

if is_vllm_available():
from vllm import SamplingParams
from vllm.sampling_params import GuidedDecodingParams
from vllm.sampling_params import StructuredOutputsParams


def _build_colocate_sampling_params(
Expand All @@ -32,18 +32,18 @@ def _build_colocate_sampling_params(
*,
logprobs: bool = True,
) -> SamplingParams:
if trainer.guided_decoding_regex:
guided_decoding = GuidedDecodingParams(regex=trainer.guided_decoding_regex)
if trainer.structured_outputs_regex:
structured_outputs = StructuredOutputsParams(regex=trainer.structured_outputs_regex)
else:
guided_decoding = None
structured_outputs = None

generation_kwargs: dict[str, Any] = {
"n": 1,
"temperature": trainer.temperature,
"top_k": trainer.top_k,
"min_p": 0.0 if trainer.min_p is None else trainer.min_p,
"max_tokens": trainer.max_completion_length,
"guided_decoding": guided_decoding,
"structured_outputs": structured_outputs,
}
if trainer.repetition_penalty is not None:
generation_kwargs["repetition_penalty"] = trainer.repetition_penalty
Expand Down
12 changes: 6 additions & 6 deletions trl/extras/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def generate(
min_p: float = 0.0,
max_tokens: int = 16,
truncate_prompt_tokens: int | None = None,
guided_decoding_regex: str | None = None,
structured_outputs_regex: str | None = None,
generation_kwargs: dict | None = None,
) -> dict[str, list[list[int]]]:
"""
Expand Down Expand Up @@ -219,7 +219,7 @@ def generate(
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
disabled.
guided_decoding_regex (`str`, *optional*):
structured_outputs_regex (`str`, *optional*):
Regular expression to guide the decoding process.
generation_kwargs (`dict`, *optional*):
Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like
Expand Down Expand Up @@ -253,7 +253,7 @@ def generate(
"min_p": min_p,
"max_tokens": max_tokens,
"truncate_prompt_tokens": truncate_prompt_tokens,
"guided_decoding_regex": guided_decoding_regex,
"structured_outputs_regex": structured_outputs_regex,
"generation_kwargs": generation_kwargs or {},
},
)
Expand All @@ -278,7 +278,7 @@ def chat(
min_p: float = 0.0,
max_tokens: int = 16,
truncate_prompt_tokens: int | None = None,
guided_decoding_regex: str | None = None,
structured_outputs_regex: str | None = None,
generation_kwargs: dict | None = None,
chat_template_kwargs: dict | None = None,
tools: list | None = None,
Expand Down Expand Up @@ -309,7 +309,7 @@ def chat(
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
disabled.
guided_decoding_regex (`str`, *optional*):
structured_outputs_regex (`str`, *optional*):
Regular expression to guide the decoding process.
generation_kwargs (`dict`, *optional*):
Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like
Expand Down Expand Up @@ -360,7 +360,7 @@ def chat(
"min_p": min_p,
"max_tokens": max_tokens,
"truncate_prompt_tokens": truncate_prompt_tokens,
"guided_decoding_regex": guided_decoding_regex,
"structured_outputs_regex": structured_outputs_regex,
"generation_kwargs": generation_kwargs or {},
"chat_template_kwargs": chat_template_kwargs or {},
},
Expand Down
4 changes: 2 additions & 2 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def is_uvicorn_available() -> bool:
def is_vllm_available() -> bool:
_vllm_available, _vllm_version = _is_package_available("vllm", return_version=True)
if _vllm_available:
if not (version.parse("0.10.2") <= version.parse(_vllm_version) <= version.parse("0.11.2")):
if not (version.parse("0.10.2") <= version.parse(_vllm_version) <= version.parse("0.12.0")):
warnings.warn(
"TRL currently supports vLLM versions: 0.10.2, 0.11.0, 0.11.1, 0.11.2. You have version "
"TRL currently supports vLLM versions: 0.10.2, 0.11.0, 0.11.1, 0.11.2, 0.12.0. You have version "
f"{_vllm_version} installed. We recommend installing a supported version to avoid compatibility "
"issues.",
stacklevel=2,
Expand Down
Loading
Loading