Skip to content

Restore Flash > SDPA > Flex priority for non-gemma3 models#5455

Merged
mmathew23 merged 4 commits into
unslothai:mainfrom
mmathew23:fix/attnorder
May 15, 2026
Merged

Restore Flash > SDPA > Flex priority for non-gemma3 models#5455
mmathew23 merged 4 commits into
unslothai:mainfrom
mmathew23:fix/attnorder

Conversation

@mmathew23
Copy link
Copy Markdown
Collaborator

Summary

Tighten the scope of flex_attention selection in
resolve_attention_implementation so it stays restricted to
_FLEX_PREFERRED_MODELS (the gemma3 family) as intended, and lets every
other model fall back to SDPA when flash_attn is not installed.

PR #5346 ("Gemma attn", commit cbb41222) restored gemma3's preference for
flex_attention after #4210 / #4201 had it routing through SDPA. The
preference list (_FLEX_PREFERRED_MODELS = ("gemma3", "gemma3_text", "shieldgemma2")) is the right shape for the carve-out. This PR scopes the
fallback chain to match: only models in that list pick flex_attention over
flash, and flex stays as a last-resort for models that do not support SDPA.

Current behavior

Today the priority chain is:

  1. _is_eager_only(model_type) -> eager (gemma3n)
  2. prefers_flex_attention AND supports_flex_attention -> flex_attention (gemma3 family, the carve-out)
  3. !prefers_flex AND !disabled AND HAS_FLASH AND supports_flash
    -> flash_attention_2
  4. supports_flex_attention -> flex_attention
  5. flash_attention_disabled -> SDPA/eager fallback
  6. supports_sdpa -> sdpa
  7. else -> eager

Branch 4 picks flex_attention whenever a model has
_supports_flex_attn=True and flash_attn is not available, regardless of
whether the model is in _FLEX_PREFERRED_MODELS. Most modern decoder-only
families (qwen3, llama, mistral, qwen3_moe, ...) advertise
_supports_flex_attn=True, so on hosts without flash_attn they all land on
flex rather than SDPA. The carve-out list ends up acting as a "prefers flex"
hint rather than the gating list it reads as.

For LoRA / QLoRA this is invisible at runtime because the runtime
dispatcher (select_attention_backend in
unsloth/utils/attention_dispatch.py) ignores _attn_implementation
entirely. For full fine-tuning the compiled-cache shim does honor
_attn_implementation, so the choice is observable. Empirically on
qwen3-4B with xformers installed and flash_attn not installed:

Mode config decided runtime kernel that fired
QLoRA flex_attention xformers.ops.fmha.memory_efficient_attention x 216
LoRA flex_attention xformers.ops.fmha.memory_efficient_attention x 216
FFT flex_attention transformers.ALL_ATTENTION_FUNCTIONS['flex_attention'] x 216

The qwen3 FFT path ends up on flex even though qwen3 is not in the
preference list and SDPA is fully supported.

Change

One-function update in resolve_attention_implementation. Move the
flex-attention fallback to after the SDPA branch, so it runs only when
SDPA is also unavailable:

         if _is_eager_only(model_type):
             attn_impl = _set_attn_impl(config, "eager")
         elif prefers_flex_attention and supports_flex_attention:
+            # _FLEX_PREFERRED_MODELS (gemma3 family) prefers flex_attention
+            # over flash. Caller can still override by passing
+            # requested_attn_implementation="sdpa" (handled below).
             attn_impl = _set_attn_impl(config, "flex_attention")
         elif (
-            not prefers_flex_attention
-            and not flash_attention_disabled
+            not flash_attention_disabled
             and HAS_FLASH_ATTENTION
             and supports_flash_attention
         ):
             attn_impl = _set_attn_impl(config, "flash_attention_2")
-        elif supports_flex_attention:
-            attn_impl = _set_attn_impl(config, "flex_attention")
         elif flash_attention_disabled:
             attn_impl = _disable_flash_attention_if_needed(...)
         elif supports_sdpa:
             attn_impl = _set_attn_impl(config, "sdpa")
+        elif supports_flex_attention:
+            # Flex is a last-resort fallback for models that do not support
+            # SDPA. Without this branch such models would land on eager.
+            attn_impl = _set_attn_impl(config, "flex_attention")
         else:
             attn_impl = _set_attn_impl(config, "eager")

Updated priority chain:

1. _is_eager_only                                          -> eager           (gemma3n)
2. prefers_flex_attention AND supports_flex_attention      -> flex_attention  (gemma3 family, UNCHANGED)
3. !flash_attention_disabled AND HAS_FLASH AND supports_flash
                                                           -> flash_attention_2
4. flash_attention_disabled                                -> SDPA/eager fallback (head_dim>256)
5. supports_sdpa                                           -> sdpa            <- qwen3 etc. now land here
6. supports_flex_attention                                 -> flex_attention  (last-resort, SDPA-less models only)
7. else                                                    -> eager

The requested_attn_implementation override block at lines 527-538 is
untouched. Explicit attn_implementation="..." from the caller still wins
for both gemma3 and the others.

Consistency with #5346

Quoting the PR body:

▎ For gemma3, previously we had the flex_attn backend. But after the refactor
▎ in #4210 #4201 it seems to have been falling back to SDPA. This is an
▎ attempt to restore old behaviour via 1) set a list of models which PREFER
▎ flex attn 2) if the model is among those and we see flex attn available, we
▎ use that.

The carve-out and the preference list both stay. Branch 2 of the chain is
unchanged, so gemma3 / gemma3_text / shieldgemma2 still pick flex over
flash. This PR only narrows branch 4 so non-listed models are no longer
swept into the same preference.

The gemini-code-assist review on #5346 also raised the case of a gemma3
variant that loses flex support skipping FA2 entirely. The updated chain
covers that case too: a gemma3 model without supports_flex_attention falls
through to branch 3 (FA2 if available), then branch 4 (SDPA/eager fallback),
then branch 5 (SDPA), then branch 6 (flex last-resort), then branch 7
(eager).

@mmathew23 mmathew23 requested a review from danielhanchen as a code owner May 15, 2026 16:03
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the attention implementation resolution logic to prioritize flex_attention for specific model families and introduces it as a fallback when SDPA is unavailable. A review comment identifies a logic inconsistency where flex_attention is bypassed as a fallback option when Flash Attention is explicitly disabled (e.g., due to head dimension constraints), suggesting that the fallback mechanism should be updated to maintain consistent priority across different failure modes.

Comment thread unsloth/models/_utils.py
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 5758086ad0

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/models/_utils.py
Comment on lines +524 to +525
elif supports_flex_attention:
# Flex is only a fallback for models that don't support SDPA
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Route flash-disabled non-SDPA models to flex fallback

This new fallback placement is skipped whenever flash_attention_disabled is true, so a model with supports_sdpa=False and supports_flex_attention=True now falls into _disable_flash_attention_if_needed(...) and is forced to eager instead of flex_attention. That is a regression from the previous behavior and contradicts the stated "flex as last-resort for non-SDPA models" logic, which can materially hurt throughput (and increase memory pressure) for configurations where Flash is disabled due to head-dim limits.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator

@Datta0 Datta0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mmathew23 mmathew23 merged commit 3596ce1 into unslothai:main May 15, 2026
37 of 39 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants