From 4d7345acb5b3ea5ce592dbef516f87d6cc895edd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 12 Dec 2025 21:24:44 +0000 Subject: [PATCH 1/3] align rloo and grpo --- trl/trainer/rloo_trainer.py | 91 +++++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 39 deletions(-) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 24b18280705..dca61fce688 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -83,10 +83,14 @@ split_tensor_dict, unsplit_pixel_values_by_grid, ) +from .utils import ( + RepeatSampler, + create_model_from_path) if is_peft_available(): from peft import PeftConfig, PeftModel + from peft import PeftConfig, PeftModel, get_peft_model if is_vllm_available(): from vllm import LLM, SamplingParams @@ -120,21 +124,15 @@ class RLOOTrainer(BaseTrainer): ```python from datasets import load_dataset from trl import RLOOTrainer + from trl.rewards import accuracy_reward - dataset = load_dataset("trl-lib/tldr", split="train") - - - def reward_func(completions, **kwargs): - # Dummy reward function that rewards completions with more unique letters. - return [float(len(set(completion))) for completion in completions] - + dataset = load_dataset("trl-lib/DeepMath-103K", split="train") trainer = RLOOTrainer( model="Qwen/Qwen2-0.5B-Instruct", - reward_funcs=reward_func, + reward_funcs=accuracy_reward, train_dataset=dataset, ) - trainer.train() ``` @@ -231,8 +229,8 @@ def reward_func(completions, **kwargs): def __init__( self, - model: str | PreTrainedModel = None, - reward_funcs: RewardFunc | list[RewardFunc] = None, + model: str | PreTrainedModel, + reward_funcs: RewardFunc | list[RewardFunc], args: RLOOConfig | None = None, train_dataset: Dataset | IterableDataset | None = None, eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, @@ -248,28 +246,14 @@ def __init__( model_name = model_name.split("/")[-1] args = RLOOConfig(f"{model_name}-RLOO") - # Models - # Trained model - model_init_kwargs = args.model_init_kwargs or {} + # Model if isinstance(model, str): - model_id = model - dtype = model_init_kwargs.get("dtype", "auto") - if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: - pass # dtype is already a torch.dtype or "auto" or None - elif isinstance(dtype, str): # it's a str, but not "auto" - dtype = getattr(torch, dtype) - model_init_kwargs["dtype"] = dtype - else: - raise ValueError( - "Invalid `dtype` passed to `RLOOConfig`. Expected either 'auto' or a string representing " - f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." - ) - model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto") - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - model = architecture.from_pretrained(model_id, **model_init_kwargs) + model_init_kwargs = args.model_init_kwargs or {} + # Special case for DeepSpeed: requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type == "DEEPSPEED": + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) else: - model_id = get_config_model_id(model.config) if args.model_init_kwargs is not None: logger.warning( "You passed `model_init_kwargs` to the `RLOOConfig`, but your model is already instantiated. " @@ -284,12 +268,11 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) - if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): - model = prepare_peft_model(model, peft_config, args) - # Processing class if processing_class is None: - processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config), truncation_side="left") + processing_class = AutoProcessor.from_pretrained( + get_config_model_id(model.config), truncation_side="left", padding_side="left" + ) # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): @@ -306,12 +289,40 @@ def __init__( self.pad_token_id = tokenizer.pad_token_id self.eos_token_id = tokenizer.eos_token_id + if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None: + # If the model is already a PeftModel, we need to merge and unload it. + # Further information: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft + model = model.merge_and_unload() + + # Create PEFT model + if peft_config is not None: + model = get_peft_model(model, peft_config) + + # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally + # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 + if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing: + model.enable_input_require_grads() + + # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the + # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by + # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for + # quantized models. See: https://github.com/huggingface/peft/issues/2889 + # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do + if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False): + for param in model.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.bfloat16) + # Reward functions if not isinstance(reward_funcs, list): reward_funcs = [reward_funcs] self.reward_func_names = [] for i, reward_func in enumerate(reward_funcs): if isinstance(reward_func, str): + model_init_kwargs = args.model_init_kwargs or {} + # Special case for DeepSpeed: requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type == "DEEPSPEED": + model_init_kwargs["device_map"] = None reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( reward_func, num_labels=1, **model_init_kwargs ) @@ -362,7 +373,7 @@ def __init__( self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length self.num_generations = args.num_generations - self.num_generations_eval = args.num_generations_eval or args.num_generations + self.num_generations_eval = args.num_generations_eval or self.num_generations self.chat_template_kwargs = args.chat_template_kwargs or {} self.temperature = args.temperature self.top_p = args.top_p @@ -433,9 +444,11 @@ def __init__( self.ref_model = None else: # For deepspeed, fsdp or non-distributed models, create a reference model from scratch - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) + model_init_kwargs = args.model_init_kwargs or {} + # Special case for DeepSpeed: requires device_map=None ("auto" fails) + if self.args.distributed_state.distributed_type == "DEEPSPEED": + model_init_kwargs["device_map"] = None + self.ref_model = create_model_from_path(get_config_model_id(self.model.config), **model_init_kwargs) # Disable dropout in the models if args.disable_dropout: From 21549b29a8bfbe3141c87a6ec7bd08b25c8017ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 12 Dec 2025 21:27:13 +0000 Subject: [PATCH 2/3] style --- trl/trainer/rloo_trainer.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index dca61fce688..2b37cd912fd 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -27,7 +27,6 @@ import pandas as pd import torch import torch.utils.data -import transformers from accelerate import logging from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from datasets import Dataset, IterableDataset @@ -35,7 +34,6 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader, Sampler from transformers import ( - AutoConfig, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer, @@ -60,13 +58,14 @@ from ..extras.profiling import profiling_context, profiling_decorator from ..extras.vllm_client import VLLMClient from ..import_utils import is_vllm_available -from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation +from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation from ..models.utils import disable_gradient_checkpointing from .base_trainer import BaseTrainer from .callbacks import SyncRefModelCallback from .rloo_config import RLOOConfig from .utils import ( RepeatSampler, + create_model_from_path, disable_dropout_in_model, ensure_master_addr_port, entropy_from_logits, @@ -83,13 +82,9 @@ split_tensor_dict, unsplit_pixel_values_by_grid, ) -from .utils import ( - RepeatSampler, - create_model_from_path) if is_peft_available(): - from peft import PeftConfig, PeftModel from peft import PeftConfig, PeftModel, get_peft_model if is_vllm_available(): From 8be5fc0df84d94f0f1cc6233f35422630425d688 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 16 Dec 2025 08:43:46 -0700 Subject: [PATCH 3/3] Move `prepare_model_for_kbit_training`, `enable_gradient_checkpointing`, `prepare_peft_model` to `experimental.utils` (#4686) --- .../online_dpo/online_dpo_trainer.py | 10 +- trl/experimental/prm/prm_trainer.py | 2 +- trl/experimental/utils.py | 132 +++++++++++++++++- trl/models/__init__.py | 18 +-- trl/models/utils.py | 129 +---------------- 5 files changed, 136 insertions(+), 155 deletions(-) diff --git a/trl/experimental/online_dpo/online_dpo_trainer.py b/trl/experimental/online_dpo/online_dpo_trainer.py index cd94f645bfe..c292f7ead92 100644 --- a/trl/experimental/online_dpo/online_dpo_trainer.py +++ b/trl/experimental/online_dpo/online_dpo_trainer.py @@ -55,17 +55,11 @@ from ...extras.profiling import profiling_context from ...extras.vllm_client import VLLMClient from ...import_utils import is_vllm_available -from ...models.utils import ( - create_reference_model, - prepare_deepspeed, - prepare_fsdp, - prepare_peft_model, - unwrap_model_for_generation, -) +from ...models.utils import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation from ...trainer.base_trainer import BaseTrainer from ...trainer.utils import disable_dropout_in_model, empty_cache, ensure_master_addr_port, get_config_model_id, pad from ..judges import BasePairwiseJudge -from ..utils import SIMPLE_CHAT_TEMPLATE, DPODataCollatorWithPadding, truncate_right +from ..utils import SIMPLE_CHAT_TEMPLATE, DPODataCollatorWithPadding, prepare_peft_model, truncate_right from .online_dpo_config import OnlineDPOConfig diff --git a/trl/experimental/prm/prm_trainer.py b/trl/experimental/prm/prm_trainer.py index 26e251f929d..b6c77e7c45c 100644 --- a/trl/experimental/prm/prm_trainer.py +++ b/trl/experimental/prm/prm_trainer.py @@ -35,9 +35,9 @@ from transformers.trainer_utils import EvalPrediction from transformers.utils import is_peft_available -from ...models import prepare_peft_model from ...trainer.base_trainer import BaseTrainer from ...trainer.utils import disable_dropout_in_model +from ..utils import prepare_peft_model from .prm_config import PRMConfig diff --git a/trl/experimental/utils.py b/trl/experimental/utils.py index 43298703045..93881fe29a2 100644 --- a/trl/experimental/utils.py +++ b/trl/experimental/utils.py @@ -14,14 +14,23 @@ # This file contains utility classes and functions that are used across more than one experimental trainer or feature. +import inspect from dataclasses import dataclass from typing import Any import torch +from accelerate.utils import is_peft_model +from packaging import version from torch.nn.utils.rnn import pad_sequence -from transformers import PreTrainedTokenizerBase +from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainingArguments +from transformers.utils import is_peft_available -from ..trainer.utils import first_true_indices, pad +from ..trainer.utils import first_true_indices, pad, peft_module_casting_to_bf16 + + +if is_peft_available(): + import peft + from peft import PeftConfig, PeftModel, get_peft_model @dataclass @@ -306,3 +315,122 @@ def add_eos_token_if_needed( rejected_tokens["input_ids"].append(eos_token_id) rejected_tokens["attention_mask"].append(1) return chosen_tokens, rejected_tokens + + +def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None): + r""" + Prepare a k-bit quantized transformers model for training (PEFT/QLoRA). + """ + loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) + quant_methods = ["gptq", "aqlm", "eetq", "torchao", "hqq"] + is_quantized = getattr(model, "quantization_method", None) in quant_methods or getattr( + model, "hqq_quantized", False + ) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + + for _, param in model.named_parameters(): + # freeze all parameters + param.requires_grad = False + + # Enable gradient checkpointing if needed + if (loaded_in_kbit or is_quantized) and use_gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + # backward-compatible hook + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( + inspect.signature(model.gradient_checkpointing_enable).parameters + ) + gc_kwargs = {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} if supports_gc_kwargs else {} + model.gradient_checkpointing_enable(**gc_kwargs) + + return model + + +def enable_gradient_checkpointing( + model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None +) -> PreTrainedModel: + """Enables gradient checkpointing for the model.""" + # Enable gradient checkpointing on the base model for PEFT + if is_peft_model(model): + model.base_model.gradient_checkpointing_enable() + # Enable gradient checkpointing for non-PEFT models + else: + model.gradient_checkpointing_enable() + + gradient_checkpointing_kwargs = gradient_checkpointing_kwargs or {} + use_reentrant = ( + "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] + ) + + if use_reentrant: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + return model + + +def prepare_peft_model( + model: PreTrainedModel, peft_config: "PeftConfig | None", args: TrainingArguments +) -> PreTrainedModel: + """Prepares a model for PEFT training.""" + if not is_peft_available(): + raise ImportError("PEFT is required to use a peft model. Run `pip install peft`.") + + # If the model is already a PeftModel, we need to merge and unload it. + # Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft + if isinstance(model, PeftModel) and peft_config is not None: + model = model.merge_and_unload() + + # Handle quantized models (QLoRA) + is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False) + + is_sharded_qlora = False + if getattr(model, "is_loaded_in_4bit", False): + # Check if model is sharded (FSDP/DS-Zero3) + for _, param in model.named_parameters(): + if param.__class__.__name__ == "Params4bit": + is_sharded_qlora = param.data.device.type in {"cpu", "meta"} + break + + # Prepare model for kbit training if needed + if is_qlora and not is_sharded_qlora and not isinstance(model, PeftModel): + model = prepare_model_for_kbit_training( + model, + use_gradient_checkpointing=args.gradient_checkpointing, + gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs or {}, + ) + # Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training + args.gradient_checkpointing = False + elif args.gradient_checkpointing: + model = enable_gradient_checkpointing(model, args.gradient_checkpointing_kwargs) + + # Create PEFT model + if peft_config is not None: + if ( + version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12 + and getattr(model, "is_loaded_in_4bit", False) + and is_sharded_qlora + ): + model = get_peft_model(model, peft_config, autocast_adapter_dtype=False) + else: + model = get_peft_model(model, peft_config) + + # Handle bf16 casting for 4-bit models + if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora: + peft_module_casting_to_bf16(model) + + return model diff --git a/trl/models/__init__.py b/trl/models/__init__.py index 4dc013c891f..dccde7b9157 100644 --- a/trl/models/__init__.py +++ b/trl/models/__init__.py @@ -21,14 +21,7 @@ "activation_offloading": ["get_act_offloading_ctx_manager"], "modeling_base": ["PreTrainedModelWrapper"], "modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"], - "utils": [ - "create_reference_model", - "prepare_deepspeed", - "prepare_fsdp", - "prepare_model_for_kbit_training", - "prepare_peft_model", - "unwrap_model_for_generation", - ], + "utils": ["create_reference_model", "prepare_deepspeed", "prepare_fsdp", "unwrap_model_for_generation"], } @@ -36,14 +29,7 @@ from .activation_offloading import get_act_offloading_ctx_manager from .modeling_base import PreTrainedModelWrapper from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead - from .utils import ( - create_reference_model, - prepare_deepspeed, - prepare_fsdp, - prepare_model_for_kbit_training, - prepare_peft_model, - unwrap_model_for_generation, - ) + from .utils import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation else: import sys diff --git a/trl/models/utils.py b/trl/models/utils.py index 76674c00b6c..6b29cec527c 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import itertools import logging from collections.abc import Callable @@ -22,16 +21,9 @@ import torch import torch.nn as nn -from accelerate.utils import is_peft_model from packaging import version -from transformers import PreTrainedModel, TrainingArguments +from transformers import PreTrainedModel from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled -from transformers.utils import is_peft_available - - -if is_peft_available(): - import peft - from peft import PeftConfig, PeftModel, get_peft_model if TYPE_CHECKING: @@ -258,72 +250,6 @@ def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn. pass -def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None): - r""" - Prepare a k-bit quantized transformers model for training (PEFT/QLoRA). - """ - loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) - quant_methods = ["gptq", "aqlm", "eetq", "torchao", "hqq"] - is_quantized = getattr(model, "quantization_method", None) in quant_methods or getattr( - model, "hqq_quantized", False - ) - - if gradient_checkpointing_kwargs is None: - gradient_checkpointing_kwargs = {} - - for _, param in model.named_parameters(): - # freeze all parameters - param.requires_grad = False - - # Enable gradient checkpointing if needed - if (loaded_in_kbit or is_quantized) and use_gradient_checkpointing: - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - # backward-compatible hook - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( - inspect.signature(model.gradient_checkpointing_enable).parameters - ) - gc_kwargs = {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} if supports_gc_kwargs else {} - model.gradient_checkpointing_enable(**gc_kwargs) - - return model - - -def enable_gradient_checkpointing( - model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None -) -> PreTrainedModel: - """Enables gradient checkpointing for the model.""" - # Enable gradient checkpointing on the base model for PEFT - if is_peft_model(model): - model.base_model.gradient_checkpointing_enable() - # Enable gradient checkpointing for non-PEFT models - else: - model.gradient_checkpointing_enable() - - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs or {} - use_reentrant = ( - "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] - ) - - if use_reentrant: - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - return model - - def peft_module_casting_to_bf16(model): for name, module in model.named_modules(): if isinstance(module, torch.nn.LayerNorm) or "norm" in name: @@ -334,59 +260,6 @@ def peft_module_casting_to_bf16(model): module = module.to(torch.bfloat16) -def prepare_peft_model( - model: PreTrainedModel, peft_config: "PeftConfig | None", args: TrainingArguments -) -> PreTrainedModel: - """Prepares a model for PEFT training.""" - if not is_peft_available(): - raise ImportError("PEFT is required to use a peft model. Run `pip install peft`.") - - # If the model is already a PeftModel, we need to merge and unload it. - # Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft - if isinstance(model, PeftModel) and peft_config is not None: - model = model.merge_and_unload() - - # Handle quantized models (QLoRA) - is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False) - - is_sharded_qlora = False - if getattr(model, "is_loaded_in_4bit", False): - # Check if model is sharded (FSDP/DS-Zero3) - for _, param in model.named_parameters(): - if param.__class__.__name__ == "Params4bit": - is_sharded_qlora = param.data.device.type in {"cpu", "meta"} - break - - # Prepare model for kbit training if needed - if is_qlora and not is_sharded_qlora and not isinstance(model, PeftModel): - model = prepare_model_for_kbit_training( - model, - use_gradient_checkpointing=args.gradient_checkpointing, - gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs or {}, - ) - # Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training - args.gradient_checkpointing = False - elif args.gradient_checkpointing: - model = enable_gradient_checkpointing(model, args.gradient_checkpointing_kwargs) - - # Create PEFT model - if peft_config is not None: - if ( - version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12 - and getattr(model, "is_loaded_in_4bit", False) - and is_sharded_qlora - ): - model = get_peft_model(model, peft_config, autocast_adapter_dtype=False) - else: - model = get_peft_model(model, peft_config) - - # Handle bf16 casting for 4-bit models - if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora: - peft_module_casting_to_bf16(model) - - return model - - @contextmanager def disable_gradient_checkpointing(model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None = None): """