Skip to content

Fix loss function not patched for Qwen3.5 models#5442

Merged
danielhanchen merged 4 commits into
unslothai:mainfrom
rycerzes:fix/patch-loss-function-qwen3-5-conditional-generation
May 19, 2026
Merged

Fix loss function not patched for Qwen3.5 models#5442
danielhanchen merged 4 commits into
unslothai:mainfrom
rycerzes:fix/patch-loss-function-qwen3-5-conditional-generation

Conversation

@rycerzes

Copy link
Copy Markdown
Contributor

Qwen3.5 models (Qwen3_5ForConditionalGeneration) have a loss_type of "ForConditionalGeneration" rather than "ForCausalLM". patch_loss_functions only ever updated the "ForCausalLM" key in LOSS_MAPPING, so Qwen3.5 silently kept using HuggingFace's stock ForCausalLMLoss — which casts logits to fp32 and immediately OOMs on anything ≤24GB at Qwen3.5's vocab size of 248k tokens.

The fix is straightforward: instead of hardcoding "ForCausalLM", the patch now scans all entries in LOSS_MAPPING and replaces any that still point to the original ForCausalLMLoss with the Unsloth kernel. This means new model architectures with different loss_type names won't silently regress in the future.

On the test side, added a check that every mapping key pointing to ForCausalLMLoss gets patched (not just "ForCausalLM"), a check that unrelated loss types like masked-LM or detection aren't accidentally overwritten, and made sure tests clean up after themselves so patching in one test doesn't bleed into another.

Closes #5441, related to #4188.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an OOM issue in models like Qwen3.5 by ensuring that all loss mapping keys aliased to ForCausalLMLoss are correctly patched with the Unsloth kernel. It also introduces regression tests to verify the patching logic and ensure unrelated loss types are not affected. The review feedback suggests improving the exception handling in the patching logic by catching specific exceptions and adding debug logging instead of using a broad, silent pass.

Comment thread unsloth/kernels/cross_entropy_loss.py Outdated
Comment on lines +471 to +480
try:
import transformers.loss.loss_utils as _lu

_unsloth_loss = _lu.LOSS_MAPPING.get("ForCausalLM")
if _unsloth_loss is not None:
_causal_lm_loss_name = "ForCausalLMLoss"
for _key, _fn in list(_lu.LOSS_MAPPING.items()):
if (
_key != "ForCausalLM"
and getattr(_fn, "__name__", "") == _causal_lm_loss_name

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Avoid using broad exception handlers like except Exception:. While logging the error is an improvement over a silent pass, you should catch specific exceptions such as ModuleNotFoundError and AttributeError to avoid suppressing unrelated issues. This aligns with the repository's guidelines on exception handling and optional dependencies.

Suggested change
try:
import transformers.loss.loss_utils as _lu
_unsloth_loss = _lu.LOSS_MAPPING.get("ForCausalLM")
if _unsloth_loss is not None:
_causal_lm_loss_name = "ForCausalLMLoss"
for _key, _fn in list(_lu.LOSS_MAPPING.items()):
if (
_key != "ForCausalLM"
and getattr(_fn, "__name__", "") == _causal_lm_loss_name
try:
import transformers.loss.loss_utils as _lu
_unsloth_loss = _lu.LOSS_MAPPING.get("ForCausalLM")
if _unsloth_loss is not None:
_causal_lm_loss_name = "ForCausalLMLoss"
for _key, _fn in list(_lu.LOSS_MAPPING.items()):
if _key != "ForCausalLM" and getattr(_fn, "__name__", "") == _causal_lm_loss_name:
_lu.LOSS_MAPPING[_key] = _unsloth_loss
except (ModuleNotFoundError, AttributeError) as e:
logger.debug(f"Unsloth: Failed to patch additional loss functions: {e}", exc_info=True)
References
  1. When catching an ImportError for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.
  2. When handling exceptions, avoid broad except Exception: pass clauses. Instead, catch specific exceptions and log them (at least at a debug level) to aid in troubleshooting. If a failure is expected, log the specific exception type and its details.

danielhanchen and others added 2 commits May 19, 2026 07:52
Replace bare except Exception with the only two compatibility errors we
actually care about so genuine bugs in the sweep surface. Drop the
redundant _key != "ForCausalLM" guard since the __name__ predicate
already excludes the patched entry (UnslothForCausalLMLoss != ForCausalLMLoss).

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 812c514ff8

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +579 to +580
cg_loss = lu.LOSS_MAPPING.get("ForConditionalGeneration")
assert cg_loss is unsloth_loss, (

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Handle Transformers versions without conditional-generation loss

This assertion fails under supported Transformers releases that do not define LOSS_MAPPING['ForConditionalGeneration'] yet; for example the package constraint in pyproject.toml still allows transformers>=4.51.3, and 4.51.3's loss mapping only has ForCausalLM plus the older task keys. In that environment cg_loss is None, so this new drift test goes red even though there is no alias to patch. Please gate this check on the key being present, or test all existing keys whose original value was ForCausalLMLoss instead of requiring this newer key unconditionally.

Useful? React with 👍 / 👎.

@danielhanchen

Copy link
Copy Markdown
Member

Thanks for the PR!

@danielhanchen danielhanchen merged commit 66cfbea into unslothai:main May 19, 2026
1 check passed
rsd-darshan pushed a commit to rsd-darshan/unsloth that referenced this pull request Jun 3, 2026
* fix: patch loss functions for Qwen3_5ForConditionalGeneration to prevent OOM errors

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Narrow except scope and simplify LOSS_MAPPING sweep

Replace bare except Exception with the only two compatibility errors we
actually care about so genuine bugs in the sweep surface. Drop the
redundant _key != "ForCausalLM" guard since the __name__ predicate
already excludes the patched entry (UnslothForCausalLMLoss != ForCausalLMLoss).

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] model.loss_function not patched for Qwen3_5ForConditionalGeneration, causes logits.float() OOM on ≤24GB GPU

2 participants