From 48c1d1f77490e7ef6af3480d0ad95da645a7931e Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Sun, 15 Mar 2026 11:42:18 +0000 Subject: [PATCH 1/2] Skip flex_attention on pre-Ampere GPUs (T4, V100) flex_attention Triton kernels require sm80+ (Ampere). On older GPUs the dense Python fallback runs instead, but sdpa_dense_backward has a dtype mismatch under fp16 autocast -- the matmul at line 904 of torch/_higher_order_ops/flex_attention.py does softmax_scores.to(query.dtype) @ grad_out where query.dtype is Half and grad_out is Float, producing "expected scalar type Float but found Half". This affected Ministral-3B/8B on T4 GPUs (issue #4295) and potentially any model using flex_attention on pre-Ampere hardware. Fix: check torch.cuda.get_device_capability() >= (8, 0) before enabling flex_attention. Falls back to sdpa on older GPUs. --- unsloth/models/_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index cbaebcc7ac..550fdf0ea6 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -234,6 +234,17 @@ def prefer_flex_attn_if_supported(model_class, config): model_class, "_supports_flex_attn", False ): return None + # flex_attention Triton kernels require sm80+ (Ampere and above). + # On older GPUs (T4/sm75, V100/sm70) the dense Python fallback runs + # instead, but sdpa_dense_backward has a dtype mismatch under fp16 + # autocast (Half @ Float matmul). Skip flex_attention there. + import torch + if torch.cuda.is_available(): + major, _ = torch.cuda.get_device_capability() + if major < 8: + return None + else: + return None # GPT-OSS, Mllama and Gemma3N use eager/sdpa attention during # inference since flex attention returns incorrect results or errors out. # GPT-OSS: left padding issues cause incorrect outputs. From 45feffde6fded824b3246caff7bdf2d3ad8bbb73 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Mar 2026 11:42:52 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 550fdf0ea6..b529ba0dd8 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -239,6 +239,7 @@ def prefer_flex_attn_if_supported(model_class, config): # instead, but sdpa_dense_backward has a dtype mismatch under fp16 # autocast (Half @ Float matmul). Skip flex_attention there. import torch + if torch.cuda.is_available(): major, _ = torch.cuda.get_device_capability() if major < 8: