fix: add FastQwen3_5Model with fused CE loss for Qwen3.5 OOM#4334
fix: add FastQwen3_5Model with fused CE loss for Qwen3.5 OOM#4334danielhanchen wants to merge 2 commits into
Conversation
Qwen3.5 has a 248,320-token vocabulary. At 8K context the full logits tensor is 8192 x 248320 x 4 = 7.68 GB, which causes OOM on T4/P100. The unsloth compiler already applies fused CE via apply_fused_lm_head, but this adds an explicit FastQwen3_5Model dispatch for cleaner routing and better error messages when Qwen3.5 is not supported. Changes: - Add unsloth/models/qwen3_5.py with FastQwen3_5Model that patches Qwen3_5ForConditionalGeneration and Qwen3_5ForCausalLM forwards to use unsloth_fused_ce_loss directly from hidden_states - Add loader dispatch for model_type == "qwen3_5" before "qwen3" - Version gate uses >= 5.0.0 (qwen3_5 only exists in transformers 5.x) - Guarded import in loader.py with try/except fallback - GDN layers intentionally left unpatched (flash-linear-attention) - 23 unit tests covering all 4 code paths Fixes from original PR #4331 by @vitalis: - Add explicit _get_dtype import (wildcard import skips _-prefixed names) - Single-token fast path now checks labels is None before returning early - Default model name corrected to Qwen/Qwen3.5-9B (8B does not exist) - Test assertion on nn.Linear removed (not a mock) - Unused imports removed Tested: Qwen3.5-0.8B 4bit training, 1.38 GB peak memory, 23/23 tests pass. Backwards compatible: import unsloth works on transformers 4.57.6.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Unsloth library by providing optimized support for Qwen3.5 models. The primary goal is to resolve critical memory issues encountered during training due to the model's exceptionally large vocabulary, making it feasible to run these models on more constrained hardware. The changes ensure efficient memory utilization and improved stability for Qwen3.5, while also refining the integration process within the Unsloth ecosystem. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces FastQwen3_5Model to address an out-of-memory issue with Qwen3.5 models by using a fused cross-entropy loss, avoiding the materialization of large logit tensors. The changes are well-implemented and include comprehensive unit tests, robust version checking, and graceful error handling. My review includes a couple of minor suggestions to improve code clarity and conciseness, but overall, this is a high-quality contribution that effectively solves the reported issue.
| # they already have Triton kernels via flash-linear-attention and are | ||
| # architecturally incompatible with Unsloth's standard attention optimisations. | ||
|
|
||
| from .llama import * |
There was a problem hiding this comment.
Using a wildcard import (from .llama import *) can make the code harder to read and maintain as it's not immediately clear where names are coming from. It's a best practice to import names explicitly. Based on the usage in this file, you only need unsloth_fused_ce_loss and EMPTY_LOGITS.
| from .llama import * | |
| from .llama import unsloth_fused_ce_loss, EMPTY_LOGITS |
| n_items = kwargs.get("num_items_in_batch") | ||
| if n_items is None: | ||
| n_items = kwargs.get("n_items") |
There was a problem hiding this comment.
Tests require transformers >= 5.0.0 which is not yet widely deployed. The fused CE path is already covered by the compiler's generic apply_fused_lm_head mechanism and verified via training runs.
|
Closing -- the unsloth compiler already applies fused CE to Qwen3.5 automatically via apply_fused_lm_head on transformers 5.x. The FORCE_FLOAT32 entry for qwen3_5 is already on main. No additional code needed. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a7b944ee57
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| return_dict if return_dict is not None else self.config.use_return_dict | ||
| ) | ||
| # Normalise both generation knobs | ||
| logits_to_keep = max(logits_to_keep, num_logits_to_keep) |
There was a problem hiding this comment.
Preserve non-scalar logits_to_keep values
This normalization assumes logits_to_keep is always an integer, but _qwen3_5_compute_loss_or_logits explicitly supports non-ints (e.g., index tensors/slices) in the partial-logits path. If callers pass an index tensor (as done in packed/incremental decoding flows), max(logits_to_keep, num_logits_to_keep) will raise before forward reaches loss/logit computation, causing Qwen3.5 generation/training calls to fail at runtime. Only scalar knobs should be merged with max, while tensor/slice logits_to_keep should be forwarded unchanged.
Useful? React with 👍 / 👎.
|
@danielhanchen — thanks for the detailed review and fixes in this PR. The 7 corrections (especially the However, I checked upstream
The I confirmed the OOM is still reproducible on Kaggle T4 with transformers 4.57.x + Qwen3.5-2B-Base without this patch. With it: 2.2/15.6 GB, trains to completion. Would you be willing to reopen this (or review #4331 with your fixes applied)? |
|
@danielhanchen — correction to my previous comment. You were right. After looking at The transformers 4.x / Kaggle concern I raised is also moot — Qwen3.5 wasn't added until transformers 5.2.0, so no one can even load the model on 4.57.x. Sorry for the noise. The OOM fix is indeed already handled. |
Summary
Fixes #4188. Adds
FastQwen3_5Modelfor Qwen3.5 models with explicit fused CE loss patching and loader dispatch.Qwen3.5 has a 248,320-token vocabulary (1.64x larger than Qwen3). At 8K context the full logits tensor is 7.68 GB, causing OOM on T4/P100. This patches both
Qwen3_5ForConditionalGenerationandQwen3_5ForCausalLMto useunsloth_fused_ce_lossdirectly from hidden states, bypassing logits materialization entirely.Note: The unsloth compiler already applies fused CE generically via
apply_fused_lm_headfor Qwen3.5. This PR adds an explicit dispatch for cleaner routing and better error messages.Changes
unsloth/models/qwen3_5.py--FastQwen3_5Modelwith patched forwards usingunsloth_fused_ce_lossunsloth/models/loader.py--SUPPORTS_QWEN3_5(>= 5.0.0), guarded import withtry/except, dispatch beforeqwen3to prevent misroutingunsloth/models/__init__.py-- Export withexcept ImportErrorguardtests/utils/test_qwen3_5.py-- 23 unit tests covering all 4 code pathsFixes applied (vs original PR #4331 by @vitalis)
_get_dtypemissing import --from .llama import *does not export_-prefixed names sincellama.pyhas no__all__. Every other model file imports_get_dtypeexplicitly. Addedfrom unsloth_zoo.utils import _get_dtype.>= 4.53.0to>= 5.0.0.transformers.models.qwen3_5does not exist in any 4.x release. The original gate would crashimport unslothon transformers 4.53.0 through 4.57.6.try/except ImportErrorsoimport unslothnever breaks even if qwen3_5 is unavailable.if bsz == 1 and q_len == 1:toif bsz == 1 and q_len == 1 and labels is None:. The original unconditionally returned(None, logits)even when labels were provided, silently skipping loss computation. Inllama.pythis path falls through to loss computation.Qwen/Qwen3.5-8Bdoes not exist on HuggingFace. Changed toQwen/Qwen3.5-9B.self_.lm_head.assert_not_called()which called a mock method on a realnn.Linear.Qwen3_5Attention,Qwen3_5TextModel,__version__,Version.Test results
import unslothworks,SUPPORTS_QWEN3_5 = False