Skip to content

Patch every LOSS_MAPPING key aliased to ForCausalLMLoss#656

Merged
danielhanchen merged 1 commit into
mainfrom
daniel/patch-loss-for-conditional-generation
May 17, 2026
Merged

Patch every LOSS_MAPPING key aliased to ForCausalLMLoss#656
danielhanchen merged 1 commit into
mainfrom
daniel/patch-loss-for-conditional-generation

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

PreTrainedModel.__init__ resolves loss_type via a regex match on the class name, so anything whose name isn't a literal key in LOSS_MAPPING lands on the regex-matched fallback. Qwen3_5ForConditionalGeneration resolves to LOSS_MAPPING["ForConditionalGeneration"], and that key (plus CsmForConditionalGeneration) is aliased to the stock ForCausalLMLoss in transformers.

patch_loss_functions only ever rewrote LOSS_MAPPING["ForCausalLM"], leaving those aliases pointing at the un-patched loss. Affected models silently fall back to stock ForCausalLMLoss, which does logits.float() over the full (seq_len x vocab_size) tensor and OOMs on 24 GB GPUs at Qwen3.5's 248k vocab.

This patch sweeps every entry in LOSS_MAPPING whose function is currently named ForCausalLMLoss and replaces it with the Unsloth kernel. Idempotent (after the first call no entry reports __name__ == "ForCausalLMLoss") and leaves unrelated loss types alone (ForMaskedLMLoss, segmentation, detection, etc.).

This is the source-side fix; the wrapper version of the same idea in unslothai/unsloth#5442 becomes a no-op once this lands.

Closes unslothai/unsloth#5441, related to unslothai/unsloth#4188.

Test plan

  • New tests/test_patch_loss_functions_coverage.py pins three invariants:
    • LOSS_MAPPING["ForConditionalGeneration"] resolves to the Unsloth kernel after patch_loss_functions().
    • Keys aliased to other loss functions (masked LM, segmentation, detection) are untouched.
    • Re-calling patch_loss_functions() is idempotent.
  • Suite passes on the patched branch; the conditional-generation test fails on main as expected.

PreTrainedModel.__init__ resolves loss_type via regex on the class name,
so Qwen3_5ForConditionalGeneration lands on
LOSS_MAPPING["ForConditionalGeneration"]. That key (and
CsmForConditionalGeneration) is aliased to the stock ForCausalLMLoss
in transformers, and patch_loss_functions only ever rewrote the
"ForCausalLM" entry. The result is that affected classes silently fall
back to the un-patched loss, which materialises (seq_len x vocab_size)
fp32 logits and OOMs on <= 24 GB GPUs at large vocab sizes.

Sweep every entry whose function is currently named ForCausalLMLoss
and replace it with the Unsloth kernel. Idempotent (after the first
call no key reports __name__ == "ForCausalLMLoss") and leaves
unrelated loss types alone.

Closes unslothai/unsloth#5441

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 43cce1cf6d

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".


def test_loss_mapping_for_conditional_generation_patched():
lu = pytest.importorskip("transformers.loss.loss_utils")
from unsloth_zoo import loss_utils as zoo_loss

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Import the module under test with importorskip

In environments where transformers is installed but the separate unsloth package is not, this regular import turns the new regression tests into hard failures: unsloth_zoo.__init__ raises ImportError("Please install Unsloth..."), so pytest -q tests/test_patch_loss_functions_coverage.py fails instead of skipping like the existing standalone tests that use pytest.importorskip("unsloth_zoo..."). This makes the test suite fail in the repository’s supported/testable partial-install context even though the dependency is optional for these drift-style tests.

Useful? React with 👍 / 👎.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the loss function patching logic to iterate through all aliases of ForCausalLMLoss in the transformers library, preventing OOM issues for models like Qwen 3.5 that resolve to these aliases. A new test suite is included to verify coverage, idempotency, and isolation from other loss types. Feedback was provided to improve the patching logic by ensuring that subsequent calls to the patching function can update existing patches, rather than only targeting the original stock loss function.

Comment thread unsloth_zoo/loss_utils.py
Comment on lines +146 to +148
for _key, _fn in list(LOSS_MAPPING.items()):
if getattr(_fn, "__name__", "") == "ForCausalLMLoss":
LOSS_MAPPING[_key] = UnslothForCausalLMLoss

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation makes patch_loss_functions "sticky" for all aliases of ForCausalLMLoss. Because it only replaces functions with the exact name "ForCausalLMLoss", subsequent calls to patch_loss_functions (e.g., with different torch_compile settings or a different _fast_cross_entropy_loss) will not update the mapping if it has already been patched. This is a regression from the previous behavior where LOSS_MAPPING["ForCausalLM"] was updated unconditionally on every call.

