Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions trl/experimental/online_dpo/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,11 @@
from ...extras.profiling import profiling_context
from ...extras.vllm_client import VLLMClient
from ...import_utils import is_vllm_available
from ...models.utils import (
create_reference_model,
prepare_deepspeed,
prepare_fsdp,
prepare_peft_model,
unwrap_model_for_generation,
)
from ...models.utils import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
from ...trainer.base_trainer import BaseTrainer
from ...trainer.utils import disable_dropout_in_model, empty_cache, ensure_master_addr_port, get_config_model_id, pad
from ..judges import BasePairwiseJudge
from ..utils import SIMPLE_CHAT_TEMPLATE, DPODataCollatorWithPadding, truncate_right
from ..utils import SIMPLE_CHAT_TEMPLATE, DPODataCollatorWithPadding, prepare_peft_model, truncate_right
from .online_dpo_config import OnlineDPOConfig


Expand Down
2 changes: 1 addition & 1 deletion trl/experimental/prm/prm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available

from ...models import prepare_peft_model
from ...trainer.base_trainer import BaseTrainer
from ...trainer.utils import disable_dropout_in_model
from ..utils import prepare_peft_model
from .prm_config import PRMConfig


Expand Down
132 changes: 130 additions & 2 deletions trl/experimental/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,23 @@

# This file contains utility classes and functions that are used across more than one experimental trainer or feature.

import inspect
from dataclasses import dataclass
from typing import Any

import torch
from accelerate.utils import is_peft_model
from packaging import version
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedTokenizerBase
from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainingArguments
from transformers.utils import is_peft_available

from ..trainer.utils import first_true_indices, pad
from ..trainer.utils import first_true_indices, pad, peft_module_casting_to_bf16


if is_peft_available():
import peft
from peft import PeftConfig, PeftModel, get_peft_model


@dataclass
Expand Down Expand Up @@ -306,3 +315,122 @@ def add_eos_token_if_needed(
rejected_tokens["input_ids"].append(eos_token_id)
rejected_tokens["attention_mask"].append(1)
return chosen_tokens, rejected_tokens


def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None):
r"""
Prepare a k-bit quantized transformers model for training (PEFT/QLoRA).
"""
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
quant_methods = ["gptq", "aqlm", "eetq", "torchao", "hqq"]
is_quantized = getattr(model, "quantization_method", None) in quant_methods or getattr(
model, "hqq_quantized", False
)

if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}

for _, param in model.named_parameters():
# freeze all parameters
param.requires_grad = False

# Enable gradient checkpointing if needed
if (loaded_in_kbit or is_quantized) and use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
# backward-compatible hook
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
inspect.signature(model.gradient_checkpointing_enable).parameters
)
gc_kwargs = {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} if supports_gc_kwargs else {}
model.gradient_checkpointing_enable(**gc_kwargs)

return model


def enable_gradient_checkpointing(
model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None
) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(model):
model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
else:
model.gradient_checkpointing_enable()

gradient_checkpointing_kwargs = gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)

if use_reentrant:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:

def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

return model


def prepare_peft_model(
model: PreTrainedModel, peft_config: "PeftConfig | None", args: TrainingArguments
) -> PreTrainedModel:
"""Prepares a model for PEFT training."""
if not is_peft_available():
raise ImportError("PEFT is required to use a peft model. Run `pip install peft`.")

# If the model is already a PeftModel, we need to merge and unload it.
# Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
if isinstance(model, PeftModel) and peft_config is not None:
model = model.merge_and_unload()

# Handle quantized models (QLoRA)
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)

is_sharded_qlora = False
if getattr(model, "is_loaded_in_4bit", False):
# Check if model is sharded (FSDP/DS-Zero3)
for _, param in model.named_parameters():
if param.__class__.__name__ == "Params4bit":
is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
break

# Prepare model for kbit training if needed
if is_qlora and not is_sharded_qlora and not isinstance(model, PeftModel):
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=args.gradient_checkpointing,
gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs or {},
)
# Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
args.gradient_checkpointing = False
elif args.gradient_checkpointing:
model = enable_gradient_checkpointing(model, args.gradient_checkpointing_kwargs)

# Create PEFT model
if peft_config is not None:
if (
version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
and getattr(model, "is_loaded_in_4bit", False)
and is_sharded_qlora
):
model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
else:
model = get_peft_model(model, peft_config)

