Skip to content

Flexattn refactor#4210

Closed
danielhanchen wants to merge 3444 commits into
unslothai:mainfrom
Datta0:flexattn_refactor
Closed

Flexattn refactor#4210
danielhanchen wants to merge 3444 commits into
unslothai:mainfrom
Datta0:flexattn_refactor

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

No description provided.

danielhanchen and others added 30 commits December 28, 2025 19:57
…nslothai#3780)

* fix(trainer): import psutil to prevent NameError in _prepare_dataset

Fixes unslothai#3777

* Update rl.py

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
* Guard optional trl.experimental.openenv usage in RL patches

* Simplify optional trl.openenv import handling

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…3790)

* Fix is_contiguous() method call and remove duplicate imports

- Fix bug in rope_embedding.py where is_contiguous was used without
  parentheses, causing the method object (always truthy) to be evaluated
  instead of calling the method. This fixes issue unslothai#3781 where fast rope
  backpropagation was broken for zero strided/non-contiguous tensors.

- Remove duplicate `import torch` in rl.py (lines 20 and 25)
- Remove duplicate `import functools` and `import types` in vision.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Fix Boolean value of Tensor ambiguity error in mistral.py

Replace `or` operator with explicit `is None` check when getting
n_items from kwargs. The `or` operator fails when the value is a
Tensor because Python cannot determine the boolean value of a
multi-element tensor.

Fixes unslothai#3766

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Update rope_embedding.py

---------

Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
…lothai#3794)

Add "corda" as an allowed value for the init_lora_weights parameter
in FastLanguageModel.get_peft_model() and FastBaseModel.get_peft_model().

This enables users to use CorDA (Correlation-aware Decomposed Adaptation)
initialization from PEFT, which provides an alternative LoRA initialization
strategy for improved finetuning performance.

Fixes unslothai#3693

Signed-off-by: majiayu000 <1835304752@qq.com>
…lothai#3811)

* Fix correctness bugs in rl.py, rl_replacements.py, and vision.py

1. rl_replacements.py (lines 864, 870): Fixed undefined `nanmin`/`nanmax`
   functions by using `.nan_to_num(nan=inf/-inf).min()/.max()` pattern.
   PyTorch doesn't have torch.nanmin/nanmax, so we replace NaN values
   before computing min/max.

2. vision.py (line 150): Fixed bug where code checked for "input" key
   but then accessed kwargs["input_ids"] instead of kwargs["input"].

3. vision.py (line 159): Fixed bug where literal string "key" was used
   instead of the variable `key` when accessing kwargs.

4. rl.py (lines 903, 905): Fixed non-existent `MathError` exception
   by replacing with `ValueError`.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1. cohere.py:347-348 - Fixed wrong variable names in QK normalization.
   Used `Q`/`K` but variables were named `Qn`/`Kn`. This caused NameError
   when `use_qk_norm=True` (e.g., c4ai-command-r-plus models).

2. cohere.py:482 - Fixed wrong object reference in inference loop.
   Used `self.mlp` but should be `decoder_layer.mlp` since we're
   iterating through decoder layers. Caused AttributeError during inference.

3. falcon_h1.py:459,461 - Fixed wrong attribute names in inference path.
   Used `post_attention_layernorm` and `mlp` but Falcon H1 uses
   `pre_ff_layernorm` and `feed_forward`. Caused AttributeError during generation.

4. qwen3_moe.py:210 - Fixed wrong module path with incorrect capitalization.
   Used `transformers.models.Qwen3Moe` but should be `transformers.models.qwen3_moe`.
   Caused AttributeError when patching rotary embeddings.

5. qwen3_moe.py:239 - Fixed wrong model_patcher class.
   Used `FastQwen3Model` but should be `FastQwen3MoeModel` for MoE models.
   Caused incorrect patching for Qwen3 MoE models.

6. hf_hub.py:21-22 - Fixed floor division and missing return for billion values.
   Used `//` instead of `/` for millions, and had no return for values >= 1B.
   Caused incorrect formatting and None return for large numbers.

7. save.py:550 - Fixed self-assignment that did nothing.
   `sharded_ram_usage = sharded_ram_usage` should be `= max_shard_size`.
   Caused integer shard sizes to be ignored.

8. rl.py:562-567 - Fixed orphan string not included in length_check.
   The elif branch for max_seq_length validation was a standalone string
   expression, not concatenated to length_check. Caused silent skip of
   the max_seq_length > model_max_seq_length warning.