To maintain the ability to re-configure the patch while still sweeping aliases, consider also checking if the function is the one currently assigned to "ForCausalLM" or if it matches the patched name. This identity check ensures we are modifying the correct function instance.

Suggested change
for _key, _fn in list(LOSS_MAPPING.items()):
if getattr(_fn, "__name__", "") == "ForCausalLMLoss":
LOSS_MAPPING[_key] = UnslothForCausalLMLoss
current_causal_loss = LOSS_MAPPING.get("ForCausalLM")
for _key, _fn in list(LOSS_MAPPING.items()):
if _fn is current_causal_loss or getattr(_fn, "__name__", "") == "ForCausalLMLoss":
LOSS_MAPPING[_key] = UnslothForCausalLMLoss
References
  1. When unpatching or updating a patched function, perform an identity check to ensure the function being replaced is the one originally patched by your code. This is more robust than relying on the state of other modules.

@danielhanchen danielhanchen merged commit b6d292b into main May 17, 2026
15 checks passed
danielhanchen added a commit that referenced this pull request May 17, 2026
is_enabled() now returns True unless UNSLOTH_FUSED_FORWARD is explicitly
set to "0". Updated docstrings and the __init__.py comment to reflect the
new default. The two-tier installer + LOSS_MAPPING backstop in #656 means
the worst case for any class we touch is no-op (refused via _UNMATCHED
or composite-head guard) -- never a worse forward than the stock path.

Test suite (23 cases, was 22 + new test_install_default_is_on): all
green. Refresh of the multi-model equivalence rerun with no env var set
versus UNSLOTH_FUSED_FORWARD=0 (Llama-3.2-1B / Qwen3-0.6B / Gemma-3-1B /
Mistral-7B-v0.3, seed 3407, max_steps=10, alpaca-cleaned):

  model           off enabled  default enabled  s1 eq  max|loss d|  max|grad d|
  Llama-3.2-1B          False            True   True      0.00410      0.02336
  Qwen3-0.6B            False            True   True      0.00680      0.02561
  Gemma-3-1B            False            True   True      0.00000      0.00000
  Mistral-7B-v.3        False            True   True      0.00530      0.05310

Step 1 loss + grad_norm bitwise identical for every model; deltas across
the run stay within bf16 -> fp32 chunked-CE rounding noise. Audit
reports 11 classes patched at default and 0 patched when explicitly
disabled.
Brishen pushed a commit to Brishen/unsloth-zoo that referenced this pull request May 19, 2026
PreTrainedModel.__init__ resolves loss_type via regex on the class name,
so Qwen3_5ForConditionalGeneration lands on
LOSS_MAPPING["ForConditionalGeneration"]. That key (and
CsmForConditionalGeneration) is aliased to the stock ForCausalLMLoss
in transformers, and patch_loss_functions only ever rewrote the
"ForCausalLM" entry. The result is that affected classes silently fall
back to the un-patched loss, which materialises (seq_len x vocab_size)
fp32 logits and OOMs on <= 24 GB GPUs at large vocab sizes.

Sweep every entry whose function is currently named ForCausalLMLoss
and replace it with the Unsloth kernel. Idempotent (after the first
call no key reports __name__ == "ForCausalLMLoss") and leaves
unrelated loss types alone.

Closes unslothai/unsloth#5441
Brishen pushed a commit to Brishen/unsloth-zoo that referenced this pull request May 19, 2026
is_enabled() now returns True unless UNSLOTH_FUSED_FORWARD is explicitly
set to "0". Updated docstrings and the __init__.py comment to reflect the
new default. The two-tier installer + LOSS_MAPPING backstop in unslothai#656 means
the worst case for any class we touch is no-op (refused via _UNMATCHED
or composite-head guard) -- never a worse forward than the stock path.

Test suite (23 cases, was 22 + new test_install_default_is_on): all
green. Refresh of the multi-model equivalence rerun with no env var set
versus UNSLOTH_FUSED_FORWARD=0 (Llama-3.2-1B / Qwen3-0.6B / Gemma-3-1B /
Mistral-7B-v0.3, seed 3407, max_steps=10, alpaca-cleaned):

  model           off enabled  default enabled  s1 eq  max|loss d|  max|grad d|
  Llama-3.2-1B          False            True   True      0.00410      0.02336
  Qwen3-0.6B            False            True   True      0.00680      0.02561
  Gemma-3-1B            False            True   True      0.00000      0.00000
  Mistral-7B-v.3        False            True   True      0.00530      0.05310

Step 1 loss + grad_norm bitwise identical for every model; deltas across
the run stay within bf16 -> fp32 chunked-CE rounding noise. Audit
reports 11 classes patched at default and 0 patched when explicitly
disabled.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] model.loss_function not patched for Qwen3_5ForConditionalGeneration, causes logits.float() OOM on ≤24GB GPU

1 participant