Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion trl/experimental/gold/gold_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from accelerate.utils import DistributedType, broadcast_object_list, gather_object, is_peft_model
from datasets import Dataset, IterableDataset
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import AutoTokenizer
from transformers import AutoTokenizer, is_bitsandbytes_available
from transformers.data.data_collator import DataCollator
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.generation.configuration_utils import GenerationConfig
Expand Down Expand Up @@ -84,6 +84,9 @@
from rich.table import Table
from rich.text import Text

if is_bitsandbytes_available():
import bitsandbytes as bnb


def print_prompt_completions_sample_uld(
prompts: list[str],
Expand Down Expand Up @@ -941,6 +944,15 @@ def __init__(
os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes)
ensure_master_addr_port()

vllm_quantization = None
if is_bitsandbytes_available():
for _, module in model.named_modules():
if isinstance(module, bnb.nn.Linear4bit):
vllm_quantization = "bitsandbytes"
break
elif isinstance(module, bnb.nn.Linear8bitLt):
raise ValueError("vLLM does not support in-flight 8-bit quantization.")

self.vllm_engine = LLM(
model=student_model_name_or_path,
revision=self.model_revision,
Expand All @@ -952,6 +964,7 @@ def __init__(
# Feed identical seed for tp groups to ensure sampling results are the same across workers
seed=self.accelerator.process_index // self.vllm_tensor_parallel_size,
enable_sleep_mode=self.vllm_enable_sleep_mode,
quantization=vllm_quantization,
)

if self.vllm_enable_sleep_mode:
Expand Down
18 changes: 18 additions & 0 deletions trl/scripts/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import sys
from dataclasses import dataclass, field

import torch
from accelerate import logging
from datasets import load_dataset

Expand All @@ -38,7 +39,9 @@
ScriptArguments,
TrlParser,
get_dataset,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward

Expand Down Expand Up @@ -112,6 +115,21 @@ def main(script_args, training_args, model_args, dataset_args):
f"Could not load reward function '{func_name}'. Expected one of "
f"{list(reward_funcs_registry.keys())} or a valid import path."
)
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)

model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
)
quantization_config = get_quantization_config(model_args)

if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

training_args.model_init_kwargs = model_kwargs

# Load the dataset
if dataset_args.datasets and script_args.dataset_name:
Expand Down
13 changes: 13 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
PreTrainedTokenizerBase,
ProcessorMixin,
TrainerCallback,
is_bitsandbytes_available,
is_trackio_available,
is_wandb_available,
)
Expand Down Expand Up @@ -101,6 +102,8 @@
if is_trackio_available():
import trackio

if is_bitsandbytes_available():
import bitsandbytes as bnb

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -596,6 +599,15 @@ def cast_outputs_to_original_dtype(module, args, output):
max_model_len = self.max_prompt_length + self.max_completion_length
else:
max_model_len = None

vllm_quantization = None
if is_bitsandbytes_available():
for _, module in model.named_modules():
if isinstance(module, bnb.nn.Linear4bit):
vllm_quantization = "bitsandbytes"
break
elif isinstance(module, bnb.nn.Linear8bitLt):
raise ValueError("vLLM does not support in-flight 8-bit quantization.")
self.llm = LLM(
model=model.name_or_path,
tensor_parallel_size=args.vllm_tensor_parallel_size,
Expand All @@ -613,6 +625,7 @@ def cast_outputs_to_original_dtype(module, args, output):
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
# Important so temperature scaling/logit tweaking affects the TIS log probs
logprobs_mode="processed_logprobs",
quantization=vllm_quantization,
)
if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=2)
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ class ModelConfig:
default=False,
metadata={"help": "Whether to use nested quantization."},
)
bnb_4bit_quant_storage: str | None = field(
default=None,
metadata={"help": "Quantization storage dtype"},
)
# Deprecated params
torch_dtype: str | None = field(
default=None,
Expand Down
12 changes: 12 additions & 0 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
ProcessorMixin,
Trainer,
TrainerCallback,
is_bitsandbytes_available,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
from transformers.trainer_utils import EvalPrediction, seed_worker
Expand Down Expand Up @@ -97,6 +98,8 @@
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams

if is_bitsandbytes_available():
import bitsandbytes as bnb

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -477,6 +480,14 @@ def __init__(
# after the first optimizer step and remain in GPU memory throughout training. So we must reserve enough
# space for them.
# Configure vLLM parameters
vllm_quantization = None
if is_bitsandbytes_available():
for _, module in model.named_modules():
if isinstance(module, bnb.nn.Linear4bit):
vllm_quantization = "bitsandbytes"
break
elif isinstance(module, bnb.nn.Linear8bitLt):
raise ValueError("vLLM does not support in-flight 8-bit quantization.")
vllm_kwargs = {
"model": model.name_or_path,
"tensor_parallel_size": self.vllm_tensor_parallel_size,
Expand All @@ -489,6 +500,7 @@ def __init__(
"seed": self.accelerator.process_index // self.vllm_tensor_parallel_size,
# Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768)
"max_num_batched_tokens": 4096,
"quantization": vllm_quantization,
}

# vLLM requires the environment variables to be set for distributed training.
Expand Down
13 changes: 13 additions & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
PreTrainedTokenizerBase,
ProcessorMixin,
TrainerCallback,
is_bitsandbytes_available,
is_trackio_available,
is_wandb_available,
)
Expand Down Expand Up @@ -97,6 +98,9 @@
if is_trackio_available():
import trackio

if is_bitsandbytes_available():
import bitsandbytes as bnb


logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -513,6 +517,14 @@ def __init__(
max_model_len = self.max_prompt_length + self.max_completion_length
else:
max_model_len = None
vllm_quantization = None
if is_bitsandbytes_available():
for _, module in model.named_modules():
if isinstance(module, bnb.nn.Linear4bit):
vllm_quantization = "bitsandbytes"
break
elif isinstance(module, bnb.nn.Linear8bitLt):
raise ValueError("vLLM does not support in-flight 8-bit quantization.")
self.llm = LLM(
model=model.name_or_path,
tensor_parallel_size=args.vllm_tensor_parallel_size,
Expand All @@ -528,6 +540,7 @@ def __init__(
max_num_batched_tokens=4096,
model_impl=self.args.vllm_model_impl,
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
quantization=vllm_quantization,
)
if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=2)
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def get_quantization_config(model_args: ModelConfig) -> BitsAndBytesConfig | Non
bnb_4bit_compute_dtype=model_args.dtype, # For consistency with model weights, we use the same value as `dtype`
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
bnb_4bit_quant_storage=model_args.dtype,
bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage,
)
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
Expand Down
Loading