From 592b3c94a2d60f3e56d5fc1f626a85280fb7eaa1 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 6 Nov 2025 18:30:09 +0100 Subject: [PATCH 1/9] Add quantization option for vllm --- trl/trainer/grpo_config.py | 6 ++++++ trl/trainer/grpo_trainer.py | 1 + 2 files changed, 7 insertions(+) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 6001a8dc524..dc70dfa020e 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -462,6 +462,12 @@ class GRPOConfig(TrainingArguments): default=None, metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, ) + vllm_quantization: str | None = field( + default=None, + metadata={ + "help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied." + }, + ) # Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) vllm_server_base_url: str | None = field( diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 378dc2b6193..6ae2a93480a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -610,6 +610,7 @@ def cast_outputs_to_original_dtype(module, args, output): enable_sleep_mode=self.args.vllm_enable_sleep_mode, # Important so temperature scaling/logit tweaking affects the TIS log probs logprobs_mode="processed_logprobs", + quantization=self.args.vllm_quantization, ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=2) From df1203325e875d5690a849c3d86cd00082a26fb1 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 6 Nov 2025 18:30:35 +0100 Subject: [PATCH 2/9] Code quality --- trl/trainer/grpo_config.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index dc70dfa020e..a7f228cdd85 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -464,9 +464,7 @@ class GRPOConfig(TrainingArguments): ) vllm_quantization: str | None = field( default=None, - metadata={ - "help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied." - }, + metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."}, ) # Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) From 3fbdb4b13d0f993817581f4d02de159ac4ac3d21 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 12 Nov 2025 15:47:43 +0100 Subject: [PATCH 3/9] Extend to more trainers --- trl/experimental/gold/gold_config.py | 4 ++++ trl/experimental/gold/gold_trainer.py | 1 + trl/trainer/online_dpo_config.py | 4 ++++ trl/trainer/online_dpo_trainer.py | 1 + trl/trainer/rloo_config.py | 4 ++++ trl/trainer/rloo_trainer.py | 1 + 6 files changed, 15 insertions(+) diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index 76bb23edc36..23e9555c3b3 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -304,6 +304,10 @@ class GOLDConfig(SFTConfig): default=None, metadata={"help": "Regex pattern used for vLLM guided decoding (optional)."}, ) + vllm_quantization: str | None = field( + default=None, + metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."}, + ) vllm_sync_frequency: int = field( default=1, metadata={ diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index b2697d94623..6d744bc22f7 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -952,6 +952,7 @@ def __init__( # Feed identical seed for tp groups to ensure sampling results are the same across workers seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, enable_sleep_mode=self.vllm_enable_sleep_mode, + quantization=self.args.vllm_quantization, # 29519MiB / 81920MiB --> ) if self.vllm_enable_sleep_mode: diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index 680aeb922f1..f31d782db96 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -297,6 +297,10 @@ class may differ from those in [`~transformers.TrainingArguments`]. default=None, metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, ) + vllm_quantization: str | None = field( + default=None, + metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."}, + ) vllm_gpu_memory_utilization: float | None = field( default=0.55, metadata={ diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 59cd5d3d1a4..a24fa020aac 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -484,6 +484,7 @@ def __init__( "seed": self.accelerator.process_index // self.vllm_tensor_parallel_size, # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) "max_num_batched_tokens": 4096, + "quantization": self.args.vllm_quantization, } # vLLM requires the environment variables to be set for distributed training. diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index 1645d4741bf..d1fe28dc41c 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -393,6 +393,10 @@ class RLOOConfig(TrainingArguments): default=None, metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, ) + vllm_quantization: str | None = field( + default=None, + metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."}, + ) # Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) vllm_server_base_url: str | None = field( diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index d4a192f9bb3..126d3ff6c90 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -526,6 +526,7 @@ def __init__( max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, + quantization=self.args.vllm_quantization, ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=2) From 18a92e7b128901088bc5f0a87cdb26b270bc89a6 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 13 Nov 2025 18:00:13 +0100 Subject: [PATCH 4/9] Working example --- trl/experimental/gold/gold_trainer.py | 2 +- trl/scripts/grpo.py | 18 ++++++++++++++++++ trl/trainer/model_config.py | 4 ++++ trl/trainer/utils.py | 2 +- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 6d744bc22f7..de443dddcdf 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -952,7 +952,7 @@ def __init__( # Feed identical seed for tp groups to ensure sampling results are the same across workers seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, enable_sleep_mode=self.vllm_enable_sleep_mode, - quantization=self.args.vllm_quantization, # 29519MiB / 81920MiB --> + quantization=self.args.vllm_quantization, ) if self.vllm_enable_sleep_mode: diff --git a/trl/scripts/grpo.py b/trl/scripts/grpo.py index 7e46a3858fc..ece0ae55808 100644 --- a/trl/scripts/grpo.py +++ b/trl/scripts/grpo.py @@ -27,6 +27,7 @@ import sys from dataclasses import dataclass, field +import torch from accelerate import logging from datasets import load_dataset @@ -38,7 +39,9 @@ ScriptArguments, TrlParser, get_dataset, + get_kbit_device_map, get_peft_config, + get_quantization_config, ) from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward @@ -112,6 +115,21 @@ def main(script_args, training_args, model_args, dataset_args): f"Could not load reward function '{func_name}'. Expected one of " f"{list(reward_funcs_registry.keys())} or a valid import path." ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) + quantization_config = get_quantization_config(model_args) + + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + + training_args.model_init_kwargs = model_kwargs # Load the dataset if dataset_args.datasets and script_args.dataset_name: diff --git a/trl/trainer/model_config.py b/trl/trainer/model_config.py index 9e3d5fe8021..8e0cd669a6f 100644 --- a/trl/trainer/model_config.py +++ b/trl/trainer/model_config.py @@ -176,6 +176,10 @@ class ModelConfig: default=False, metadata={"help": "Whether to use nested quantization."}, ) + bnb_4bit_quant_storage: str | None = field( + default=None, + metadata={"help": "Quantization storage dtype"}, + ) # Deprecated params torch_dtype: str | None = field( default=None, diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 529c6bb622a..02d4cc78073 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -662,7 +662,7 @@ def get_quantization_config(model_args: ModelConfig) -> BitsAndBytesConfig | Non bnb_4bit_compute_dtype=model_args.dtype, # For consistency with model weights, we use the same value as `dtype` bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, - bnb_4bit_quant_storage=model_args.dtype, + bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage, ) elif model_args.load_in_8bit: quantization_config = BitsAndBytesConfig( From 38c9a60f2832966798c227a52b7b69f32defb897 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 14 Nov 2025 10:42:31 +0100 Subject: [PATCH 5/9] Update vllm_quantization check --- trl/trainer/grpo_trainer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index aaa0af66ab6..0e35358833f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -45,6 +45,7 @@ PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, + is_bitsandbytes_available, is_trackio_available, is_wandb_available, ) @@ -101,6 +102,8 @@ if is_trackio_available(): import trackio +if is_bitsandbytes_available(): + import bitsandbytes as bnb logger = logging.get_logger(__name__) @@ -596,6 +599,13 @@ def cast_outputs_to_original_dtype(module, args, output): max_model_len = self.max_prompt_length + self.max_completion_length else: max_model_len = None + + vllm_quantization = None + if is_bitsandbytes_available(): + for _, module in model.named_modules(): + if isinstance(module, bnb.nn.Linear4bit): + vllm_quantization = "bitsandbytes" + break self.llm = LLM( model=model.name_or_path, tensor_parallel_size=args.vllm_tensor_parallel_size, @@ -613,7 +623,7 @@ def cast_outputs_to_original_dtype(module, args, output): enable_sleep_mode=self.args.vllm_enable_sleep_mode, # Important so temperature scaling/logit tweaking affects the TIS log probs logprobs_mode="processed_logprobs", - quantization=self.args.vllm_quantization, + quantization=vllm_quantization, ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=2) From 8d4c5a110b09f9266189936bb1f8472434a99363 Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Fri, 14 Nov 2025 10:54:05 +0100 Subject: [PATCH 6/9] Apply suggestion from @kashif Co-authored-by: Kashif Rasul --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0e35358833f..acf11831fd4 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -603,7 +603,7 @@ def cast_outputs_to_original_dtype(module, args, output): vllm_quantization = None if is_bitsandbytes_available(): for _, module in model.named_modules(): - if isinstance(module, bnb.nn.Linear4bit): + if isinstance(module, (bnb.nn.Linear4bit, bnb.nn. Linear8bitLt)): vllm_quantization = "bitsandbytes" break self.llm = LLM( From dccec2c2d82a59f87fe7fccdcbca11e2245c41db Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 14 Nov 2025 11:20:56 +0100 Subject: [PATCH 7/9] 8bit case --- trl/trainer/grpo_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index acf11831fd4..3343f716c81 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -603,9 +603,11 @@ def cast_outputs_to_original_dtype(module, args, output): vllm_quantization = None if is_bitsandbytes_available(): for _, module in model.named_modules(): - if isinstance(module, (bnb.nn.Linear4bit, bnb.nn. Linear8bitLt)): + if isinstance(module, bnb.nn.Linear4bit): vllm_quantization = "bitsandbytes" break + elif isinstance(module, bnb.nn.Linear8bitLt): + raise ValueError("vLLM does not support in-flight 8-bit quantization.") self.llm = LLM( model=model.name_or_path, tensor_parallel_size=args.vllm_tensor_parallel_size, From d7f8933823c629e91bf95f5c6c01e0317bd78cb3 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 14 Nov 2025 11:40:43 +0100 Subject: [PATCH 8/9] extend --- trl/experimental/gold/gold_trainer.py | 16 ++++++++++++++-- trl/trainer/online_dpo_config.py | 4 ---- trl/trainer/online_dpo_trainer.py | 13 ++++++++++++- trl/trainer/rloo_config.py | 4 ---- trl/trainer/rloo_trainer.py | 14 +++++++++++++- 5 files changed, 39 insertions(+), 12 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index de443dddcdf..635ee417de5 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -29,7 +29,7 @@ from accelerate.utils import DistributedType, broadcast_object_list, gather_object, is_peft_model from datasets import Dataset, IterableDataset from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from transformers import AutoTokenizer +from transformers import AutoTokenizer, is_bitsandbytes_available from transformers.data.data_collator import DataCollator from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.generation.configuration_utils import GenerationConfig @@ -84,6 +84,9 @@ from rich.table import Table from rich.text import Text +if is_bitsandbytes_available(): + import bitsandbytes as bnb + def print_prompt_completions_sample_uld( prompts: list[str], @@ -941,6 +944,15 @@ def __init__( os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) ensure_master_addr_port() + vllm_quantization = None + if is_bitsandbytes_available(): + for _, module in model.named_modules(): + if isinstance(module, bnb.nn.Linear4bit): + vllm_quantization = "bitsandbytes" + break + elif isinstance(module, bnb.nn.Linear8bitLt): + raise ValueError("vLLM does not support in-flight 8-bit quantization.") + self.vllm_engine = LLM( model=student_model_name_or_path, revision=self.model_revision, @@ -952,7 +964,7 @@ def __init__( # Feed identical seed for tp groups to ensure sampling results are the same across workers seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, enable_sleep_mode=self.vllm_enable_sleep_mode, - quantization=self.args.vllm_quantization, + quantization=vllm_quantization, ) if self.vllm_enable_sleep_mode: diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index 9be4f823b3b..d2d21e3c3ee 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -307,10 +307,6 @@ class may differ from those in [`~transformers.TrainingArguments`]. default=None, metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, ) - vllm_quantization: str | None = field( - default=None, - metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."}, - ) vllm_gpu_memory_utilization: float | None = field( default=0.55, metadata={ diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 651cb494f1b..1b4e487c861 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -45,6 +45,7 @@ ProcessorMixin, Trainer, TrainerCallback, + is_bitsandbytes_available, ) from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES from transformers.trainer_utils import EvalPrediction, seed_worker @@ -97,6 +98,8 @@ from vllm import LLM, SamplingParams from vllm.sampling_params import GuidedDecodingParams +if is_bitsandbytes_available(): + import bitsandbytes as bnb logger = logging.get_logger(__name__) @@ -477,6 +480,14 @@ def __init__( # after the first optimizer step and remain in GPU memory throughout training. So we must reserve enough # space for them. # Configure vLLM parameters + vllm_quantization = None + if is_bitsandbytes_available(): + for _, module in model.named_modules(): + if isinstance(module, bnb.nn.Linear4bit): + vllm_quantization = "bitsandbytes" + break + elif isinstance(module, bnb.nn.Linear8bitLt): + raise ValueError("vLLM does not support in-flight 8-bit quantization.") vllm_kwargs = { "model": model.name_or_path, "tensor_parallel_size": self.vllm_tensor_parallel_size, @@ -489,7 +500,7 @@ def __init__( "seed": self.accelerator.process_index // self.vllm_tensor_parallel_size, # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) "max_num_batched_tokens": 4096, - "quantization": self.args.vllm_quantization, + "quantization": vllm_quantization, } # vLLM requires the environment variables to be set for distributed training. diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index 9c18440ad92..eb893c604c8 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -415,10 +415,6 @@ class RLOOConfig(TrainingArguments): default=None, metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, ) - vllm_quantization: str | None = field( - default=None, - metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."}, - ) # Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) vllm_server_base_url: str | None = field( diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 30fe0f28be2..ca572caae7e 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -45,6 +45,7 @@ PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, + is_bitsandbytes_available, is_trackio_available, is_wandb_available, ) @@ -97,6 +98,9 @@ if is_trackio_available(): import trackio +if is_bitsandbytes_available(): + import bitsandbytes as bnb + logger = logging.get_logger(__name__) @@ -513,6 +517,14 @@ def __init__( max_model_len = self.max_prompt_length + self.max_completion_length else: max_model_len = None + vllm_quantization = None + if is_bitsandbytes_available(): + for _, module in model.named_modules(): + if isinstance(module, bnb.nn.Linear4bit): + vllm_quantization = "bitsandbytes" + break + elif isinstance(module, bnb.nn.Linear8bitLt): + raise ValueError("vLLM does not support in-flight 8-bit quantization.") self.llm = LLM( model=model.name_or_path, tensor_parallel_size=args.vllm_tensor_parallel_size, @@ -528,7 +540,7 @@ def __init__( max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, - quantization=self.args.vllm_quantization, + quantization=vllm_quantization, ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=2) From a0949226eead6c814e6f88ddc86ddb0155354516 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 14 Nov 2025 11:42:01 +0100 Subject: [PATCH 9/9] extend --- trl/experimental/gold/gold_config.py | 4 ---- trl/trainer/grpo_config.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index 23e9555c3b3..76bb23edc36 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -304,10 +304,6 @@ class GOLDConfig(SFTConfig): default=None, metadata={"help": "Regex pattern used for vLLM guided decoding (optional)."}, ) - vllm_quantization: str | None = field( - default=None, - metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."}, - ) vllm_sync_frequency: int = field( default=1, metadata={ diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 92eb3cb2ed6..2d97d67bd8e 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -489,10 +489,6 @@ class GRPOConfig(TrainingArguments): default=None, metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, ) - vllm_quantization: str | None = field( - default=None, - metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."}, - ) # Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) vllm_server_base_url: str | None = field(