-
Notifications
You must be signed in to change notification settings - Fork 2.2k
FIX Prompt learning issue with 4d attention mask #2458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
450059d
8aa14a1
2b3fbc9
337febd
352a140
a323d8f
25e36ae
c43e0c3
ee8933c
8d5e36f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -34,7 +34,7 @@ | |||
| from safetensors import safe_open | ||||
| from safetensors.torch import save_file as safe_save_file | ||||
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||||
| from transformers import Cache, DynamicCache, EncoderDecoderCache, PreTrainedModel | ||||
| from transformers import Cache, DynamicCache, EncoderDecoderCache, HybridCache, PreTrainedModel | ||||
| from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput | ||||
| from transformers.utils import PushToHubMixin | ||||
|
|
||||
|
|
@@ -676,7 +676,9 @@ def get_prompt_embedding_to_save(self, adapter_name: str) -> torch.Tensor: | |||
|
|
||||
| return prompt_embeddings[0].detach().cpu() | ||||
|
|
||||
| def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -> torch.Tensor: | ||||
| def get_prompt( | ||||
| self, batch_size: int, task_ids: Optional[torch.Tensor] = None, max_cache_len: Optional[int] = None | ||||
| ) -> torch.Tensor: | ||||
| """ | ||||
| Returns the virtual prompts to use for Peft. Only applicable when using a prompt learning method. | ||||
| """ | ||||
|
|
@@ -705,12 +707,42 @@ def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) - | |||
| ) | ||||
| if peft_config.num_transformer_submodules == 2: | ||||
| past_key_values = torch.cat([past_key_values, past_key_values], dim=2) | ||||
|
|
||||
| # Transpose: 2 x [num_layers, batch_size, num_heads, num_virtual_tokens, head_dim] | ||||
| past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split( | ||||
| peft_config.num_transformer_submodules * 2 | ||||
| ) | ||||
|
|
||||
| base_model = self.get_base_model() | ||||
| model_config = getattr(base_model, "config", None) | ||||
| model_type = getattr(model_config, "model_type", "") | ||||
| if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None: | ||||
| post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type] | ||||
| past_key_values = post_process_fn(past_key_values) | ||||
| elif ("gemma2" in model_type) or ("gemma3_text" in model_type): | ||||
| # Gemma2 and Gemma3 only support HybridCache (which does not have the from_legacy_cache method) | ||||
| if max_cache_len is None: | ||||
| raise ValueError( | ||||
| "max_cache_len is None but it should have been passed. Something went wrong, please open an " | ||||
| "issue on GitHub with a reproducer: https://github.com/huggingface/peft/issues" | ||||
| ) | ||||
| base_config = base_model.config | ||||
| if hasattr(base_config, "get_text_config"): | ||||
| base_config = base_config.get_text_config() | ||||
| new_cache = HybridCache( | ||||
| base_config, | ||||
| max_batch_size=batch_size, | ||||
| max_cache_len=max_cache_len, | ||||
| dtype=past_key_values[0].dtype, | ||||
| device=past_key_values[0].device, | ||||
| ) | ||||
| cache_position = torch.arange(peft_config.num_virtual_tokens) | ||||
| for layer_idx in range(peft_config.num_layers): | ||||
| key_states, value_states = past_key_values[0][layer_idx], past_key_values[1][layer_idx] | ||||
| new_cache.update( | ||||
| key_states, value_states, layer_idx, cache_kwargs={"cache_position": cache_position} | ||||
| ) | ||||
| past_key_values = new_cache | ||||
| elif peft_config.num_transformer_submodules == 1: | ||||
| # Dont' apply this to encoder-decoder models and not to models requiring special processing. | ||||
| # local import in case users use a very old transformers version | ||||
|
|
@@ -1810,7 +1842,12 @@ def forward( | |||
|
|
||||
| if peft_config.peft_type == PeftType.PREFIX_TUNING: | ||||
| # overwrite past_kv in kwargs | ||||
| kwargs["past_key_values"] = self.get_prompt(batch_size) | ||||
| # some archs require max_cache_len to re-initialize the cache | ||||
| if input_ids is not None: | ||||
| max_cache_len = input_ids.shape[1] + peft_config.num_virtual_tokens | ||||
| else: | ||||
| max_cache_len = inputs_embeds.shape[1] + peft_config.num_virtual_tokens | ||||
| kwargs["past_key_values"] = self.get_prompt(batch_size, max_cache_len=max_cache_len) | ||||
| return self.base_model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs) | ||||
| elif peft_config.peft_type == PeftType.CPT: | ||||
| return self._cpt_forward(input_ids, inputs_embeds, peft_config, task_ids, batch_size, **kwargs) | ||||
|
|
@@ -1936,12 +1973,20 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor] | |||
| if seq_len >= model_kwargs["input_ids"].shape[1]: | ||||
| model_kwargs["input_ids"] = model_kwargs["input_ids"][:, -1:] | ||||
|
|
||||
| if model_kwargs.get("attention_mask", None) is not None: | ||||
| if (attention_mask := model_kwargs.get("attention_mask", None)) is not None: | ||||
| size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens | ||||
| prefix_attention_mask = torch.ones(size).to(model_kwargs["input_ids"].device) | ||||
| model_kwargs["attention_mask"] = torch.cat( | ||||
| (prefix_attention_mask, model_kwargs["attention_mask"]), dim=1 | ||||
| ) | ||||
| if attention_mask.dim() == 4: | ||||
| # Transform the 4d attention mask to 2d, leave it up to the model to deal with it instead of trying | ||||
| # to create a 4d attention mask here. | ||||
| # from [batch_size, heads, input_ids_length, total_sequence_length] | ||||
| # to [batch_size, total_sequence_length] | ||||
| bs = attention_mask.shape[0] | ||||
| total_seq_len = prefix_attention_mask.shape[1] + attention_mask.shape[2] | ||||
| model_kwargs["attention_mask"] = torch.ones((bs, total_seq_len), dtype=attention_mask.dtype) | ||||
| else: | ||||
| # 2d attention mask | ||||
| model_kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1) | ||||
|
|
||||
| if model_kwargs.get("position_ids", None) is not None: | ||||
| warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") | ||||
|
|
@@ -1960,7 +2005,12 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor] | |||
| ) | ||||
|
|
||||
| if requires_prompt_injection and peft_config.peft_type == PeftType.PREFIX_TUNING: | ||||
| new_past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0]) | ||||
| # some archs require max_cache_len to re-initialize the cache | ||||
| max_cache_len = getattr(model_kwargs.get("past_key_values", None), "max_cache_len", None) | ||||
| new_past_key_values = self.get_prompt( | ||||
| batch_size=model_kwargs["input_ids"].shape[0], | ||||
| max_cache_len=max_cache_len, | ||||
| ) | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw, does this update
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We pop the Line 1976 in 3a67a44
|
||||
| model_kwargs["past_key_values"] = new_past_key_values | ||||
| elif requires_prompt_injection: | ||||
| inputs_embeds = self.word_embeddings(model_kwargs["input_ids"]) | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the device can be different per layer which is currently handled by
layer_device_mapdict in kwargs. I am not sure how common it will be for users to do multiGPU peft tho, feel free to ignore if not neededThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We add this to to map the device, is it still enough?
peft/src/peft/peft_model.py
Line 726 in 3a67a44