Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*.class
unsloth_compiled_cache/

# C extensions
*.so
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def CohereAttention_fast_forward(
past_key_value = (K, V) if use_cache else None

# Attention module
if (not HAS_FLASH_ATTENTION and attention_mask is None):
if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None):
# Xformers memory efficient attention
# Also has Flash Attention v2 dispatching
Q = Q.transpose(1, 2)
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def GraniteAttention_fast_forward(
past_key_value = (K, V) if use_cache else None

# Attention module
if (not HAS_FLASH_ATTENTION and attention_mask is None):
if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None):
# Xformers memory efficient attention
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
Expand Down
40 changes: 30 additions & 10 deletions unsloth/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def MistralAttention_fast_forward(
past_key_value = (K, V) if use_cache else None

# Attention module
if (not HAS_FLASH_ATTENTION and attention_mask is None):
if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None):
# Xformers memory efficient attention
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
Expand Down Expand Up @@ -191,15 +191,35 @@ def MistralForCausalLM_fast_forward(
if causal_mask is None and past_key_values is None:
bsz, q_len = input_ids.shape
sliding_window = getattr(self.config, "sliding_window", None)
if sliding_window is None or sliding_window == "null" or sliding_window <= 0:
causal_mask = xformers.attn_bias.LowerTriangularMask()
elif q_len <= sliding_window:
causal_mask = xformers.attn_bias.LowerTriangularMask()
else:
causal_mask = xformers.attn_bias.BlockDiagonalCausalMask\
.from_seqlens([q_len]*bsz)\
.make_local_attention(window_size = sliding_window)
pass

if HAS_XFORMERS and attention_mask is None:
if sliding_window is None or sliding_window == "null" or sliding_window <= 0:
causal_mask = xformers.attn_bias.LowerTriangularMask()
elif q_len <= sliding_window:
causal_mask = xformers.attn_bias.LowerTriangularMask()
else:
causal_mask = xformers.attn_bias.BlockDiagonalCausalMask\
.from_seqlens([q_len]*bsz)\
.make_local_attention(window_size = sliding_window)

elif not HAS_XFORMERS and attention_mask is None:
if sliding_window is None or sliding_window == "null" or sliding_window <= 0 or q_len <= sliding_window:
# Fully causal mask
mask = torch.full((q_len, q_len), -torch.inf, device=input_ids.device)
mask = torch.triu(mask, diagonal=1)
attention_mask = mask.expand(bsz, 1, q_len, q_len)
else:
# Sliding window attention
q_indices = torch.arange(q_len, device=input_ids.device).view(-1, 1)
k_indices = torch.arange(q_len, device=input_ids.device).view(1, -1)

causal_bool_mask = k_indices <= q_indices
window_bool_mask = (q_indices - k_indices) < sliding_window

mask = torch.where(causal_bool_mask & window_bool_mask, 0.0, -torch.inf)
attention_mask = mask[None, None, :, :].expand(bsz, 1, q_len, q_len)

attention_mask = attention_mask.to(dtype=_get_dtype(self.config.torch_dtype))

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def Qwen3Attention_fast_forward(
past_key_value = (K, V) if use_cache else None

# Attention module
if (not HAS_FLASH_ATTENTION and attention_mask is None):
if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None):
# Xformers memory efficient attention
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
Expand Down