Skip to content

Auto-install fused lm_head + cross_entropy forward (opt-in)#21

Open
danielhanchen wants to merge 5 commits into
mainfrom
pr-657-head
Open

Auto-install fused lm_head + cross_entropy forward (opt-in)#21
danielhanchen wants to merge 5 commits into
mainfrom
pr-657-head

Conversation

@danielhanchen

Copy link
Copy Markdown
Collaborator

Staging mirror of unslothai#657

Original PR: unslothai#657
Author: danielhanchen

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

Summary

Opt-in (UNSLOTH_FUSED_FORWARD=1) auto-installer that rewrites the canonical lm_head + self.loss_function triplet on every transformers *ForCausalLM / *ForConditionalGeneration whose forward matches the shape used from transformers 4.56 onwards. Skipping logits.float() over (seq_len x vocab_size) avoids the OOM that surfaced in unslothai/unsloth#5441 and shaves the bf16 logits tensor as well.

Layers

unsloth_zoo/fused_losses/forward_adapter.py
Maps the HF self.loss_function(logits=..., labels=..., vocab_size=..., **kwargs) calling convention onto unsloth_fused_ce_loss. Pops num_items_in_batch -> n_items, threads ignore_index / label_smoothing / logit_softcapping / logit_scale_multiply / logit_scale_divide, and falls back to a stock CE if the caller passes a pre-shifted shift_labels tensor (unsupported by the chunked kernel today).

unsloth_zoo/fused_losses/ast_rewriter.py
NodeTransformer that recognises the canonical triplet:

<NAME> = self.<HEAD>(<HIDDEN_EXPR>[...])
loss = None              # optional
if labels is not None:
    <LOSS> = self.loss_function(<NAME>, labels, vocab_size=..., **kwargs)

