Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions tests/test_patch_loss_functions_coverage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Unsloth Zoo - Utilities for Unsloth
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

"""Regression for unslothai/unsloth#5441.

PreTrainedModel.__init__ resolves loss_type from the class name. Anything
whose name doesn't appear as a literal LOSS_MAPPING key falls back to a
regex match, which means `Qwen3_5ForConditionalGeneration` lands on
`LOSS_MAPPING["ForConditionalGeneration"]`. That entry, plus
`CsmForConditionalGeneration`, is aliased to the stock `ForCausalLMLoss`
in transformers. `patch_loss_functions()` only rewrote
`LOSS_MAPPING["ForCausalLM"]`, leaving those aliases pointing at the
un-patched loss which does `logits.float()` and OOMs on <= 24 GB GPUs at
large vocab sizes.

This suite pins:
- Every key originally aliased to ForCausalLMLoss is replaced with
the Unsloth kernel.
- Keys aliased to other loss types (ForMaskedLMLoss, segmentation,
detection, etc.) are not overwritten.
- The patch is idempotent.
"""

from __future__ import annotations

import pytest


def _restore(mapping, saved):
mapping.clear()
mapping.update(saved)


def test_loss_mapping_for_conditional_generation_patched():
lu = pytest.importorskip("transformers.loss.loss_utils")
from unsloth_zoo import loss_utils as zoo_loss

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Import the module under test with importorskip

In environments where transformers is installed but the separate unsloth package is not, this regular import turns the new regression tests into hard failures: unsloth_zoo.__init__ raises ImportError("Please install Unsloth..."), so pytest -q tests/test_patch_loss_functions_coverage.py fails instead of skipping like the existing standalone tests that use pytest.importorskip("unsloth_zoo..."). This makes the test suite fail in the repository’s supported/testable partial-install context even though the dependency is optional for these drift-style tests.

Useful? React with 👍 / 👎.

from unsloth_zoo.fused_losses import unsloth_fused_ce_loss # noqa: F401

saved = dict(lu.LOSS_MAPPING)
try:
# A naive cross_entropy stub keeps torch.compile out of the picture and
# makes the regression test pure-Python.
def _fast_ce(logits, labels, n_items=None, **kw):
import torch
return torch.nn.functional.cross_entropy(
logits.float(), labels, ignore_index=-100,
)
zoo_loss.patch_loss_functions(_fast_ce, torch_compile=False)

forcausal = lu.LOSS_MAPPING.get("ForCausalLM")
assert forcausal is not None
assert getattr(forcausal, "__name__", "") != "ForCausalLMLoss", (
f"LOSS_MAPPING['ForCausalLM'] was not replaced: {forcausal}"
)

cg = lu.LOSS_MAPPING.get("ForConditionalGeneration")
assert cg is forcausal, (
f"LOSS_MAPPING['ForConditionalGeneration'] still aliases the stock "
f"ForCausalLMLoss; Qwen3_5ForConditionalGeneration would OOM via "
f"logits.float(). got: {cg}"
)
finally:
_restore(lu.LOSS_MAPPING, saved)


def test_loss_mapping_other_losses_left_alone():
lu = pytest.importorskip("transformers.loss.loss_utils")
from unsloth_zoo import loss_utils as zoo_loss

# Keys not currently aliased to ForCausalLMLoss must survive the sweep.
non_causal = {
k: v for k, v in lu.LOSS_MAPPING.items()
if getattr(v, "__name__", "") != "ForCausalLMLoss"
}
saved = dict(lu.LOSS_MAPPING)
try:
def _fast_ce(logits, labels, n_items=None, **kw):
import torch
return torch.nn.functional.cross_entropy(logits.float(), labels, ignore_index=-100)
zoo_loss.patch_loss_functions(_fast_ce, torch_compile=False)

