Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions trl/experimental/gold/gold_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ class GOLDConfig(SFTConfig):
default=None,
metadata={"help": "Regex pattern used for vLLM guided decoding (optional)."},
)
vllm_quantization: str | None = field(
default=None,
metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."},
)
vllm_sync_frequency: int = field(
default=1,
metadata={
Expand Down
1 change: 1 addition & 0 deletions trl/experimental/gold/gold_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,7 @@ def __init__(
# Feed identical seed for tp groups to ensure sampling results are the same across workers
seed=self.accelerator.process_index // self.vllm_tensor_parallel_size,
enable_sleep_mode=self.vllm_enable_sleep_mode,
quantization=self.args.vllm_quantization,
)

if self.vllm_enable_sleep_mode:
Expand Down
18 changes: 18 additions & 0 deletions trl/scripts/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import sys
from dataclasses import dataclass, field

import torch
from accelerate import logging
from datasets import load_dataset

Expand All @@ -38,7 +39,9 @@
ScriptArguments,
TrlParser,
get_dataset,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward

Expand Down Expand Up @@ -112,6 +115,21 @@ def main(script_args, training_args, model_args, dataset_args):
f"Could not load reward function '{func_name}'. Expected one of "
f"{list(reward_funcs_registry.keys())} or a valid import path."
)
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)

model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
)
quantization_config = get_quantization_config(model_args)

if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

training_args.model_init_kwargs = model_kwargs

# Load the dataset
if dataset_args.datasets and script_args.dataset_name:
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ class GRPOConfig(TrainingArguments):
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
vllm_quantization: str | None = field(
default=None,
metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."},
)

# Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
vllm_server_base_url: str | None = field(
Expand Down
13 changes: 13 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
PreTrainedTokenizerBase,
ProcessorMixin,
TrainerCallback,
is_bitsandbytes_available,
is_trackio_available,
is_wandb_available,
)
Expand Down Expand Up @@ -101,6 +102,8 @@
if is_trackio_available():
import trackio

if is_bitsandbytes_available():
import bitsandbytes as bnb

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -596,6 +599,15 @@ def cast_outputs_to_original_dtype(module, args, output):
max_model_len = self.max_prompt_length + self.max_completion_length
else:
max_model_len = None

vllm_quantization = None
if is_bitsandbytes_available():
for _, module in model.named_modules():
if isinstance(module, bnb.nn.Linear4bit):
Comment thread
sergiopaniego marked this conversation as resolved.
vllm_quantization = "bitsandbytes"
break
elif isinstance(module, bnb.nn.Linear8bitLt):
raise ValueError("vLLM does not support in-flight 8-bit quantization.")
self.llm = LLM(
model=model.name_or_path,
tensor_parallel_size=args.vllm_tensor_parallel_size,
Expand All @@ -613,6 +625,7 @@ def cast_outputs_to_original_dtype(module, args, output):
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
# Important so temperature scaling/logit tweaking affects the TIS log probs
logprobs_mode="processed_logprobs",
quantization=vllm_quantization,
)
if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=2)
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ class ModelConfig:
default=False,
metadata={"help": "Whether to use nested quantization."},
)
bnb_4bit_quant_storage: str | None = field(
default=None,
metadata={"help": "Quantization storage dtype"},
)
# Deprecated params
torch_dtype: str | None = field(
default=None,
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/online_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ class may differ from those in [`~transformers.TrainingArguments`].
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
vllm_quantization: str | None = field(
default=None,
metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."},
)
vllm_gpu_memory_utilization: float | None = field(
default=0.55,
metadata={
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def __init__(
"seed": self.accelerator.process_index // self.vllm_tensor_parallel_size,
# Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768)
"max_num_batched_tokens": 4096,
"quantization": self.args.vllm_quantization,
}

# vLLM requires the environment variables to be set for distributed training.
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,10 @@ class RLOOConfig(TrainingArguments):
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
vllm_quantization: str | None = field(
default=None,
metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."},
)

# Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
vllm_server_base_url: str | None = field(
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def __init__(
max_num_batched_tokens=4096,
model_impl=self.args.vllm_model_impl,
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
quantization=self.args.vllm_quantization,
)
if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=2)
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def get_quantization_config(model_args: ModelConfig) -> BitsAndBytesConfig | Non
bnb_4bit_compute_dtype=model_args.dtype, # For consistency with model weights, we use the same value as `dtype`
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
bnb_4bit_quant_storage=model_args.dtype,
bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage,
)
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
Expand Down
Loading