Skip to content
66 changes: 58 additions & 8 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from safetensors import safe_open
from safetensors.torch import save_file as safe_save_file
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import Cache, DynamicCache, EncoderDecoderCache, PreTrainedModel
from transformers import Cache, DynamicCache, EncoderDecoderCache, HybridCache, PreTrainedModel
from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
from transformers.utils import PushToHubMixin

Expand Down Expand Up @@ -676,7 +676,9 @@ def get_prompt_embedding_to_save(self, adapter_name: str) -> torch.Tensor:

return prompt_embeddings[0].detach().cpu()

def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
def get_prompt(
self, batch_size: int, task_ids: Optional[torch.Tensor] = None, max_cache_len: Optional[int] = None
) -> torch.Tensor:
"""
Returns the virtual prompts to use for Peft. Only applicable when using a prompt learning method.
"""
Expand Down Expand Up @@ -705,12 +707,42 @@ def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -
)
if peft_config.num_transformer_submodules == 2:
past_key_values = torch.cat([past_key_values, past_key_values], dim=2)

# Transpose: 2 x [num_layers, batch_size, num_heads, num_virtual_tokens, head_dim]
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
peft_config.num_transformer_submodules * 2
)

base_model = self.get_base_model()
model_config = getattr(base_model, "config", None)
model_type = getattr(model_config, "model_type", "")
if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
past_key_values = post_process_fn(past_key_values)
elif ("gemma2" in model_type) or ("gemma3_text" in model_type):
# Gemma2 and Gemma3 only support HybridCache (which does not have the from_legacy_cache method)
if max_cache_len is None:
raise ValueError(
"max_cache_len is None but it should have been passed. Something went wrong, please open an "
"issue on GitHub with a reproducer: https://github.com/huggingface/peft/issues"
)
base_config = base_model.config
if hasattr(base_config, "get_text_config"):
base_config = base_config.get_text_config()
new_cache = HybridCache(
base_config,
max_batch_size=batch_size,
max_cache_len=max_cache_len,
dtype=past_key_values[0].dtype,
device=past_key_values[0].device,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the device can be different per layer which is currently handled by layer_device_map dict in kwargs. I am not sure how common it will be for users to do multiGPU peft tho, feel free to ignore if not needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We add this to to map the device, is it still enough?

map_cache_to_layer_device_map(self.get_base_model(), past_key_values) # no-op if not a Cache instance

)
cache_position = torch.arange(peft_config.num_virtual_tokens)
for layer_idx in range(peft_config.num_layers):
key_states, value_states = past_key_values[0][layer_idx], past_key_values[1][layer_idx]
new_cache.update(
key_states, value_states, layer_idx, cache_kwargs={"cache_position": cache_position}
)
past_key_values = new_cache
elif peft_config.num_transformer_submodules == 1:
# Dont' apply this to encoder-decoder models and not to models requiring special processing.
# local import in case users use a very old transformers version
Expand Down Expand Up @@ -1810,7 +1842,12 @@ def forward(

if peft_config.peft_type == PeftType.PREFIX_TUNING:
# overwrite past_kv in kwargs
kwargs["past_key_values"] = self.get_prompt(batch_size)
# some archs require max_cache_len to re-initialize the cache
if input_ids is not None:
max_cache_len = input_ids.shape[1] + peft_config.num_virtual_tokens
else:
max_cache_len = inputs_embeds.shape[1] + peft_config.num_virtual_tokens
kwargs["past_key_values"] = self.get_prompt(batch_size, max_cache_len=max_cache_len)
return self.base_model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
elif peft_config.peft_type == PeftType.CPT:
return self._cpt_forward(input_ids, inputs_embeds, peft_config, task_ids, batch_size, **kwargs)
Expand Down Expand Up @@ -1936,12 +1973,20 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
if seq_len >= model_kwargs["input_ids"].shape[1]:
model_kwargs["input_ids"] = model_kwargs["input_ids"][:, -1:]

if model_kwargs.get("attention_mask", None) is not None:
if (attention_mask := model_kwargs.get("attention_mask", None)) is not None:
size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
prefix_attention_mask = torch.ones(size).to(model_kwargs["input_ids"].device)
model_kwargs["attention_mask"] = torch.cat(
(prefix_attention_mask, model_kwargs["attention_mask"]), dim=1
)
if attention_mask.dim() == 4:
# Transform the 4d attention mask to 2d, leave it up to the model to deal with it instead of trying
# to create a 4d attention mask here.
# from [batch_size, heads, input_ids_length, total_sequence_length]
# to [batch_size, total_sequence_length]
bs = attention_mask.shape[0]
total_seq_len = prefix_attention_mask.shape[1] + attention_mask.shape[2]
model_kwargs["attention_mask"] = torch.ones((bs, total_seq_len), dtype=attention_mask.dtype)
else:
# 2d attention mask
model_kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)

