diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d0ff413925..c4b8581147 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -318,6 +318,13 @@ def LlamaAttention_fast_forward_inference( # Knn, Vnn = Knn, Vnn # pass + # when qlen==vlen and attn_mask is None, we should use causal attention + Q_len = Qn.shape[-2] + K_len = Knn.shape[-2] + if attention_mask is None and Q_len == K_len: + is_causal = True + else: + is_causal = False # Attention if bsz == 1: Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 @@ -518,11 +525,18 @@ def LlamaAttention_fast_forward( V = V.transpose(1, 2) A = flash_attn_func(Q, K, V, causal = True) else: + # when qlen==vlen and attn_mask is None, we should use causal attention + Q_len = Q.shape[-2] + K_len = K.shape[-2] + if attention_mask is None and Q_len == K_len: + is_causal = True + else: + is_causal = False # Grouped query attention if SDPA_HAS_GQA: # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = is_causal, enable_gqa = n_groups != 1) # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2)#.contiguous() else: @@ -537,7 +551,7 @@ def LlamaAttention_fast_forward( Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = is_causal) # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass