diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index df498e89fb..7bf8866f46 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -383,6 +383,7 @@ def _disable_flash_attention_if_needed( config, attn_implementation = None, supports_sdpa = False, + supports_flex_attention = False, would_use_flash_attention = False, disable_reason = None, ): @@ -402,7 +403,12 @@ def _disable_flash_attention_if_needed( if requested_attn_implementation == "eager": return _set_attn_impl(config, "eager") - fallback_attn_implementation = "sdpa" if supports_sdpa else "eager" + if supports_sdpa: + fallback_attn_implementation = "sdpa" + elif supports_flex_attention: + fallback_attn_implementation = "flex_attention" + else: + fallback_attn_implementation = "eager" if ( _is_flash_attention_requested(requested_attn_implementation) or would_use_flash_attention @@ -487,33 +493,32 @@ def resolve_attention_implementation( getattr(model_class, "_supports_flash_attn_2", False) or getattr(model_class, "_supports_flash_attn", False) ) + supports_flex_attention = _supports_flex_attention(model_class, config, model_type) disable_reason = _get_flash_attention_disable_reason(config) flash_attention_disabled = disable_reason is not None if model_class is None: attn_impl = _set_attn_impl(config, "sdpa" if supports_sdpa else "eager") else: - supports_flex_attention = _supports_flex_attention( - model_class, config, model_type - ) prefers_flex_attention = _config_prefers_flex_attention(config) if _is_eager_only(model_type): attn_impl = _set_attn_impl(config, "eager") elif prefers_flex_attention and supports_flex_attention: + # Models in _FLEX_PREFERRED_MODELS (gemma3 family) prefer flex_attention + # over flash. Caller can still override by passing + # requested_attn_implementation="sdpa" (handled below). attn_impl = _set_attn_impl(config, "flex_attention") elif ( - not prefers_flex_attention - and not flash_attention_disabled + not flash_attention_disabled and HAS_FLASH_ATTENTION and supports_flash_attention ): attn_impl = _set_attn_impl(config, "flash_attention_2") - elif supports_flex_attention: - attn_impl = _set_attn_impl(config, "flex_attention") elif flash_attention_disabled: attn_impl = _disable_flash_attention_if_needed( config, supports_sdpa = supports_sdpa, + supports_flex_attention = supports_flex_attention, would_use_flash_attention = ( HAS_FLASH_ATTENTION and supports_flash_attention ), @@ -521,6 +526,11 @@ def resolve_attention_implementation( ) elif supports_sdpa: attn_impl = _set_attn_impl(config, "sdpa") + elif supports_flex_attention: + # Flex is only a fallback for models that don't support SDPA + # (e.g. some custom configurations). Without this fallback such + # models would land on eager. + attn_impl = _set_attn_impl(config, "flex_attention") else: attn_impl = _set_attn_impl(config, "eager") @@ -531,6 +541,7 @@ def resolve_attention_implementation( config, requested_attn_implementation, supports_sdpa = supports_sdpa, + supports_flex_attention = supports_flex_attention, disable_reason = disable_reason, ) else: