diff --git a/.gitignore b/.gitignore index 014a60a024..b7b6522364 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *.class +unsloth_compiled_cache/ # C extensions *.so diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 25704301d1..bfa833be9b 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -128,7 +128,7 @@ def CohereAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if (not HAS_FLASH_ATTENTION and attention_mask is None): + if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None): # Xformers memory efficient attention # Also has Flash Attention v2 dispatching Q = Q.transpose(1, 2) diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 0f88bbc55e..243922fc13 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -115,7 +115,7 @@ def GraniteAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if (not HAS_FLASH_ATTENTION and attention_mask is None): + if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None): # Xformers memory efficient attention Q = Q.transpose(1, 2) K = K.transpose(1, 2) diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index ef71c89c4a..a3e07e3b0f 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -99,7 +99,7 @@ def MistralAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if (not HAS_FLASH_ATTENTION and attention_mask is None): + if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None): # Xformers memory efficient attention Q = Q.transpose(1, 2) K = K.transpose(1, 2) @@ -191,15 +191,35 @@ def MistralForCausalLM_fast_forward( if causal_mask is None and past_key_values is None: bsz, q_len = input_ids.shape sliding_window = getattr(self.config, "sliding_window", None) - if sliding_window is None or sliding_window == "null" or sliding_window <= 0: - causal_mask = xformers.attn_bias.LowerTriangularMask() - elif q_len <= sliding_window: - causal_mask = xformers.attn_bias.LowerTriangularMask() - else: - causal_mask = xformers.attn_bias.BlockDiagonalCausalMask\ - .from_seqlens([q_len]*bsz)\ - .make_local_attention(window_size = sliding_window) - pass + + if HAS_XFORMERS and attention_mask is None: + if sliding_window is None or sliding_window == "null" or sliding_window <= 0: + causal_mask = xformers.attn_bias.LowerTriangularMask() + elif q_len <= sliding_window: + causal_mask = xformers.attn_bias.LowerTriangularMask() + else: + causal_mask = xformers.attn_bias.BlockDiagonalCausalMask\ + .from_seqlens([q_len]*bsz)\ + .make_local_attention(window_size = sliding_window) + + elif not HAS_XFORMERS and attention_mask is None: + if sliding_window is None or sliding_window == "null" or sliding_window <= 0 or q_len <= sliding_window: + # Fully causal mask + mask = torch.full((q_len, q_len), -torch.inf, device=input_ids.device) + mask = torch.triu(mask, diagonal=1) + attention_mask = mask.expand(bsz, 1, q_len, q_len) + else: + # Sliding window attention + q_indices = torch.arange(q_len, device=input_ids.device).view(-1, 1) + k_indices = torch.arange(q_len, device=input_ids.device).view(1, -1) + + causal_bool_mask = k_indices <= q_indices + window_bool_mask = (q_indices - k_indices) < sliding_window + + mask = torch.where(causal_bool_mask & window_bool_mask, 0.0, -torch.inf) + attention_mask = mask[None, None, :, :].expand(bsz, 1, q_len, q_len) + + attention_mask = attention_mask.to(dtype=_get_dtype(self.config.torch_dtype)) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py index 83c9dbea0a..80bd6ee7c9 100644 --- a/unsloth/models/qwen3.py +++ b/unsloth/models/qwen3.py @@ -126,7 +126,7 @@ def Qwen3Attention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if (not HAS_FLASH_ATTENTION and attention_mask is None): + if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None): # Xformers memory efficient attention Q = Q.transpose(1, 2) K = K.transpose(1, 2)