diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 7f7cfecbbc7..74c8f48e5c2 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1600,97 +1600,55 @@ def test_prompt_tuning(self): @require_peft @require_bitsandbytes - def test_peft_model_with_quantization(self): - """SFTTrainer should not freeze layers of existing PeftModel. - - This test simulates a realistic QLoRA scenario where a quantized base model is first converted to a PeftModel, - then passed to SFTTrainer. The issue was that prepare_model_for_kbit_training would freeze all parameters - including the LoRA adapters, making training impossible. - """ + def test_peft_with_quantization(self): # Get the base model model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" - model = AutoModelForCausalLM.from_pretrained(model_id) - # Simulate a realistic QLoRA setup by mocking quantization attributes - # This mimics what happens when loading a model with load_in_4bit=True - model.is_loaded_in_4bit = True - model.is_loaded_in_8bit = False - - # Verify that this triggers the is_qlora condition - is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False) - assert is_qlora, "Model should be detected as QLoRA (quantized)" - - # Create LoRA configuration suitable for QLoRA - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - target_modules=["q_proj", "v_proj"], - r=16, - lora_alpha=32, - lora_dropout=0.1, + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, ) - - # Convert the quantized model to a PeftModel (typical QLoRA workflow) - peft_model = get_peft_model(model, lora_config) - - # Verify the quantization attributes are preserved on the PeftModel - assert getattr(peft_model, "is_loaded_in_4bit", False), "PeftModel should preserve quantization flag" + model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") - # Analyze parameters before SFTTrainer initialization - trainable_params_before = [] - base_params_before = [] - lora_params_before = [] - - for name, param in peft_model.named_parameters(): - if param.requires_grad: - trainable_params_before.append(name) - if "lora" in name.lower(): - lora_params_before.append(name) - else: - base_params_before.append(name) - - # Ensure we have the expected parameter distribution for QLoRA - assert len(trainable_params_before) > 0, "PeftModel should have trainable parameters initially" - assert len(lora_params_before) > 0, "PeftModel should have trainable LoRA parameters" - assert len(base_params_before) == 0, "Base model parameters should already be frozen in PeftModel" - # Initialize the trainer with the already configured PeftModel - training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", max_steps=1) - trainer = SFTTrainer(model=peft_model, args=training_args, train_dataset=dataset) - - # Analyze parameters after SFTTrainer initialization - trainable_params_after = [] - lora_params_after = [] - - for name, param in trainer.model.named_parameters(): - if param.requires_grad: - trainable_params_after.append(name) - if "lora" in name.lower(): - lora_params_after.append(name) + training_args = SFTConfig(output_dir=self.tmp_dir, learning_rate=0.1, report_to="none") + trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset, peft_config=LoraConfig()) - # LoRA parameters should remain trainable - assert len(trainable_params_after) > 0, ( - f"PeftModel should still have trainable parameters after SFTTrainer initialization. " - f"Found {len(trainable_params_after)} trainable params. " - f"This test fails without the fix for issue #3926." - ) + # Save initial parameters to check they change during training + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} - assert len(lora_params_after) > 0, ( - f"LoRA adapter parameters should remain trainable. " - f"Found {len(lora_params_after)} trainable LoRA params out of {len(lora_params_before)} original." - ) + trainer.train() - # Ensure the parameter counts are preserved (no additional freezing occurred) - assert len(trainable_params_before) == len(trainable_params_after), ( - "Number of trainable parameters should not change after SFTTrainer initialization" - ) + # Check that training completed successfully + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None - # Verify that all original LoRA parameters are still trainable - assert set(lora_params_before) == set(lora_params_after), ( - "All original LoRA parameters should remain trainable after SFTTrainer initialization" - ) + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # In bitsandbytes, bias parameters are automatically cast to the input dtype during the forward pass if + # their dtype doesn’t match. This causes the module to change unexpectedly during the first forward pass of + # the training. To handle this, we cast these specific bias parameters to float32 before comparison. + # https://github.com/bitsandbytes-foundation/bitsandbytes/blob/45553f7392e524eacf400b132cfe01261f6477be/bitsandbytes/nn/modules.py#L518 + # We still need to investigate why the compute dtype ends up being different than for these parameters. + if n in [ + "base_model.model.model.layers.1.self_attn.k_proj.bias", + "base_model.model.model.layers.1.self_attn.q_proj.base_layer.bias", + "base_model.model.model.layers.1.self_attn.v_proj.base_layer.bias", + ]: + param = param.float() + + if "lora" not in n: # We expect the base model parameters to be the same + assert torch.allclose(param, new_param), f"Parameter {n} has changed" + elif "lora" in n: # We expect the peft parameters to be different + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + else: + raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}") @require_peft def test_prompt_tuning_peft_model(self): diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7e5c0d8c6bb..cbd7a454797 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -28,7 +28,6 @@ import pandas as pd import torch import torch.utils.data -import transformers from accelerate import logging from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from datasets import Dataset, IterableDataset @@ -36,7 +35,6 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader, Sampler from transformers import ( - AutoConfig, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer, @@ -61,13 +59,14 @@ from ..extras.profiling import profiling_context, profiling_decorator from ..extras.vllm_client import VLLMClient from ..import_utils import is_liger_kernel_available, is_vllm_available -from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation +from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation from ..models.utils import _ForwardRedirection from .base_trainer import BaseTrainer from .callbacks import SyncRefModelCallback from .grpo_config import GRPOConfig from .utils import ( RepeatSampler, + create_model_from_path, disable_dropout_in_model, ensure_master_addr_port, entropy_from_logits, @@ -87,7 +86,7 @@ if is_peft_available(): - from peft import PeftConfig, PeftModel + from peft import PeftConfig, PeftModel, get_peft_model if is_liger_kernel_available(): from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss @@ -254,28 +253,14 @@ def __init__( model_name = model_name.split("/")[-1] args = GRPOConfig(f"{model_name}-GRPO") - # Models - # Trained model - model_init_kwargs = args.model_init_kwargs or {} + # Model if isinstance(model, str): - model_id = model - dtype = model_init_kwargs.get("dtype", "auto") - if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: - pass # dtype is already a torch.dtype or "auto" or None - elif isinstance(dtype, str): # it's a str, but not "auto" - dtype = getattr(torch, dtype) - model_init_kwargs["dtype"] = dtype - else: - raise ValueError( - "Invalid `dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " - f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." - ) - model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto") - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - model = architecture.from_pretrained(model_id, **model_init_kwargs) + model_init_kwargs = args.model_init_kwargs or {} + # Special case for DeepSpeed: requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type == "DEEPSPEED": + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) else: - model_id = get_config_model_id(model.config) if args.model_init_kwargs is not None: logger.warning( "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " @@ -290,9 +275,6 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) - if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): - model = prepare_peft_model(model, peft_config, args) - # Processing class if processing_class is None: processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config), truncation_side="left") @@ -312,12 +294,40 @@ def __init__( self.pad_token_id = tokenizer.pad_token_id self.eos_token_id = tokenizer.eos_token_id + if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None: + # If the model is already a PeftModel, we need to merge and unload it. + # Further information: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft + model = model.merge_and_unload() + + # Create PEFT model + if peft_config is not None: + model = get_peft_model(model, peft_config) + + # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally + # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 + if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing: + model.enable_input_require_grads() + + # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the + # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by + # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for + # quantized models. See: https://github.com/huggingface/peft/issues/2889 + # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do + if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False): + for param in model.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.bfloat16) + # Reward functions if not isinstance(reward_funcs, list): reward_funcs = [reward_funcs] self.reward_func_names = [] for i, reward_func in enumerate(reward_funcs): if isinstance(reward_func, str): + model_init_kwargs = args.model_init_kwargs or {} + # Special case for DeepSpeed: requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type == "DEEPSPEED": + model_init_kwargs["device_map"] = None reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( reward_func, num_labels=1, **model_init_kwargs ) @@ -476,9 +486,11 @@ def __init__( self.ref_model = None else: # For deepspeed, fsdp or non-distributed models, create a reference model from scratch - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) + model_init_kwargs = args.model_init_kwargs or {} + # Special case for DeepSpeed: requires device_map=None ("auto" fails) + if self.args.distributed_state.distributed_type == "DEEPSPEED": + model_init_kwargs["device_map"] = None + self.ref_model = create_model_from_path(get_config_model_id(self.model.config), **model_init_kwargs) # Disable dropout in the models if args.disable_dropout: diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 5c567331159..27b846770b2 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -25,7 +25,6 @@ import torch import torch.nn as nn -import transformers from accelerate import PartialState from accelerate.logging import get_logger from datasets import Dataset, IterableDataset @@ -42,14 +41,14 @@ from transformers.utils import is_peft_available from ..data_utils import is_conversational -from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model +from ..models import clone_chat_template, get_act_offloading_ctx_manager from .base_trainer import BaseTrainer from .reward_config import RewardConfig -from .utils import disable_dropout_in_model, get_config_model_id, pad, remove_none_values +from .utils import create_model_from_path, disable_dropout_in_model, get_config_model_id, pad, remove_none_values if is_peft_available(): - from peft import PeftConfig, PeftModel + from peft import PeftConfig, PeftModel, get_peft_model logger = get_logger(__name__) @@ -279,24 +278,13 @@ def __init__( args = RewardConfig(f"{model_name}-Reward") # Model - model_init_kwargs = args.model_init_kwargs or {} if isinstance(model, str): - model_id = model - dtype = model_init_kwargs.get("dtype", "auto") - if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: - pass # dtype is already a torch.dtype or "auto" or None - elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]: - model_init_kwargs["dtype"] = getattr(torch, dtype) - else: - raise ValueError( - "Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing " - f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}." - ) - model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto") - with suppress_from_pretrained_warning(transformers.modeling_utils.logger): - model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs) + model_init_kwargs = args.model_init_kwargs or {} + # Special case for DeepSpeed: requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type == "DEEPSPEED": + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, AutoModelForSequenceClassification, **model_init_kwargs) else: - model_id = get_config_model_id(model.config) if args.model_init_kwargs is not None: logger.warning( "You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. " @@ -305,7 +293,7 @@ def __init__( # Processing class if processing_class is None: - processing_class = AutoTokenizer.from_pretrained(model_id) + processing_class = AutoTokenizer.from_pretrained(get_config_model_id(model.config)) # Handle pad token for processors or tokenizers if args.eos_token is not None: @@ -356,8 +344,29 @@ def __init__( else: peft_config.modules_to_save.append("lm_head") - if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): - model = prepare_peft_model(model, peft_config, args) + if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None: + # If the model is already a PeftModel, we need to merge and unload it. + # Further information: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft + model = model.merge_and_unload() + + # Create PEFT model + if peft_config is not None: + model = get_peft_model(model, peft_config) + + # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally + # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 + if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing: + model.enable_input_require_grads() + + # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the + # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by + # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for + # quantized models. See: https://github.com/huggingface/peft/issues/2889 + # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do + if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False): + for param in model.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.bfloat16) # Disable dropout in the model if args.disable_dropout: diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index dc70b50a67a..c389005f882 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -48,7 +48,7 @@ prepare_multimodal_messages, truncate_dataset, ) -from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model +from ..models import clone_chat_template, get_act_offloading_ctx_manager from .base_trainer import BaseTrainer from .sft_config import SFTConfig from .utils import ( @@ -63,7 +63,7 @@ if is_peft_available(): - from peft import PeftConfig, PeftModel, PeftType + from peft import PeftConfig, PeftModel, PeftType, get_peft_model logger = logging.get_logger(__name__) @@ -693,12 +693,34 @@ def __init__( else: peft_config.modules_to_save.append("lm_head") + if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None: + # If the model is already a PeftModel, we need to merge and unload it. + # Further information: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft + model = model.merge_and_unload() + + # Create PEFT model + if peft_config is not None: + model = get_peft_model(model, peft_config) + + # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally + # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 + if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing: + model.enable_input_require_grads() + + # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the + # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by + # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for + # quantized models. See: https://github.com/huggingface/peft/issues/2889 + # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do + if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False): + for param in model.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.bfloat16) + # In Prompt Tuning a small set of trainable virtual tokens (continuous prompt embeddings) is prepended to the # input. We store the number of these tokens so we can account for them correctly when calculating accuracy. self.num_virtual_tokens = 0 - - if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): - model = prepare_peft_model(model, peft_config, args) + if is_peft_available() and isinstance(model, PeftModel): if model.active_adapter in model.peft_config: peft_model_config = model.peft_config[model.active_adapter] self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 7f47a378ebc..4e1042879b3 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -48,6 +48,7 @@ TrainingArguments, is_comet_available, ) +from transformers.models.auto.auto_factory import _BaseAutoModelClass from transformers.utils import ( ModelOutput, is_peft_available, @@ -1981,13 +1982,18 @@ def remove_none_values(example: TListOrMapping) -> TListOrMapping: raise TypeError("Input must be a list or a dictionary.") -def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel: +def create_model_from_path( + model_id: str, architecture: _BaseAutoModelClass | None = None, **kwargs +) -> PreTrainedModel: """ Create a model from a given path using the specified initialization arguments. Args: model_id (`str`): Path to the model. Can be either a local directory or a model identifier from the Hugging Face Hub. + architecture (`_BaseAutoModelClass` or `None`, *optional*): + Model architecture class to instantiate. The model is initialized using the `from_pretrained` method of + this class. If `None`, the architecture will be inferred from the model's configuration. kwargs (`dict`): Initialization keyword arguments to pass to the model's `from_pretrained` method. When `'dtype'` is specified, it can be either a `torch.dtype` or one of the strings: `'bfloat16'`, `'float16'`, `'float32'`, @@ -2008,8 +2014,9 @@ def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel: f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}." ) kwargs["device_map"] = kwargs.get("device_map", "auto") - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) + if architecture is None: + config = AutoConfig.from_pretrained(model_id) + architecture = getattr(transformers, config.architectures[0]) model = architecture.from_pretrained(model_id, **kwargs) return model