Skip to content

fix: add FastQwen3_5Model with fused CE loss for Qwen3.5 OOM#4334

Closed
danielhanchen wants to merge 2 commits into
mainfrom
fix/qwen3_5-fused-ce-clean
Closed

fix: add FastQwen3_5Model with fused CE loss for Qwen3.5 OOM#4334
danielhanchen wants to merge 2 commits into
mainfrom
fix/qwen3_5-fused-ce-clean

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

Summary

Fixes #4188. Adds FastQwen3_5Model for Qwen3.5 models with explicit fused CE loss patching and loader dispatch.

Qwen3.5 has a 248,320-token vocabulary (1.64x larger than Qwen3). At 8K context the full logits tensor is 7.68 GB, causing OOM on T4/P100. This patches both Qwen3_5ForConditionalGeneration and Qwen3_5ForCausalLM to use unsloth_fused_ce_loss directly from hidden states, bypassing logits materialization entirely.

Note: The unsloth compiler already applies fused CE generically via apply_fused_lm_head for Qwen3.5. This PR adds an explicit dispatch for cleaner routing and better error messages.

Changes

  • unsloth/models/qwen3_5.py -- FastQwen3_5Model with patched forwards using unsloth_fused_ce_loss
  • unsloth/models/loader.py -- SUPPORTS_QWEN3_5 (>= 5.0.0), guarded import with try/except, dispatch before qwen3 to prevent misrouting
  • unsloth/models/__init__.py -- Export with except ImportError guard
  • tests/utils/test_qwen3_5.py -- 23 unit tests covering all 4 code paths

