diff --git a/tests/test_fused_forward_install.py b/tests/test_fused_forward_install.py new file mode 100644 index 000000000..5148e06d9 --- /dev/null +++ b/tests/test_fused_forward_install.py @@ -0,0 +1,681 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +"""Tests for the fused lm_head + cross_entropy auto-installer. + +Covers: + - AST rewriter recognises the canonical HF triplet shape (keyword form, + positional vocab_size, `.float()` wrapper, no-`loss = None` initialiser). + - AST rewriter declines on non-matching forwards (no triplet, missing + if-labels block, missing loss_function call). + - install_for_class: + * no-op when UNSLOTH_FUSED_FORWARD is off + * patches a synthetic *ForCausalLM whose forward matches the triplet + * leaves a hand-crafted bespoke forward in _UNMATCHED + * is idempotent + - Numerical equivalence of the rewritten forward vs the original at + small shapes (mean MSE under 1e-4 on bf16 -> fp32). +""" + +from __future__ import annotations + +import os +import sys +import types + +import pytest + + +# Reset module state between tests so install registries don't bleed. +@pytest.fixture +def fresh_install(): + from unsloth_zoo.fused_losses import forward_install as fi + with fi._REGISTRY_LOCK: + fi._PATCHED.clear() + fi._UNMATCHED.clear() + fi._FAILED.clear() + fi._CANONICAL_FORWARDS.clear() + yield fi + with fi._REGISTRY_LOCK: + fi._PATCHED.clear() + fi._UNMATCHED.clear() + fi._FAILED.clear() + fi._CANONICAL_FORWARDS.clear() + + +@pytest.fixture +def enable_env(monkeypatch): + monkeypatch.setenv("UNSLOTH_FUSED_FORWARD", "1") + + +# --------------------------------------------------------------------------- +# AST rewriter unit tests +# --------------------------------------------------------------------------- + + +CANONICAL_KW_SRC = """ +def forward(self, input_ids=None, labels=None, logits_to_keep=0, **kwargs): + outputs = self.model(input_ids=input_ids, **kwargs) + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + return (loss, logits) +""" + +CANONICAL_POS_SRC = """ +def forward(self, input_ids=None, labels=None, **kwargs): + outputs = self.model(input_ids=input_ids, **kwargs) + hidden_states = outputs.last_hidden_state + lm_logits = self.lm_head(hidden_states).float() + loss = None + if labels is not None: + loss = self.loss_function(lm_logits, labels, self.config.vocab_size, **kwargs) + return (loss, lm_logits) +""" + +NON_CANONICAL_SRC = """ +def forward(self, input_ids=None, labels=None, **kwargs): + outputs = self.model(input_ids=input_ids, **kwargs) + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) + if labels is not None: + loss_fct = object() # legacy CrossEntropyLoss path + loss = loss_fct(logits, labels) + else: + loss = None + return (loss, logits) +""" + + +def test_ast_rewriter_matches_keyword_form(): + from unsloth_zoo.fused_losses.ast_rewriter import rewrite_forward_source + new_src, cap = rewrite_forward_source(CANONICAL_KW_SRC) + assert new_src is not None + assert cap is not None + assert cap.head_attr == "lm_head" + assert cap.logits_name == "logits" + assert "unsloth_fused_lm_head_loss" in new_src + assert "EMPTY_LOGITS" in new_src + # The original self.loss_function call must be gone from the rewritten src. + assert "self.loss_function" not in new_src + + +def test_ast_rewriter_matches_positional_with_float_wrapper(): + from unsloth_zoo.fused_losses.ast_rewriter import rewrite_forward_source + new_src, cap = rewrite_forward_source(CANONICAL_POS_SRC) + assert new_src is not None + assert cap is not None + assert cap.head_attr == "lm_head" + assert cap.logits_name == "lm_logits" + assert "unsloth_fused_lm_head_loss" in new_src + + +def test_ast_rewriter_declines_non_canonical(): + from unsloth_zoo.fused_losses.ast_rewriter import rewrite_forward_source + new_src, cap = rewrite_forward_source(NON_CANONICAL_SRC) + assert new_src is None + assert cap is None + + +COHERE_REBINDING_SRC = """ +def forward(self, input_ids=None, labels=None, **kwargs): + outputs = self.model(input_ids=input_ids, **kwargs) + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) + logits = logits * self.logit_scale + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + return (loss, logits) +""" + + +def test_ast_rewriter_declines_when_logits_rebound(): + # Cohere-style `logits = logits * self.logit_scale` between lm_head and + # the if-labels block: removing the lm_head call would leave the + # rebinding referencing an undefined `logits`. The rewriter must refuse. + from unsloth_zoo.fused_losses.ast_rewriter import rewrite_forward_source + new_src, cap = rewrite_forward_source(COHERE_REBINDING_SRC) + assert new_src is None + assert cap is None + + +GEMMA_SOFTCAP_SRC = """ +def forward(self, input_ids=None, labels=None, **kwargs): + outputs = self.model(input_ids=input_ids, **kwargs) + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) + if self.final_logit_softcapping is not None: + logits = logits / self.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.final_logit_softcapping + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + return (loss, logits) +""" + + +def test_ast_rewriter_declines_when_intermediate_touches_logits(): + # Gemma-style softcap mutates logits between lm_head and the labels-if. + # Wholesale rewriting would skip that step and feed un-softcapped logits + # to the fused loss; refuse and let the backstop handle it. + from unsloth_zoo.fused_losses.ast_rewriter import rewrite_forward_source + new_src, cap = rewrite_forward_source(GEMMA_SOFTCAP_SRC) + assert new_src is None + assert cap is None + + +CSM_ALIASED_LABELS_SRC = """ +def forward(self, input_ids=None, labels=None, backbone_labels=None, **kwargs): + outputs = self.model(input_ids=input_ids, **kwargs) + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs) + return (loss, logits) +""" + + +def test_ast_rewriter_declines_when_labels_aliased(): + # CSM-style: gates on `labels is not None` but passes a different + # aliased name to loss_function. Wholesale rewrite would forward the + # wrong tensor; refuse. + from unsloth_zoo.fused_losses.ast_rewriter import rewrite_forward_source + new_src, cap = rewrite_forward_source(CSM_ALIASED_LABELS_SRC) + assert new_src is None + assert cap is None + + +MULTISTMT_LABEL_BRANCH_SRC = """ +def forward(self, input_ids=None, labels=None, **kwargs): + outputs = self.model(input_ids=input_ids, **kwargs) + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + aux_loss = self.aux_loss_coef * compute_aux(outputs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + loss = loss + aux_loss + return (loss, logits) +""" + + +def test_ast_rewriter_declines_non_trivial_labels_branch(): + # MoE-style auxiliary loss inside the labels branch would be silently + # dropped by a wholesale rewrite. The rewriter must refuse. + from unsloth_zoo.fused_losses.ast_rewriter import rewrite_forward_source + new_src, cap = rewrite_forward_source(MULTISTMT_LABEL_BRANCH_SRC) + assert new_src is None + assert cap is None + + +EXTRA_LOSS_KW_SRC = """ +def forward(self, input_ids=None, labels=None, **kwargs): + outputs = self.model(input_ids=input_ids, **kwargs) + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.vocab_size, + num_items_in_batch=kwargs.get("num_items_in_batch"), + ) + return (loss, logits) +""" + + +def test_ast_rewriter_forwards_explicit_extra_kwargs(): + # Bloom-style: loss_function gets explicit `num_items_in_batch=...` even + # though there's no **kwargs unpack. The rewriter must preserve that + # kwarg in the call to unsloth_fused_lm_head_loss. + from unsloth_zoo.fused_losses.ast_rewriter import rewrite_forward_source + new_src, cap = rewrite_forward_source(EXTRA_LOSS_KW_SRC) + assert new_src is not None + assert cap is not None + assert ("num_items_in_batch", ) == tuple(name for name, _ in cap.extra_loss_kws) + assert "num_items_in_batch=" in new_src + + +# --------------------------------------------------------------------------- +# install_for_class +# --------------------------------------------------------------------------- + + +_SYNTH_COUNTER = 0 + + +def _make_synthetic_class(forward_src: str, name: str = "SyntheticForCausalLM"): + """Build a class whose forward source is recoverable via inspect.getsource. + + `inspect.getsource` relies on `linecache`. Exec'd functions without a + real file backing return OSError, which is what the installer falls + back on. To exercise the rewriter we register a unique synthetic file + name with `linecache` and compile through it. + """ + import linecache + global _SYNTH_COUNTER + _SYNTH_COUNTER += 1 + fake_path = f"" + src = forward_src.lstrip("\n") + linecache.cache[fake_path] = ( + len(src), None, [line + "\n" for line in src.splitlines()], fake_path, + ) + namespace = {} + code = compile(src, fake_path, "exec") + exec(code, namespace) + forward_fn = namespace["forward"] + cls = type(name, (), {"forward": forward_fn}) + cls.__module__ = "transformers.models.synthetic.modeling_synthetic" + return cls + + +def test_install_noop_when_disabled(fresh_install, monkeypatch): + # On by default; UNSLOTH_FUSED_FORWARD=0 is the explicit opt-out. + monkeypatch.setenv("UNSLOTH_FUSED_FORWARD", "0") + cls = _make_synthetic_class(CANONICAL_KW_SRC) + original = cls.forward + assert fresh_install.install_for_class(cls) is False + assert cls.forward is original + + +def test_install_default_is_on(fresh_install, monkeypatch): + # With no env var set, the installer must be active. + monkeypatch.delenv("UNSLOTH_FUSED_FORWARD", raising=False) + cls = _make_synthetic_class(CANONICAL_KW_SRC) + assert fresh_install.is_enabled() is True + assert fresh_install.install_for_class(cls) is True + + +def test_install_skips_ineligible_name(fresh_install, enable_env): + cls = _make_synthetic_class(CANONICAL_KW_SRC, name="SyntheticModel") + original = cls.forward + assert fresh_install.install_for_class(cls) is False + assert cls.forward is original + + +def test_install_skips_for_conditional_generation(fresh_install, enable_env): + # *ForConditionalGeneration uses aligned labels (seq2seq); the fused + # kernel hardcodes a causal shift and would produce off-by-one losses. + # Such classes must be skipped. + cls = _make_synthetic_class(CANONICAL_KW_SRC, name="SyntheticForConditionalGeneration") + original = cls.forward + assert fresh_install.install_for_class(cls) is False + assert cls.forward is original + + +COMPOSITE_HEAD_SRC = """ +def forward(self, input_ids=None, labels=None, **kwargs): + outputs = self.model(input_ids=input_ids, **kwargs) + hidden_states = outputs.last_hidden_state + logits = self.cls(hidden_states) + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + return (loss, logits) +""" + + +def test_install_skips_composite_head(fresh_install, enable_env): + # BigBird-style `self.cls(...)` (BigBirdOnlyMLMHead) is a composite head, + # not a Linear; the adapter would crash on `lm_head.weight`. The + # installer must reject heads that aren't in the _LINEAR_HEAD_ATTRS + # allowlist. + cls = _make_synthetic_class(COMPOSITE_HEAD_SRC, name="SyntheticForCausalLM") + original = cls.forward + assert fresh_install.install_for_class(cls) is False + assert cls.forward is original + assert cls.__qualname__ in fresh_install._UNMATCHED + assert "non-linear-head" in fresh_install._UNMATCHED[cls.__qualname__] + + +def test_install_patches_canonical_forward(fresh_install, enable_env): + cls = _make_synthetic_class(CANONICAL_KW_SRC) + original = cls.forward + ok = fresh_install.install_for_class(cls) + assert ok is True + assert cls.forward is not original + rep = fresh_install._PATCHED[cls.__qualname__] + assert rep["tier"] == "2-ast-triplet" + assert rep["head_attr"] == "lm_head" + + +def test_install_idempotent(fresh_install, enable_env): + cls = _make_synthetic_class(CANONICAL_KW_SRC) + first = fresh_install.install_for_class(cls) + patched_fn = cls.forward + second = fresh_install.install_for_class(cls) + assert first is True and second is True + assert cls.forward is patched_fn + + +def test_install_leaves_non_canonical_in_unmatched(fresh_install, enable_env): + cls = _make_synthetic_class(NON_CANONICAL_SRC) + ok = fresh_install.install_for_class(cls) + assert ok is False + assert cls.__qualname__ in fresh_install._UNMATCHED + + +def test_install_function_override_fast_path(fresh_install, enable_env): + from unsloth_zoo.fused_losses.forward_install import _structural_hash + cls = _make_synthetic_class(CANONICAL_KW_SRC) + target_hash = _structural_hash(cls.forward) + assert target_hash is not None + + sentinel = [] + def _replacement(self, *a, **kw): + sentinel.append(True) + return (None, None) + + fresh_install.register_canonical(target_hash, _replacement) + ok = fresh_install.install_for_class(cls) + assert ok is True + rep = fresh_install._PATCHED[cls.__qualname__] + assert rep["tier"] == "1-function-override" + cls.forward(object()) + assert sentinel == [True] + + +def test_audit_dump(fresh_install, enable_env): + cls = _make_synthetic_class(CANONICAL_KW_SRC) + fresh_install.install_for_class(cls) + out = fresh_install.audit() + assert out["enabled"] is True + assert out["n_patched"] >= 1 + assert cls.__qualname__ in out["patched"] + + +# --------------------------------------------------------------------------- +# Numerical equivalence on a small toy model +# --------------------------------------------------------------------------- + + +def _toy_forward_src(): + # Mirrors the canonical HF template enough to be rewriter-eligible. + return """ +def forward(self, hidden_states, labels=None, **kwargs): + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + return (loss, logits) +""" + + +def test_rewritten_forward_loss_matches_reference(fresh_install, enable_env): + torch = pytest.importorskip("torch") + if not (hasattr(torch, "cuda") and torch.cuda.is_available()): + pytest.skip("fused CE kernel requires a CUDA device") + + cls = _make_synthetic_class(_toy_forward_src(), name="ToyForCausalLM") + + # Wire a config + lm_head + reference loss_function. + B, T, H, V = 2, 8, 32, 64 + + class _Config: + vocab_size = V + + instance = cls() + instance.config = _Config() + instance.lm_head = torch.nn.Linear(H, V, bias=False).cuda().to(torch.bfloat16) + + def _reference_loss(logits, labels, vocab_size, **kw): + # unsloth_fused_ce_loss shifts labels by one (causal LM convention). + # Mirror that here so the two losses are apples-to-apples. + shifted = labels.clone() + shifted[..., :-1] = labels[..., 1:] + shifted[..., -1] = -100 + return torch.nn.functional.cross_entropy( + logits.float().view(-1, vocab_size), + shifted.view(-1), + ignore_index=-100, + ) + instance.loss_function = _reference_loss + + hidden = torch.randn(B, T, H, device="cuda", dtype=torch.bfloat16, requires_grad=True) + labels = torch.randint(0, V, (B, T), device="cuda") + + ref_loss, _ = instance.forward(hidden, labels=labels) + ref_loss_value = float(ref_loss.detach().cpu().item()) + + # Install fused forward. + ok = fresh_install.install_for_class(cls) + assert ok is True + + # The instance still binds the old forward (Python attribute lookup hits + # the class on call), so we re-fetch from the class. + fused_loss, fused_logits = cls.forward(instance, hidden, labels=labels) + fused_loss_value = float(fused_loss.detach().cpu().item()) + + # Loss should match the reference to within bf16 -> fp32 rounding noise. + assert abs(fused_loss_value - ref_loss_value) < 0.05, ( + f"fused loss {fused_loss_value} diverged from reference {ref_loss_value}" + ) + # logits slot becomes the EMPTY_LOGITS sentinel under fused path. + assert fused_logits.numel() == 0 + + +def test_fused_kernel_respects_ignore_index(): + torch = pytest.importorskip("torch") + if not (hasattr(torch, "cuda") and torch.cuda.is_available()): + pytest.skip("fused CE kernel requires a CUDA device") + from unsloth_zoo.fused_losses import unsloth_fused_ce_loss + + B, T, H, V = 1, 16, 8, 32 + hidden = torch.randn(B, T, H, device="cuda", dtype=torch.float32, requires_grad=True) + weight = torch.randn(V, H, device="cuda", dtype=torch.float32, requires_grad=True) + labels = torch.randint(0, V, (B, T), device="cuda") + labels[0, 0] = 99 # would be a CUDA-side assert if not masked out + + loss = unsloth_fused_ce_loss( + trainer=None, + hidden_states=hidden, + lm_head_weight=weight, + lm_head_bias=None, + labels=labels, + torch_compile=False, + ignore_index=99, + ) + assert torch.isfinite(loss), f"loss not finite with ignore_index=99: {loss}" + + +def test_fused_kernel_accepts_int_n_items(): + # HF Trainer / gradient accumulation passes a Python int for + # num_items_in_batch. The kernel must promote it to a tensor before + # the DataParallel .numel()/.ravel() guard. + torch = pytest.importorskip("torch") + if not (hasattr(torch, "cuda") and torch.cuda.is_available()): + pytest.skip("fused CE kernel requires a CUDA device") + from unsloth_zoo.fused_losses import unsloth_fused_ce_loss + + B, T, H, V = 1, 8, 8, 16 + hidden = torch.randn(B, T, H, device="cuda", dtype=torch.float32, requires_grad=True) + weight = torch.randn(V, H, device="cuda", dtype=torch.float32, requires_grad=True) + labels = torch.randint(0, V, (B, T), device="cuda") + + loss = unsloth_fused_ce_loss( + trainer=None, hidden_states=hidden, lm_head_weight=weight, lm_head_bias=None, + labels=labels, torch_compile=False, n_items=3, # int, not tensor + ) + assert torch.isfinite(loss), f"loss not finite with int n_items: {loss}" + + +def _ce_reference(hidden, lm_head, labels, shift_labels=None, n_items=None, + ignore_index=-100, label_smoothing=0.0): + """Reference: F.cross_entropy on materialised logits. Mirrors HF + ForCausalLMLoss when shift_labels is supplied, otherwise does the + canonical causal shift itself.""" + import torch + logits = torch.nn.functional.linear(hidden, lm_head.weight, + getattr(lm_head, "bias", None)) + if shift_labels is None: + # Standard causal shift: predict token t+1 from position t. + target = torch.full_like(labels, ignore_index) + target[..., :-1] = labels[..., 1:] + else: + target = shift_labels + reduction = "sum" if n_items is not None else "mean" + loss = torch.nn.functional.cross_entropy( + logits.view(-1, logits.shape[-1]).float(), + target.reshape(-1).to(logits.device), + ignore_index=ignore_index, + label_smoothing=label_smoothing, + reduction=reduction, + ) + if n_items is not None: + loss = loss / float(n_items) + return loss + + +def test_adapter_auto_shift_matches_F_cross_entropy(): + torch = pytest.importorskip("torch") + if not (hasattr(torch, "cuda") and torch.cuda.is_available()): + pytest.skip("requires CUDA") + from unsloth_zoo.fused_losses import unsloth_fused_lm_head_loss + + torch.manual_seed(0) + B, T, H, V = 2, 32, 64, 128 + hidden = torch.randn(B, T, H, device="cuda", dtype=torch.float32, requires_grad=True) + lm_head = torch.nn.Linear(H, V, bias=False).cuda().float() + labels = torch.randint(0, V, (B, T), device="cuda") + labels[0, 5:8] = -100 # sprinkle ignore_index + + fused = unsloth_fused_lm_head_loss(hidden, lm_head, labels, vocab_size=V) + ref = _ce_reference(hidden, lm_head, labels) + assert torch.allclose(fused, ref, atol=1e-5, rtol=1e-5), ( + f"fused auto-shift {fused.item()} != reference {ref.item()}" + ) + + +def test_adapter_pre_shifted_tensor_matches_F_cross_entropy(): + torch = pytest.importorskip("torch") + if not (hasattr(torch, "cuda") and torch.cuda.is_available()): + pytest.skip("requires CUDA") + from unsloth_zoo.fused_losses import unsloth_fused_lm_head_loss + + torch.manual_seed(1) + B, T, H, V = 2, 32, 64, 128 + hidden = torch.randn(B, T, H, device="cuda", dtype=torch.float32, requires_grad=True) + lm_head = torch.nn.Linear(H, V, bias=False).cuda().float() + labels = torch.randint(0, V, (B, T), device="cuda") + # Simulate trl padding_free pre-shifted target: shift labels left by 1, + # last position becomes ignore_index. Same shape as logits. + shift = torch.full_like(labels, -100) + shift[..., :-1] = labels[..., 1:] + + fused = unsloth_fused_lm_head_loss( + hidden, lm_head, labels=labels, vocab_size=V, shift_labels=shift, + ) + ref = _ce_reference(hidden, lm_head, labels=None, shift_labels=shift) + assert torch.allclose(fused, ref, atol=1e-5, rtol=1e-5), ( + f"fused pre-shifted {fused.item()} != reference {ref.item()}" + ) + + +def test_adapter_shift_labels_false_matches_F_cross_entropy(): + torch = pytest.importorskip("torch") + if not (hasattr(torch, "cuda") and torch.cuda.is_available()): + pytest.skip("requires CUDA") + from unsloth_zoo.fused_losses import unsloth_fused_lm_head_loss + + torch.manual_seed(2) + B, T, H, V = 2, 32, 64, 128 + hidden = torch.randn(B, T, H, device="cuda", dtype=torch.float32, requires_grad=True) + lm_head = torch.nn.Linear(H, V, bias=False).cuda().float() + # Caller hands us labels that are already pre-shifted (the bool=False + # contract: do not shift again, treat labels as the target tensor). + target = torch.randint(0, V, (B, T), device="cuda") + target[..., -1] = -100 # canonical pre-shift fills last position + fused = unsloth_fused_lm_head_loss( + hidden, lm_head, labels=target, vocab_size=V, shift_labels=False, + ) + ref = _ce_reference(hidden, lm_head, labels=None, shift_labels=target) + assert torch.allclose(fused, ref, atol=1e-5, rtol=1e-5), ( + f"fused shift_labels=False {fused.item()} != reference {ref.item()}" + ) + + +def test_adapter_num_items_in_batch_divides_correctly(): + torch = pytest.importorskip("torch") + if not (hasattr(torch, "cuda") and torch.cuda.is_available()): + pytest.skip("requires CUDA") + from unsloth_zoo.fused_losses import unsloth_fused_lm_head_loss + + torch.manual_seed(3) + B, T, H, V = 2, 16, 32, 64 + hidden = torch.randn(B, T, H, device="cuda", dtype=torch.float32, requires_grad=True) + lm_head = torch.nn.Linear(H, V, bias=False).cuda().float() + labels = torch.randint(0, V, (B, T), device="cuda") + labels[:, :2] = -100 # pad-like prefix + + # Effective token count after causal shift: only positions where the + # shifted target is not ignore_index count. + target = torch.full_like(labels, -100) + target[..., :-1] = labels[..., 1:] + n_items = int((target != -100).sum().item()) + + fused = unsloth_fused_lm_head_loss( + hidden, lm_head, labels, vocab_size=V, num_items_in_batch=n_items, + ) + ref = _ce_reference(hidden, lm_head, labels, n_items=n_items) + assert torch.allclose(fused, ref, atol=1e-5, rtol=1e-5), ( + f"fused (num_items={n_items}) {fused.item()} != reference {ref.item()}" + ) + + +def test_adapter_num_items_in_batch_as_int_and_tensor_equal(): + torch = pytest.importorskip("torch") + if not (hasattr(torch, "cuda") and torch.cuda.is_available()): + pytest.skip("requires CUDA") + from unsloth_zoo.fused_losses import unsloth_fused_lm_head_loss + + torch.manual_seed(4) + B, T, H, V = 2, 16, 32, 64 + hidden = torch.randn(B, T, H, device="cuda", dtype=torch.float32, requires_grad=True) + lm_head = torch.nn.Linear(H, V, bias=False).cuda().float() + labels = torch.randint(0, V, (B, T), device="cuda") + n_items_int = 17 + n_items_tensor = torch.tensor(17, device="cuda") + + fused_int = unsloth_fused_lm_head_loss( + hidden, lm_head, labels, vocab_size=V, num_items_in_batch=n_items_int, + ) + fused_tensor = unsloth_fused_lm_head_loss( + hidden, lm_head, labels, vocab_size=V, num_items_in_batch=n_items_tensor, + ) + assert torch.allclose(fused_int, fused_tensor, atol=1e-6), ( + f"int vs tensor n_items disagree: {fused_int.item()} vs {fused_tensor.item()}" + ) + + +def test_fused_kernel_label_smoothing_changes_loss(): + torch = pytest.importorskip("torch") + if not (hasattr(torch, "cuda") and torch.cuda.is_available()): + pytest.skip("fused CE kernel requires a CUDA device") + from unsloth_zoo.fused_losses import unsloth_fused_ce_loss + + B, T, H, V = 1, 8, 8, 16 + hidden = torch.randn(B, T, H, device="cuda", dtype=torch.float32, requires_grad=True) + weight = torch.randn(V, H, device="cuda", dtype=torch.float32, requires_grad=True) + labels = torch.randint(0, V, (B, T), device="cuda") + + loss_plain = unsloth_fused_ce_loss( + trainer=None, hidden_states=hidden, lm_head_weight=weight, lm_head_bias=None, + labels=labels, torch_compile=False, + ) + loss_smoothed = unsloth_fused_ce_loss( + trainer=None, hidden_states=hidden, lm_head_weight=weight, lm_head_bias=None, + labels=labels, torch_compile=False, label_smoothing=0.1, + ) + assert float(loss_plain.item()) != float(loss_smoothed.item()), ( + "label_smoothing kwarg was ignored: smoothed loss equals plain loss" + ) diff --git a/tests/test_upstream_source_patterns.py b/tests/test_upstream_source_patterns.py index 88a1c0806..bc0f4c386 100644 --- a/tests/test_upstream_source_patterns.py +++ b/tests/test_upstream_source_patterns.py @@ -316,9 +316,14 @@ def test_compiler_per_layer_projection_inplace_regex(): def test_compiler_cross_entropy_lm_head_pattern_present(): """``unsloth_zoo/compiler.py:1508-1525`` (cross_entropy_find_1) expects ``logits = self.lm_head(hidden_states`` at the head of the - loss block in every ForCausalLM forward.""" + loss block in every ForCausalLM forward. + + Read on-disk modeling source: the fused-forward installer rewrites + ``cls.forward`` at import time, but the upstream pattern compiler.py + pins still lives in the source file.""" pytest.importorskip("transformers") import importlib + import pathlib candidate_classes = [ "transformers.models.llama.modeling_llama.LlamaForCausalLM", "transformers.models.llama4.modeling_llama4.Llama4ForCausalLM", @@ -334,12 +339,12 @@ def test_compiler_cross_entropy_lm_head_pattern_present(): mod = importlib.import_module(mod_path) except ImportError: continue - cls = getattr(mod, cls_name, None) - if cls is None: + src_file = getattr(mod, "__file__", None) + if not src_file: continue try: - src = inspect.getsource(cls.forward) - except (OSError, TypeError): + src = pathlib.Path(src_file).read_text(encoding="utf-8") + except OSError: continue if needle in src: found = True @@ -357,9 +362,16 @@ def test_compiler_cross_entropy_lm_head_pattern_present(): def test_compiler_cross_entropy_find_2_loss_function_signature(): """``unsloth_zoo/compiler.py:1593-1600`` (cross_entropy_find_2) pins - ``loss = self.loss_function(...$LOGITS$, $LABELS$, $VOCABSIZE$...)``.""" + ``loss = self.loss_function(...$LOGITS$, $LABELS$, $VOCABSIZE$...)``. + + Read the modeling module's on-disk source directly. The fused-forward + installer (forward_install.py) replaces ``*ForCausalLM.forward`` at + import time, so ``inspect.getsource(cls.forward)`` would return the + rewritten body; the upstream pattern this test pins still lives on + disk untouched.""" pytest.importorskip("transformers") import importlib + import pathlib candidate_classes = [ "transformers.models.llama.modeling_llama.LlamaForCausalLM", "transformers.models.mistral.modeling_mistral.MistralForCausalLM", @@ -374,12 +386,12 @@ def test_compiler_cross_entropy_find_2_loss_function_signature(): mod = importlib.import_module(mod_path) except ImportError: continue - cls = getattr(mod, cls_name, None) - if cls is None: + src_file = getattr(mod, "__file__", None) + if not src_file: continue try: - src = inspect.getsource(cls.forward) - except (OSError, TypeError): + src = pathlib.Path(src_file).read_text(encoding="utf-8") + except OSError: continue if needle in src: return diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index d26684a1f..661f05bc9 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -382,6 +382,15 @@ def filter(self, x): return not (self.text in x.getMessage()) from .temporary_patches import ( encode_conversations_with_harmony, ) + + # Fused lm_head + cross_entropy auto-installer. On by default; set + # UNSLOTH_FUSED_FORWARD=0 to disable. + try: + from .fused_losses.forward_install import install_modeling_import_hook as _install_fused_forward + _install_fused_forward() + del _install_fused_forward + except Exception: + pass from .rl_environments import ( check_python_modules, create_locked_down_function, diff --git a/unsloth_zoo/fused_losses/__init__.py b/unsloth_zoo/fused_losses/__init__.py index 61b53727d..75c8e5746 100644 --- a/unsloth_zoo/fused_losses/__init__.py +++ b/unsloth_zoo/fused_losses/__init__.py @@ -15,3 +15,12 @@ # along with this program. If not, see . from .cross_entropy_loss import * +from .forward_adapter import EMPTY_LOGITS, unsloth_fused_lm_head_loss +from .forward_install import ( + install_modeling_import_hook, + install_for_module, + install_for_class, + register_canonical, + audit, + is_enabled, +) diff --git a/unsloth_zoo/fused_losses/ast_rewriter.py b/unsloth_zoo/fused_losses/ast_rewriter.py new file mode 100644 index 000000000..c58660da3 --- /dev/null +++ b/unsloth_zoo/fused_losses/ast_rewriter.py @@ -0,0 +1,338 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +"""AST-level rewriter for the canonical HF lm_head / loss_function triplet. + +Match: + = self.() # optional .float()/[slice]/.contiguous() wrappers + if labels is not None: + = self.loss_function(, labels, vocab_size=..., **kwargs) + +Rewrite the labels branch to call `unsloth_fused_lm_head_loss(, +self., labels, ...)` (skipping the bf16 logits + fp32 cast) and +substitute EMPTY_LOGITS for the returned logits. The else (generation) +branch keeps the original RHS verbatim. Forwards that miss the triplet +fall through to `_UNMATCHED`; the LOSS_MAPPING sweep is the backstop. +""" + +from __future__ import annotations + +__all__ = [ + "rewrite_forward_source", + "TripletCapture", +] + +import ast +import textwrap +from dataclasses import dataclass + + +@dataclass +class TripletCapture: + head_attr: str # e.g. "lm_head" + hidden_expr: ast.AST # the expression passed into self.(...) + logits_rhs_src: str # ast.unparse of the original `logits = ...` RHS + logits_name: str # the name the lm_head output was bound to + loss_name: str # the name the loss was bound to + vocab_expr: ast.AST | None + kwargs_name: str | None # name of the **kwargs param passed to loss_function + extra_loss_kws: list # [(name, ast.AST), ...] explicit kwargs beyond vocab_size + lm_head_assign_idx: int # index in the function body of the `logits = self.lm_head(...)` stmt + if_block_idx: int # index of the `if labels is not None:` stmt + loss_init_idx: int | None # index of the `loss = None` stmt that we delete (may be None) + + +def _is_self_attr_call(node: ast.AST) -> bool: + return ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "self" + ) + + +def _find_inner_self_call(value: ast.AST) -> ast.Call | None: + """First Call descendant whose func is `self.`. Lets us see through + `.float()` / `[slice]` / `.contiguous()` chains.""" + for node in ast.walk(value): + if _is_self_attr_call(node): + return node + return None + + +def _find_loss_function_call(if_block: ast.If) -> ast.Call | None: + # Only direct body statements -- nested ifs (guards) inside the labels + # branch would be silently dropped by the wholesale rewrite. + for stmt in if_block.body: + if isinstance(stmt, ast.Assign): + v = stmt.value + if ( + isinstance(v, ast.Call) + and isinstance(v.func, ast.Attribute) + and isinstance(v.func.value, ast.Name) + and v.func.value.id == "self" + and v.func.attr == "loss_function" + ): + return v + return None + + +def _find_loss_assign_target(if_block: ast.If, call: ast.Call) -> str | None: + for stmt in if_block.body: + if isinstance(stmt, ast.Assign) and stmt.value is call and len(stmt.targets) == 1: + tgt = stmt.targets[0] + if isinstance(tgt, ast.Name): + return tgt.id + return None + + +def _capture(fn: ast.FunctionDef | ast.AsyncFunctionDef) -> TripletCapture | None: + body = fn.body + + if_idx = None + if_node = None + for i, stmt in enumerate(body): + if not isinstance(stmt, ast.If): + continue + t = stmt.test + if not (isinstance(t, ast.Compare) + and isinstance(t.left, ast.Name) and t.left.id == "labels" + and len(t.ops) == 1 and isinstance(t.ops[0], ast.IsNot) + and isinstance(t.comparators[0], ast.Constant) + and t.comparators[0].value is None): + continue + # Must contain a self.loss_function call + if _find_loss_function_call(stmt) is None: + continue + if_idx = i + if_node = stmt + break + if if_node is None: + return None + + # Reject non-trivial label branches: anything more than `[loss = self.loss_function(...)]` + # is silently lost by the wholesale rewrite (e.g. CSM auxiliary depth-decoder loss). + if if_node.orelse: + return None + if len(if_node.body) != 1: + return None + loss_assign = if_node.body[0] + if not (isinstance(loss_assign, ast.Assign) and len(loss_assign.targets) == 1 + and isinstance(loss_assign.targets[0], ast.Name)): + return None + loss_call = loss_assign.value + if not (isinstance(loss_call, ast.Call) + and isinstance(loss_call.func, ast.Attribute) + and isinstance(loss_call.func.value, ast.Name) + and loss_call.func.value.id == "self" + and loss_call.func.attr == "loss_function"): + return None + loss_name = loss_assign.targets[0].id + + # Locate logits-bearing arg: first positional or `logits=` kw. + logits_name = None + if loss_call.args: + a0 = loss_call.args[0] + if isinstance(a0, ast.Name): + logits_name = a0.id + if logits_name is None: + for kw in loss_call.keywords: + if kw.arg == "logits" and isinstance(kw.value, ast.Name): + logits_name = kw.value.id + break + if logits_name is None: + return None + + # Labels arg must be literally the plain `labels` name; aliased labels + # (e.g. CSM `labels=backbone_labels`) need bespoke handling. + labels_arg = None + if len(loss_call.args) >= 2: + if isinstance(loss_call.args[1], ast.Name): + labels_arg = loss_call.args[1].id + for kw in loss_call.keywords: + if kw.arg == "labels": + if isinstance(kw.value, ast.Name): + labels_arg = kw.value.id + else: + return None + if labels_arg != "labels": + return None + + # vocab_size: keyword preferred, else 3rd positional. + vocab_expr = None + for kw in loss_call.keywords: + if kw.arg == "vocab_size": + vocab_expr = kw.value + break + if vocab_expr is None and len(loss_call.args) >= 3: + vocab_expr = loss_call.args[2] + + # **kwargs unpack + any explicit kwargs beyond {logits, labels, vocab_size}. + kwargs_name = None + extra_loss_kws: list = [] + for kw in loss_call.keywords: + if kw.arg is None: + if isinstance(kw.value, ast.Name): + kwargs_name = kw.value.id + continue + if kw.arg in ("logits", "labels", "vocab_size"): + continue + extra_loss_kws.append((kw.arg, kw.value)) + + # Find the lm_head assignment for logits_name (walking upward from if_idx). + head_attr = None + hidden_expr = None + logits_rhs_src = None + lm_head_assign_idx = None + for j in range(if_idx - 1, -1, -1): + stmt = body[j] + if not (isinstance(stmt, ast.Assign) and len(stmt.targets) == 1): + continue + tgt = stmt.targets[0] + if not (isinstance(tgt, ast.Name) and tgt.id == logits_name): + continue + inner = _find_inner_self_call(stmt.value) + if inner is None: + # The logits-bearing name is re-assigned by a non-lm_head + # expression (e.g. `logits = logits * self.logit_scale` for + # Cohere). Removing the original lm_head call would leave the + # rebinding referencing an undefined `logits`. Bail out and + # let the LOSS_MAPPING patch handle this class. + return None + head_attr = inner.func.attr + if not inner.args: + return None + hidden_expr = inner.args[0] + logits_rhs_src = ast.unparse(stmt.value) + lm_head_assign_idx = j + break + if head_attr is None or hidden_expr is None or lm_head_assign_idx is None: + return None + + # Optional `loss = None` between the lm_head assign and the if block. + loss_init_idx = None + for j in range(lm_head_assign_idx + 1, if_idx): + stmt = body[j] + if (isinstance(stmt, ast.Assign) + and len(stmt.targets) == 1 + and isinstance(stmt.targets[0], ast.Name) + and stmt.targets[0].id == loss_name + and isinstance(stmt.value, ast.Constant) + and stmt.value.value is None): + loss_init_idx = j + break + + # Bail if any statement between lm_head and the labels-if touches + # logits (e.g. Gemma3 final_logit_softcapping): it would run on + # EMPTY_LOGITS in the labels branch, so fused loss would see + # un-softcapped logits. + for j in range(lm_head_assign_idx + 1, if_idx): + if j == loss_init_idx: + continue + for n in ast.walk(body[j]): + if isinstance(n, ast.Name) and n.id == logits_name: + return None + + return TripletCapture( + head_attr=head_attr, + hidden_expr=hidden_expr, + logits_rhs_src=logits_rhs_src, + logits_name=logits_name, + loss_name=loss_name, + vocab_expr=vocab_expr, + kwargs_name=kwargs_name, + extra_loss_kws=extra_loss_kws, + lm_head_assign_idx=lm_head_assign_idx, + if_block_idx=if_idx, + loss_init_idx=loss_init_idx, + ) + + +def _build_replacement(cap: TripletCapture) -> list[ast.stmt]: + """Build the AST nodes for the rewritten labels-branch / else-branch.""" + head_attr = cap.head_attr + logits = cap.logits_name + loss = cap.loss_name + vocab = ast.unparse(cap.vocab_expr) if cap.vocab_expr is not None else "None" + extra = "".join( + f", {name}={ast.unparse(value)}" for name, value in cap.extra_loss_kws + ) + kwargs_unpack = f", **{cap.kwargs_name}" if cap.kwargs_name else "" + hidden_src = ast.unparse(cap.hidden_expr) + logits_rhs = cap.logits_rhs_src or f"self.{head_attr}({hidden_src})" + + template = textwrap.dedent(f""" + if labels is not None: + {loss} = unsloth_fused_lm_head_loss( + {hidden_src}, self.{head_attr}, labels, + vocab_size={vocab}{extra}{kwargs_unpack}, + ) + {logits} = EMPTY_LOGITS + else: + {logits} = {logits_rhs} + {loss} = None + """).strip() + return ast.parse(template).body + + +def rewrite_forward_source(source: str) -> tuple[str | None, TripletCapture | None]: + """Rewrite a forward function source string. + + Returns (new_source, capture) on success, (None, None) if the canonical + triplet wasn't found (and the caller should leave the class alone). + """ + try: + tree = ast.parse(textwrap.dedent(source)) + except SyntaxError: + return (None, None) + if not tree.body or not isinstance(tree.body[0], (ast.FunctionDef, ast.AsyncFunctionDef)): + return (None, None) + fn = tree.body[0] + cap = _capture(fn) + if cap is None: + return (None, None) + + new_block = _build_replacement(cap) + body = fn.body + delete_indices = {cap.lm_head_assign_idx, cap.if_block_idx} + if cap.loss_init_idx is not None: + delete_indices.add(cap.loss_init_idx) + new_body = [] + inserted = False + for i, stmt in enumerate(body): + if i in delete_indices: + if not inserted: + new_body.extend(new_block) + inserted = True + continue + new_body.append(stmt) + fn.body = new_body + # @can_return_tuple carries return_dict=False semantics and must + # survive; only strip the docstring-only decorators below. + _DROP_DECORATORS = { + "auto_docstring", + "add_start_docstrings", + "add_start_docstrings_to_model_forward", + "add_end_docstrings", + "replace_return_docstrings", + } + fn.decorator_list = [ + d for d in fn.decorator_list if _decorator_name(d) not in _DROP_DECORATORS + ] + ast.fix_missing_locations(tree) + return (ast.unparse(tree), cap) + + +def _decorator_name(node: ast.AST) -> str | None: + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + return node.attr + if isinstance(node, ast.Call): + return _decorator_name(node.func) + return None diff --git a/unsloth_zoo/fused_losses/cross_entropy_loss.py b/unsloth_zoo/fused_losses/cross_entropy_loss.py index 414c3177c..f5b16a473 100644 --- a/unsloth_zoo/fused_losses/cross_entropy_loss.py +++ b/unsloth_zoo/fused_losses/cross_entropy_loss.py @@ -85,13 +85,17 @@ def compute_fused_ce_loss( 1) logit_scale_multiply (X = X * logit_scale_multiply) 2) logit_scale_divide (X = X / logit_scale_divide) 3) logit_softcapping (X = tanh(X / logit_softcapping) * logit_softcapping) + 4) ignore_index (passed to F.cross_entropy; defaults to -100) + 5) label_smoothing (passed to F.cross_entropy; defaults to 0.0) """ + ignore_index = int(kwargs.get("ignore_index", -100)) + label_smoothing = float(kwargs.get("label_smoothing", 0.0)) device = lm_head_weight.device if shift_labels: # Get shifted labels first _labels = torch.empty_like(labels, device = device) _labels[..., :-1] = labels[..., 1:] - _labels[..., -1] = -100 + _labels[..., -1] = ignore_index labels = _labels pass @@ -121,6 +125,8 @@ def compute_fused_ce_loss( input = logits.view(-1, vocab_size).float().contiguous(), target = labels.view(-1).to(device).contiguous(), reduction = reduction, + ignore_index = ignore_index, + label_smoothing = label_smoothing, ) loss = loss / n_items if n_items is not None else loss # Scale loss if needed for mixed precision training @@ -192,6 +198,8 @@ def forward( """ device = lm_head_weight.device if extra_kwargs is None: extra_kwargs = {} + # Thread ignore_index through label-shift and the inner CE call. + ignore_index = int(extra_kwargs.get("ignore_index", -100)) # Get shifted labels first if shift_labels: @@ -200,15 +208,22 @@ def forward( # Also check mask if mask is not None: mask = mask.to(device = device) - _labels[..., :-1][mask[..., 1:] == 0] = -100 + _labels[..., :-1][mask[..., 1:] == 0] = ignore_index pass - _labels[..., -1] = -100 + _labels[..., -1] = ignore_index _labels = _labels.view(-1) labels = _labels + else: + # Caller already shifted (e.g. trl padding_free passes + # shift_labels=). Flatten so chunking aligns with + # hidden_states.reshape(-1, hd). + labels = labels.contiguous().view(-1).to(device = device) pass # N items divisor - divisor = n_items if n_items is not None else (labels != -100).sum() + divisor = n_items if n_items is not None else (labels != ignore_index).sum() + if not torch.is_tensor(divisor): + divisor = torch.tensor(divisor, dtype = torch.float32, device = device) # Counteract DataParallel having multiple items since it does scatter & gather if divisor.numel() != 1: divisor = divisor.ravel()[0] divisor = divisor.to(dtype = torch.float32, device = device) @@ -263,7 +278,7 @@ def accumulate_chunk( labels_j, divisor, scaling, - not shift_labels, # Already label shifted + False, # Outer pre-shifted (or caller did); inner skips **kwargs, ) grad_lm_head.add_(chunk_grad_lm_head) @@ -281,7 +296,7 @@ def accumulate_chunk( labels_j, divisor, scaling, - not shift_labels, # Already label shifted + False, # Outer pre-shifted (or caller did); inner skips **kwargs, ) grad_lm_head.add_(chunk_grad_lm_head) @@ -298,7 +313,7 @@ def accumulate_chunk( labels_j, divisor, scaling, - not shift_labels, # Already label shifted + False, # Outer pre-shifted (or caller did); inner skips **kwargs, ) grad_lm_head_bias.add_(chunk_grad_lm_head_bias) @@ -315,7 +330,7 @@ def accumulate_chunk( labels_j, divisor, scaling, - not shift_labels, # Already label shifted + False, # Outer pre-shifted (or caller did); inner skips **kwargs, ) pass @@ -524,12 +539,14 @@ def unsloth_fused_ce_loss( target_gb : Optional[int] = None, torch_compile : Optional[bool] = True, overwrite : Optional[bool] = False, + shift_labels : bool = True, **kwargs, ): """ Computes chunked fused cross_entropy_loss(chunk(X) @ W + b, chunk(labels)) * If n_items is not given, does mean(ce_loss), otherwise sum(ce_loss)/n_items - * Auto does shift of labels ie hidden_states[..., :-1] and labels[..., 1:] + * shift_labels=True (default) shifts internally: hidden_states[..., :-1] and labels[..., 1:]. + Set False when caller already pre-shifted (e.g. trl padding_free). * Allows scaling factor from mixed precision fp16, fp8 * target_gb specifies the max GB memory the fused loss can use - default detects VRAM left * Upcasts to float32 and allows kwargs to have: @@ -561,7 +578,7 @@ def unsloth_fused_ce_loss( mask = mask, n_items = n_items, scaling = scaling, - shift_labels = True, + shift_labels = shift_labels, target_gb = target_gb, torch_compile = torch_compile, overwrite = overwrite, diff --git a/unsloth_zoo/fused_losses/forward_adapter.py b/unsloth_zoo/fused_losses/forward_adapter.py new file mode 100644 index 000000000..2001317ec --- /dev/null +++ b/unsloth_zoo/fused_losses/forward_adapter.py @@ -0,0 +1,94 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""HF-call-convention adapter for unsloth_fused_ce_loss. + +The fused kernel takes hidden_states + lm_head.weight directly, skipping +both the lm_head matmul and the fp32 logits materialisation that the HF +template performs before `self.loss_function(...)`. `EMPTY_LOGITS` is the +0-element sentinel slotted into the `logits=` return field. +""" + +from __future__ import annotations + +__all__ = [ + "EMPTY_LOGITS", + "unsloth_fused_lm_head_loss", +] + +import torch + +from .cross_entropy_loss import unsloth_fused_ce_loss + + +EMPTY_LOGITS = torch.empty(0) + + +def unsloth_fused_lm_head_loss( + hidden_states, + lm_head, + labels, + vocab_size=None, + **kwargs, +): + """Replacement for the canonical `self.loss_function(logits=..., labels=..., + vocab_size=..., **kwargs)` call site. Routes through the chunked fused + cross-entropy kernel without materialising fp32 logits. + + Args: + hidden_states: the tensor that was about to be fed into `self.lm_head`. + lm_head: the lm_head module (Linear). Weight + bias pulled off it. + labels: integer label tensor. + vocab_size: ignored. Kernel reads it from `lm_head.weight.shape[0]`. + **kwargs: forwarded to the kernel. Accepts `num_items_in_batch` + (renamed to `n_items`), `logit_scale_multiply`, `logit_scale_divide`, + `logit_softcapping`, plus any other extras the original + `self.loss_function` would have ignored. + """ + n_items = kwargs.pop("num_items_in_batch", None) + if n_items is None: + n_items = kwargs.pop("n_items", None) + else: + kwargs.pop("n_items", None) + # vocab_size is read from lm_head_weight.shape[0]; drop the keyword. + kwargs.pop("vocab_size", None) + # If caller already shifted (e.g. trl padding_free with packing passes + # shift_labels=), route the pre-shifted target into the fused + # kernel with shift_labels=False so we still get chunked logits. + shift_labels_kw = kwargs.pop("shift_labels", None) + pre_shifted_tensor = ( + shift_labels_kw is not None and not isinstance(shift_labels_kw, bool) + ) + if pre_shifted_tensor: + target = shift_labels_kw + do_shift = False + elif shift_labels_kw is False: + target = labels + do_shift = False + else: + target = labels + do_shift = True + + return unsloth_fused_ce_loss( + trainer = None, + hidden_states = hidden_states, + lm_head_weight = lm_head.weight, + lm_head_bias = getattr(lm_head, "bias", None), + labels = target, + n_items = n_items, + shift_labels = do_shift, + **kwargs, + ) diff --git a/unsloth_zoo/fused_losses/forward_install.py b/unsloth_zoo/fused_losses/forward_install.py new file mode 100644 index 000000000..99ff25e59 --- /dev/null +++ b/unsloth_zoo/fused_losses/forward_install.py @@ -0,0 +1,377 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +"""Auto-installer for the fused lm_head + cross_entropy forward. + +Tier 1 swaps the forward via a hand-registered structural-hash allowlist +(empty by default; populate with `register_canonical`). Tier 2 falls back +to `ast_rewriter` which rewrites the canonical HF triplet in-place; misses +go to `_UNMATCHED` and the LOSS_MAPPING sweep stays as the backstop. + +On by default. Set `UNSLOTH_FUSED_FORWARD=0` to disable. Soft floor at +transformers >= 4.56, the release where every `*ForCausalLM` settled on +the `outputs.last_hidden_state` + `self.loss_function(logits, labels, +vocab_size, **kwargs)` shape we match against. +""" + +from __future__ import annotations + +__all__ = [ + "install_modeling_import_hook", + "install_for_module", + "install_for_class", + "register_canonical", + "audit", + "is_enabled", + "EMPTY_LOGITS", + "unsloth_fused_lm_head_loss", +] + +import ast +import hashlib +import importlib.abc +import importlib.util +import inspect +import linecache +import logging +import os +import sys +import textwrap +import threading +import warnings +from typing import Any + +from .ast_rewriter import rewrite_forward_source +from .forward_adapter import EMPTY_LOGITS, unsloth_fused_lm_head_loss + + +logger = logging.getLogger("unsloth_zoo.fused_forward") + +_MIN_TRANSFORMERS = (4, 56, 0) + +_REGISTRY_LOCK = threading.RLock() +_PATCHED: dict[str, dict[str, Any]] = {} # qualname -> {tier, kind, hash, module} +_UNMATCHED: dict[str, str] = {} # qualname -> reason +_FAILED: dict[str, str] = {} # qualname -> error +_CANONICAL_FORWARDS: dict[str, Any] = {} # forward_hash -> replacement forward fn + +_INSTALL_DONE = False # set once install_modeling_import_hook has run + + +def is_enabled() -> bool: + # On by default; opt out via UNSLOTH_FUSED_FORWARD=0. + return os.environ.get("UNSLOTH_FUSED_FORWARD", "1") != "0" + + +def register_canonical(forward_hash: str, replacement_forward) -> None: + """Register a hand-written canonical forward for a known structural hash. + Future installs that fingerprint to `forward_hash` get the replacement + directly without the AST rewrite step.""" + with _REGISTRY_LOCK: + _CANONICAL_FORWARDS[forward_hash] = replacement_forward + + +def audit() -> dict[str, Any]: + """Snapshot of what's been patched / left alone / errored. JSON-safe.""" + with _REGISTRY_LOCK: + return { + "enabled": is_enabled(), + "n_patched": len(_PATCHED), + "n_unmatched": len(_UNMATCHED), + "n_failed": len(_FAILED), + "patched": dict(_PATCHED), + "unmatched": dict(_UNMATCHED), + "failed": dict(_FAILED), + "canonical_hashes_registered": sorted(_CANONICAL_FORWARDS), + } + + +def _transformers_version_ok() -> bool: + try: + import transformers # noqa: PLC0415 + except Exception: + return False + v = getattr(transformers, "__version__", "0.0.0") + parts = [] + for chunk in v.split("+")[0].split("."): + try: + parts.append(int(chunk)) + except ValueError: + parts.append(0) + if len(parts) == 3: + break + while len(parts) < 3: + parts.append(0) + return tuple(parts) >= _MIN_TRANSFORMERS + + +def _structural_hash(fn) -> str | None: + try: + src = textwrap.dedent(inspect.getsource(fn)) + except (OSError, TypeError): + return None + try: + tree = ast.parse(src) + except SyntaxError: + return None + # Strip docstrings so cosmetic changes do not bust the hash. + _BODY_HOLDERS = (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef) + for node in ast.walk(tree): + if not isinstance(node, _BODY_HOLDERS): + continue + body = node.body + if (body and isinstance(body[0], ast.Expr) + and isinstance(body[0].value, ast.Constant) + and isinstance(body[0].value.value, str)): + body.pop(0) + return hashlib.sha256( + ast.dump(tree, annotate_fields=False, include_attributes=False).encode() + ).hexdigest()[:16] + + +_LINEAR_HEAD_ATTRS = { + "lm_head", + "output_projection", + "embed_out", + "proj_out", + "generator_lm_head", + "head", + "logits_dense", + "codec_head", +} + + +def _is_eligible_class(cls) -> bool: + # ForConditionalGeneration uses aligned labels; fused kernel hardcodes + # a causal shift, so accept only ForCausalLM. + name = getattr(cls, "__name__", "") + if not name.endswith("ForCausalLM"): + return False + if not hasattr(cls, "forward"): + return False + return True + + +def install_for_class(cls) -> bool: + """Try to install the fused forward on `cls`. Returns True on success.""" + if not is_enabled(): + return False + if not _transformers_version_ok(): + return False + if not _is_eligible_class(cls): + return False + qn = getattr(cls, "__qualname__", cls.__name__) + with _REGISTRY_LOCK: + if qn in _PATCHED: + return True + + forward = cls.forward + fhash = _structural_hash(forward) + + # Tier 1: hash-allowlisted function override. + if fhash is not None: + replacement = _CANONICAL_FORWARDS.get(fhash) + if replacement is not None: + try: + replacement.__qualname__ = forward.__qualname__ + replacement.__module__ = forward.__module__ + except Exception: + pass + cls.forward = replacement + with _REGISTRY_LOCK: + _PATCHED[qn] = { + "tier": "1-function-override", + "kind": cls.__name__, + "hash": fhash, + "module": getattr(cls, "__module__", ""), + } + return True + + # Tier 2: AST triplet rewrite. + try: + src = textwrap.dedent(inspect.getsource(forward)) + except (OSError, TypeError) as exc: + with _REGISTRY_LOCK: + _UNMATCHED[qn] = f"source-unavailable: {exc}" + return False + + new_src, cap = rewrite_forward_source(src) + if new_src is None: + with _REGISTRY_LOCK: + _UNMATCHED[qn] = "no-canonical-triplet" + return False + # Composite heads (e.g. BigBird's BigBirdOnlyMLMHead via self.cls) lack + # .weight/.bias and would crash inside the adapter. + if cap.head_attr not in _LINEAR_HEAD_ATTRS: + with _REGISTRY_LOCK: + _UNMATCHED[qn] = f"non-linear-head: {cap.head_attr}" + return False + + ns = dict(getattr(forward, "__globals__", {})) + ns["unsloth_fused_lm_head_loss"] = unsloth_fused_lm_head_loss + ns["EMPTY_LOGITS"] = EMPTY_LOGITS + try: + from transformers.utils.generic import can_return_tuple + ns.setdefault("can_return_tuple", can_return_tuple) + except Exception: + pass + # Backfill transformers.modeling_outputs symbols; unsloth's compiled-cache + # forwards reference CausalLMOutputWithPast & friends in the return line. + try: + import transformers.modeling_outputs as _mo + for _name in dir(_mo): + if _name.startswith("_"): + continue + ns.setdefault(_name, getattr(_mo, _name)) + except Exception: + pass + # Register rewritten source with linecache so inspect.getsource and + # tracebacks see the installed body. + synthetic_path = f"" + linecache.cache[synthetic_path] = ( + len(new_src), None, + [line + "\n" for line in new_src.splitlines()], + synthetic_path, + ) + try: + code = compile(new_src, synthetic_path, "exec") + exec(code, ns) + except Exception as exc: + with _REGISTRY_LOCK: + _FAILED[qn] = f"compile-or-exec: {type(exc).__name__}: {exc}" + return False + + new_forward = ns.get(forward.__name__) + if not callable(new_forward): + with _REGISTRY_LOCK: + _FAILED[qn] = "rewritten-forward-missing" + return False + try: + new_forward.__qualname__ = forward.__qualname__ + new_forward.__module__ = forward.__module__ + new_forward.__doc__ = forward.__doc__ + except Exception: + pass + + cls.forward = new_forward + with _REGISTRY_LOCK: + _PATCHED[qn] = { + "tier": "2-ast-triplet", + "kind": cls.__name__, + "hash": fhash, + "module": getattr(cls, "__module__", ""), + "head_attr": cap.head_attr, + } + return True + + +def install_for_module(module) -> int: + """Scan a transformers `modeling_*` module and install where eligible. + Returns the number of classes newly patched.""" + if not is_enabled(): + return 0 + if not _transformers_version_ok(): + return 0 + name = getattr(module, "__name__", "") + if not (name.startswith("transformers.models.") and ".modeling_" in name): + return 0 + n = 0 + for attr in dir(module): + try: + obj = getattr(module, attr) + except Exception: + continue + if not isinstance(obj, type): + continue + if getattr(obj, "__module__", "") != name: + continue # skip re-exports + try: + if install_for_class(obj): + n += 1 + except Exception as exc: + qn = getattr(obj, "__qualname__", obj.__name__) + with _REGISTRY_LOCK: + _FAILED[qn] = f"install: {type(exc).__name__}: {exc}" + return n + + +class _ModelingLoader(importlib.abc.Loader): + """Wraps an inner loader, runs `install_for_module` after exec_module.""" + def __init__(self, inner): + self._inner = inner + + def create_module(self, spec): + if hasattr(self._inner, "create_module"): + return self._inner.create_module(spec) + return None + + def exec_module(self, module): + self._inner.exec_module(module) + try: + install_for_module(module) + except Exception as exc: + logger.debug( + "unsloth fused-forward install_for_module failed for %s: %s", + getattr(module, "__name__", "?"), exc, + ) + + +class _ModelingFinder(importlib.abc.MetaPathFinder): + """Intercepts `transformers.models..modeling_` imports.""" + PREFIX = "transformers.models." + + def find_spec(self, fullname, path, target=None): + if not (fullname.startswith(self.PREFIX) and ".modeling_" in fullname): + return None + if fullname in sys.modules: + return None + for finder in sys.meta_path: + if finder is self: + continue + try: + spec = finder.find_spec(fullname, path, target) + except Exception: + continue + if spec is None or spec.loader is None: + continue + spec.loader = _ModelingLoader(spec.loader) + return spec + return None + + +def install_modeling_import_hook() -> bool: + """Register the meta-path finder + scan already-imported modeling modules. + Returns True if the hook was installed (or already present); False if the + install was skipped (feature disabled, transformers missing, version too + old).""" + global _INSTALL_DONE + if _INSTALL_DONE: + return True + if not is_enabled(): + return False + if not _transformers_version_ok(): + warnings.warn( + "Unsloth fused-forward install skipped: requires transformers >= " + f"{'.'.join(map(str, _MIN_TRANSFORMERS))}.", + stacklevel=2, + ) + return False + if not any(isinstance(f, _ModelingFinder) for f in sys.meta_path): + sys.meta_path.insert(0, _ModelingFinder()) + # Catch modules already imported before zoo loaded. + for name in list(sys.modules): + if name.startswith("transformers.models.") and ".modeling_" in name: + mod = sys.modules.get(name) + if mod is None: + continue + try: + install_for_module(mod) + except Exception: + continue + _INSTALL_DONE = True + return True