diff --git a/unsloth_zoo/temporary_patches/gemma.py b/unsloth_zoo/temporary_patches/gemma.py index 9935d9340..c7db09443 100644 --- a/unsloth_zoo/temporary_patches/gemma.py +++ b/unsloth_zoo/temporary_patches/gemma.py @@ -39,6 +39,36 @@ _UNSLOTH_FLEX_ATTENTION_DISABLED = os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0" +def _prepare_gemma3_sdpa_attention_mask(attention_mask, query_states, key_states, sliding_window=None): + if attention_mask is None or attention_mask.dim() != 2: + return attention_mask + + q_len = query_states.shape[2] + kv_len = key_states.shape[2] + mask_len = attention_mask.shape[-1] + if mask_len < kv_len: + pad = torch.ones( + (attention_mask.shape[0], kv_len - mask_len), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat((attention_mask, pad), dim=-1) + elif mask_len > kv_len: + attention_mask = attention_mask[:, -kv_len:] + + padding_mask = attention_mask[:, None, None, :].to(query_states.device) != 0 + if q_len == 1: + return padding_mask + + q_positions = torch.arange(q_len, device=query_states.device)[:, None] + k_positions = torch.arange(kv_len, device=query_states.device)[None, :] + cache_offset = kv_len - q_len + causal_mask = k_positions <= (q_positions + cache_offset) + if sliding_window is not None: + causal_mask = causal_mask & (k_positions > (q_positions + cache_offset - sliding_window)) + return padding_mask & causal_mask[None, None, :, :] + + def _make_gemma3_attn_forwards(forward_function, has_cache_position): """Build past_key_value / past_key_values forward variants for Gemma3Attention.""" functions = [] @@ -530,6 +560,12 @@ def forward_function( **kwargs, ) else: + attn_mask_for_sdpa = _prepare_gemma3_sdpa_attention_mask( + attn_mask_for_sdpa, + query_states_fp32, + key_states_fp32, + getattr(self, "sliding_window", None), + ) is_causal = query_states_fp32.shape[2] > 1 and attn_mask_for_sdpa is None and getattr(self, "is_causal", True) # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. # We convert it to a bool for the SDPA kernel that only accepts bools. @@ -766,6 +802,12 @@ def forward_function( **kwargs, ) else: + attn_mask_for_sdpa = _prepare_gemma3_sdpa_attention_mask( + attn_mask_for_sdpa, + query_states_fp32, + key_states_fp32, + getattr(self, "sliding_window", None), + ) is_causal = query_states_fp32.shape[2] > 1 and attn_mask_for_sdpa is None and getattr(self, "is_causal", True) # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. # We convert it to a bool for the SDPA kernel that only accepts bools.