Fixes applied (vs original PR #4331 by @vitalis)

  • _get_dtype missing import -- from .llama import * does not export _-prefixed names since llama.py has no __all__. Every other model file imports _get_dtype explicitly. Added from unsloth_zoo.utils import _get_dtype.
  • Version gate -- Changed from >= 4.53.0 to >= 5.0.0. transformers.models.qwen3_5 does not exist in any 4.x release. The original gate would crash import unsloth on transformers 4.53.0 through 4.57.6.
  • Guarded loader import -- Wrapped loader import in try/except ImportError so import unsloth never breaks even if qwen3_5 is unavailable.
  • Single-token fast path -- Changed if bsz == 1 and q_len == 1: to if bsz == 1 and q_len == 1 and labels is None:. The original unconditionally returned (None, logits) even when labels were provided, silently skipping loss computation. In llama.py this path falls through to loss computation.
  • Default model name -- Qwen/Qwen3.5-8B does not exist on HuggingFace. Changed to Qwen/Qwen3.5-9B.
  • Test assertion -- Removed self_.lm_head.assert_not_called() which called a mock method on a real nn.Linear.
  • Unused imports -- Removed Qwen3_5Attention, Qwen3_5TextModel, __version__, Version.

Test results

Environment Result
transformers 5.3.0, 23 unit tests 23/23 pass
Qwen3.5-0.8B 4bit training (10 steps) 1.38 GB peak, loss 1.88
Generic FastModel path (no dispatch) 1.38 GB peak, loss 1.88 (identical -- compiler handles it)
Backwards compat: transformers 4.57.6 import unsloth works, SUPPORTS_QWEN3_5 = False

Qwen3.5 has a 248,320-token vocabulary. At 8K context the full logits
tensor is 8192 x 248320 x 4 = 7.68 GB, which causes OOM on T4/P100.

The unsloth compiler already applies fused CE via apply_fused_lm_head,
but this adds an explicit FastQwen3_5Model dispatch for cleaner routing
and better error messages when Qwen3.5 is not supported.

Changes:
- Add unsloth/models/qwen3_5.py with FastQwen3_5Model that patches
  Qwen3_5ForConditionalGeneration and Qwen3_5ForCausalLM forwards to
  use unsloth_fused_ce_loss directly from hidden_states
- Add loader dispatch for model_type == "qwen3_5" before "qwen3"
- Version gate uses >= 5.0.0 (qwen3_5 only exists in transformers 5.x)
- Guarded import in loader.py with try/except fallback
- GDN layers intentionally left unpatched (flash-linear-attention)
- 23 unit tests covering all 4 code paths

Fixes from original PR #4331 by @vitalis:
- Add explicit _get_dtype import (wildcard import skips _-prefixed names)
- Single-token fast path now checks labels is None before returning early
- Default model name corrected to Qwen/Qwen3.5-9B (8B does not exist)
- Test assertion on nn.Linear removed (not a mock)
- Unused imports removed

Tested: Qwen3.5-0.8B 4bit training, 1.38 GB peak memory, 23/23 tests pass.
Backwards compatible: import unsloth works on transformers 4.57.6.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Unsloth library by providing optimized support for Qwen3.5 models. The primary goal is to resolve critical memory issues encountered during training due to the model's exceptionally large vocabulary, making it feasible to run these models on more constrained hardware. The changes ensure efficient memory utilization and improved stability for Qwen3.5, while also refining the integration process within the Unsloth ecosystem.

Highlights

  • Qwen3.5 OOM Fix: Introduced FastQwen3_5Model to explicitly patch Qwen3.5 models, addressing Out-Of-Memory (OOM) issues on T4/P100 GPUs caused by the large 248,320-token vocabulary of Qwen3.5.
  • Fused CE Loss Integration: Implemented unsloth_fused_ce_loss directly from hidden states for Qwen3_5ForConditionalGeneration and Qwen3_5ForCausalLM, completely bypassing the materialization of the large logits tensor during training.
  • Loader and Import Guarding: Updated unsloth/models/loader.py to include SUPPORTS_QWEN3_5 (gated for transformers >= 5.0.0) and added guarded imports with try/except ImportError to ensure robustness and prevent import failures on incompatible transformers versions.
  • Comprehensive Unit Tests: Added 23 new unit tests in tests/utils/test_qwen3_5.py to cover all four forward paths and ensure the correctness and stability of the Qwen3.5 model optimizations.
  • Refined Logic and Compatibility: Corrected several issues from a previous PR, including fixing a missing _get_dtype import, adjusting the version gate to >= 5.0.0 for transformers, refining the single-token fast path logic, updating the default model name, and removing unused imports and incorrect test assertions.
Changelog
  • tests/utils/test_qwen3_5.py
    • Added a new test file for unsloth/models/qwen3_5.py.
    • Included helper functions _make_fake_unsloth_fused_ce_loss, _make_self, and _make_outputs for testing.
    • Implemented TestComputeLossOrLogits to verify single-token decode, partial-logits, training (fused CE), and eval/inference paths.
    • Added tests for UNSLOTH_RETURN_LOGITS and n_items fallback logic.
    • Included TestForwardFunctionBehaviour to test num_logits_to_keep normalization and return_dict=False handling.
    • Added tests for UNSLOTH_RETURN_HIDDEN_STATES functionality.
    • Created TestPrePatch to confirm pre_patch() correctly replaces forward functions.
    • Implemented TestFromPretrained to ensure from_pretrained calls FastLlamaModel with the correct patcher.
    • Added TestLoaderRouting to verify qwen3_5 routes to FastQwen3_5Model and is in FORCE_FLOAT32.
    • Included TestInitExports to check FastQwen3_5Model importability and correct ImportError handling in __init__.py.
  • unsloth/models/init.py
    • Added guarded import for FastQwen3_5Model with an ImportError exception handler.
  • unsloth/models/loader.py
    • Defined SUPPORTS_QWEN3_5 to check for transformers version >= 5.0.0.
    • Added a try/except ImportError block for FastQwen3_5Model import, setting SUPPORTS_QWEN3_5 to False on failure.
    • Integrated qwen3_5 model type into the from_pretrained dispatch logic, including a version check and error message.
    • Ensured qwen3_5 remains in FORCE_FLOAT32 list due to RMSNorm characteristics.
  • unsloth/models/qwen3_5.py
    • Added _qwen3_5_compute_loss_or_logits helper function to manage loss and logits computation paths.
    • Implemented Qwen3_5ForConditionalGeneration_fast_forward to patch the forward method for conditional generation models.
    • Implemented Qwen3_5ForCausalLM_fast_forward to patch the forward method for causal language models.
    • Created FastQwen3_5Model class inheriting from FastLlamaModel to apply Qwen3.5 specific optimizations.
    • Overrode pre_patch method to assign fast-forward functions to Qwen3.5 model classes.
    • Overrode from_pretrained method to ensure correct model patching and loading for Qwen3.5 models.
Activity
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces FastQwen3_5Model to address an out-of-memory issue with Qwen3.5 models by using a fused cross-entropy loss, avoiding the materialization of large logit tensors. The changes are well-implemented and include comprehensive unit tests, robust version checking, and graceful error handling. My review includes a couple of minor suggestions to improve code clarity and conciseness, but overall, this is a high-quality contribution that effectively solves the reported issue.

Comment thread unsloth/models/qwen3_5.py
# they already have Triton kernels via flash-linear-attention and are
# architecturally incompatible with Unsloth's standard attention optimisations.

from .llama import *
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a wildcard import (from .llama import *) can make the code harder to read and maintain as it's not immediately clear where names are coming from. It's a best practice to import names explicitly. Based on the usage in this file, you only need unsloth_fused_ce_loss and EMPTY_LOGITS.

Suggested change
from .llama import *
from .llama import unsloth_fused_ce_loss, EMPTY_LOGITS

Comment thread unsloth/models/qwen3_5.py
Comment on lines +100 to +102
n_items = kwargs.get("num_items_in_batch")
if n_items is None:
n_items = kwargs.get("n_items")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for determining n_items can be made more concise by using the default argument of dict.get().

Suggested change
n_items = kwargs.get("num_items_in_batch")
if n_items is None:
n_items = kwargs.get("n_items")
n_items = kwargs.get("num_items_in_batch", kwargs.get("n_items"))

Tests require transformers >= 5.0.0 which is not yet widely deployed.
The fused CE path is already covered by the compiler's generic
apply_fused_lm_head mechanism and verified via training runs.
@danielhanchen
Copy link
Copy Markdown
Member Author

Closing -- the unsloth compiler already applies fused CE to Qwen3.5 automatically via apply_fused_lm_head on transformers 5.x. The FORCE_FLOAT32 entry for qwen3_5 is already on main. No additional code needed.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a7b944ee57

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth/models/qwen3_5.py
return_dict if return_dict is not None else self.config.use_return_dict
)
# Normalise both generation knobs
logits_to_keep = max(logits_to_keep, num_logits_to_keep)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve non-scalar logits_to_keep values

