Fix loss function not patched for Qwen3.5 models#5442
Conversation
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request addresses an OOM issue in models like Qwen3.5 by ensuring that all loss mapping keys aliased to ForCausalLMLoss are correctly patched with the Unsloth kernel. It also introduces regression tests to verify the patching logic and ensure unrelated loss types are not affected. The review feedback suggests improving the exception handling in the patching logic by catching specific exceptions and adding debug logging instead of using a broad, silent pass.
| try: | ||
| import transformers.loss.loss_utils as _lu | ||
|
|
||
| _unsloth_loss = _lu.LOSS_MAPPING.get("ForCausalLM") | ||
| if _unsloth_loss is not None: | ||
| _causal_lm_loss_name = "ForCausalLMLoss" | ||
| for _key, _fn in list(_lu.LOSS_MAPPING.items()): | ||
| if ( | ||
| _key != "ForCausalLM" | ||
| and getattr(_fn, "__name__", "") == _causal_lm_loss_name |
There was a problem hiding this comment.
Avoid using broad exception handlers like except Exception:. While logging the error is an improvement over a silent pass, you should catch specific exceptions such as ModuleNotFoundError and AttributeError to avoid suppressing unrelated issues. This aligns with the repository's guidelines on exception handling and optional dependencies.
| try: | |
| import transformers.loss.loss_utils as _lu | |
| _unsloth_loss = _lu.LOSS_MAPPING.get("ForCausalLM") | |
| if _unsloth_loss is not None: | |
| _causal_lm_loss_name = "ForCausalLMLoss" | |
| for _key, _fn in list(_lu.LOSS_MAPPING.items()): | |
| if ( | |
| _key != "ForCausalLM" | |
| and getattr(_fn, "__name__", "") == _causal_lm_loss_name | |
| try: | |
| import transformers.loss.loss_utils as _lu | |
| _unsloth_loss = _lu.LOSS_MAPPING.get("ForCausalLM") | |
| if _unsloth_loss is not None: | |
| _causal_lm_loss_name = "ForCausalLMLoss" | |
| for _key, _fn in list(_lu.LOSS_MAPPING.items()): | |
| if _key != "ForCausalLM" and getattr(_fn, "__name__", "") == _causal_lm_loss_name: | |
| _lu.LOSS_MAPPING[_key] = _unsloth_loss | |
| except (ModuleNotFoundError, AttributeError) as e: | |
| logger.debug(f"Unsloth: Failed to patch additional loss functions: {e}", exc_info=True) |
References
- When catching an
ImportErrorfor an optional dependency, prefer catching the more specificModuleNotFoundErrorand check the module name to avoid suppressing unrelated import errors. - When handling exceptions, avoid broad
except Exception: passclauses. Instead, catch specific exceptions and log them (at least at a debug level) to aid in troubleshooting. If a failure is expected, log the specific exception type and its details.
Replace bare except Exception with the only two compatibility errors we actually care about so genuine bugs in the sweep surface. Drop the redundant _key != "ForCausalLM" guard since the __name__ predicate already excludes the patched entry (UnslothForCausalLMLoss != ForCausalLMLoss).
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 812c514ff8
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| cg_loss = lu.LOSS_MAPPING.get("ForConditionalGeneration") | ||
| assert cg_loss is unsloth_loss, ( |
There was a problem hiding this comment.
Handle Transformers versions without conditional-generation loss
This assertion fails under supported Transformers releases that do not define LOSS_MAPPING['ForConditionalGeneration'] yet; for example the package constraint in pyproject.toml still allows transformers>=4.51.3, and 4.51.3's loss mapping only has ForCausalLM plus the older task keys. In that environment cg_loss is None, so this new drift test goes red even though there is no alias to patch. Please gate this check on the key being present, or test all existing keys whose original value was ForCausalLMLoss instead of requiring this newer key unconditionally.
Useful? React with 👍 / 👎.
|
Thanks for the PR! |
* fix: patch loss functions for Qwen3_5ForConditionalGeneration to prevent OOM errors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Narrow except scope and simplify LOSS_MAPPING sweep Replace bare except Exception with the only two compatibility errors we actually care about so genuine bugs in the sweep surface. Drop the redundant _key != "ForCausalLM" guard since the __name__ predicate already excludes the patched entry (UnslothForCausalLMLoss != ForCausalLMLoss). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Qwen3.5 models (
Qwen3_5ForConditionalGeneration) have aloss_typeof"ForConditionalGeneration"rather than"ForCausalLM".patch_loss_functionsonly ever updated the"ForCausalLM"key inLOSS_MAPPING, so Qwen3.5 silently kept using HuggingFace's stockForCausalLMLoss— which casts logits to fp32 and immediately OOMs on anything ≤24GB at Qwen3.5's vocab size of 248k tokens.The fix is straightforward: instead of hardcoding
"ForCausalLM", the patch now scans all entries inLOSS_MAPPINGand replaces any that still point to the originalForCausalLMLosswith the Unsloth kernel. This means new model architectures with differentloss_typenames won't silently regress in the future.On the test side, added a check that every mapping key pointing to
ForCausalLMLossgets patched (not just"ForCausalLM"), a check that unrelated loss types like masked-LM or detection aren't accidentally overwritten, and made sure tests clean up after themselves so patching in one test doesn't bleed into another.Closes #5441, related to #4188.