Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions tests/test_import_fixes_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,73 @@ def test_transformers_pretrained_model_has_get_input_embeddings():
# ===========================================================================


# ===========================================================================
# transformers LOSS_MAPPING -- patch_loss_functions() coverage
# Regression for https://github.com/unslothai/unsloth/issues/4188:
# Qwen3_5ForConditionalGeneration has loss_type='ForConditionalGeneration',
# a separate LOSS_MAPPING key that was never patched, leaving the model with
# the stock ForCausalLMLoss which does logits.float() and OOMs on <=24 GB GPUs.
# ===========================================================================


def _reset_loss_mapping(mapping, saved):
mapping.clear()
mapping.update(saved)


def test_patch_loss_functions_covers_conditional_generation():
"""After patch_loss_functions(), every LOSS_MAPPING key that was aliased
to ForCausalLMLoss must also point at the Unsloth kernel -- not just
LOSS_MAPPING['ForCausalLM']."""
lu = pytest.importorskip("transformers.loss.loss_utils")
cel = pytest.importorskip("unsloth.kernels.cross_entropy_loss")

saved = dict(lu.LOSS_MAPPING)
try:
cel.patch_loss_functions(torch_compile = False)

unsloth_loss = lu.LOSS_MAPPING.get("ForCausalLM")
assert unsloth_loss is not None
assert "Unsloth" in str(
unsloth_loss
), f"LOSS_MAPPING['ForCausalLM'] was not replaced: {unsloth_loss}"

cg_loss = lu.LOSS_MAPPING.get("ForConditionalGeneration")
assert cg_loss is unsloth_loss, (
Comment on lines +579 to +580

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Handle Transformers versions without conditional-generation loss

This assertion fails under supported Transformers releases that do not define LOSS_MAPPING['ForConditionalGeneration'] yet; for example the package constraint in pyproject.toml still allows transformers>=4.51.3, and 4.51.3's loss mapping only has ForCausalLM plus the older task keys. In that environment cg_loss is None, so this new drift test goes red even though there is no alias to patch. Please gate this check on the key being present, or test all existing keys whose original value was ForCausalLMLoss instead of requiring this newer key unconditionally.

Useful? React with 👍 / 👎.

f"LOSS_MAPPING['ForConditionalGeneration'] not patched: {cg_loss}. "
f"Qwen3_5ForConditionalGeneration will silently use the stock "
f"ForCausalLMLoss and OOM at large sequence lengths."
)
finally:
_reset_loss_mapping(lu.LOSS_MAPPING, saved)


def test_patch_loss_functions_does_not_touch_other_loss_types():
"""patch_loss_functions() must not overwrite unrelated loss types
(segmentation, detection, masked-LM, etc.) with the causal-LM kernel."""
lu = pytest.importorskip("transformers.loss.loss_utils")
cel = pytest.importorskip("unsloth.kernels.cross_entropy_loss")

non_causal_keys = {
k
for k, v in lu.LOSS_MAPPING.items()
if getattr(v, "__name__", "") != "ForCausalLMLoss"
}

saved = dict(lu.LOSS_MAPPING)
try:
cel.patch_loss_functions(torch_compile = False)

unsloth_loss = lu.LOSS_MAPPING.get("ForCausalLM")
for key in non_causal_keys:
assert lu.LOSS_MAPPING.get(key) is not unsloth_loss, (
f"patch_loss_functions() incorrectly overwrote "
f"LOSS_MAPPING['{key}'] with the Unsloth ForCausalLM kernel."
)
finally:
_reset_loss_mapping(lu.LOSS_MAPPING, saved)


def test_accelerate_utils_imports_module_present():
"""``disable_broken_wandb`` + ``fix_trl_vllm_ascend`` (import_fixes.py
493-516, 1320-1372). Both reach into accelerate.utils.imports."""
Expand Down
15 changes: 15 additions & 0 deletions unsloth/kernels/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,18 @@ def fast_cross_entropy_loss(
# Patch CE Losses in transformers
def patch_loss_functions(torch_compile = True):
_patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile)

# Defense-in-depth sweep for LOSS_MAPPING aliases still pointing at the
# stock ForCausalLMLoss (e.g. ForConditionalGeneration for Qwen3.5,
# CsmForConditionalGeneration). unsloth_zoo also does this; remove once
# the floor pin moves past unslothai/unsloth-zoo#656.
try:
import transformers.loss.loss_utils as _lu

_unsloth_loss = _lu.LOSS_MAPPING.get("ForCausalLM")
if _unsloth_loss is not None:
for _key, _fn in list(_lu.LOSS_MAPPING.items()):
if getattr(_fn, "__name__", "") == "ForCausalLMLoss":
_lu.LOSS_MAPPING[_key] = _unsloth_loss
except (ImportError, AttributeError):
pass