This normalization assumes logits_to_keep is always an integer, but _qwen3_5_compute_loss_or_logits explicitly supports non-ints (e.g., index tensors/slices) in the partial-logits path. If callers pass an index tensor (as done in packed/incremental decoding flows), max(logits_to_keep, num_logits_to_keep) will raise before forward reaches loss/logit computation, causing Qwen3.5 generation/training calls to fail at runtime. Only scalar knobs should be merged with max, while tensor/slice logits_to_keep should be forwarded unchanged.

Useful? React with 👍 / 👎.

@vitalis
Copy link
Copy Markdown

vitalis commented Mar 17, 2026

@danielhanchen — thanks for the detailed review and fixes in this PR. The 7 corrections (especially the _get_dtype import, version gate, and single-token fast path bug) are all valid and I'll apply them to #4331.

However, I checked upstream main after you closed this and the issue is not actually resolved:

unsloth/models/loader.py on main has qwen3_5 only in FORCE_FLOAT32. There is no elif model_type == "qwen3_5": dispatch — it falls through to FastModel.from_pretrained() with no fast-forward patching. The OOM path is unchanged.

The apply_fused_lm_head compiler path — does it cover models that go through FastModel with FORCE_FLOAT32? I couldn't find evidence of this in the codebase, and Kaggle users (the main affected environment) are on transformers 4.57.x, not 5.x.

I confirmed the OOM is still reproducible on Kaggle T4 with transformers 4.57.x + Qwen3.5-2B-Base without this patch. With it: 2.2/15.6 GB, trains to completion.

Would you be willing to reopen this (or review #4331 with your fixes applied)?

@vitalis
Copy link
Copy Markdown

vitalis commented Mar 17, 2026

@danielhanchen — correction to my previous comment. You were right.

After looking at unsloth_zoo/compiler.py, unsloth_compile_transformers() is fully generic — it builds the module path from model_type dynamically and regex-rewrites all GenerationMixin subclass forward() methods to use unsloth_fused_ce_loss. FastModel.from_pretrained() calls it with fuse_lm_head=True, so Qwen3.5 is covered without a dedicated dispatch.

The transformers 4.x / Kaggle concern I raised is also moot — Qwen3.5 wasn't added until transformers 5.2.0, so no one can even load the model on 4.57.x.

Sorry for the noise. The OOM fix is indeed already handled.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Extremely high CPU/VRAM usage and slow training with Qwen3.5

2 participants