Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions unsloth_zoo/temporary_patches/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,36 @@
_UNSLOTH_FLEX_ATTENTION_DISABLED = os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0"


def _prepare_gemma3_sdpa_attention_mask(attention_mask, query_states, key_states, sliding_window=None):
if attention_mask is None or attention_mask.dim() != 2:
return attention_mask

q_len = query_states.shape[2]
kv_len = key_states.shape[2]
mask_len = attention_mask.shape[-1]
if mask_len < kv_len:
pad = torch.ones(
(attention_mask.shape[0], kv_len - mask_len),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
Comment on lines +50 to +54
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Pad missing SDPA mask positions as masked entries

When mask_len < kv_len, this branch pads the 2D attention mask with 1s, which marks the extra KV positions as attendable. In cache-backed generation (notably static/preallocated caches), those extra positions can correspond to not-yet-written cache slots, so the model may attend to invalid keys/values and produce incorrect logits. The fix is to pad with masked values (0 for this boolean-style mask) or derive validity from cache_position instead of assuming new positions are visible.

Useful? React with 👍 / 👎.

attention_mask = torch.cat((attention_mask, pad), dim=-1)
elif mask_len > kv_len:
attention_mask = attention_mask[:, -kv_len:]

padding_mask = attention_mask[:, None, None, :].to(query_states.device) != 0
if q_len == 1:
return padding_mask

q_positions = torch.arange(q_len, device=query_states.device)[:, None]
k_positions = torch.arange(kv_len, device=query_states.device)[None, :]
cache_offset = kv_len - q_len
causal_mask = k_positions <= (q_positions + cache_offset)
if sliding_window is not None:
causal_mask = causal_mask & (k_positions > (q_positions + cache_offset - sliding_window))
return padding_mask & causal_mask[None, None, :, :]
Comment on lines +42 to +69
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _prepare_gemma3_sdpa_attention_mask function has two significant issues related to sliding_window support:

  1. Sliding window ignored when attention_mask is None: If no attention mask is provided but a sliding_window is specified, the function returns None. This causes the calling code to set is_causal=True for SDPA, which uses a standard causal mask that does not account for the sliding window constraint.
  2. Sliding window ignored during generation (q_len == 1): The early return at line 60 only returns the padding_mask. This allows the model to attend to all previous tokens in the KV cache, violating the sliding window constraint if the sequence length exceeds the window size.

Additionally, moving the attention_mask to the target device at the beginning of the function would avoid performing operations like torch.cat on the CPU if the mask hasn't been moved yet.

def _prepare_gemma3_sdpa_attention_mask(attention_mask, query_states, key_states, sliding_window=None):
    if attention_mask is None:
        if sliding_window is None:
            return None
        # Create a default mask to trigger sliding window mask generation
        attention_mask = torch.ones(
            (query_states.shape[0], key_states.shape[2]),
            dtype=torch.bool,
            device=query_states.device,
        )
    elif attention_mask.dim() != 2:
        return attention_mask

    q_len = query_states.shape[2]
    kv_len = key_states.shape[2]
    
    # Move to device early to ensure all operations are on GPU
    attention_mask = attention_mask.to(query_states.device)
    
    mask_len = attention_mask.shape[-1]
    if mask_len < kv_len:
        pad = torch.ones(
            (attention_mask.shape[0], kv_len - mask_len),
            dtype=attention_mask.dtype,
            device=attention_mask.device,
        )
        attention_mask = torch.cat((attention_mask, pad), dim=-1)
    elif mask_len > kv_len:
        attention_mask = attention_mask[:, -kv_len:]

    padding_mask = attention_mask[:, None, None, :] != 0
    
    # Optimization: if q_len == 1 and no sliding window, causal mask is all True
    if q_len == 1 and sliding_window is None:
        return padding_mask

    q_positions = torch.arange(q_len, device=query_states.device)[:, None]
    k_positions = torch.arange(kv_len, device=query_states.device)[None, :]
    cache_offset = kv_len - q_len
    causal_mask = k_positions <= (q_positions + cache_offset)
    if sliding_window is not None:
        causal_mask = causal_mask & (k_positions > (q_positions + cache_offset - sliding_window))
    return padding_mask & causal_mask[None, None, :, :]



def _make_gemma3_attn_forwards(forward_function, has_cache_position):
"""Build past_key_value / past_key_values forward variants for Gemma3Attention."""
functions = []
Expand Down Expand Up @@ -530,6 +560,12 @@ def forward_function(
**kwargs,
)
else:
attn_mask_for_sdpa = _prepare_gemma3_sdpa_attention_mask(
attn_mask_for_sdpa,
query_states_fp32,
key_states_fp32,
getattr(self, "sliding_window", None),
)
is_causal = query_states_fp32.shape[2] > 1 and attn_mask_for_sdpa is None and getattr(self, "is_causal", True)
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
Expand Down Expand Up @@ -766,6 +802,12 @@ def forward_function(
**kwargs,
)
else:
attn_mask_for_sdpa = _prepare_gemma3_sdpa_attention_mask(
attn_mask_for_sdpa,
query_states_fp32,
key_states_fp32,
getattr(self, "sliding_window", None),
)
is_causal = query_states_fp32.shape[2] > 1 and attn_mask_for_sdpa is None and getattr(self, "is_causal", True)
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
Expand Down