diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 883c51560cb..8862baa8119 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -237,7 +237,14 @@ def make_inputs_require_grad(module, input, output): model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - model = get_peft_model(model, peft_config) + if ( + "autocast_adapter_dtype" in list(inspect.signature(get_peft_model).parameters) + and getattr(model, "is_loaded_in_4bit", False) + and is_sharded_qlora + ): + model = get_peft_model(model, peft_config, autocast_adapter_dtype=False) + else: + model = get_peft_model(model, peft_config) if ( args is not None and args.bf16