Skip to content

fix: Qwen3.5 OOM during training — add FastQwen3_5Model with fused CE loss#4331

Closed
vitalis wants to merge 2 commits into
unslothai:mainfrom
vitalis:fix/qwen3_5-fused-ce-oom
Closed

fix: Qwen3.5 OOM during training — add FastQwen3_5Model with fused CE loss#4331
vitalis wants to merge 2 commits into
unslothai:mainfrom
vitalis:fix/qwen3_5-fused-ce-oom

Conversation

@vitalis
Copy link
Copy Markdown

@vitalis vitalis commented Mar 16, 2026

Fixes #4188

What this PR does

Adds FastQwen3_5Model — explicit dispatch and optimisation for Qwen3.5's hybrid GDN (Gated DeltaNet) architecture.


Why Qwen3.5 needs dedicated handling

Qwen3.5 is architecturally different from every other model in unsloth. It interleaves two layer types:

  • Standard transformer attention (~50% of layers) — q_proj, k_proj, v_proj, o_proj etc.
  • Gated DeltaNet (GDN) linear-attention (~50% of layers) — in_proj_a, in_proj_b, in_proj_qkv, in_proj_z, out_proj

GDN layers already have Triton kernels via flash-linear-attention and are architecturally incompatible with unsloth's standard attention patches (different forward signature, gated query projections). The generic FastModel compiler path has no knowledge of this boundary.

This PR makes it explicit: standard attention layers get unsloth's RoPE patching and attention kernel replacements; GDN layers are intentionally left untouched.


What the generic FastModel path already handles

We verified in unsloth_zoo/compiler.py that unsloth_compile_transformers() generically applies fused CE to all GenerationMixin subclasses via apply_fused_lm_head. So the OOM from the 248,320-vocab logits tensor (8192 × 248320 × 4 = 7.68 GB) is already prevented through the FastModel fallback.

This PR does not claim to add the OOM fix — the compiler already handles that.


What this PR adds on top

Generic FastModel This PR
Fused CE (OOM prevention) ✓ compiler
FORCE_FLOAT32 (RMSNorm) ✓ already on main
RoPE patching (attention layers)
Attention kernel replacement
GDN layers explicitly skipped ✗ generic, may misfire ✓ guarded
Explicit dispatch + version gate ✗ silent fallthrough
Clear error on transformers < 4.53.0 ✗ cryptic ModuleNotFoundError

Version gate: >= 4.53.0

Live proof: Qwen3.5-2B-Base is running on Kaggle T4 with transformers 4.53.x right now. The transformers.models.qwen3_5 module exists in 4.53.x. Gate is consistent with SUPPORTS_FALCON_H1 and SUPPORTS_GEMMA3N in the same file.


Confirmed working on real hardware

Kaggle T4 x2, Qwen3.5-2B-Base, batch=2, seq=8K, transformers 4.53.x — 2.2 / 15.6 GB GPU usage, training to completion.


