Patch every LOSS_MAPPING key aliased to ForCausalLMLoss#656
Conversation
PreTrainedModel.__init__ resolves loss_type via regex on the class name, so Qwen3_5ForConditionalGeneration lands on LOSS_MAPPING["ForConditionalGeneration"]. That key (and CsmForConditionalGeneration) is aliased to the stock ForCausalLMLoss in transformers, and patch_loss_functions only ever rewrote the "ForCausalLM" entry. The result is that affected classes silently fall back to the un-patched loss, which materialises (seq_len x vocab_size) fp32 logits and OOMs on <= 24 GB GPUs at large vocab sizes. Sweep every entry whose function is currently named ForCausalLMLoss and replace it with the Unsloth kernel. Idempotent (after the first call no key reports __name__ == "ForCausalLMLoss") and leaves unrelated loss types alone. Closes unslothai/unsloth#5441
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 43cce1cf6d
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
|
||
| def test_loss_mapping_for_conditional_generation_patched(): | ||
| lu = pytest.importorskip("transformers.loss.loss_utils") | ||
| from unsloth_zoo import loss_utils as zoo_loss |
There was a problem hiding this comment.
Import the module under test with importorskip
In environments where transformers is installed but the separate unsloth package is not, this regular import turns the new regression tests into hard failures: unsloth_zoo.__init__ raises ImportError("Please install Unsloth..."), so pytest -q tests/test_patch_loss_functions_coverage.py fails instead of skipping like the existing standalone tests that use pytest.importorskip("unsloth_zoo..."). This makes the test suite fail in the repository’s supported/testable partial-install context even though the dependency is optional for these drift-style tests.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Code Review
This pull request updates the loss function patching logic to iterate through all aliases of ForCausalLMLoss in the transformers library, preventing OOM issues for models like Qwen 3.5 that resolve to these aliases. A new test suite is included to verify coverage, idempotency, and isolation from other loss types. Feedback was provided to improve the patching logic by ensuring that subsequent calls to the patching function can update existing patches, rather than only targeting the original stock loss function.
| for _key, _fn in list(LOSS_MAPPING.items()): | ||
| if getattr(_fn, "__name__", "") == "ForCausalLMLoss": | ||
| LOSS_MAPPING[_key] = UnslothForCausalLMLoss |
There was a problem hiding this comment.
The current implementation makes patch_loss_functions "sticky" for all aliases of ForCausalLMLoss. Because it only replaces functions with the exact name "ForCausalLMLoss", subsequent calls to patch_loss_functions (e.g., with different torch_compile settings or a different _fast_cross_entropy_loss) will not update the mapping if it has already been patched. This is a regression from the previous behavior where LOSS_MAPPING["ForCausalLM"] was updated unconditionally on every call.
To maintain the ability to re-configure the patch while still sweeping aliases, consider also checking if the function is the one currently assigned to "ForCausalLM" or if it matches the patched name. This identity check ensures we are modifying the correct function instance.
| for _key, _fn in list(LOSS_MAPPING.items()): | |
| if getattr(_fn, "__name__", "") == "ForCausalLMLoss": | |
| LOSS_MAPPING[_key] = UnslothForCausalLMLoss | |
| current_causal_loss = LOSS_MAPPING.get("ForCausalLM") | |
| for _key, _fn in list(LOSS_MAPPING.items()): | |
| if _fn is current_causal_loss or getattr(_fn, "__name__", "") == "ForCausalLMLoss": | |
| LOSS_MAPPING[_key] = UnslothForCausalLMLoss |
References
- When unpatching or updating a patched function, perform an identity check to ensure the function being replaced is the one originally patched by your code. This is more robust than relying on the state of other modules.
is_enabled() now returns True unless UNSLOTH_FUSED_FORWARD is explicitly set to "0". Updated docstrings and the __init__.py comment to reflect the new default. The two-tier installer + LOSS_MAPPING backstop in #656 means the worst case for any class we touch is no-op (refused via _UNMATCHED or composite-head guard) -- never a worse forward than the stock path. Test suite (23 cases, was 22 + new test_install_default_is_on): all green. Refresh of the multi-model equivalence rerun with no env var set versus UNSLOTH_FUSED_FORWARD=0 (Llama-3.2-1B / Qwen3-0.6B / Gemma-3-1B / Mistral-7B-v0.3, seed 3407, max_steps=10, alpaca-cleaned): model off enabled default enabled s1 eq max|loss d| max|grad d| Llama-3.2-1B False True True 0.00410 0.02336 Qwen3-0.6B False True True 0.00680 0.02561 Gemma-3-1B False True True 0.00000 0.00000 Mistral-7B-v.3 False True True 0.00530 0.05310 Step 1 loss + grad_norm bitwise identical for every model; deltas across the run stay within bf16 -> fp32 chunked-CE rounding noise. Audit reports 11 classes patched at default and 0 patched when explicitly disabled.
PreTrainedModel.__init__ resolves loss_type via regex on the class name, so Qwen3_5ForConditionalGeneration lands on LOSS_MAPPING["ForConditionalGeneration"]. That key (and CsmForConditionalGeneration) is aliased to the stock ForCausalLMLoss in transformers, and patch_loss_functions only ever rewrote the "ForCausalLM" entry. The result is that affected classes silently fall back to the un-patched loss, which materialises (seq_len x vocab_size) fp32 logits and OOMs on <= 24 GB GPUs at large vocab sizes. Sweep every entry whose function is currently named ForCausalLMLoss and replace it with the Unsloth kernel. Idempotent (after the first call no key reports __name__ == "ForCausalLMLoss") and leaves unrelated loss types alone. Closes unslothai/unsloth#5441
is_enabled() now returns True unless UNSLOTH_FUSED_FORWARD is explicitly set to "0". Updated docstrings and the __init__.py comment to reflect the new default. The two-tier installer + LOSS_MAPPING backstop in unslothai#656 means the worst case for any class we touch is no-op (refused via _UNMATCHED or composite-head guard) -- never a worse forward than the stock path. Test suite (23 cases, was 22 + new test_install_default_is_on): all green. Refresh of the multi-model equivalence rerun with no env var set versus UNSLOTH_FUSED_FORWARD=0 (Llama-3.2-1B / Qwen3-0.6B / Gemma-3-1B / Mistral-7B-v0.3, seed 3407, max_steps=10, alpaca-cleaned): model off enabled default enabled s1 eq max|loss d| max|grad d| Llama-3.2-1B False True True 0.00410 0.02336 Qwen3-0.6B False True True 0.00680 0.02561 Gemma-3-1B False True True 0.00000 0.00000 Mistral-7B-v.3 False True True 0.00530 0.05310 Step 1 loss + grad_norm bitwise identical for every model; deltas across the run stay within bf16 -> fp32 chunked-CE rounding noise. Audit reports 11 classes patched at default and 0 patched when explicitly disabled.
Summary
PreTrainedModel.__init__resolvesloss_typevia a regex match on the class name, so anything whose name isn't a literal key inLOSS_MAPPINGlands on the regex-matched fallback.Qwen3_5ForConditionalGenerationresolves toLOSS_MAPPING["ForConditionalGeneration"], and that key (plusCsmForConditionalGeneration) is aliased to the stockForCausalLMLossin transformers.patch_loss_functionsonly ever rewroteLOSS_MAPPING["ForCausalLM"], leaving those aliases pointing at the un-patched loss. Affected models silently fall back to stockForCausalLMLoss, which doeslogits.float()over the full(seq_len x vocab_size)tensor and OOMs on 24 GB GPUs at Qwen3.5's 248k vocab.This patch sweeps every entry in
LOSS_MAPPINGwhose function is currently namedForCausalLMLossand replaces it with the Unsloth kernel. Idempotent (after the first call no entry reports__name__ == "ForCausalLMLoss") and leaves unrelated loss types alone (ForMaskedLMLoss, segmentation, detection, etc.).This is the source-side fix; the wrapper version of the same idea in unslothai/unsloth#5442 becomes a no-op once this lands.
Closes unslothai/unsloth#5441, related to unslothai/unsloth#4188.
Test plan
tests/test_patch_loss_functions_coverage.pypins three invariants:LOSS_MAPPING["ForConditionalGeneration"]resolves to the Unsloth kernel afterpatch_loss_functions().patch_loss_functions()is idempotent.mainas expected.