From 2f95f7b24d86118c1ac84d1e4ea7df94169515e1 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Fri, 15 May 2026 16:33:02 +0530 Subject: [PATCH 1/4] fix: patch loss functions for Qwen3_5ForConditionalGeneration to prevent OOM errors --- tests/test_import_fixes_drift.py | 66 +++++++++++++++++++++++++++ unsloth/kernels/cross_entropy_loss.py | 17 +++++++ 2 files changed, 83 insertions(+) diff --git a/tests/test_import_fixes_drift.py b/tests/test_import_fixes_drift.py index f90556bf66..6e7ece60c2 100644 --- a/tests/test_import_fixes_drift.py +++ b/tests/test_import_fixes_drift.py @@ -545,6 +545,72 @@ def test_transformers_pretrained_model_has_get_input_embeddings(): # =========================================================================== +# =========================================================================== +# transformers LOSS_MAPPING -- patch_loss_functions() coverage +# Regression for https://github.com/unslothai/unsloth/issues/4188: +# Qwen3_5ForConditionalGeneration has loss_type='ForConditionalGeneration', +# a separate LOSS_MAPPING key that was never patched, leaving the model with +# the stock ForCausalLMLoss which does logits.float() and OOMs on <=24 GB GPUs. +# =========================================================================== + + +def _reset_loss_mapping(mapping, saved): + mapping.clear() + mapping.update(saved) + + +def test_patch_loss_functions_covers_conditional_generation(): + """After patch_loss_functions(), every LOSS_MAPPING key that was aliased + to ForCausalLMLoss must also point at the Unsloth kernel -- not just + LOSS_MAPPING['ForCausalLM'].""" + lu = pytest.importorskip("transformers.loss.loss_utils") + cel = pytest.importorskip("unsloth.kernels.cross_entropy_loss") + + saved = dict(lu.LOSS_MAPPING) + try: + cel.patch_loss_functions(torch_compile=False) + + unsloth_loss = lu.LOSS_MAPPING.get("ForCausalLM") + assert unsloth_loss is not None + assert "Unsloth" in str(unsloth_loss), ( + f"LOSS_MAPPING['ForCausalLM'] was not replaced: {unsloth_loss}" + ) + + cg_loss = lu.LOSS_MAPPING.get("ForConditionalGeneration") + assert cg_loss is unsloth_loss, ( + f"LOSS_MAPPING['ForConditionalGeneration'] not patched: {cg_loss}. " + f"Qwen3_5ForConditionalGeneration will silently use the stock " + f"ForCausalLMLoss and OOM at large sequence lengths." + ) + finally: + _reset_loss_mapping(lu.LOSS_MAPPING, saved) + + +def test_patch_loss_functions_does_not_touch_other_loss_types(): + """patch_loss_functions() must not overwrite unrelated loss types + (segmentation, detection, masked-LM, etc.) with the causal-LM kernel.""" + lu = pytest.importorskip("transformers.loss.loss_utils") + cel = pytest.importorskip("unsloth.kernels.cross_entropy_loss") + + non_causal_keys = { + k for k, v in lu.LOSS_MAPPING.items() + if getattr(v, "__name__", "") != "ForCausalLMLoss" + } + + saved = dict(lu.LOSS_MAPPING) + try: + cel.patch_loss_functions(torch_compile=False) + + unsloth_loss = lu.LOSS_MAPPING.get("ForCausalLM") + for key in non_causal_keys: + assert lu.LOSS_MAPPING.get(key) is not unsloth_loss, ( + f"patch_loss_functions() incorrectly overwrote " + f"LOSS_MAPPING['{key}'] with the Unsloth ForCausalLM kernel." + ) + finally: + _reset_loss_mapping(lu.LOSS_MAPPING, saved) + + def test_accelerate_utils_imports_module_present(): """``disable_broken_wandb`` + ``fix_trl_vllm_ascend`` (import_fixes.py 493-516, 1320-1372). Both reach into accelerate.utils.imports.""" diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index d92229314f..f8c2575692 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -461,3 +461,20 @@ def fast_cross_entropy_loss( # Patch CE Losses in transformers def patch_loss_functions(torch_compile = True): _patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile) + + # _patch_loss_functions only updates LOSS_MAPPING["ForCausalLM"]. + # Some model classes (e.g. Qwen3_5ForConditionalGeneration) have + # loss_type="ForConditionalGeneration", which is a separate key in + # LOSS_MAPPING that still points to the old ForCausalLMLoss. + # Since the property reads LOSS_MAPPING live, we just need to update + # every key that is currently aliased to ForCausalLMLoss. + try: + import transformers.loss.loss_utils as _lu + _unsloth_loss = _lu.LOSS_MAPPING.get("ForCausalLM") + if _unsloth_loss is not None: + _causal_lm_loss_name = "ForCausalLMLoss" + for _key, _fn in list(_lu.LOSS_MAPPING.items()): + if _key != "ForCausalLM" and getattr(_fn, "__name__", "") == _causal_lm_loss_name: + _lu.LOSS_MAPPING[_key] = _unsloth_loss + except Exception: + pass From 9e1e5bc6e8cf584a0ca5afc88950b386a959db65 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 May 2026 11:06:40 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_import_fixes_drift.py | 13 +++++++------ unsloth/kernels/cross_entropy_loss.py | 6 +++++- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/test_import_fixes_drift.py b/tests/test_import_fixes_drift.py index 6e7ece60c2..099b65d09a 100644 --- a/tests/test_import_fixes_drift.py +++ b/tests/test_import_fixes_drift.py @@ -568,13 +568,13 @@ def test_patch_loss_functions_covers_conditional_generation(): saved = dict(lu.LOSS_MAPPING) try: - cel.patch_loss_functions(torch_compile=False) + cel.patch_loss_functions(torch_compile = False) unsloth_loss = lu.LOSS_MAPPING.get("ForCausalLM") assert unsloth_loss is not None - assert "Unsloth" in str(unsloth_loss), ( - f"LOSS_MAPPING['ForCausalLM'] was not replaced: {unsloth_loss}" - ) + assert "Unsloth" in str( + unsloth_loss + ), f"LOSS_MAPPING['ForCausalLM'] was not replaced: {unsloth_loss}" cg_loss = lu.LOSS_MAPPING.get("ForConditionalGeneration") assert cg_loss is unsloth_loss, ( @@ -593,13 +593,14 @@ def test_patch_loss_functions_does_not_touch_other_loss_types(): cel = pytest.importorskip("unsloth.kernels.cross_entropy_loss") non_causal_keys = { - k for k, v in lu.LOSS_MAPPING.items() + k + for k, v in lu.LOSS_MAPPING.items() if getattr(v, "__name__", "") != "ForCausalLMLoss" } saved = dict(lu.LOSS_MAPPING) try: - cel.patch_loss_functions(torch_compile=False) + cel.patch_loss_functions(torch_compile = False) unsloth_loss = lu.LOSS_MAPPING.get("ForCausalLM") for key in non_causal_keys: diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index f8c2575692..0616954c02 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -470,11 +470,15 @@ def patch_loss_functions(torch_compile = True): # every key that is currently aliased to ForCausalLMLoss. try: import transformers.loss.loss_utils as _lu + _unsloth_loss = _lu.LOSS_MAPPING.get("ForCausalLM") if _unsloth_loss is not None: _causal_lm_loss_name = "ForCausalLMLoss" for _key, _fn in list(_lu.LOSS_MAPPING.items()): - if _key != "ForCausalLM" and getattr(_fn, "__name__", "") == _causal_lm_loss_name: + if ( + _key != "ForCausalLM" + and getattr(_fn, "__name__", "") == _causal_lm_loss_name + ): _lu.LOSS_MAPPING[_key] = _unsloth_loss except Exception: pass From 212e5b04a282c9e73cf7727cab12502d52508eeb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 19 May 2026 07:52:40 +0000 Subject: [PATCH 3/4] Narrow except scope and simplify LOSS_MAPPING sweep Replace bare except Exception with the only two compatibility errors we actually care about so genuine bugs in the sweep surface. Drop the redundant _key != "ForCausalLM" guard since the __name__ predicate already excludes the patched entry (UnslothForCausalLMLoss != ForCausalLMLoss). --- unsloth/kernels/cross_entropy_loss.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 0616954c02..4bd0885ed6 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -462,23 +462,16 @@ def fast_cross_entropy_loss( def patch_loss_functions(torch_compile = True): _patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile) - # _patch_loss_functions only updates LOSS_MAPPING["ForCausalLM"]. - # Some model classes (e.g. Qwen3_5ForConditionalGeneration) have - # loss_type="ForConditionalGeneration", which is a separate key in - # LOSS_MAPPING that still points to the old ForCausalLMLoss. - # Since the property reads LOSS_MAPPING live, we just need to update - # every key that is currently aliased to ForCausalLMLoss. + # Defense-in-depth sweep for LOSS_MAPPING aliases still pointing at the + # stock ForCausalLMLoss (e.g. ForConditionalGeneration for Qwen3.5, + # CsmForConditionalGeneration). unsloth_zoo also does this; remove once + # the floor pin moves past unslothai/unsloth-zoo#656. try: import transformers.loss.loss_utils as _lu - _unsloth_loss = _lu.LOSS_MAPPING.get("ForCausalLM") if _unsloth_loss is not None: - _causal_lm_loss_name = "ForCausalLMLoss" for _key, _fn in list(_lu.LOSS_MAPPING.items()): - if ( - _key != "ForCausalLM" - and getattr(_fn, "__name__", "") == _causal_lm_loss_name - ): + if getattr(_fn, "__name__", "") == "ForCausalLMLoss": _lu.LOSS_MAPPING[_key] = _unsloth_loss - except Exception: + except (ImportError, AttributeError): pass From 812c514ff81369e4910be54ec9446b604ccfdd69 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 May 2026 07:52:53 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/kernels/cross_entropy_loss.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 4bd0885ed6..4a8f83ad04 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -468,6 +468,7 @@ def patch_loss_functions(torch_compile = True): # the floor pin moves past unslothai/unsloth-zoo#656. try: import transformers.loss.loss_utils as _lu + _unsloth_loss = _lu.LOSS_MAPPING.get("ForCausalLM") if _unsloth_loss is not None: for _key, _fn in list(_lu.LOSS_MAPPING.items()):