9. granite.py:49-52 - Fixed wrong model name and version in error message.
   Said "Gemma2" and "4.42.3" but should be "Granite" and "4.45.0".
…tmul

Fix 3D tensor support for bitsandbytes 8-bit matmul in forward pass
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
FIX: weight tying for LoRA embeddings and lm_head
Gemma3 models have a large vocabulary (262144 tokens) which causes
training loss to explode when using int8 embedding quantization.

This fix auto-detects Gemma3 models and switches from int8-int4
(phone-deployment) to int4 weight-only QAT for stable training.
…lity

Fix Gemma3 QAT training instability with int8-int4 scheme
When users load a model with fast_inference=False but then try to use
vLLM-style arguments with fast_generate, they previously got confusing
errors. This adds a wrapper that detects common mistakes and provides
helpful guidance:

- Using sampling_params: explains to use HF generate args instead
- Using lora_request: explains LoRA weights are already merged
- Passing text strings: shows how to tokenize input first

Changes:
- Add make_fast_generate_wrapper to _utils.py
- Apply wrapper in llama.py when fast_inference=False
- Apply wrapper in vision.py when fast_inference=False
…apper-helpful-errors

Add helpful error messages for fast_generate when fast_inference=False
Datta0 and others added 19 commits March 3, 2026 06:30
* Fix lm_head lora save

* Fix _need_to_train_embeddings guard for lm_head LoRA targets

When lm_head is already in final_modules as a LoRA target, the
_need_to_train_embeddings block should not also add it to
modules_to_save. This prevents dual-wrapping (LoRA + modules_to_save
on the same module) which causes assertion failures downstream.

Check if embed_tokens/lm_head are already being trained as LoRA
targets before adding them to modules_to_save. Also prevents
duplicate entries with elif guards.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* add intel support for torch210

* fix for typo
…support (unslothai#4138)

* fix: update GGUF save paths to use ~/.unsloth/llama.cpp with Windows support

* fix: quote LLAMA_CPP_DEFAULT_DIR in fallback shell commands to handle paths with spaces

* refactor: deduplicate platform-specific build instructions in quantization error message

* chore: remove accidentally committed PR description file

* Fix import safety and f-string bugs in save.py

- H4: Add defensive try/except for LLAMA_CPP_DEFAULT_DIR and IS_WINDOWS imports
  with fallback defaults, so save.py works even if zoo PR unslothai#526 is not merged yet
- H5: Fix Kaggle error path using plain "Error: {e}" instead of f"Error: {e}",
  so the actual exception is shown to users

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fixup mapper issues and resolve properly

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fix broken wandb import crashing unsloth startup

When wandb is installed but broken (e.g., wandb < 0.19.11 with
protobuf >= 6.0), the import chain unsloth -> trl -> transformers ->
is_wandb_available() -> import wandb crashes with:

  ImportError: cannot import name 'Imports' from
  'wandb.proto.wandb_telemetry_pb2'

This happens because transformers' is_wandb_available() has no
try/except around `import wandb`. The error propagates up and kills
`from unsloth import FastLanguageModel` even though wandb is optional.

Add disable_broken_wandb() following the same pattern as
disable_torchcodec_if_broken(). It proactively tries importing wandb
during early init, and if the import fails, patches
is_wandb_available() to return False and sets WANDB_DISABLED=true.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…slothai#4148)

trl/trainer/callbacks.py imports is_wandb_available from
accelerate.utils, not from transformers. The original fix in unslothai#4147
only patched the transformers version, so `from trl import GRPOTrainer`
still crashed via the callbacks.py -> accelerate -> wandb path.

Must patch both the source module (accelerate.utils.imports) AND the
re-export namespace (accelerate.utils) since Python's
`from accelerate.utils import X` reads from the latter, which holds
its own cached reference.
* Refactor loss computation to include completion_mask

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
)

* Fix gpt temporary patch for grpo to happen after compile

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Refactor loss computation to include completion_mask

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixes for trl 0.28 and above

Remove sync/reload weights calls , remove vllm.LLM instantiation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactor loss computation to include completion_mask

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixes for trl 0.28 and above

Remove sync/reload weights calls , remove vllm.LLM instantiation

* patch rpc in openenv for newer trl

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pluesclues <136766175+pluesclues@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.15.4 → v0.15.5](astral-sh/ruff-pre-commit@v0.15.4...v0.15.5)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly streamlines how attention mechanisms are selected and configured across various models within the unsloth library. By introducing a unified, hierarchical function, it ensures that the most performant and compatible attention implementation (Flash Attention 2, Flex Attention, SDPA, or Eager) is automatically chosen based on available hardware, model support, and environmental settings. This change enhances efficiency and consistency in attention handling without altering the core behavior for models with known limitations.

