-
Notifications
You must be signed in to change notification settings - Fork 265
Mask for gemma3 attn #635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mask for gemma3 attn #635
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,36 @@ | |
| _UNSLOTH_FLEX_ATTENTION_DISABLED = os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0" | ||
|
|
||
|
|
||
| def _prepare_gemma3_sdpa_attention_mask(attention_mask, query_states, key_states, sliding_window=None): | ||
| if attention_mask is None or attention_mask.dim() != 2: | ||
| return attention_mask | ||
|
|
||
| q_len = query_states.shape[2] | ||
| kv_len = key_states.shape[2] | ||
| mask_len = attention_mask.shape[-1] | ||
| if mask_len < kv_len: | ||
| pad = torch.ones( | ||
| (attention_mask.shape[0], kv_len - mask_len), | ||
| dtype=attention_mask.dtype, | ||
| device=attention_mask.device, | ||
| ) | ||
| attention_mask = torch.cat((attention_mask, pad), dim=-1) | ||
| elif mask_len > kv_len: | ||
| attention_mask = attention_mask[:, -kv_len:] | ||
|
|
||
| padding_mask = attention_mask[:, None, None, :].to(query_states.device) != 0 | ||
| if q_len == 1: | ||
| return padding_mask | ||
|
|
||
| q_positions = torch.arange(q_len, device=query_states.device)[:, None] | ||
| k_positions = torch.arange(kv_len, device=query_states.device)[None, :] | ||
| cache_offset = kv_len - q_len | ||
| causal_mask = k_positions <= (q_positions + cache_offset) | ||
| if sliding_window is not None: | ||
| causal_mask = causal_mask & (k_positions > (q_positions + cache_offset - sliding_window)) | ||
| return padding_mask & causal_mask[None, None, :, :] | ||
|
Comment on lines
+42
to
+69
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Additionally, moving the def _prepare_gemma3_sdpa_attention_mask(attention_mask, query_states, key_states, sliding_window=None):
if attention_mask is None:
if sliding_window is None:
return None
# Create a default mask to trigger sliding window mask generation
attention_mask = torch.ones(
(query_states.shape[0], key_states.shape[2]),
dtype=torch.bool,
device=query_states.device,
)
elif attention_mask.dim() != 2:
return attention_mask
q_len = query_states.shape[2]
kv_len = key_states.shape[2]
# Move to device early to ensure all operations are on GPU
attention_mask = attention_mask.to(query_states.device)
mask_len = attention_mask.shape[-1]
if mask_len < kv_len:
pad = torch.ones(
(attention_mask.shape[0], kv_len - mask_len),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat((attention_mask, pad), dim=-1)
elif mask_len > kv_len:
attention_mask = attention_mask[:, -kv_len:]
padding_mask = attention_mask[:, None, None, :] != 0
# Optimization: if q_len == 1 and no sliding window, causal mask is all True
if q_len == 1 and sliding_window is None:
return padding_mask
q_positions = torch.arange(q_len, device=query_states.device)[:, None]
k_positions = torch.arange(kv_len, device=query_states.device)[None, :]
cache_offset = kv_len - q_len
causal_mask = k_positions <= (q_positions + cache_offset)
if sliding_window is not None:
causal_mask = causal_mask & (k_positions > (q_positions + cache_offset - sliding_window))
return padding_mask & causal_mask[None, None, :, :] |
||
|
|
||
|
|
||
| def _make_gemma3_attn_forwards(forward_function, has_cache_position): | ||
| """Build past_key_value / past_key_values forward variants for Gemma3Attention.""" | ||
| functions = [] | ||
|
|
@@ -530,6 +560,12 @@ def forward_function( | |
| **kwargs, | ||
| ) | ||
| else: | ||
| attn_mask_for_sdpa = _prepare_gemma3_sdpa_attention_mask( | ||
| attn_mask_for_sdpa, | ||
| query_states_fp32, | ||
| key_states_fp32, | ||
| getattr(self, "sliding_window", None), | ||
| ) | ||
| is_causal = query_states_fp32.shape[2] > 1 and attn_mask_for_sdpa is None and getattr(self, "is_causal", True) | ||
| # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. | ||
| # We convert it to a bool for the SDPA kernel that only accepts bools. | ||
|
|
@@ -766,6 +802,12 @@ def forward_function( | |
| **kwargs, | ||
| ) | ||
| else: | ||
| attn_mask_for_sdpa = _prepare_gemma3_sdpa_attention_mask( | ||
| attn_mask_for_sdpa, | ||
| query_states_fp32, | ||
| key_states_fp32, | ||
| getattr(self, "sliding_window", None), | ||
| ) | ||
| is_causal = query_states_fp32.shape[2] > 1 and attn_mask_for_sdpa is None and getattr(self, "is_causal", True) | ||
| # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. | ||
| # We convert it to a bool for the SDPA kernel that only accepts bools. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When
mask_len < kv_len, this branch pads the 2D attention mask with1s, which marks the extra KV positions as attendable. In cache-backed generation (notably static/preallocated caches), those extra positions can correspond to not-yet-written cache slots, so the model may attend to invalid keys/values and produce incorrect logits. The fix is to pad with masked values (0for this boolean-style mask) or derive validity fromcache_positioninstead of assuming new positions are visible.Useful? React with 👍 / 👎.