Skip to content

Commit

Permalink
FIX [PEFT / Trainer ] Handle better peft + quantized compiled mod…
Browse files Browse the repository at this point in the history
…els (#29055)

* handle peft + compiled models

* add tests

* fixup

* adapt from suggestions

* clarify comment
  • Loading branch information
younesbelkada authored Feb 20, 2024
1 parent 5e95dca commit efdd436
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,12 @@ def __init__(
getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
)

# Filter out quantized + compiled models
if _is_quantized_and_base_model and hasattr(model, "_orig_mod"):
raise ValueError(
"You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT"
)

# At this stage the model is already loaded
if _is_quantized_and_base_model and not _is_peft_model(model):
raise ValueError(
Expand Down
37 changes: 37 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
require_deepspeed,
require_intel_extension_for_pytorch,
require_optuna,
require_peft,
require_ray,
require_safetensors,
require_sentencepiece,
Expand Down Expand Up @@ -873,6 +874,42 @@ def test_number_of_steps_in_training_with_ipex(self):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)

@require_peft
@require_bitsandbytes
def test_bnb_compile(self):
from peft import LoraConfig, get_peft_model

# Simply tests if initializing a Trainer with a PEFT + compiled model works out of the box
# QLoRA + torch compile is not really supported yet, but we should at least support the model
# loading and let torch throw the
tiny_model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM", load_in_4bit=True
)

peft_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
tiny_model = get_peft_model(tiny_model, peft_config)

tiny_model = torch.compile(tiny_model)

x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
tmp_dir,
learning_rate=1e-9,
logging_steps=5,
)
with self.assertRaises(ValueError):
_ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa

@require_bitsandbytes
def test_rmsprop_bnb(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
Expand Down

0 comments on commit efdd436

Please sign in to comment.