Highlights

  • Refactored Attention Implementation Logic: The core logic for determining the attention mechanism has been refactored into a new function, determine_attention_implementation, replacing the previous prefer_flex_attn_if_supported.
  • Hierarchical Attention Selection: The new function introduces a clear hierarchy for selecting attention implementations: Flash Attention 2 is prioritized, followed by Flex Attention, then SDPA, and finally falling back to Eager attention.
  • Model-Specific Exclusions Maintained: Specific model types like GPT-OSS, Mllama, and Gemma3 variants continue to be excluded from Flash Attention 2 and Flex Attention due to known compatibility issues, ensuring stable operation.
  • Unified Attention Configuration: The llama.py and vision.py models now consistently use the new determine_attention_implementation function to configure their attention mechanisms, simplifying the codebase and improving maintainability.
Changelog
  • unsloth/models/_utils.py
    • Renamed prefer_flex_attn_if_supported to determine_attention_implementation.
    • Implemented a new hierarchical logic for selecting attention mechanisms (Flash Attention 2, Flex Attention, SDPA, Eager).
    • Added checks for HAS_FLASH_ATTENTION and model-specific support for different attention types.
    • Maintained exclusions for certain model types (gpt_oss, mllama, gemma3) from Flash Attention 2 and Flex Attention.
  • unsloth/models/llama.py
    • Updated the from_pretrained method to use the new determine_attention_implementation function for selecting the preferred attention implementation.
  • unsloth/models/vision.py
    • Removed previous logic for flex_attn_impl and default_attn_impl.
    • Integrated the new determine_attention_implementation function to set attn_impl.
    • Simplified the logic for setting attn_implementation in kwargs, removing specific gemma3n checks.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the attention implementation selection logic by introducing a new determine_attention_implementation function. This function centralizes the logic for choosing the best available attention backend, prioritizing them in the order of Flash Attention 2, Flex Attention, SDPA, and Eager. This is a great improvement for code clarity and maintainability. I have one suggestion to further refactor the new function to reduce some code duplication.

Comment thread unsloth/models/_utils.py
Comment on lines +225 to +290
def determine_attention_implementation(model_class, config):
model_type = getattr(config, "model_type", "").lower()

if not is_torch_flex_attn_available():
return None
if model_class is None or not getattr(
model_class, "_supports_flex_attn", False
):
return None
# GPT-OSS, Mllama and Gemma3N use eager/sdpa attention during
# inference since flex attention returns incorrect results or errors out.
# GPT-OSS: left padding issues cause incorrect outputs.
# Mllama: _update_causal_mask uses make_flex_block_causal_mask which
# creates BlockMask with Q_LEN=KV_LEN=total_seq_len, but during
# decode q_len=1, causing ValueError. Needs transformers update.
# Gemma3N: timm vision wrappers (eg Gemma3nVisionConfig) do not
# support flex_attention.
model_type = getattr(config, "model_type", "") if config else ""
if model_type in ("gpt_oss", "mllama") or str(model_type).startswith("gemma3n"):
return None
# 1. Flash Attention 2
if (
HAS_FLASH_ATTENTION
and model_type not in ("gpt_oss", "mllama")
and not model_type.startswith("gemma3")
):
supports_fa2 = False
if model_class is not None:
supports_fa2 = getattr(
model_class, "_supports_flash_attn_2", False
) or getattr(model_class, "_supports_flash_attn", False)

if supports_fa2:
if config is not None:
setattr(config, "_attn_implementation", "flash_attention_2")
if hasattr(config, "attn_implementation"):
setattr(config, "attn_implementation", "flash_attention_2")
return "flash_attention_2"

# 2. Flex Attention
if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") != "0":
try:
from transformers.utils.import_utils import is_torch_flex_attn_available

if (
is_torch_flex_attn_available()
and (model_class is not None)
and getattr(model_class, "_supports_flex_attn", False)
):
# GPT-OSS, Mllama and Gemma3 use eager/sdpa attention during
# inference since flex attention returns incorrect results or errors out.
# GPT-OSS: left padding issues cause incorrect outputs.
# Mllama: _update_causal_mask uses make_flex_block_causal_mask which
# creates BlockMask with Q_LEN=KV_LEN=total_seq_len, but during
# decode q_len=1, causing ValueError. Needs transformers update.
# Gemma3N: timm vision wrappers (eg Gemma3nVisionConfig) do not
# support flex_attention.
if model_type not in (
"gpt_oss",
"mllama",
) and not model_type.startswith("gemma3"):
if config is not None:
setattr(config, "_attn_implementation", "flex_attention")
if hasattr(config, "attn_implementation"):
setattr(config, "attn_implementation", "flex_attention")
return "flex_attention"
except Exception:
pass

