Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 30 additions & 40 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

cached_mask = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as seen with you offline, this cannot work as-is due to sharing across model instances

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes will either not cache it at the class level but instance level a tril. Or pass it as kwargs. Jax does not seem to care so should not be too bad

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we pass the attention_mask just into the cache.update(...) function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check that as well, but that doesn't help for all type of attention we have which need a pre-processed mask, will work after the pre-processing tho


def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
Expand Down Expand Up @@ -358,11 +360,6 @@ def forward(
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)

bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
Expand Down Expand Up @@ -418,12 +415,19 @@ def forward(
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
if (attention_mask is None and LlamaAttention.cached_mask is None) or self.layer_idx == 0:
# create the 4d mask and cache it
attention_mask = LlamaAttention.cached_mask = _prepare_4d_causal_attention_mask(
attention_mask, (bsz, q_len), hidden_states, kv_seq_len - q_len
)
elif LlamaAttention.cached_mask is not None and self.layer_idx != 0: # use the cached value
attention_mask = LlamaAttention.cached_mask

if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
Expand Down Expand Up @@ -544,6 +548,9 @@ def forward(
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

if 0 not in attention_mask:
attention_mask = None

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)
Expand Down Expand Up @@ -711,17 +718,20 @@ def forward(
value_states = repeat_kv(value_states, self.num_key_value_groups)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
if LlamaAttention.cached_mask is None or self.layer_idx == 0:
# create the 4d mask and cache it
attention_mask = LlamaAttention.cached_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, (bsz, q_len), hidden_states, kv_seq_len - q_len
)
elif LlamaAttention.cached_mask is not None:
attention_mask = LlamaAttention.cached_mask

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda":
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
Expand Down Expand Up @@ -955,8 +965,6 @@ def __init__(self, config: LlamaConfig):
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self._use_sdpa = config._attn_implementation == "sdpa"
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
Expand Down Expand Up @@ -1024,24 +1032,6 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)

# embed positions
hidden_states = inputs_embeds

Expand Down