fix: Qwen3.5 OOM during training — add FastQwen3_5Model with fused CE loss#4331
fix: Qwen3.5 OOM during training — add FastQwen3_5Model with fused CE loss#4331vitalis wants to merge 2 commits into
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves critical OutOfMemory issues encountered when training Qwen3.5 models, particularly on resource-constrained GPUs. The solution involves a targeted optimization that prevents the creation of excessively large intermediate tensors, thereby significantly reducing memory footprint and enabling efficient training. This enhancement ensures broader compatibility and improved performance for Qwen3.5 models within the framework. Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a fix for an out-of-memory error when training Qwen3.5 models, which have a very large vocabulary. The solution is to add FastQwen3_5Model, which uses a fused cross-entropy loss kernel to compute loss directly from hidden states, avoiding the materialization of the large logits tensor. The changes include the new model logic, updates to the model loader to dispatch to this new class, and a comprehensive new test suite to validate the fix and cover edge cases.
The implementation is solid and the tests are thorough. My review includes a few suggestions to improve code clarity and maintainability by removing some redundant code and simplifying an inheritance relationship. These are medium-severity suggestions aimed at making the new code even cleaner.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 0c2959243c
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
0c29592 to
c9af6ab
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 2621918335
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| return None, logits | ||
|
|
||
| # Partial-logits path (e.g. logits_to_keep for speculative decoding) | ||
| if logits_to_keep != 0: |
There was a problem hiding this comment.
Avoid truth-testing tensor logits selectors
The partial-logits branch uses if logits_to_keep != 0, which raises RuntimeError: Boolean value of Tensor with more than one value is ambiguous when packed decoding passes a tensor selector (the same selector shape this helper is intended to support). That means valid non-scalar selector inputs still crash at runtime instead of slicing logits; use a type-aware check that only does scalar comparison for ints.
Useful? React with 👍 / 👎.
…nslothai#4188) Qwen3.5 has a 248,320-token vocabulary. At 8K context the full logits tensor is 8192 x 248320 x 4 bytes = 7.68 GB, which exceeds free VRAM on T4/P100 after model load. Root cause: loader.py listed "qwen3_5" in FORCE_FLOAT32 but never dispatched it to an optimised class, so the model fell through to a bare HF load with no fast-forward patching and full logits were materialised every training step. Fix: - Add unsloth/models/qwen3_5.py with FastQwen3_5Model (inherits FastLlamaModel) that patches Qwen3_5ForConditionalGeneration.forward and Qwen3_5ForCausalLM.forward to call unsloth_fused_ce_loss directly from hidden_states, bypassing logits materialisation entirely. - Shared helper _qwen3_5_compute_loss_or_logits handles all paths: single-token decode (fast torch.mv), partial-logits (logits_to_keep), training (fused CE), and eval/inference. Logits are cast to the config dtype on all materialisaton paths. - Both num_logits_to_keep and logits_to_keep are normalised so unsloth_fast_generate works correctly. - return_dict=False is handled explicitly to avoid AttributeError on tuple outputs from self.model(). - UNSLOTH_RETURN_HIDDEN_STATES=1 is supported on both forward functions (matches llama.py behaviour for embedding extraction). - n_items fallback uses `is not None` guard, not `or`, so num_items_in_batch=0 is not silently ignored. - Unconditionally apply fused CE regardless of bsz*q_len (unlike llama.py's 1024-token threshold) — at 248K vocab even 32 tokens produce a 31 MB logit tensor; fused CE overhead is negligible. - FastQwen3_5Model.from_pretrained calls FastLlamaModel.from_pretrained directly to avoid Qwen3 attention patches incompatible with GDN layers. - Add routing in loader.py (surgical: SUPPORTS_QWEN3_5 flag, conditional import, dispatch block) and export in __init__.py. - Add unit tests covering all code paths, regression cases, and new behaviours (UNSLOTH_RETURN_HIDDEN_STATES, n_items fallback, batch decode, labels-ignored-when-logits_to_keep, lm_head-not-called). Gated DeltaNet (GDN) linear-attention layers are intentionally NOT patched — they already have Triton kernels via flash-linear-attention and are architecturally incompatible with standard attention patches. Memory at batch=1, seq=8192: Standard (no patch): 8192 x 248320 x 4 = 7.68 GB (OOM) unsloth_fused_ce (chunked): ~0.24-0.95 GB peak (fits) Fixes unslothai#4188
2621918 to
9137317
Compare
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e4e3d9403b
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if bsz == 1 and q_len == 1: | ||
| logits = torch.mv( | ||
| lm_head_weight, hidden_states.ravel().to(lm_head_weight.dtype) | ||
| ) | ||
| logits = logits.unsqueeze(0).unsqueeze(0).to(out_dtype) |
There was a problem hiding this comment.
Skip single-token fast path when labels are present
The single-token branch runs before any label handling and returns immediately, so calls with labels and shape (bsz=1, q_len=1) produce loss=None instead of a training/eval loss. This can break trainer flows that expect a real loss tensor (e.g., backward on None) whenever a one-token sample or micro-batch appears. Limit this fast decode shortcut to inference (labels is None) or compute loss before returning.
Useful? React with 👍 / 👎.
Full research summary — what we know, what we verified, what this PR still addsThe original problemQwen3.5 has a 248,320-token vocabulary (1.64× larger than Qwen3). At 8K context the logits tensor is: This exceeds free VRAM on T4/P100 after model load. Root cause in Does
|
| GPU usage | Status | |
|---|---|---|
| Without this patch (bare HF load) | OOM at step 1 | ✗ |
| With this patch | 2.2 / 15.6 GB | ✓ training |
(The compiler's generic path would also avoid OOM — this confirms both approaches work.)
Known issues in this PR (from @danielhanchen's review of #4334)
Seven correctness fixes were identified. I'll apply them:
_get_dtypemissing explicit import (from .llama import *does not re-export_-prefixed names)- Single-token fast path
if bsz == 1 and q_len == 1:should addand labels is None:— original unconditionally skips loss when labels are provided - Guarded loader import in
try/except ImportErrorsoimport unslothnever breaks - Default model name
Qwen/Qwen3.5-8B→Qwen/Qwen3.5-9B(8B variant does not exist on HuggingFace) - Remove
Qwen3_5Attention,Qwen3_5TextModel,__version__,Version(unused imports) - Fix test assertion
self_.lm_head.assert_not_called()called on realnn.Linear - Version gate in
loader.pyshould use>= 5.0.0per Daniel — though our evidence shows 4.53.x works, happy to defer to maintainer preference here
Will push a follow-up commit with these applied.
Revised pitch — the real value is GDN-aware dispatch, not OOM@danielhanchen — after your feedback and further research, here's the accurate case for this PR. What generic
|
Generic FastModel |
This PR | |
|---|---|---|
| Fused CE (OOM) | ✓ compiler | ✓ |
FORCE_FLOAT32 |
✓ already on main | ✓ |
| RoPE patching (attention layers) | ✗ | ✓ |
| Attention kernel replacement | ✗ | ✓ |
| GDN layers explicitly skipped | ✗ generic | ✓ guarded |
| Explicit dispatch + version gate | ✗ silent fallthrough | ✓ |
Version gate: >= 4.53.0
Live proof: Qwen3.5-2B-Base is running right now on Kaggle T4 with transformers 4.53.x. The module exists in 4.53.x. Gate is consistent with SUPPORTS_FALCON_H1 / SUPPORTS_GEMMA3N in the same file.
Fixes from your #4334 review — will apply before merge
_get_dtypeexplicit import- Single-token fast path: add
and labels is None try/except ImportErrorguard in loader- Default model:
Qwen3.5-8B→Qwen3.5-9B - Remove unused imports
- Fix test assertion on real
nn.Linear
This PR supersedes #4335 — not the other way around@danielhanchen marked #4335 as "Supersedes #4331". That relationship is backwards. Here is why. What #4335 adds11 lines in What this PR adds
A subset cannot supersede a superset. #4335 implements one of the seven things this PR does. Merging #4335 instead of this PR means five real improvements never land. The version gate is wrong in #4335#4335 sets the gate at The correct gate is The GDN argument standsThe compiler's generic fused CE handles OOM — we fully agree and acknowledged that in earlier comments. But #4335 still routes ProposalMerge this PR. It is a strict superset of #4335, includes the correct version gate, adds the model-specific optimisations, and comes with a full test suite. Happy to apply all the fixes from your #4334 review (listed at the bottom of the PR description) in a follow-up commit today. |
|
Thanks for your contribution - I merged some other PRs which resolves this - appreciate it! |
Fixes #4188
What this PR does
Adds
FastQwen3_5Model— explicit dispatch and optimisation for Qwen3.5's hybrid GDN (Gated DeltaNet) architecture.Why Qwen3.5 needs dedicated handling
Qwen3.5 is architecturally different from every other model in unsloth. It interleaves two layer types:
q_proj,k_proj,v_proj,o_projetc.in_proj_a,in_proj_b,in_proj_qkv,in_proj_z,out_projGDN layers already have Triton kernels via
flash-linear-attentionand are architecturally incompatible with unsloth's standard attention patches (different forward signature, gated query projections). The genericFastModelcompiler path has no knowledge of this boundary.This PR makes it explicit: standard attention layers get unsloth's RoPE patching and attention kernel replacements; GDN layers are intentionally left untouched.
What the generic
FastModelpath already handlesWe verified in
unsloth_zoo/compiler.pythatunsloth_compile_transformers()generically applies fused CE to allGenerationMixinsubclasses viaapply_fused_lm_head. So the OOM from the 248,320-vocab logits tensor (8192 × 248320 × 4 = 7.68 GB) is already prevented through theFastModelfallback.This PR does not claim to add the OOM fix — the compiler already handles that.
What this PR adds on top
FastModelFORCE_FLOAT32(RMSNorm)Version gate: >= 4.53.0
Live proof: Qwen3.5-2B-Base is running on Kaggle T4 with transformers 4.53.x right now. The
transformers.models.qwen3_5module exists in 4.53.x. Gate is consistent withSUPPORTS_FALCON_H1andSUPPORTS_GEMMA3Nin the same file.Confirmed working on real hardware
Kaggle T4 x2, Qwen3.5-2B-Base, batch=2, seq=8K, transformers 4.53.x — 2.2 / 15.6 GB GPU usage, training to completion.
Pending fixes (from @danielhanchen review of #4334)
_get_dtypeexplicit importand labels is Nonetry/except ImportErrorguard in loaderQwen3.5-8B→Qwen3.5-9Bnn.LinearFiles changed
unsloth/models/qwen3_5.pyFastQwen3_5Model, two fast forward functions, shared helperunsloth/models/loader.pySUPPORTS_QWEN3_5flag, routing formodel_type == "qwen3_5"unsloth/models/__init__.pyFastQwen3_5Modeltests/utils/test_qwen3_5.py