Restore Flash > SDPA > Flex priority for non-gemma3 models#5455
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the attention implementation resolution logic to prioritize flex_attention for specific model families and introduces it as a fallback when SDPA is unavailable. A review comment identifies a logic inconsistency where flex_attention is bypassed as a fallback option when Flash Attention is explicitly disabled (e.g., due to head dimension constraints), suggesting that the fallback mechanism should be updated to maintain consistent priority across different failure modes.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 5758086ad0
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| elif supports_flex_attention: | ||
| # Flex is only a fallback for models that don't support SDPA |
There was a problem hiding this comment.
Route flash-disabled non-SDPA models to flex fallback
This new fallback placement is skipped whenever flash_attention_disabled is true, so a model with supports_sdpa=False and supports_flex_attention=True now falls into _disable_flash_attention_if_needed(...) and is forced to eager instead of flex_attention. That is a regression from the previous behavior and contradicts the stated "flex as last-resort for non-SDPA models" logic, which can materially hurt throughput (and increase memory pressure) for configurations where Flash is disabled due to head-dim limits.
Useful? React with 👍 / 👎.
for more information, see https://pre-commit.ci
Summary
Tighten the scope of
flex_attentionselection inresolve_attention_implementationso it stays restricted to_FLEX_PREFERRED_MODELS(the gemma3 family) as intended, and lets everyother model fall back to SDPA when flash_attn is not installed.
PR #5346 ("Gemma attn", commit
cbb41222) restored gemma3's preference forflex_attentionafter #4210 / #4201 had it routing through SDPA. Thepreference list (
_FLEX_PREFERRED_MODELS = ("gemma3", "gemma3_text", "shieldgemma2")) is the right shape for the carve-out. This PR scopes thefallback chain to match: only models in that list pick
flex_attentionoverflash, and flex stays as a last-resort for models that do not support SDPA.
Current behavior
Today the priority chain is:
-> flash_attention_2
Branch 4 picks
flex_attentionwhenever a model has_supports_flex_attn=Trueand flash_attn is not available, regardless ofwhether the model is in
_FLEX_PREFERRED_MODELS. Most modern decoder-onlyfamilies (qwen3, llama, mistral, qwen3_moe, ...) advertise
_supports_flex_attn=True, so on hosts without flash_attn they all land onflex rather than SDPA. The carve-out list ends up acting as a "prefers flex"
hint rather than the gating list it reads as.
For LoRA / QLoRA this is invisible at runtime because the runtime
dispatcher (
select_attention_backendinunsloth/utils/attention_dispatch.py) ignores_attn_implementationentirely. For full fine-tuning the compiled-cache shim does honor
_attn_implementation, so the choice is observable. Empirically onqwen3-4B with
xformersinstalled andflash_attnnot installed:flex_attentionxformers.ops.fmha.memory_efficient_attentionx 216flex_attentionxformers.ops.fmha.memory_efficient_attentionx 216flex_attentiontransformers.ALL_ATTENTION_FUNCTIONS['flex_attention']x 216The qwen3 FFT path ends up on flex even though qwen3 is not in the
preference list and SDPA is fully supported.
Change
One-function update in
resolve_attention_implementation. Move theflex-attention fallback to after the SDPA branch, so it runs only when
SDPA is also unavailable:
if _is_eager_only(model_type): attn_impl = _set_attn_impl(config, "eager") elif prefers_flex_attention and supports_flex_attention: + # _FLEX_PREFERRED_MODELS (gemma3 family) prefers flex_attention + # over flash. Caller can still override by passing + # requested_attn_implementation="sdpa" (handled below). attn_impl = _set_attn_impl(config, "flex_attention") elif ( - not prefers_flex_attention - and not flash_attention_disabled + not flash_attention_disabled and HAS_FLASH_ATTENTION and supports_flash_attention ): attn_impl = _set_attn_impl(config, "flash_attention_2") - elif supports_flex_attention: - attn_impl = _set_attn_impl(config, "flex_attention") elif flash_attention_disabled: attn_impl = _disable_flash_attention_if_needed(...) elif supports_sdpa: attn_impl = _set_attn_impl(config, "sdpa") + elif supports_flex_attention: + # Flex is a last-resort fallback for models that do not support + # SDPA. Without this branch such models would land on eager. + attn_impl = _set_attn_impl(config, "flex_attention") else: attn_impl = _set_attn_impl(config, "eager") Updated priority chain: 1. _is_eager_only -> eager (gemma3n) 2. prefers_flex_attention AND supports_flex_attention -> flex_attention (gemma3 family, UNCHANGED) 3. !flash_attention_disabled AND HAS_FLASH AND supports_flash -> flash_attention_2 4. flash_attention_disabled -> SDPA/eager fallback (head_dim>256) 5. supports_sdpa -> sdpa <- qwen3 etc. now land here 6. supports_flex_attention -> flex_attention (last-resort, SDPA-less models only) 7. else -> eager The requested_attn_implementation override block at lines 527-538 is untouched. Explicit attn_implementation="..." from the caller still wins for both gemma3 and the others. Consistency with #5346 Quoting the PR body: ▎ For gemma3, previously we had the flex_attn backend. But after the refactor ▎ in #4210 #4201 it seems to have been falling back to SDPA. This is an ▎ attempt to restore old behaviour via 1) set a list of models which PREFER ▎ flex attn 2) if the model is among those and we see flex attn available, we ▎ use that. The carve-out and the preference list both stay. Branch 2 of the chain is unchanged, so gemma3 / gemma3_text / shieldgemma2 still pick flex over flash. This PR only narrows branch 4 so non-listed models are no longer swept into the same preference. The gemini-code-assist review on #5346 also raised the case of a gemma3 variant that loses flex support skipping FA2 entirely. The updated chain covers that case too: a gemma3 model without supports_flex_attention falls through to branch 3 (FA2 if available), then branch 4 (SDPA/eager fallback), then branch 5 (SDPA), then branch 6 (flex last-resort), then branch 7 (eager).