Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions trl/experimental/gold/gold_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ class GOLDConfig(SFTConfig):
default=None,
metadata={"help": "Regex pattern used for vLLM guided decoding (optional)."},
)
vllm_quantization: str | None = field(
default=None,
metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."},
)
vllm_sync_frequency: int = field(
default=1,
metadata={
Expand Down
1 change: 1 addition & 0 deletions trl/experimental/gold/gold_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,7 @@ def __init__(
# Feed identical seed for tp groups to ensure sampling results are the same across workers
seed=self.accelerator.process_index // self.vllm_tensor_parallel_size,
enable_sleep_mode=self.vllm_enable_sleep_mode,
quantization=self.args.vllm_quantization,
)

if self.vllm_enable_sleep_mode:
Expand Down
18 changes: 18 additions & 0 deletions trl/scripts/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import sys
from dataclasses import dataclass, field

import torch
from accelerate import logging
from datasets import load_dataset

Expand All @@ -38,7 +39,9 @@
ScriptArguments,
TrlParser,
get_dataset,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward

Expand Down Expand Up @@ -112,6 +115,21 @@ def main(script_args, training_args, model_args, dataset_args):
f"Could not load reward function '{func_name}'. Expected one of "
f"{list(reward_funcs_registry.keys())} or a valid import path."
)
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)

model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
)
quantization_config = get_quantization_config(model_args)

if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

training_args.model_init_kwargs = model_kwargs

# Load the dataset
if dataset_args.datasets and script_args.dataset_name:
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ class GRPOConfig(TrainingArguments):
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
vllm_quantization: str | None = field(
default=None,
metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."},
)

# Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
vllm_server_base_url: str | None = field(
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ def cast_outputs_to_original_dtype(module, args, output):
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
# Important so temperature scaling/logit tweaking affects the TIS log probs
logprobs_mode="processed_logprobs",
quantization=self.args.vllm_quantization,
)
if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=2)
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ class ModelConfig:
default=False,
metadata={"help": "Whether to use nested quantization."},
)
bnb_4bit_quant_storage: str | None = field(
default=None,
metadata={"help": "Quantization storage dtype"},
)
# Deprecated params
torch_dtype: str | None = field(
default=None,
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/online_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ class may differ from those in [`~transformers.TrainingArguments`].
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
vllm_quantization: str | None = field(
default=None,
metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."},
)
vllm_gpu_memory_utilization: float | None = field(
default=0.55,
metadata={
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def __init__(
"seed": self.accelerator.process_index // self.vllm_tensor_parallel_size,
# Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768)
"max_num_batched_tokens": 4096,
"quantization": self.args.vllm_quantization,
}

# vLLM requires the environment variables to be set for distributed training.
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,10 @@ class RLOOConfig(TrainingArguments):
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
vllm_quantization: str | None = field(
default=None,
metadata={"help": "Quantization method to use for vLLM. If `None` (default), no quantization is applied."},
)

# Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
vllm_server_base_url: str | None = field(
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def __init__(
max_num_batched_tokens=4096,
model_impl=self.args.vllm_model_impl,
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
quantization=self.args.vllm_quantization,
)
if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=2)
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def get_quantization_config(model_args: ModelConfig) -> BitsAndBytesConfig | Non
bnb_4bit_compute_dtype=model_args.dtype, # For consistency with model weights, we use the same value as `dtype`
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
bnb_4bit_quant_storage=model_args.dtype,
bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage,
)
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
Expand Down
Loading