diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index b2697d94623..635ee417de5 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -29,7 +29,7 @@ from accelerate.utils import DistributedType, broadcast_object_list, gather_object, is_peft_model from datasets import Dataset, IterableDataset from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from transformers import AutoTokenizer +from transformers import AutoTokenizer, is_bitsandbytes_available from transformers.data.data_collator import DataCollator from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.generation.configuration_utils import GenerationConfig @@ -84,6 +84,9 @@ from rich.table import Table from rich.text import Text +if is_bitsandbytes_available(): + import bitsandbytes as bnb + def print_prompt_completions_sample_uld( prompts: list[str], @@ -941,6 +944,15 @@ def __init__( os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) ensure_master_addr_port() + vllm_quantization = None + if is_bitsandbytes_available(): + for _, module in model.named_modules(): + if isinstance(module, bnb.nn.Linear4bit): + vllm_quantization = "bitsandbytes" + break + elif isinstance(module, bnb.nn.Linear8bitLt): + raise ValueError("vLLM does not support in-flight 8-bit quantization.") + self.vllm_engine = LLM( model=student_model_name_or_path, revision=self.model_revision, @@ -952,6 +964,7 @@ def __init__( # Feed identical seed for tp groups to ensure sampling results are the same across workers seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, enable_sleep_mode=self.vllm_enable_sleep_mode, + quantization=vllm_quantization, ) if self.vllm_enable_sleep_mode: diff --git a/trl/scripts/grpo.py b/trl/scripts/grpo.py index 7e46a3858fc..ece0ae55808 100644 --- a/trl/scripts/grpo.py +++ b/trl/scripts/grpo.py @@ -27,6 +27,7 @@ import sys from dataclasses import dataclass, field +import torch from accelerate import logging from datasets import load_dataset @@ -38,7 +39,9 @@ ScriptArguments, TrlParser, get_dataset, + get_kbit_device_map, get_peft_config, + get_quantization_config, ) from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward @@ -112,6 +115,21 @@ def main(script_args, training_args, model_args, dataset_args): f"Could not load reward function '{func_name}'. Expected one of " f"{list(reward_funcs_registry.keys())} or a valid import path." ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) + quantization_config = get_quantization_config(model_args) + + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + + training_args.model_init_kwargs = model_kwargs # Load the dataset if dataset_args.datasets and script_args.dataset_name: diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index b2517eb7903..3343f716c81 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -45,6 +45,7 @@ PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, + is_bitsandbytes_available, is_trackio_available, is_wandb_available, ) @@ -101,6 +102,8 @@ if is_trackio_available(): import trackio +if is_bitsandbytes_available(): + import bitsandbytes as bnb logger = logging.get_logger(__name__) @@ -596,6 +599,15 @@ def cast_outputs_to_original_dtype(module, args, output): max_model_len = self.max_prompt_length + self.max_completion_length else: max_model_len = None + + vllm_quantization = None + if is_bitsandbytes_available(): + for _, module in model.named_modules(): + if isinstance(module, bnb.nn.Linear4bit): + vllm_quantization = "bitsandbytes" + break + elif isinstance(module, bnb.nn.Linear8bitLt): + raise ValueError("vLLM does not support in-flight 8-bit quantization.") self.llm = LLM( model=model.name_or_path, tensor_parallel_size=args.vllm_tensor_parallel_size, @@ -613,6 +625,7 @@ def cast_outputs_to_original_dtype(module, args, output): enable_sleep_mode=self.args.vllm_enable_sleep_mode, # Important so temperature scaling/logit tweaking affects the TIS log probs logprobs_mode="processed_logprobs", + quantization=vllm_quantization, ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=2) diff --git a/trl/trainer/model_config.py b/trl/trainer/model_config.py index a7556b47646..84bb1450b6a 100644 --- a/trl/trainer/model_config.py +++ b/trl/trainer/model_config.py @@ -176,6 +176,10 @@ class ModelConfig: default=False, metadata={"help": "Whether to use nested quantization."}, ) + bnb_4bit_quant_storage: str | None = field( + default=None, + metadata={"help": "Quantization storage dtype"}, + ) # Deprecated params torch_dtype: str | None = field( default=None, diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index cdbd5458fe8..1b4e487c861 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -45,6 +45,7 @@ ProcessorMixin, Trainer, TrainerCallback, + is_bitsandbytes_available, ) from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES from transformers.trainer_utils import EvalPrediction, seed_worker @@ -97,6 +98,8 @@ from vllm import LLM, SamplingParams from vllm.sampling_params import GuidedDecodingParams +if is_bitsandbytes_available(): + import bitsandbytes as bnb logger = logging.get_logger(__name__) @@ -477,6 +480,14 @@ def __init__( # after the first optimizer step and remain in GPU memory throughout training. So we must reserve enough # space for them. # Configure vLLM parameters + vllm_quantization = None + if is_bitsandbytes_available(): + for _, module in model.named_modules(): + if isinstance(module, bnb.nn.Linear4bit): + vllm_quantization = "bitsandbytes" + break + elif isinstance(module, bnb.nn.Linear8bitLt): + raise ValueError("vLLM does not support in-flight 8-bit quantization.") vllm_kwargs = { "model": model.name_or_path, "tensor_parallel_size": self.vllm_tensor_parallel_size, @@ -489,6 +500,7 @@ def __init__( "seed": self.accelerator.process_index // self.vllm_tensor_parallel_size, # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) "max_num_batched_tokens": 4096, + "quantization": vllm_quantization, } # vLLM requires the environment variables to be set for distributed training. diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 2ee731dbd65..ca572caae7e 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -45,6 +45,7 @@ PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, + is_bitsandbytes_available, is_trackio_available, is_wandb_available, ) @@ -97,6 +98,9 @@ if is_trackio_available(): import trackio +if is_bitsandbytes_available(): + import bitsandbytes as bnb + logger = logging.get_logger(__name__) @@ -513,6 +517,14 @@ def __init__( max_model_len = self.max_prompt_length + self.max_completion_length else: max_model_len = None + vllm_quantization = None + if is_bitsandbytes_available(): + for _, module in model.named_modules(): + if isinstance(module, bnb.nn.Linear4bit): + vllm_quantization = "bitsandbytes" + break + elif isinstance(module, bnb.nn.Linear8bitLt): + raise ValueError("vLLM does not support in-flight 8-bit quantization.") self.llm = LLM( model=model.name_or_path, tensor_parallel_size=args.vllm_tensor_parallel_size, @@ -528,6 +540,7 @@ def __init__( max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, + quantization=vllm_quantization, ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=2) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 529c6bb622a..02d4cc78073 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -662,7 +662,7 @@ def get_quantization_config(model_args: ModelConfig) -> BitsAndBytesConfig | Non bnb_4bit_compute_dtype=model_args.dtype, # For consistency with model weights, we use the same value as `dtype` bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, - bnb_4bit_quant_storage=model_args.dtype, + bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage, ) elif model_args.load_in_8bit: quantization_config = BitsAndBytesConfig(