Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,18 @@ def prefer_flex_attn_if_supported(model_class, config):
model_class, "_supports_flex_attn", False
):
return None
# flex_attention Triton kernels require sm80+ (Ampere and above).
# On older GPUs (T4/sm75, V100/sm70) the dense Python fallback runs
# instead, but sdpa_dense_backward has a dtype mismatch under fp16
# autocast (Half @ Float matmul). Skip flex_attention there.
import torch

if torch.cuda.is_available():
major, _ = torch.cuda.get_device_capability()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Evaluate compute capability across all visible GPUs

This check only reads torch.cuda.get_device_capability() for the current CUDA device, but prefer_flex_attn_if_supported sets one global attention implementation for the whole model. In mixed-GPU runs (for example, device_map="auto" with both Ampere and pre-Ampere cards), if the current device is sm80+ and another shard lands on sm70/sm75, flex_attention will still be enabled and the same backward dtype mismatch this patch aims to avoid can still occur on the older shard.

Useful? React with 👍 / 👎.

if major < 8:
return None
else:
Comment on lines +241 to +247
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block can be made more concise and readable. The import torch is redundant since it's already imported at the top of the file, and the conditional logic can be simplified by checking for the negative case first and returning early. This flattens the code structure.

Suggested change
import torch
if torch.cuda.is_available():
major, _ = torch.cuda.get_device_capability()
if major < 8:
return None
else:
# Check for CUDA availability and compute capability.
# Return early if not supported to avoid nested ifs.
if not torch.cuda.is_available():
return None
major, _ = torch.cuda.get_device_capability()
if major < 8:
return None

return None
# GPT-OSS, Mllama and Gemma3N use eager/sdpa attention during
# inference since flex attention returns incorrect results or errors out.
# GPT-OSS: left padding issues cause incorrect outputs.
Expand Down
Loading