if model_kwargs.get("position_ids", None) is not None:
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
Expand All @@ -1960,7 +2005,12 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
)

if requires_prompt_injection and peft_config.peft_type == PeftType.PREFIX_TUNING:
new_past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])
# some archs require max_cache_len to re-initialize the cache
max_cache_len = getattr(model_kwargs.get("past_key_values", None), "max_cache_len", None)
new_past_key_values = self.get_prompt(
batch_size=model_kwargs["input_ids"].shape[0],
max_cache_len=max_cache_len,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, does this update cache_position or we don't have cache_position at this point yet? Gemma3 like models rely on correct cache position to update cache, unlike dynamic cache where new key/value is simply appended at the end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We pop the cache_position below, which I think should result in the model assuming it starts at 0:

_ = model_kwargs.pop("cache_position", None)

model_kwargs["past_key_values"] = new_past_key_values
elif requires_prompt_injection:
inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
Expand Down
18 changes: 18 additions & 0 deletions src/peft/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"post_feedforward_layernorm",
"norm",
],
"gemma3_text": [
"input_layernorm",
"post_attention_layernorm",
"pre_feedforward_layernorm",
"post_feedforward_layernorm",
"norm",
],
"qwen2": ["post_attention_layernorm"],
}

Expand Down Expand Up @@ -116,6 +123,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"phi": ["q_proj", "v_proj", "fc1", "fc2"],
"gemma": ["q_proj", "v_proj"],
"gemma2": ["q_proj", "v_proj"],
"gemma3_text": ["q_proj", "v_proj"],
"qwen2": ["q_proj", "v_proj"],
}

Expand Down Expand Up @@ -144,6 +152,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"phi": ["q_proj", "v_proj", "fc2"],
"gemma": ["q_proj", "v_proj", "down_proj"],
"gemma2": ["q_proj", "v_proj", "down_proj"],
"gemma3_text": ["q_proj", "v_proj", "down_proj"],
"qwen2": ["q_proj", "v_proj", "down_proj"],
}

Expand Down Expand Up @@ -172,6 +181,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"phi": ["fc2"],
"gemma": ["down_proj"],
"gemma2": ["down_proj"],
"gemma3_text": ["down_proj"],
"qwen2": ["down_proj"],
}

Expand All @@ -195,6 +205,9 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"gpt_bigcode": ["c_attn"],
"deberta": ["in_proj"],
# "layoutlm": ["query", "value"],
"gemma": ["q_proj", "v_proj"],
"gemma2": ["q_proj", "v_proj"],
"gemma3_text": ["q_proj", "v_proj"],
"qwen2": ["q_proj", "v_proj"],
}

Expand Down Expand Up @@ -232,6 +245,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"phi": ["q_proj", "v_proj"],
"gemma": ["q_proj", "v_proj"],
"gemma2": ["q_proj", "v_proj"],
"gemma3_text": ["q_proj", "v_proj"],
"qwen2": ["q_proj", "v_proj"],
}

Expand Down Expand Up @@ -268,6 +282,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"phi": ["q_proj", "v_proj", "fc1", "fc2"],
"gemma": ["q_proj", "v_proj"],
"gemma2": ["q_proj", "v_proj"],
"gemma3_text": ["q_proj", "v_proj"],
"qwen2": ["q_proj", "v_proj"],
}

Expand All @@ -288,6 +303,9 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"],
"gpt_bigcode": ["c_attn"],
"deberta": ["in_proj"],
"gemma": ["q_proj", "v_proj"],
"gemma2": ["q_proj", "v_proj"],
"gemma3_text": ["q_proj", "v_proj"],
"qwen2": ["q_proj", "v_proj"],
}