# 3. SDPA
if model_class is not None and getattr(model_class, "_supports_sdpa", False):
if config is not None:
setattr(config, "_attn_implementation", "flex_attention")
setattr(config, "_attn_implementation", "sdpa")
if hasattr(config, "attn_implementation"):
setattr(config, "attn_implementation", "flex_attention")
return "flex_attention"
except Exception:
return None
setattr(config, "attn_implementation", "sdpa")
return "sdpa"

# 4. Eager
if config is not None:
setattr(config, "_attn_implementation", "eager")
if hasattr(config, "attn_implementation"):
setattr(config, "attn_implementation", "eager")
return "eager"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function has a good structure, but there's a lot of repeated code for setting the _attn_implementation and attn_implementation attributes on the config object. This can be refactored into a helper function to improve maintainability and reduce code duplication. I've also taken the liberty to slightly simplify some of the conditional logic.

def determine_attention_implementation(model_class, config):
    model_type = getattr(config, "model_type", "").lower()

    def _set_attn_impl_and_return(implementation):
        if config is not None:
            setattr(config, "_attn_implementation", implementation)
            if hasattr(config, "attn_implementation"):
                setattr(config, "attn_implementation", implementation)
        return implementation

    # 1. Flash Attention 2
    if (
        HAS_FLASH_ATTENTION
        and model_type not in ("gpt_oss", "mllama")
        and not model_type.startswith("gemma3")
    ):
        supports_fa2 = (model_class is not None) and (
            getattr(model_class, "_supports_flash_attn_2", False)
            or getattr(model_class, "_supports_flash_attn", False)
        )
        if supports_fa2:
            return _set_attn_impl_and_return("flash_attention_2")

    # 2. Flex Attention
    if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") != "0":
        try:
            from transformers.utils.import_utils import is_torch_flex_attn_available

            if (
                is_torch_flex_attn_available()
                and (model_class is not None)
                and getattr(model_class, "_supports_flex_attn", False)
                and model_type not in ("gpt_oss", "mllama")
                and not model_type.startswith("gemma3")
            ):
                # GPT-OSS, Mllama and Gemma3 use eager/sdpa attention during
                # inference since flex attention returns incorrect results or errors out.
                # GPT-OSS: left padding issues cause incorrect outputs.
                # Mllama: _update_causal_mask uses make_flex_block_causal_mask which
                # creates BlockMask with Q_LEN=KV_LEN=total_seq_len, but during
                # decode q_len=1, causing ValueError. Needs transformers update.
                # Gemma3N: timm vision wrappers (eg Gemma3nVisionConfig) do not
                # support flex_attention.
                return _set_attn_impl_and_return("flex_attention")
        except Exception:
            pass

    # 3. SDPA
    if model_class is not None and getattr(model_class, "_supports_sdpa", False):
        return _set_attn_impl_and_return("sdpa")

    # 4. Eager
    return _set_attn_impl_and_return("eager")

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c0b9cd98e8

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth/models/vision.py
default_attn_impl = "flex_attention" if flex_attn_impl else "sdpa"
if not ("attn_implementation" in kwargs):
kwargs["attn_implementation"] = default_attn_impl
kwargs["attn_implementation"] = attn_impl
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore Gemma3N eager default for auto attention selection

This assignment now applies determine_attention_implementation() directly, which removed the previous gemma3n-specific eager override from FastBaseModel.from_pretrained. With the new helper, gemma3n is excluded from flash/flex and then falls through to SDPA when _supports_sdpa is true, so Gemma3N defaults to sdpa instead of the prior eager-safe path. That is a behavior regression for Gemma3N loads (especially when callers do not pass attn_implementation) and can reintroduce the attention incompatibility that the earlier Gemma3N guard in this loader was added to avoid.

Useful? React with 👍 / 👎.

@Datta0
Copy link
Copy Markdown
Collaborator

Datta0 commented Mar 12, 2026

I guess I'll let @mmathew23 review this as it doesn't make much sense to approve my code myself lol 😭

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.