# Handle bf16 casting for 4-bit models
if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
peft_module_casting_to_bf16(model)

return model
18 changes: 2 additions & 16 deletions trl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,15 @@
"activation_offloading": ["get_act_offloading_ctx_manager"],
"modeling_base": ["PreTrainedModelWrapper"],
"modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"],
"utils": [
"create_reference_model",
"prepare_deepspeed",
"prepare_fsdp",
"prepare_model_for_kbit_training",
"prepare_peft_model",
"unwrap_model_for_generation",
],
"utils": ["create_reference_model", "prepare_deepspeed", "prepare_fsdp", "unwrap_model_for_generation"],
}


if TYPE_CHECKING:
from .activation_offloading import get_act_offloading_ctx_manager
from .modeling_base import PreTrainedModelWrapper
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import (
create_reference_model,
prepare_deepspeed,
prepare_fsdp,
prepare_model_for_kbit_training,
prepare_peft_model,
unwrap_model_for_generation,
)
from .utils import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
else:
import sys

Expand Down
129 changes: 1 addition & 128 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import itertools
import logging
from collections.abc import Callable
Expand All @@ -22,16 +21,9 @@

import torch
import torch.nn as nn
from accelerate.utils import is_peft_model
from packaging import version
from transformers import PreTrainedModel, TrainingArguments
from transformers import PreTrainedModel
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available


if is_peft_available():
import peft
from peft import PeftConfig, PeftModel, get_peft_model


if TYPE_CHECKING:
Expand Down Expand Up @@ -258,72 +250,6 @@ def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.
pass


def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None):
r"""
Prepare a k-bit quantized transformers model for training (PEFT/QLoRA).
"""
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
quant_methods = ["gptq", "aqlm", "eetq", "torchao", "hqq"]
is_quantized = getattr(model, "quantization_method", None) in quant_methods or getattr(
model, "hqq_quantized", False
)

if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}

for _, param in model.named_parameters():
# freeze all parameters
param.requires_grad = False

# Enable gradient checkpointing if needed
if (loaded_in_kbit or is_quantized) and use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
# backward-compatible hook
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
inspect.signature(model.gradient_checkpointing_enable).parameters
)
gc_kwargs = {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} if supports_gc_kwargs else {}
model.gradient_checkpointing_enable(**gc_kwargs)

return model


def enable_gradient_checkpointing(
model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None
) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(model):
model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
else:
model.gradient_checkpointing_enable()

gradient_checkpointing_kwargs = gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)

if use_reentrant:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:

def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

return model


def peft_module_casting_to_bf16(model):
for name, module in model.named_modules():
if isinstance(module, torch.nn.LayerNorm) or "norm" in name:
Expand All @@ -334,59 +260,6 @@ def peft_module_casting_to_bf16(model):
module = module.to(torch.bfloat16)


def prepare_peft_model(
model: PreTrainedModel, peft_config: "PeftConfig | None", args: TrainingArguments
) -> PreTrainedModel:
"""Prepares a model for PEFT training."""
if not is_peft_available():
raise ImportError("PEFT is required to use a peft model. Run `pip install peft`.")

# If the model is already a PeftModel, we need to merge and unload it.
# Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
if isinstance(model, PeftModel) and peft_config is not None:
model = model.merge_and_unload()

# Handle quantized models (QLoRA)
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)

is_sharded_qlora = False
if getattr(model, "is_loaded_in_4bit", False):
# Check if model is sharded (FSDP/DS-Zero3)
for _, param in model.named_parameters():
if param.__class__.__name__ == "Params4bit":
is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
break

# Prepare model for kbit training if needed
if is_qlora and not is_sharded_qlora and not isinstance(model, PeftModel):
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=args.gradient_checkpointing,
gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs or {},
)
# Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
args.gradient_checkpointing = False
elif args.gradient_checkpointing:
model = enable_gradient_checkpointing(model, args.gradient_checkpointing_kwargs)

# Create PEFT model
if peft_config is not None:
if (
version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
and getattr(model, "is_loaded_in_4bit", False)
and is_sharded_qlora
):
model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
else:
model = get_peft_model(model, peft_config)

# Handle bf16 casting for 4-bit models
if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
peft_module_casting_to_bf16(model)

return model


@contextmanager
def disable_gradient_checkpointing(model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None = None):
"""
Expand Down
Loading