Expand Down
2 changes: 1 addition & 1 deletion tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@

PEFT_DECODER_MODELS_TO_TEST = [
"hf-internal-testing/tiny-random-OPTForCausalLM",
"hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"hf-internal-testing/tiny-random-GPT2LMHeadModel",
"hf-internal-testing/tiny-random-BloomForCausalLM",
"hf-internal-testing/tiny-random-gpt_neo",
"hf-internal-testing/tiny-random-GPTJForCausalLM",
"hf-internal-testing/tiny-random-GPTBigCodeForCausalLM",
"trl-internal-testing/tiny-random-LlamaForCausalLM",
"peft-internal-testing/tiny-dummy-qwen2",
"hf-internal-testing/tiny-random-Gemma3ForCausalLM",
]

SMALL_GRID_MODELS = [
Expand Down
26 changes: 23 additions & 3 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,10 @@ def _test_merge_layers_nan(self, model_id, config_cls, config_kwargs):
if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig):
self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)")

if "gemma" in model_id.lower():
# TODO: could be related to tied weights
self.skipTest("Merging currently fails with gemma")

with hub_online_once(model_id):
model = self.transformers_class.from_pretrained(model_id)
config = config_cls(
Expand Down Expand Up @@ -792,6 +796,10 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs):
if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig):
self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)")

if "gemma" in model_id.lower():
# TODO: could be related to tied weights
self.skipTest("Merging currently fails with gemma")

with hub_online_once(model_id):
model = self.transformers_class.from_pretrained(model_id)
config = config_cls(
Expand Down Expand Up @@ -1326,7 +1334,7 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs):
nb_trainable = 0

for n, param in model.named_parameters():
if "lora" in n or (has_trainable_tokens and "trainable_tokens" in n):
if model.prefix in n or (has_trainable_tokens and "trainable_tokens" in n):
assert param.grad is not None
nb_trainable += 1
else:
Expand All @@ -1343,6 +1351,7 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs):
logits_from_pretrained = model_from_pretrained(**inputs)[0][0]
assert torch.allclose(logits, logits_from_pretrained, atol=1e-4, rtol=1e-4)

# check the nb of trainable params again but without layers_to_transform
model = self.transformers_class.from_pretrained(model_id)
config = config_cls(
base_model_name_or_path=model_id,
Expand All @@ -1352,10 +1361,16 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs):
nb_trainable_all = 0

for n, param in model.named_parameters():
if "lora" in n or (has_trainable_tokens and "trainable_tokens" in n):
if model.prefix in n or (has_trainable_tokens and "trainable_tokens" in n):
nb_trainable_all += 1

assert nb_trainable < nb_trainable_all
mod_list = next((m for m in model.modules() if isinstance(m, torch.nn.ModuleList)), None)
if mod_list and len(mod_list) == 1:
# there is only a single layer
assert nb_trainable == nb_trainable_all
else:
# more than 1 layer, i.e. setting layers_to_transform=[0] should target fewer layers
assert nb_trainable < nb_trainable_all

def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
if config_cls == PrefixTuningConfig:
Expand Down Expand Up @@ -1433,6 +1448,9 @@ def _test_peft_model_device_map(self, model_id, config_cls, config_kwargs):
def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwargs):
if not issubclass(config_cls, PromptLearningConfig):
return pytest.skip(f"Test not applicable for {config_cls}")
if ("gemma" in model_id.lower()) and (config_cls == PrefixTuningConfig):
# TODO might be caused by the 4d causal attention mask of gemma
return pytest.skip("Prefix tuning + gemma is currently failing")

with hub_online_once(model_id):
model = self.transformers_class.from_pretrained(model_id)
Expand Down Expand Up @@ -1885,6 +1903,8 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
if model_id.endswith("qwen2"):
# Qwen2 fails with weighted adapter combinations using SVD
return pytest.skip(f"Test does not work with model {model_id}")
if "gemma" in model_id.lower():
return pytest.skip("Combining Gemma adapters with SVD is currently failing")

adapter_list = ["adapter1", "adapter_2", "adapter_3"]
weight_list = [0.5, 1.5, 1.5]
Expand Down
Loading