From 5758086ad0a9dc87ca7c3e9ee4b088fdb0c18284 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Fri, 15 May 2026 14:35:21 +0000 Subject: [PATCH 1/3] update attn preferences --- unsloth/models/_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index df498e89fb..3b4c83dc6e 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -500,16 +500,16 @@ def resolve_attention_implementation( if _is_eager_only(model_type): attn_impl = _set_attn_impl(config, "eager") elif prefers_flex_attention and supports_flex_attention: + # Models in _FLEX_PREFERRED_MODELS (gemma3 family) prefer flex_attention + # over flash. Caller can still override by passing + # requested_attn_implementation="sdpa" (handled below). attn_impl = _set_attn_impl(config, "flex_attention") elif ( - not prefers_flex_attention - and not flash_attention_disabled + not flash_attention_disabled and HAS_FLASH_ATTENTION and supports_flash_attention ): attn_impl = _set_attn_impl(config, "flash_attention_2") - elif supports_flex_attention: - attn_impl = _set_attn_impl(config, "flex_attention") elif flash_attention_disabled: attn_impl = _disable_flash_attention_if_needed( config, @@ -521,6 +521,11 @@ def resolve_attention_implementation( ) elif supports_sdpa: attn_impl = _set_attn_impl(config, "sdpa") + elif supports_flex_attention: + # Flex is only a fallback for models that don't support SDPA + # (e.g. some custom configurations). Without this fallback such + # models would land on eager. + attn_impl = _set_attn_impl(config, "flex_attention") else: attn_impl = _set_attn_impl(config, "eager") From ecc6b6d4ee5e2eae56518e67569b73967211d016 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Fri, 15 May 2026 11:25:12 -0500 Subject: [PATCH 2/3] address gemini review suggestion --- unsloth/models/_utils.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 3b4c83dc6e..936d0f8b0a 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -383,6 +383,7 @@ def _disable_flash_attention_if_needed( config, attn_implementation = None, supports_sdpa = False, + supports_flex_attention = False, would_use_flash_attention = False, disable_reason = None, ): @@ -402,7 +403,12 @@ def _disable_flash_attention_if_needed( if requested_attn_implementation == "eager": return _set_attn_impl(config, "eager") - fallback_attn_implementation = "sdpa" if supports_sdpa else "eager" + if supports_sdpa: + fallback_attn_implementation = "sdpa" + elif supports_flex_attention: + fallback_attn_implementation = "flex_attention" + else: + fallback_attn_implementation = "eager" if ( _is_flash_attention_requested(requested_attn_implementation) or would_use_flash_attention @@ -487,15 +493,15 @@ def resolve_attention_implementation( getattr(model_class, "_supports_flash_attn_2", False) or getattr(model_class, "_supports_flash_attn", False) ) + supports_flex_attention = _supports_flex_attention( + model_class, config, model_type + ) disable_reason = _get_flash_attention_disable_reason(config) flash_attention_disabled = disable_reason is not None if model_class is None: attn_impl = _set_attn_impl(config, "sdpa" if supports_sdpa else "eager") else: - supports_flex_attention = _supports_flex_attention( - model_class, config, model_type - ) prefers_flex_attention = _config_prefers_flex_attention(config) if _is_eager_only(model_type): attn_impl = _set_attn_impl(config, "eager") @@ -514,6 +520,7 @@ def resolve_attention_implementation( attn_impl = _disable_flash_attention_if_needed( config, supports_sdpa = supports_sdpa, + supports_flex_attention = supports_flex_attention, would_use_flash_attention = ( HAS_FLASH_ATTENTION and supports_flash_attention ), @@ -536,6 +543,7 @@ def resolve_attention_implementation( config, requested_attn_implementation, supports_sdpa = supports_sdpa, + supports_flex_attention = supports_flex_attention, disable_reason = disable_reason, ) else: From 3d775d34076b1e624cd91f8756366426694b1f97 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 May 2026 16:25:49 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 936d0f8b0a..7bf8866f46 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -493,9 +493,7 @@ def resolve_attention_implementation( getattr(model_class, "_supports_flash_attn_2", False) or getattr(model_class, "_supports_flash_attn", False) ) - supports_flex_attention = _supports_flex_attention( - model_class, config, model_type - ) + supports_flex_attention = _supports_flex_attention(model_class, config, model_type) disable_reason = _get_flash_attention_disable_reason(config) flash_attention_disabled = disable_reason is not None