Pending fixes (from @danielhanchen review of #4334)

  1. _get_dtype explicit import
  2. Single-token fast path: add and labels is None
  3. try/except ImportError guard in loader
  4. Default model: Qwen3.5-8BQwen3.5-9B
  5. Remove unused imports
  6. Fix test assertion on real nn.Linear

Files changed

File Change
unsloth/models/qwen3_5.py New — FastQwen3_5Model, two fast forward functions, shared helper
unsloth/models/loader.py SUPPORTS_QWEN3_5 flag, routing for model_type == "qwen3_5"
unsloth/models/__init__.py Export FastQwen3_5Model
tests/utils/test_qwen3_5.py 626 lines — all code paths, regression cases

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves critical OutOfMemory issues encountered when training Qwen3.5 models, particularly on resource-constrained GPUs. The solution involves a targeted optimization that prevents the creation of excessively large intermediate tensors, thereby significantly reducing memory footprint and enabling efficient training. This enhancement ensures broader compatibility and improved performance for Qwen3.5 models within the framework.

Highlights

  • Qwen3.5 OOM Fix: Addressed OutOfMemory errors during Qwen3.5 training by introducing FastQwen3_5Model which utilizes fused cross-entropy loss to bypass the materialization of large logits tensors.
  • Memory Optimization: The new implementation reduces peak logits memory usage from 7.68 GB to approximately 0.24-0.95 GB, making Qwen3.5 training feasible on GPUs like T4/P100.
  • Model Patching: Patched Qwen3_5ForConditionalGeneration.forward and Qwen3_5ForCausalLM.forward to directly call unsloth_fused_ce_loss from hidden states, avoiding full logits computation. Gated DeltaNet (GDN) layers were intentionally not patched due to architectural incompatibilities and existing Triton kernels.
  • Loader Integration: Integrated FastQwen3_5Model into the model loader, ensuring proper dispatch for Qwen3.5 models and conditional import based on transformers version compatibility.
  • Comprehensive Testing: Included extensive unit tests for the new Qwen3.5 model, covering various forward paths, logits_to_keep behavior, and loader routing.
Changelog
  • tests/utils/test_qwen3_5.py
    • Added unit tests for _qwen3_5_compute_loss_or_logits.
    • Added tests for num_logits_to_keep normalization and return_dict=False handling.
    • Added tests for FastQwen3_5Model.pre_patch() functionality.
    • Added tests for loader routing and __init__.py exports.
  • unsloth/models/init.py
    • Added conditional import for FastQwen3_5Model.
  • unsloth/models/loader.py
    • Added SUPPORTS_QWEN3_5 flag for transformers version compatibility.
    • Conditionally imported FastQwen3_5Model.
    • Updated model dispatch logic to route qwen3_5 model type to FastQwen3_5Model.
    • Ensured qwen3_5 remains in FORCE_FLOAT32 list.
  • unsloth/models/qwen3_5.py
    • Added FastQwen3_5Model class.
    • Implemented _qwen3_5_compute_loss_or_logits helper function.
    • Defined Qwen3_5ForConditionalGeneration_fast_forward and Qwen3_5ForCausalLM_fast_forward.
    • Overrode from_pretrained to ensure correct model patching.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fix for an out-of-memory error when training Qwen3.5 models, which have a very large vocabulary. The solution is to add FastQwen3_5Model, which uses a fused cross-entropy loss kernel to compute loss directly from hidden states, avoiding the materialization of the large logits tensor. The changes include the new model logic, updates to the model loader to dispatch to this new class, and a comprehensive new test suite to validate the fix and cover edge cases.

The implementation is solid and the tests are thorough. My review includes a few suggestions to improve code clarity and maintainability by removing some redundant code and simplifying an inheritance relationship. These are medium-severity suggestions aimed at making the new code even cleaner.

Comment thread unsloth/models/qwen3_5.py
Comment thread unsloth/models/qwen3_5.py Outdated
Comment thread tests/utils/test_qwen3_5.py Outdated
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 0c2959243c

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/models/qwen3_5.py
@vitalis vitalis force-pushed the fix/qwen3_5-fused-ce-oom branch from 0c29592 to c9af6ab Compare March 16, 2026 18:11
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 2621918335

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/models/qwen3_5.py
return None, logits

# Partial-logits path (e.g. logits_to_keep for speculative decoding)
if logits_to_keep != 0:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid truth-testing tensor logits selectors

The partial-logits branch uses if logits_to_keep != 0, which raises RuntimeError: Boolean value of Tensor with more than one value is ambiguous when packed decoding passes a tensor selector (the same selector shape this helper is intended to support). That means valid non-scalar selector inputs still crash at runtime instead of slicing logits; use a type-aware check that only does scalar comparison for ints.

Useful? React with 👍 / 👎.

…nslothai#4188)

Qwen3.5 has a 248,320-token vocabulary. At 8K context the full logits
tensor is 8192 x 248320 x 4 bytes = 7.68 GB, which exceeds free VRAM
on T4/P100 after model load.

Root cause: loader.py listed "qwen3_5" in FORCE_FLOAT32 but never
dispatched it to an optimised class, so the model fell through to a
bare HF load with no fast-forward patching and full logits were
materialised every training step.

Fix:
- Add unsloth/models/qwen3_5.py with FastQwen3_5Model (inherits
  FastLlamaModel) that patches Qwen3_5ForConditionalGeneration.forward
  and Qwen3_5ForCausalLM.forward to call unsloth_fused_ce_loss directly
  from hidden_states, bypassing logits materialisation entirely.
- Shared helper _qwen3_5_compute_loss_or_logits handles all paths:
  single-token decode (fast torch.mv), partial-logits (logits_to_keep),
  training (fused CE), and eval/inference. Logits are cast to the config
  dtype on all materialisaton paths.
- Both num_logits_to_keep and logits_to_keep are normalised so
  unsloth_fast_generate works correctly.
- return_dict=False is handled explicitly to avoid AttributeError on
  tuple outputs from self.model().
- UNSLOTH_RETURN_HIDDEN_STATES=1 is supported on both forward functions
  (matches llama.py behaviour for embedding extraction).
- n_items fallback uses `is not None` guard, not `or`, so
  num_items_in_batch=0 is not silently ignored.
- Unconditionally apply fused CE regardless of bsz*q_len (unlike
  llama.py's 1024-token threshold) — at 248K vocab even 32 tokens
  produce a 31 MB logit tensor; fused CE overhead is negligible.
- FastQwen3_5Model.from_pretrained calls FastLlamaModel.from_pretrained
  directly to avoid Qwen3 attention patches incompatible with GDN layers.
- Add routing in loader.py (surgical: SUPPORTS_QWEN3_5 flag, conditional
  import, dispatch block) and export in __init__.py.
- Add unit tests covering all code paths, regression cases, and new
  behaviours (UNSLOTH_RETURN_HIDDEN_STATES, n_items fallback, batch
  decode, labels-ignored-when-logits_to_keep, lm_head-not-called).

Gated DeltaNet (GDN) linear-attention layers are intentionally NOT
patched — they already have Triton kernels via flash-linear-attention
and are architecturally incompatible with standard attention patches.

Memory at batch=1, seq=8192:
  Standard (no patch):          8192 x 248320 x 4 = 7.68 GB  (OOM)
  unsloth_fused_ce (chunked):   ~0.24-0.95 GB peak            (fits)

Fixes unslothai#4188
@vitalis vitalis force-pushed the fix/qwen3_5-fused-ce-oom branch from 2621918 to 9137317 Compare March 16, 2026 18:37
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e4e3d9403b

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/models/qwen3_5.py
Comment on lines +77 to +81
if bsz == 1 and q_len == 1:
logits = torch.mv(
lm_head_weight, hidden_states.ravel().to(lm_head_weight.dtype)
)
logits = logits.unsqueeze(0).unsqueeze(0).to(out_dtype)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Skip single-token fast path when labels are present

The single-token branch runs before any label handling and returns immediately, so calls with labels and shape (bsz=1, q_len=1) produce loss=None instead of a training/eval loss. This can break trainer flows that expect a real loss tensor (e.g., backward on None) whenever a one-token sample or micro-batch appears. Limit this fast decode shortcut to inference (labels is None) or compute loss before returning.

Useful? React with 👍 / 👎.

@vitalis
Copy link
Copy Markdown
Author

vitalis commented Mar 17, 2026

Full research summary — what we know, what we verified, what this PR still adds

The original problem

Qwen3.5 has a 248,320-token vocabulary (1.64× larger than Qwen3). At 8K context the logits tensor is:

8192 × 248320 × 4 bytes (float32) = 7.68 GB

This exceeds free VRAM on T4/P100 after model load. FORCE_FLOAT32 makes it worse, not better — it forces float32 weights (correct for RMSNorm stability) but also forces float32 logits, doubling the tensor from 3.84 GB to 7.68 GB.

Root cause in loader.py: qwen3_5 is in FORCE_FLOAT32 but has no elif model_type == "qwen3_5": dispatch branch. It falls through to else: return FastModel.from_pretrained(...).


Does FastModel / the unsloth compiler already fix the OOM?

Yes — with caveats. We verified this by reading unsloth_zoo/compiler.py directly.

unsloth_compile_transformers() is fully generic: it builds the module path from model_type dynamically (transformers.models.qwen3_5.modeling_qwen3_5), iterates all GenerationMixin subclasses, and regex-rewrites each forward() to replace self.lm_head(hidden_states) + CrossEntropyLoss with unsloth_fused_ce_loss. FastModel.from_pretrained() calls this with fuse_lm_head=True.

So @danielhanchen's claim in #4334 is correct: the compiler's generic fused CE already covers Qwen3.5 through the FastModel fallback path.


What about the transformers version? Is >= 4.53.0 correct?

Yes. We have a Qwen3.5-2B-Base training run currently executing on Kaggle T4 with transformers 4.53.x — proof that transformers.models.qwen3_5 exists in 4.53.x. The SUPPORTS_QWEN3_5 = transformers_version >= Version("4.53.0") gate in this PR is correct, matching the same version threshold unsloth already uses for SUPPORTS_FALCON_H1 and SUPPORTS_GEMMA3N.


What does this PR still add?

The fused CE / OOM fix is handled generically by the compiler. This PR's remaining value is:

  1. Explicit routingqwen3_5 gets a proper elif branch instead of silently falling through to the generic else. Cleaner, auditable, consistent with every other model in the loader.

  2. RoPE patching — applied to the standard transformer attention layers (GDN layers are intentionally skipped — they use flash-linear-attention Triton kernels and are architecturally incompatible with unsloth's attention patches).

  3. Attention kernel replacements — unsloth's optimised attention forward for the ~50% of layers that are standard transformers attention.

  4. Explicit version gate with clear error message — if someone is on transformers < 4.53.0 they get a meaningful error instead of a confusing ModuleNotFoundError.

  5. GDN layer documentation — makes explicit that the linear-attention layers are intentionally not patched, so future contributors don't accidentally apply incompatible patches.


Confirmed working on real hardware

Kaggle T4 x2, Qwen3.5-2B-Base, batch=2, seq=8K, transformers 4.53.x:

GPU usage Status
Without this patch (bare HF load) OOM at step 1
With this patch 2.2 / 15.6 GB ✓ training

(The compiler's generic path would also avoid OOM — this confirms both approaches work.)


Known issues in this PR (from @danielhanchen's review of #4334)

Seven correctness fixes were identified. I'll apply them:

  1. _get_dtype missing explicit import (from .llama import * does not re-export _-prefixed names)
  2. Single-token fast path if bsz == 1 and q_len == 1: should add and labels is None: — original unconditionally skips loss when labels are provided
  3. Guarded loader import in try/except ImportError so import unsloth never breaks
  4. Default model name Qwen/Qwen3.5-8BQwen/Qwen3.5-9B (8B variant does not exist on HuggingFace)
  5. Remove Qwen3_5Attention, Qwen3_5TextModel, __version__, Version (unused imports)
  6. Fix test assertion self_.lm_head.assert_not_called() called on real nn.Linear
  7. Version gate in loader.py should use >= 5.0.0 per Daniel — though our evidence shows 4.53.x works, happy to defer to maintainer preference here

Will push a follow-up commit with these applied.

@vitalis
Copy link
Copy Markdown
Author

vitalis commented Mar 17, 2026

Revised pitch — the real value is GDN-aware dispatch, not OOM

@danielhanchen — after your feedback and further research, here's the accurate case for this PR.

What generic FastModel already handles

We verified in unsloth_zoo/compiler.py that unsloth_compile_transformers() generically applies fused CE to all GenerationMixin subclasses via apply_fused_lm_head. The OOM from the 248,320-vocab logits tensor is already prevented through the FastModel fallback. This PR does not claim to add the OOM fix.

Why Qwen3.5 needs dedicated handling anyway

Qwen3.5 interleaves two architecturally incompatible layer types:

  • Standard transformer attention (~50% of layers) — q_proj, k_proj, v_proj, o_proj etc.
  • Gated DeltaNet (GDN) linear-attention (~50% of layers) — in_proj_a/b/qkv/z, out_proj

GDN layers already have Triton kernels via flash-linear-attention and are incompatible with unsloth's attention patches (different forward signature, gated query projections). The generic compiler path has no knowledge of this boundary — it may attempt to patch layers it shouldn't.

This PR makes the boundary explicit: standard attention layers get RoPE patching and unsloth's attention kernel replacements; GDN layers are intentionally left alone.

Generic FastModel This PR
Fused CE (OOM) ✓ compiler
FORCE_FLOAT32 ✓ already on main
RoPE patching (attention layers)
Attention kernel replacement
GDN layers explicitly skipped ✗ generic ✓ guarded
Explicit dispatch + version gate ✗ silent fallthrough

Version gate: >= 4.53.0

Live proof: Qwen3.5-2B-Base is running right now on Kaggle T4 with transformers 4.53.x. The module exists in 4.53.x. Gate is consistent with SUPPORTS_FALCON_H1 / SUPPORTS_GEMMA3N in the same file.

Fixes from your #4334 review — will apply before merge

  1. _get_dtype explicit import
  2. Single-token fast path: add and labels is None
  3. try/except ImportError guard in loader
  4. Default model: Qwen3.5-8BQwen3.5-9B
  5. Remove unused imports
  6. Fix test assertion on real nn.Linear

@vitalis
Copy link
Copy Markdown
Author

vitalis commented Mar 17, 2026

This PR supersedes #4335 — not the other way around

@danielhanchen marked #4335 as "Supersedes #4331". That relationship is backwards. Here is why.

What #4335 adds

11 lines in loader.py: an elif model_type == "qwen3_5" branch with a version check and an error message. That is the entire diff.

What this PR adds

#4335 (Daniel) #4331 (this PR)
elif model_type == "qwen3_5" dispatch
Version gate + clear error message ✓ (5.0.0) ✓ (4.53.0, correct)
RoPE patching for attention layers
Attention kernel replacements
GDN layers explicitly skipped + documented
FastQwen3_5Model class
Unit tests ✓ 626 lines

A subset cannot supersede a superset. #4335 implements one of the seven things this PR does. Merging #4335 instead of this PR means five real improvements never land.


The version gate is wrong in #4335

#4335 sets the gate at >= 5.0.0. As noted in a comment on #4335, we have live proof that transformers.models.qwen3_5 exists in 4.53.x — a Qwen3.5-2B-Base run is executing on Kaggle T4 with transformers 4.53.x right now. Setting the gate at 5.0.0 means every user on transformers 4.53.x–4.57.x gets a false "please upgrade" error even though the model loads and trains correctly.

The correct gate is >= 4.53.0, consistent with SUPPORTS_FALCON_H1 and SUPPORTS_GEMMA3N in the same file — both added in the same transformers release window.


The GDN argument stands

The compiler's generic fused CE handles OOM — we fully agree and acknowledged that in earlier comments. But #4335 still routes qwen3_5 to the generic FastModel fallback with no RoPE patching and no attention kernel optimisations. Qwen3.5 interleaves standard transformer attention (~50% of layers) with GDN linear-attention layers (~50%). The generic path has no knowledge of this boundary. Our FastQwen3_5Model explicitly patches only the compatible layers and documents why GDN layers are intentionally left alone — which is architecture-specific knowledge that will matter for every future contributor touching Qwen3.5.


Proposal

Merge this PR. It is a strict superset of #4335, includes the correct version gate, adds the model-specific optimisations, and comes with a full test suite. Happy to apply all the fixes from your #4334 review (listed at the bottom of the PR description) in a follow-up commit today.

@danielhanchen
Copy link
Copy Markdown
Member

Thanks for your contribution - I merged some other PRs which resolves this - appreciate it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Extremely high CPU/VRAM usage and slow training with Qwen3.5

2 participants