Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,13 @@ def LlamaAttention_fast_forward_inference(
# Knn, Vnn = Knn, Vnn
# pass

# when qlen==vlen and attn_mask is None, we should use causal attention
Q_len = Qn.shape[-2]
K_len = Knn.shape[-2]
if attention_mask is None and Q_len == K_len:
is_causal = True
else:
is_causal = False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Attention
if bsz == 1:
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
Expand Down Expand Up @@ -518,11 +525,18 @@ def LlamaAttention_fast_forward(
V = V.transpose(1, 2)
A = flash_attn_func(Q, K, V, causal = True)
else:
# when qlen==vlen and attn_mask is None, we should use causal attention
Q_len = Q.shape[-2]
K_len = K.shape[-2]
if attention_mask is None and Q_len == K_len:
is_causal = True
else:
is_causal = False
Comment thread
leizhenyuan marked this conversation as resolved.
# Grouped query attention
if SDPA_HAS_GQA:
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1)
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = is_causal, enable_gqa = n_groups != 1)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2)#.contiguous()
else:
Expand All @@ -537,7 +551,7 @@ def LlamaAttention_fast_forward(
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = is_causal)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
pass
Expand Down