unsloth_loss = lu.LOSS_MAPPING["ForCausalLM"]
for key, original_fn in non_causal.items():
assert lu.LOSS_MAPPING[key] is original_fn, (
f"LOSS_MAPPING['{key}'] was overwritten by the sweep; "
f"expected {original_fn}, got {lu.LOSS_MAPPING[key]}"
)
assert lu.LOSS_MAPPING[key] is not unsloth_loss, (
f"LOSS_MAPPING['{key}'] incorrectly replaced with the Unsloth kernel."
)
finally:
_restore(lu.LOSS_MAPPING, saved)


def test_loss_mapping_sweep_idempotent():
lu = pytest.importorskip("transformers.loss.loss_utils")
from unsloth_zoo import loss_utils as zoo_loss

saved = dict(lu.LOSS_MAPPING)
try:
def _fast_ce(logits, labels, n_items=None, **kw):
import torch
return torch.nn.functional.cross_entropy(logits.float(), labels, ignore_index=-100)
zoo_loss.patch_loss_functions(_fast_ce, torch_compile=False)
first = dict(lu.LOSS_MAPPING)
zoo_loss.patch_loss_functions(_fast_ce, torch_compile=False)
second = dict(lu.LOSS_MAPPING)
for k in first:
assert first[k] is second[k], (
f"LOSS_MAPPING['{k}'] mutated on second patch_loss_functions call."
)
finally:
_restore(lu.LOSS_MAPPING, saved)
10 changes: 9 additions & 1 deletion unsloth_zoo/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,15 @@ def UnslothForCausalLMLoss(
# Now patch the losses!
import transformers.modeling_utils
LOSS_MAPPING = transformers.loss.loss_utils.LOSS_MAPPING
LOSS_MAPPING["ForCausalLM"] = UnslothForCausalLMLoss
# Patch every key still aliased to the stock ForCausalLMLoss. PreTrainedModel
# resolves loss_type via regex on the class name, so classes like
# Qwen3_5ForConditionalGeneration land on LOSS_MAPPING["ForConditionalGeneration"]
# (and CsmForConditionalGeneration on its own key), both of which point at the
# stock ForCausalLMLoss. Without this sweep those models keep the un-patched
# loss and OOM via logits.float() at large vocab sizes.
for _key, _fn in list(LOSS_MAPPING.items()):
if getattr(_fn, "__name__", "") == "ForCausalLMLoss":
LOSS_MAPPING[_key] = UnslothForCausalLMLoss
Comment on lines +146 to +148

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation makes patch_loss_functions "sticky" for all aliases of ForCausalLMLoss. Because it only replaces functions with the exact name "ForCausalLMLoss", subsequent calls to patch_loss_functions (e.g., with different torch_compile settings or a different _fast_cross_entropy_loss) will not update the mapping if it has already been patched. This is a regression from the previous behavior where LOSS_MAPPING["ForCausalLM"] was updated unconditionally on every call.

To maintain the ability to re-configure the patch while still sweeping aliases, consider also checking if the function is the one currently assigned to "ForCausalLM" or if it matches the patched name. This identity check ensures we are modifying the correct function instance.

Suggested change
for _key, _fn in list(LOSS_MAPPING.items()):
if getattr(_fn, "__name__", "") == "ForCausalLMLoss":
LOSS_MAPPING[_key] = UnslothForCausalLMLoss
current_causal_loss = LOSS_MAPPING.get("ForCausalLM")
for _key, _fn in list(LOSS_MAPPING.items()):
if _fn is current_causal_loss or getattr(_fn, "__name__", "") == "ForCausalLMLoss":
LOSS_MAPPING[_key] = UnslothForCausalLMLoss
References
  1. When unpatching or updating a patched function, perform an identity check to ensure the function being replaced is the one originally patched by your code. This is more robust than relying on the state of other modules.


# Remove @property and @lru_cache
if hasattr(transformers.modeling_utils.PreTrainedModel.loss_function, "fget") and \
Expand Down
Loading