Auto-install fused lm_head + cross_entropy forward (opt-in)#21
Open
danielhanchen wants to merge 5 commits into
Open
Auto-install fused lm_head + cross_entropy forward (opt-in)#21danielhanchen wants to merge 5 commits into
danielhanchen wants to merge 5 commits into
Conversation
Adds an opt-in (UNSLOTH_FUSED_FORWARD=1) auto-installer that rewrites
the canonical lm_head + self.loss_function triplet on every transformers
`*ForCausalLM` / `*ForConditionalGeneration` whose forward matches the
shape used from transformers 4.56 onwards. Skipping logits.float() over
(seq_len x vocab_size) avoids the OOM that surfaced in #5441 and shaves
the bf16 logits tensor as well.
Layers:
unsloth_zoo/fused_losses/forward_adapter.py
Maps the HF self.loss_function(logits=..., labels=..., vocab_size=...,
**kwargs) calling convention onto unsloth_fused_ce_loss. Pops
num_items_in_batch -> n_items, threads ignore_index / label_smoothing /
logit_softcapping / logit_scale_multiply / logit_scale_divide, and
falls back to a stock CE if the caller passes a pre-shifted
shift_labels tensor (unsupported by the chunked kernel today).
unsloth_zoo/fused_losses/ast_rewriter.py
NodeTransformer that recognises the canonical triplet:
<NAME> = self.<HEAD>(<HIDDEN_EXPR>[...])
loss = None (optional)
if labels is not None:
<LOSS> = self.loss_function(<NAME>, labels, vocab_size=..., **kwargs)
and rewrites it to call unsloth_fused_lm_head_loss(<HIDDEN_EXPR>,
self.<HEAD>, labels, ...). Tolerates keyword vs positional vocab_size,
`.float()` / `[slice]` chains around the lm_head call, and detects
logits re-binding (e.g. Cohere's `logits = logits * self.logit_scale`)
as a refuse signal so we never produce a broken forward.
unsloth_zoo/fused_losses/forward_install.py
Two-tier installer: (1) hash-allowlist fast path via
register_canonical(hash, forward_fn); (2) AST triplet rewrite.
Driven by a meta-path import hook that intercepts
transformers.models.<X>.modeling_<X> imports and patches eligible
classes as their module loads. Soft floor at transformers >= 4.56.
audit() returns a JSON-safe dict of patched / unmatched / failed
classes for observability.
Kernel updates:
unsloth_zoo/fused_losses/cross_entropy_loss.py
compute_fused_ce_loss + UnslothFusedLoss.forward now thread
ignore_index (default -100) into the label-shift step and the inner
F.cross_entropy call. compute_fused_ce_loss also accepts
label_smoothing. Matches HF ForCausalLMLoss semantics so callers
that override either no longer silently regress.
Tests (tests/test_fused_forward_install.py, 14 cases):
- AST rewriter accepts keyword form, positional vocab_size, `.float()`
wrapper. Declines non-canonical, declines on logits rebinding.
- install_for_class: noop when disabled, skips ineligible names,
patches canonical, idempotent, function-override fast path,
audit() snapshot.
- Numerical equivalence on a toy CUDA model: fused loss within
bf16 -> fp32 rounding noise of the reference.
- Kernel respects ignore_index and label_smoothing kwargs.
End-to-end equivalence on Llama-3.2-1B + alpaca-cleaned (seed 3407,
max_steps 10): identical step-1 loss + grad_norm, max |loss delta| =
0.005, max |grad_norm delta| = 0.025 across the run. Audit reported
19 classes patched, 0 failed when UNSLOTH_FUSED_FORWARD=1 (LlamaForCausalLM,
Qwen3ForCausalLM, MistralForCausalLM, Gemma2/3 / GemmaForCausalLM,
Mllama, DeepseekV3, Qwen3MoE / Qwen3Next, Bloom, FalconH1, etc.).
Off by default. Set UNSLOTH_FUSED_FORWARD=1 to opt in.
Forwards routed through unsloth_compiled_cache see __globals__ for the cached module, which does not always re-import the HF output dataclass the original modeling file referenced (e.g. Gemma3ForCausalLM's return statement uses CausalLMOutputWithPast). Populate the exec namespace with everything from transformers.modeling_outputs as a fallback so the rewritten forward links cleanly. Caught during multi-model equivalence run (Gemma3-1B fused) which now matches the stock path bit-for-bit alongside Llama, Qwen3, Phi3, and Mistral.
Collaborator
Author
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces an opt-in auto-installer for fused lm_head and cross_entropy losses, targeting Hugging Face ForCausalLM and ForConditionalGeneration models. The implementation uses an AST-level rewriter to replace standard projection and loss calls with a fused kernel, integrated via import hooks for automatic patching. The fused cross-entropy loss has also been updated to support ignore_index and label_smoothing. Feedback was provided regarding a redundant import torch statement in the forward adapter.
Comment on lines
+83
to
+84
| import torch | ||
| logits = torch.nn.functional.linear( |
forward_adapter.py - shift_labels fallback now uses reduction=sum and divides by n_items when num_items_in_batch is supplied, matching HF ForCausalLMLoss gradient-accumulation scaling. - shift_labels=False (bool) now routes to the same stock-CE fallback instead of leaking through to the always-shifting fused kernel. - Removed redundant inner import torch. cross_entropy_loss.py - Promote a non-tensor n_items divisor (HF trainers pass a Python int via gradient accumulation) to a scalar tensor before the existing DataParallel .numel()/.ravel() guard, which is preserved verbatim. Without the promotion an int n_items raises AttributeError inside the autograd forward. ast_rewriter.py - Capture the full lm_head RHS (e.g. .float()/.contiguous()/[slice]) and emit it in the else-branch so the inference path keeps its original dtype/shape semantics. - Only strip docstring-only decorators (auto_docstring, add_start_docstrings*, add_end_docstrings, replace_return_docstrings); @can_return_tuple carries return_dict=False semantics and stays. - Reject forwards with non-empty else, multi-statement labels branches, or aliased labels arguments (CSM-style depth-decoder loss survives intact rather than being silently dropped). - Reject forwards where any statement between the lm_head assign and the labels-if mutates or reads logits (Gemma3 final_logit_softcapping used to be silently bypassed by the fused-loss path). - Forward explicit loss_function keywords beyond vocab_size (Bloom passes num_items_in_batch=kwargs.get(...) without a **kwargs unpack). - _find_loss_function_call / _find_loss_assign_target now inspect only the direct if-body, so a nested guard inside the labels branch is not silently dropped. forward_install.py - Drop *ForConditionalGeneration from auto-install eligibility (the fused kernel hardcodes a causal shift; aligned-label seq2seq losses would be off-by-one). - Skip composite/non-linear heads via a _LINEAR_HEAD_ATTRS allowlist so BigBird-style self.cls(...) (BigBirdOnlyMLMHead) is not patched. - install_for_class / install_for_module now also gate on the transformers version floor, matching install_modeling_import_hook. - Inject transformers.utils.generic.can_return_tuple into the exec namespace so the preserved decorator resolves at runtime.
Compress narrative docstrings and inline rationale blocks across fused_losses/* and the __init__.py opt-in stanza. Load-bearing notes (@can_return_tuple semantics, Gemma3 softcap reasoning, BigBird composite-head guard, transformers >= 4.56 floor, ForCausalLM-only eligibility) are preserved; only WHAT-restating prose was removed.
7a6a8fe to
eb4f63f
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Staging mirror of unslothai#657
Original PR: unslothai#657
Author: danielhanchen
This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.
Original description
Summary
Opt-in (
UNSLOTH_FUSED_FORWARD=1) auto-installer that rewrites the canonical lm_head +self.loss_functiontriplet on every transformers*ForCausalLM/*ForConditionalGenerationwhose forward matches the shape used from transformers 4.56 onwards. Skippinglogits.float()over(seq_len x vocab_size)avoids the OOM that surfaced in unslothai/unsloth#5441 and shaves the bf16 logits tensor as well.Layers
unsloth_zoo/fused_losses/forward_adapter.pyMaps the HF
self.loss_function(logits=..., labels=..., vocab_size=..., **kwargs)calling convention ontounsloth_fused_ce_loss. Popsnum_items_in_batch->n_items, threadsignore_index/label_smoothing/logit_softcapping/logit_scale_multiply/logit_scale_divide, and falls back to a stock CE if the caller passes a pre-shiftedshift_labelstensor (unsupported by the chunked kernel today).unsloth_zoo/fused_losses/ast_rewriter.pyNodeTransformer that recognises the canonical triplet:
and rewrites it to call
unsloth_fused_lm_head_loss(<HIDDEN_EXPR>, self.<HEAD>, labels, ...). Tolerates keyword vs positionalvocab_size,.float()/[slice]chains around the lm_head call, and detects logits re-binding (e.g. Cohere'slogits = logits * self.logit_scale) as a refuse signal so we never produce a broken forward.unsloth_zoo/fused_losses/forward_install.pyTwo-tier installer: (1) hash-allowlist fast path via
register_canonical(hash, forward_fn)(for future hand-written canonical forwards); (2) AST triplet rewrite. Driven by a meta-path import hook that interceptstransformers.models.<X>.modeling_<X>imports and patches eligible classes as their module loads. Soft floor at transformers >= 4.56.audit()returns a JSON-safe dict of patched / unmatched / failed classes for observability.Kernel updates
unsloth_zoo/fused_losses/cross_entropy_loss.pycompute_fused_ce_loss+UnslothFusedLoss.forwardnow threadignore_index(default-100) into the label-shift step and the innerF.cross_entropycall.compute_fused_ce_lossalso acceptslabel_smoothing. Matches HFForCausalLMLosssemantics so callers that override either no longer silently regress. (logit_softcapping,logit_scale_multiply,logit_scale_dividewere already supported.)Test plan
tests/test_fused_forward_install.py:vocab_size,.float()wrapper. Declines non-canonical, declines on logits rebinding.install_for_class: noop when disabled, skips ineligible names, patches canonical, idempotent, function-override fast path,audit()snapshot.ignore_indexand `label_smThis PR tracks the moving review branch (pr-657-head). Iteration fix commits land here directly. Review-added tests are in a separate PR.
Changed files:
.github/workflows/consolidated-tests-ci.yml.github/workflows/lint-ci.yml.github/workflows/mlx-ci.yml.github/workflows/security-audit.yml.github/workflows/stale.yml.github/workflows/wheel-smoke.ymlunsloth_zoo/__init__.pyunsloth_zoo/fused_losses/__init__.pyunsloth_zoo/fused_losses/ast_rewriter.pyunsloth_zoo/fused_losses/cross_entropy_loss.pyunsloth_zoo/fused_losses/forward_adapter.pyunsloth_zoo/fused_losses/forward_install.pytests/test_fused_forward_install.py