and rewrites it to call unsloth_fused_lm_head_loss(<HIDDEN_EXPR>, self.<HEAD>, labels, ...). Tolerates keyword vs positional vocab_size, .float() / [slice] chains around the lm_head call, and detects logits re-binding (e.g. Cohere's logits = logits * self.logit_scale) as a refuse signal so we never produce a broken forward.

unsloth_zoo/fused_losses/forward_install.py
Two-tier installer: (1) hash-allowlist fast path via register_canonical(hash, forward_fn) (for future hand-written canonical forwards); (2) AST triplet rewrite. Driven by a meta-path import hook that intercepts transformers.models.<X>.modeling_<X> imports and patches eligible classes as their module loads. Soft floor at transformers >= 4.56. audit() returns a JSON-safe dict of patched / unmatched / failed classes for observability.

Kernel updates

unsloth_zoo/fused_losses/cross_entropy_loss.py
compute_fused_ce_loss + UnslothFusedLoss.forward now thread ignore_index (default -100) into the label-shift step and the inner F.cross_entropy call. compute_fused_ce_loss also accepts label_smoothing. Matches HF ForCausalLMLoss semantics so callers that override either no longer silently regress. (logit_softcapping, logit_scale_multiply, logit_scale_divide were already supported.)

Test plan

  • 14 cases in tests/test_fused_forward_install.py:
    • AST rewriter accepts keyword form, positional vocab_size, .float() wrapper. Declines non-canonical, declines on logits rebinding.
    • install_for_class: noop when disabled, skips ineligible names, patches canonical, idempotent, function-override fast path, audit() snapshot.
    • Numerical equivalence on a toy CUDA model: fused loss within bf16 -> fp32 rounding noise of the reference.
    • Kernel respects ignore_index and `label_sm

This PR tracks the moving review branch (pr-657-head). Iteration fix commits land here directly. Review-added tests are in a separate PR.

Changed files:

  • .github/workflows/consolidated-tests-ci.yml
  • .github/workflows/lint-ci.yml
  • .github/workflows/mlx-ci.yml
  • .github/workflows/security-audit.yml
  • .github/workflows/stale.yml
  • .github/workflows/wheel-smoke.yml
  • unsloth_zoo/__init__.py
  • unsloth_zoo/fused_losses/__init__.py
  • unsloth_zoo/fused_losses/ast_rewriter.py
  • unsloth_zoo/fused_losses/cross_entropy_loss.py
  • unsloth_zoo/fused_losses/forward_adapter.py
  • unsloth_zoo/fused_losses/forward_install.py
  • tests/test_fused_forward_install.py

Adds an opt-in (UNSLOTH_FUSED_FORWARD=1) auto-installer that rewrites
the canonical lm_head + self.loss_function triplet on every transformers
`*ForCausalLM` / `*ForConditionalGeneration` whose forward matches the
shape used from transformers 4.56 onwards. Skipping logits.float() over
(seq_len x vocab_size) avoids the OOM that surfaced in #5441 and shaves
the bf16 logits tensor as well.

Layers:

  unsloth_zoo/fused_losses/forward_adapter.py
    Maps the HF self.loss_function(logits=..., labels=..., vocab_size=...,
    **kwargs) calling convention onto unsloth_fused_ce_loss. Pops
    num_items_in_batch -> n_items, threads ignore_index / label_smoothing /
    logit_softcapping / logit_scale_multiply / logit_scale_divide, and
    falls back to a stock CE if the caller passes a pre-shifted
    shift_labels tensor (unsupported by the chunked kernel today).

  unsloth_zoo/fused_losses/ast_rewriter.py
    NodeTransformer that recognises the canonical triplet:
        <NAME> = self.<HEAD>(<HIDDEN_EXPR>[...])
        loss = None              (optional)
        if labels is not None:
            <LOSS> = self.loss_function(<NAME>, labels, vocab_size=..., **kwargs)
    and rewrites it to call unsloth_fused_lm_head_loss(<HIDDEN_EXPR>,
    self.<HEAD>, labels, ...). Tolerates keyword vs positional vocab_size,
    `.float()` / `[slice]` chains around the lm_head call, and detects
    logits re-binding (e.g. Cohere's `logits = logits * self.logit_scale`)
    as a refuse signal so we never produce a broken forward.

  unsloth_zoo/fused_losses/forward_install.py
    Two-tier installer: (1) hash-allowlist fast path via
    register_canonical(hash, forward_fn); (2) AST triplet rewrite.
    Driven by a meta-path import hook that intercepts
    transformers.models.<X>.modeling_<X> imports and patches eligible
    classes as their module loads. Soft floor at transformers >= 4.56.
    audit() returns a JSON-safe dict of patched / unmatched / failed
    classes for observability.

Kernel updates:

  unsloth_zoo/fused_losses/cross_entropy_loss.py
    compute_fused_ce_loss + UnslothFusedLoss.forward now thread
    ignore_index (default -100) into the label-shift step and the inner
    F.cross_entropy call. compute_fused_ce_loss also accepts
    label_smoothing. Matches HF ForCausalLMLoss semantics so callers
    that override either no longer silently regress.

Tests (tests/test_fused_forward_install.py, 14 cases):
  - AST rewriter accepts keyword form, positional vocab_size, `.float()`
    wrapper. Declines non-canonical, declines on logits rebinding.
  - install_for_class: noop when disabled, skips ineligible names,
    patches canonical, idempotent, function-override fast path,
    audit() snapshot.
  - Numerical equivalence on a toy CUDA model: fused loss within
    bf16 -> fp32 rounding noise of the reference.
  - Kernel respects ignore_index and label_smoothing kwargs.

End-to-end equivalence on Llama-3.2-1B + alpaca-cleaned (seed 3407,
max_steps 10): identical step-1 loss + grad_norm, max |loss delta| =
0.005, max |grad_norm delta| = 0.025 across the run. Audit reported
19 classes patched, 0 failed when UNSLOTH_FUSED_FORWARD=1 (LlamaForCausalLM,
Qwen3ForCausalLM, MistralForCausalLM, Gemma2/3 / GemmaForCausalLM,
Mllama, DeepseekV3, Qwen3MoE / Qwen3Next, Bloom, FalconH1, etc.).

Off by default. Set UNSLOTH_FUSED_FORWARD=1 to opt in.
Forwards routed through unsloth_compiled_cache see __globals__ for the
cached module, which does not always re-import the HF output dataclass
the original modeling file referenced (e.g. Gemma3ForCausalLM's return
statement uses CausalLMOutputWithPast). Populate the exec namespace
with everything from transformers.modeling_outputs as a fallback so
the rewritten forward links cleanly.

Caught during multi-model equivalence run (Gemma3-1B fused) which now
matches the stock path bit-for-bit alongside Llama, Qwen3, Phi3, and
Mistral.
@danielhanchen

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an opt-in auto-installer for fused lm_head and cross_entropy losses, targeting Hugging Face ForCausalLM and ForConditionalGeneration models. The implementation uses an AST-level rewriter to replace standard projection and loss calls with a fused kernel, integrated via import hooks for automatic patching. The fused cross-entropy loss has also been updated to support ignore_index and label_smoothing. Feedback was provided regarding a redundant import torch statement in the forward adapter.

Comment on lines +83 to +84
import torch
logits = torch.nn.functional.linear(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The import torch statement is redundant here as torch is already imported at the top of the file (line 37).

        logits = torch.nn.functional.linear(

forward_adapter.py
- shift_labels fallback now uses reduction=sum and divides by n_items
  when num_items_in_batch is supplied, matching HF ForCausalLMLoss
  gradient-accumulation scaling.
- shift_labels=False (bool) now routes to the same stock-CE fallback
  instead of leaking through to the always-shifting fused kernel.
- Removed redundant inner import torch.

cross_entropy_loss.py
- Promote a non-tensor n_items divisor (HF trainers pass a Python int
  via gradient accumulation) to a scalar tensor before the existing
  DataParallel .numel()/.ravel() guard, which is preserved verbatim.
  Without the promotion an int n_items raises AttributeError inside
  the autograd forward.

ast_rewriter.py
- Capture the full lm_head RHS (e.g. .float()/.contiguous()/[slice])
  and emit it in the else-branch so the inference path keeps its
  original dtype/shape semantics.
- Only strip docstring-only decorators (auto_docstring,
  add_start_docstrings*, add_end_docstrings, replace_return_docstrings);
  @can_return_tuple carries return_dict=False semantics and stays.
- Reject forwards with non-empty else, multi-statement labels branches,
  or aliased labels arguments (CSM-style depth-decoder loss survives
  intact rather than being silently dropped).
- Reject forwards where any statement between the lm_head assign and
  the labels-if mutates or reads logits (Gemma3 final_logit_softcapping
  used to be silently bypassed by the fused-loss path).
- Forward explicit loss_function keywords beyond vocab_size (Bloom
  passes num_items_in_batch=kwargs.get(...) without a **kwargs unpack).
- _find_loss_function_call / _find_loss_assign_target now inspect only
  the direct if-body, so a nested guard inside the labels branch is
  not silently dropped.

forward_install.py
- Drop *ForConditionalGeneration from auto-install eligibility (the
  fused kernel hardcodes a causal shift; aligned-label seq2seq losses
  would be off-by-one).
- Skip composite/non-linear heads via a _LINEAR_HEAD_ATTRS allowlist
  so BigBird-style self.cls(...) (BigBirdOnlyMLMHead) is not patched.
- install_for_class / install_for_module now also gate on the
  transformers version floor, matching install_modeling_import_hook.
- Inject transformers.utils.generic.can_return_tuple into the exec
  namespace so the preserved decorator resolves at runtime.
Compress narrative docstrings and inline rationale blocks across
fused_losses/* and the __init__.py opt-in stanza. Load-bearing notes
(@can_return_tuple semantics, Gemma3 softcap reasoning, BigBird
composite-head guard, transformers >= 4.56 floor, ForCausalLM-only
eligibility) are preserved; only WHAT-restating prose was removed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant