diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8ceee2d1d45a..19c35c667b56 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -285,6 +285,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" + cached_mask = None + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config @@ -358,11 +360,6 @@ def forward( use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: @@ -418,12 +415,19 @@ def forward( f" {attn_weights.size()}" ) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + if (attention_mask is None and LlamaAttention.cached_mask is None) or self.layer_idx == 0: + # create the 4d mask and cache it + attention_mask = LlamaAttention.cached_mask = _prepare_4d_causal_attention_mask( + attention_mask, (bsz, q_len), hidden_states, kv_seq_len - q_len + ) + elif LlamaAttention.cached_mask is not None and self.layer_idx != 0: # use the cached value + attention_mask = LlamaAttention.cached_mask + + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -544,6 +548,9 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) + if 0 not in attention_mask: + attention_mask = None + attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate ) @@ -711,17 +718,20 @@ def forward( value_states = repeat_kv(value_states, self.num_key_value_groups) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + if LlamaAttention.cached_mask is None or self.layer_idx == 0: + # create the 4d mask and cache it + attention_mask = LlamaAttention.cached_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, (bsz, q_len), hidden_states, kv_seq_len - q_len ) + elif LlamaAttention.cached_mask is not None: + attention_mask = LlamaAttention.cached_mask - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda": + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, @@ -955,8 +965,6 @@ def __init__(self, config: LlamaConfig): self.layers = nn.ModuleList( [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._use_sdpa = config._attn_implementation == "sdpa" - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -1024,24 +1032,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - # embed positions hidden_states = inputs_embeds