diff --git a/optimum/habana/trl/trainer/dpo_trainer.py b/optimum/habana/trl/trainer/dpo_trainer.py index e5cfea0cd3..878689430f 100644 --- a/optimum/habana/trl/trainer/dpo_trainer.py +++ b/optimum/habana/trl/trainer/dpo_trainer.py @@ -35,7 +35,6 @@ DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length, - peft_module_casting_to_bf16, ) from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments @@ -155,7 +154,7 @@ def make_inputs_require_grad(module, input, output): # get peft model with the given config model = get_peft_model(model, peft_config) if args.bf16: - peft_module_casting_to_bf16(model) + model = model.to(torch.bfloat16) # For models that use gradient_checkpoiting, we need to attach a hook that enables input # to explicitly have `requires_grad=True`, otherwise training will either silently diff --git a/optimum/habana/trl/trainer/sft_trainer.py b/optimum/habana/trl/trainer/sft_trainer.py index 49b2525f4c..05e24a9155 100644 --- a/optimum/habana/trl/trainer/sft_trainer.py +++ b/optimum/habana/trl/trainer/sft_trainer.py @@ -33,7 +33,6 @@ from trl.import_utils import is_peft_available from trl.trainer.utils import ( DataCollatorForCompletionOnlyLM, - peft_module_casting_to_bf16, ) @@ -143,7 +142,7 @@ def make_inputs_require_grad(module, input, output): model = get_peft_model(model, peft_config) if args.bf16: - peft_module_casting_to_bf16(model) + model = model.to(torch.bfloat16) if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)