Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def _disable_flash_attention_if_needed(
config,
attn_implementation = None,
supports_sdpa = False,
supports_flex_attention = False,
would_use_flash_attention = False,
disable_reason = None,
):
Expand All @@ -402,7 +403,12 @@ def _disable_flash_attention_if_needed(
if requested_attn_implementation == "eager":
return _set_attn_impl(config, "eager")

fallback_attn_implementation = "sdpa" if supports_sdpa else "eager"
if supports_sdpa:
fallback_attn_implementation = "sdpa"
elif supports_flex_attention:
fallback_attn_implementation = "flex_attention"
else:
fallback_attn_implementation = "eager"
if (
_is_flash_attention_requested(requested_attn_implementation)
or would_use_flash_attention
Expand Down Expand Up @@ -487,40 +493,44 @@ def resolve_attention_implementation(
getattr(model_class, "_supports_flash_attn_2", False)
or getattr(model_class, "_supports_flash_attn", False)
)
supports_flex_attention = _supports_flex_attention(model_class, config, model_type)
disable_reason = _get_flash_attention_disable_reason(config)
flash_attention_disabled = disable_reason is not None

if model_class is None:
attn_impl = _set_attn_impl(config, "sdpa" if supports_sdpa else "eager")
else:
supports_flex_attention = _supports_flex_attention(
model_class, config, model_type
)
prefers_flex_attention = _config_prefers_flex_attention(config)
if _is_eager_only(model_type):
attn_impl = _set_attn_impl(config, "eager")
elif prefers_flex_attention and supports_flex_attention:
# Models in _FLEX_PREFERRED_MODELS (gemma3 family) prefer flex_attention
# over flash. Caller can still override by passing
# requested_attn_implementation="sdpa" (handled below).
attn_impl = _set_attn_impl(config, "flex_attention")
elif (
not prefers_flex_attention
and not flash_attention_disabled
not flash_attention_disabled
and HAS_FLASH_ATTENTION
and supports_flash_attention
):
attn_impl = _set_attn_impl(config, "flash_attention_2")
elif supports_flex_attention:
attn_impl = _set_attn_impl(config, "flex_attention")
elif flash_attention_disabled:
attn_impl = _disable_flash_attention_if_needed(
config,
supports_sdpa = supports_sdpa,
supports_flex_attention = supports_flex_attention,
would_use_flash_attention = (
HAS_FLASH_ATTENTION and supports_flash_attention
),
disable_reason = disable_reason,
)
elif supports_sdpa:
attn_impl = _set_attn_impl(config, "sdpa")
elif supports_flex_attention:
# Flex is only a fallback for models that don't support SDPA
Comment on lines +529 to +530
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Route flash-disabled non-SDPA models to flex fallback

This new fallback placement is skipped whenever flash_attention_disabled is true, so a model with supports_sdpa=False and supports_flex_attention=True now falls into _disable_flash_attention_if_needed(...) and is forced to eager instead of flex_attention. That is a regression from the previous behavior and contradicts the stated "flex as last-resort for non-SDPA models" logic, which can materially hurt throughput (and increase memory pressure) for configurations where Flash is disabled due to head-dim limits.

Useful? React with 👍 / 👎.

# (e.g. some custom configurations). Without this fallback such
# models would land on eager.
attn_impl = _set_attn_impl(config, "flex_attention")
Comment thread
mmathew23 marked this conversation as resolved.
else:
attn_impl = _set_attn_impl(config, "eager")

Expand All @@ -531,6 +541,7 @@ def resolve_attention_implementation(
config,
requested_attn_implementation,
supports_sdpa = supports_sdpa,
supports_flex_attention = supports_flex_attention,
disable_reason = disable_reason,
)
else:
Expand Down
Loading