Skip to content

Auto-install fused lm_head + cross_entropy forward (opt-in)#657

Merged
danielhanchen merged 10 commits into
mainfrom
daniel/fused-forward-installer
May 17, 2026
Merged

Auto-install fused lm_head + cross_entropy forward (opt-in)#657
danielhanchen merged 10 commits into
mainfrom
daniel/fused-forward-installer

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

Opt-in (UNSLOTH_FUSED_FORWARD=1) auto-installer that rewrites the canonical lm_head + self.loss_function triplet on every transformers *ForCausalLM / *ForConditionalGeneration whose forward matches the shape used from transformers 4.56 onwards. Skipping logits.float() over (seq_len x vocab_size) avoids the OOM that surfaced in unslothai/unsloth#5441 and shaves the bf16 logits tensor as well.

Layers

unsloth_zoo/fused_losses/forward_adapter.py
Maps the HF self.loss_function(logits=..., labels=..., vocab_size=..., **kwargs) calling convention onto unsloth_fused_ce_loss. Pops num_items_in_batch -> n_items, threads ignore_index / label_smoothing / logit_softcapping / logit_scale_multiply / logit_scale_divide, and falls back to a stock CE if the caller passes a pre-shifted shift_labels tensor (unsupported by the chunked kernel today).

unsloth_zoo/fused_losses/ast_rewriter.py
NodeTransformer that recognises the canonical triplet:

<NAME> = self.<HEAD>(<HIDDEN_EXPR>[...])
loss = None              # optional
if labels is not None:
    <LOSS> = self.loss_function(<NAME>, labels, vocab_size=..., **kwargs)

and rewrites it to call unsloth_fused_lm_head_loss(<HIDDEN_EXPR>, self.<HEAD>, labels, ...). Tolerates keyword vs positional vocab_size, .float() / [slice] chains around the lm_head call, and detects logits re-binding (e.g. Cohere's logits = logits * self.logit_scale) as a refuse signal so we never produce a broken forward.

unsloth_zoo/fused_losses/forward_install.py
Two-tier installer: (1) hash-allowlist fast path via register_canonical(hash, forward_fn) (for future hand-written canonical forwards); (2) AST triplet rewrite. Driven by a meta-path import hook that intercepts transformers.models.<X>.modeling_<X> imports and patches eligible classes as their module loads. Soft floor at transformers >= 4.56. audit() returns a JSON-safe dict of patched / unmatched / failed classes for observability.

Kernel updates

unsloth_zoo/fused_losses/cross_entropy_loss.py
compute_fused_ce_loss + UnslothFusedLoss.forward now thread ignore_index (default -100) into the label-shift step and the inner F.cross_entropy call. compute_fused_ce_loss also accepts label_smoothing. Matches HF ForCausalLMLoss semantics so callers that override either no longer silently regress. (logit_softcapping, logit_scale_multiply, logit_scale_divide were already supported.)

Test plan

  • 14 cases in tests/test_fused_forward_install.py:
    • AST rewriter accepts keyword form, positional vocab_size, .float() wrapper. Declines non-canonical, declines on logits rebinding.
    • install_for_class: noop when disabled, skips ineligible names, patches canonical, idempotent, function-override fast path, audit() snapshot.
    • Numerical equivalence on a toy CUDA model: fused loss within bf16 -> fp32 rounding noise of the reference.
    • Kernel respects ignore_index and label_smoothing kwargs.
  • End-to-end equivalence run on unsloth/Llama-3.2-1B-Instruct + yahma/alpaca-cleaned, seed 3407, max_steps=10:
step loss (no fuse) loss (fuse) abs delta grad_norm (no fuse) grad_norm (fuse) abs delta
1 1.45730 1.45730 0.00000 2.79364 2.79364 0.00000
2 1.52950 1.52950 0.00000 1.29600 1.29597 0.00004
3 1.80210 1.79980 0.00230 1.88810 1.88224 0.00586
4 1.82610 1.82740 0.00130 2.77398 2.78684 0.01287
5 1.38150 1.38280 0.00130 1.43915 1.44090 0.00175
6 1.08200 1.08590 0.00390 1.66032 1.66319 0.00287
7 2.17880 2.18380 0.00500 3.54804 3.52294 0.02510
8 1.24400 1.24220 0.00180 1.34056 1.33508 0.00548
9 1.01660 1.01400 0.00260 2.00057 2.00310 0.00253
10 1.59860 1.59960 0.00100 1.65028 1.67345 0.02317

Step 1 loss and grad norm are bitwise identical. Across the run: max |loss delta| = 0.005, max |grad_norm delta| = 0.025 - both within bf16 -> fp32 chunked-CE rounding noise.

  • audit() after import with the flag on (Llama / Qwen3 / Mistral / Gemma3 / DeepseekV3 / Qwen3MoE / Bloom / FalconH1 / Mllama / Csm / Lfm2Vl / Qwen3VLMoe and 7 more): 19 classes patched, 0 failed, 6 unmatched (Cohere, Gemma3 VLM heads, GraniteMoeHybrid, CsmDepthDecoder - all expected outliers; LOSS_MAPPING patch in Patch every LOSS_MAPPING key aliased to ForCausalLMLoss #656 backstops them).

Activation

Off by default. Set UNSLOTH_FUSED_FORWARD=1 to opt in. When on, fused install runs at import unsloth_zoo; new transformers modeling modules imported afterwards are patched via a meta-path hook.

from unsloth_zoo.fused_losses import audit; audit() dumps the patched / unmatched / failed registry for debugging.

Related: unslothai/unsloth#5441, #656.

Adds an opt-in (UNSLOTH_FUSED_FORWARD=1) auto-installer that rewrites
the canonical lm_head + self.loss_function triplet on every transformers
`*ForCausalLM` / `*ForConditionalGeneration` whose forward matches the
shape used from transformers 4.56 onwards. Skipping logits.float() over
(seq_len x vocab_size) avoids the OOM that surfaced in #5441 and shaves
the bf16 logits tensor as well.

Layers:

  unsloth_zoo/fused_losses/forward_adapter.py
    Maps the HF self.loss_function(logits=..., labels=..., vocab_size=...,
    **kwargs) calling convention onto unsloth_fused_ce_loss. Pops
    num_items_in_batch -> n_items, threads ignore_index / label_smoothing /
    logit_softcapping / logit_scale_multiply / logit_scale_divide, and
    falls back to a stock CE if the caller passes a pre-shifted
    shift_labels tensor (unsupported by the chunked kernel today).

  unsloth_zoo/fused_losses/ast_rewriter.py
    NodeTransformer that recognises the canonical triplet:
        <NAME> = self.<HEAD>(<HIDDEN_EXPR>[...])
        loss = None              (optional)
        if labels is not None:
            <LOSS> = self.loss_function(<NAME>, labels, vocab_size=..., **kwargs)
    and rewrites it to call unsloth_fused_lm_head_loss(<HIDDEN_EXPR>,
    self.<HEAD>, labels, ...). Tolerates keyword vs positional vocab_size,
    `.float()` / `[slice]` chains around the lm_head call, and detects
    logits re-binding (e.g. Cohere's `logits = logits * self.logit_scale`)
    as a refuse signal so we never produce a broken forward.

  unsloth_zoo/fused_losses/forward_install.py
    Two-tier installer: (1) hash-allowlist fast path via
    register_canonical(hash, forward_fn); (2) AST triplet rewrite.
    Driven by a meta-path import hook that intercepts
    transformers.models.<X>.modeling_<X> imports and patches eligible
    classes as their module loads. Soft floor at transformers >= 4.56.
    audit() returns a JSON-safe dict of patched / unmatched / failed
    classes for observability.

Kernel updates:

  unsloth_zoo/fused_losses/cross_entropy_loss.py
    compute_fused_ce_loss + UnslothFusedLoss.forward now thread
    ignore_index (default -100) into the label-shift step and the inner
    F.cross_entropy call. compute_fused_ce_loss also accepts
    label_smoothing. Matches HF ForCausalLMLoss semantics so callers
    that override either no longer silently regress.

Tests (tests/test_fused_forward_install.py, 14 cases):
  - AST rewriter accepts keyword form, positional vocab_size, `.float()`
    wrapper. Declines non-canonical, declines on logits rebinding.
  - install_for_class: noop when disabled, skips ineligible names,
    patches canonical, idempotent, function-override fast path,
    audit() snapshot.
  - Numerical equivalence on a toy CUDA model: fused loss within
    bf16 -> fp32 rounding noise of the reference.
  - Kernel respects ignore_index and label_smoothing kwargs.

End-to-end equivalence on Llama-3.2-1B + alpaca-cleaned (seed 3407,
max_steps 10): identical step-1 loss + grad_norm, max |loss delta| =
0.005, max |grad_norm delta| = 0.025 across the run. Audit reported
19 classes patched, 0 failed when UNSLOTH_FUSED_FORWARD=1 (LlamaForCausalLM,
Qwen3ForCausalLM, MistralForCausalLM, Gemma2/3 / GemmaForCausalLM,
Mllama, DeepseekV3, Qwen3MoE / Qwen3Next, Bloom, FalconH1, etc.).

Off by default. Set UNSLOTH_FUSED_FORWARD=1 to opt in.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an opt-in auto-installer for fused lm_head and cross_entropy losses, utilizing an AST-level rewriter and import hooks to patch transformers models. The implementation includes updates to the fused cross-entropy kernel to support ignore_index and label_smoothing, along with a comprehensive test suite. Review feedback suggests refining the exception handling during installation to improve visibility and reconsidering the aggressive stripping of decorators in the AST rewriter to avoid potential side effects on model logic.

Comment thread unsloth_zoo/__init__.py
Comment on lines +391 to +396
try:
from .fused_losses.forward_install import install_modeling_import_hook as _install_fused_forward
_install_fused_forward()
del _install_fused_forward
except Exception:
pass

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The broad try...except Exception: pass block around the fused forward installation can make debugging difficult if the installer fails for unexpected reasons. It is recommended to at least print the exception to aid in troubleshooting, especially for visibility in 'studio' environments, as this is an opt-in feature that users might want to verify.

References
  1. Use print instead of logger.info for messages that must be visible in 'studio' when working with llama.cpp, as logger.info messages may be filtered out.

# that may not have them visible. The decorators only add docstring
# sugar / tuple-return handling and are not needed for the runtime
# forward we install.
fn.decorator_list = []

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Stripping all decorators from the rewritten forward function might lead to unexpected behavior if the model relies on functional decorators (e.g., for compilation hints, custom logic, or hooks). While many transformers decorators are docstring-related, a more selective approach or a clear justification for why this is safe for all supported models would be preferable. If the intent is to avoid issues with decorators not being present in the exec namespace, note that ns is already initialized with the original function's globals.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 680c9a3788

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

)
{logits} = EMPTY_LOGITS
else:
{logits} = self.{head_attr}({hidden_src})

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve wrappers in the labels-none branch

When the matched original assignment is wrapped, e.g. lm_logits = self.lm_head(hidden_states).float() as covered by the new positional test fixture, the rewritten labels is None path reconstructs only self.<head>(hidden) and drops the wrapper. That means opt-in patched models return different generation/eval logits (dtype/device/view transformations such as .float(), .contiguous(), or post-call slicing are lost) even though the installer claims the no-labels path is untouched; store and reuse the full original RHS for the else branch or decline these matches.

Useful? React with 👍 / 👎.

# that may not have them visible. The decorators only add docstring
# sugar / tuple-return handling and are not needed for the runtime
# forward we install.
fn.decorator_list = []

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep tuple-return decorators on rewritten forwards

Stripping all decorators removes runtime behavior from Transformers forwards, not just docstring sugar. In current Transformers, can_return_tuple pops/uses return_dict and converts a ModelOutput to output.to_tuple() when return_dict=False; after this rewrite, patched classes with that decorator will ignore the standard return_dict=False API and return a dataclass instead of a tuple. Since the exec namespace is copied from the original globals, preserve/reapply runtime decorators such as can_return_tuple or only strip known documentation-only decorators.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/fused_losses/forward_install.py Fixed

from __future__ import annotations

import os
from __future__ import annotations

import os
import sys

import os
import sys
import types
Comment thread unsloth_zoo/__init__.py Fixed
Comment thread unsloth_zoo/fused_losses/forward_install.py Fixed
Comment thread unsloth_zoo/fused_losses/forward_install.py Fixed
Forwards routed through unsloth_compiled_cache see __globals__ for the
cached module, which does not always re-import the HF output dataclass
the original modeling file referenced (e.g. Gemma3ForCausalLM's return
statement uses CausalLMOutputWithPast). Populate the exec namespace
with everything from transformers.modeling_outputs as a fallback so
the rewritten forward links cleanly.

Caught during multi-model equivalence run (Gemma3-1B fused) which now
matches the stock path bit-for-bit alongside Llama, Qwen3, Phi3, and
Mistral.
@danielhanchen

Copy link
Copy Markdown
Member Author

Multi-model equivalence run

Ran the same max_steps=10, seed 3407, alpaca-cleaned LoRA SFT across five model families to validate the fused-forward path. Each row is a pair of runs with UNSLOTH_FUSED_FORWARD=0 vs 1 on identical hyperparameters.

Model step1 loss equal max abs loss delta max abs grad_norm delta audit n_patched
unsloth/Llama-3.2-1B-Instruct yes 0.005000 0.025100 19
unsloth/Qwen3-0.6B yes 0.007300 0.100783 19
unsloth/gemma-3-1b-it yes 0.000000 0.000000 19
unsloth/Phi-3-mini-4k-instruct yes 0.008000 0.019985 19
unsloth/Mistral-7B-Instruct-v0.3 yes 0.003900 0.047178 19

Notes:

  • Step 1 loss + grad_norm are bitwise identical for every model (sanity check that the rewrite was loaded and the seed is honoured before the first chunked-CE backward pass).
  • Llama / Qwen3 / Phi3 / Mistral show deltas within bf16 -> fp32 chunked-CE rounding noise on subsequent steps.
  • Gemma3 is 0.000 across the board because Gemma3's forward goes through unsloth_compiled_cache.unsloth_compiled_module_gemma3, which compiler.py already routes through fused CE. Our class-level patch is shadowed by the compiled-cache forward and is a clean no-op there.
  • Qwen3 step 7 had a single grad_norm spike (4.18 vs 4.08, delta 0.10) at a chunk boundary; loss delta on the same step is 0.0006. The model still converges to within 0.27% of the no-fuse path by step 10.

Also fixed a NameError: CausalLMOutputWithPast that surfaced on Gemma3 (forwards routed through unsloth_compiled_cache see a different __globals__ than the original modeling module). Commit 920bea4 backfills transformers.modeling_outputs into the exec namespace.

Artifacts:

  • outputs/fused_forward_multimodel_summary.json (per-pair table).
  • outputs/equiv/{model}_{off,on}.json (per-step loss + grad_norm + audit dump).
  • scripts/fused_forward_equivalence_run.py (parameterised runner: python scripts/fused_forward_equivalence_run.py outputs/run.json <model_name>).

Comment thread unsloth_zoo/fused_losses/forward_install.py Fixed
Comment thread unsloth_zoo/fused_losses/forward_install.py Fixed
Comment thread unsloth_zoo/fused_losses/forward_install.py Fixed
@danielhanchen

Copy link
Copy Markdown
Member Author

This PR appears to address open issue(s). The duplicate detector matched the following open issues with HIGH confidence:

  • unslothai/unsloth#5230@allcodernet — PR extends fused CE loss to honor ignore_index/masking, directly addressing zero gradients with sparse last-assistant-only labels.
  • unslothai/unsloth#2253@rupaut98 — Issue is in Unsloth fused cross-entropy backward/lm_head path; PR modifies fused CE plumbing and fused lm_head loss installation.
  • unslothai/unsloth#1801@RWTHEY — Issue reports eval/training VRAM spikes from logits memory; PR fuses lm_head plus CE to avoid materializing full logits.
  • unslothai/unsloth#1147@tommedema — PR’s fused lm_head/cross-entropy avoids materializing full eval logits, directly targeting validation CUDA OOM during SFTTrainer evaluation.

If this PR fixes any of them, consider adding closes #N / resolves #N to the description so the issue auto-closes on merge. If the match is wrong, ignore this comment.

@danielhanchen danielhanchen added the auto-reviewing Auto-review in progress label May 16, 2026
forward_adapter.py
- shift_labels fallback now uses reduction=sum and divides by n_items
  when num_items_in_batch is supplied, matching HF ForCausalLMLoss
  gradient-accumulation scaling.
- shift_labels=False (bool) now routes to the same stock-CE fallback
  instead of leaking through to the always-shifting fused kernel.
- Removed redundant inner import torch.

cross_entropy_loss.py
- Promote a non-tensor n_items divisor (HF trainers pass a Python int
  via gradient accumulation) to a scalar tensor before the existing
  DataParallel .numel()/.ravel() guard, which is preserved verbatim.
  Without the promotion an int n_items raises AttributeError inside
  the autograd forward.

ast_rewriter.py
- Capture the full lm_head RHS (e.g. .float()/.contiguous()/[slice])
  and emit it in the else-branch so the inference path keeps its
  original dtype/shape semantics.
- Only strip docstring-only decorators (auto_docstring,
  add_start_docstrings*, add_end_docstrings, replace_return_docstrings);
  @can_return_tuple carries return_dict=False semantics and stays.
- Reject forwards with non-empty else, multi-statement labels branches,
  or aliased labels arguments (CSM-style depth-decoder loss survives
  intact rather than being silently dropped).
- Reject forwards where any statement between the lm_head assign and
  the labels-if mutates or reads logits (Gemma3 final_logit_softcapping
  used to be silently bypassed by the fused-loss path).
- Forward explicit loss_function keywords beyond vocab_size (Bloom
  passes num_items_in_batch=kwargs.get(...) without a **kwargs unpack).
- _find_loss_function_call / _find_loss_assign_target now inspect only
  the direct if-body, so a nested guard inside the labels branch is
  not silently dropped.

forward_install.py
- Drop *ForConditionalGeneration from auto-install eligibility (the
  fused kernel hardcodes a causal shift; aligned-label seq2seq losses
  would be off-by-one).
- Skip composite/non-linear heads via a _LINEAR_HEAD_ATTRS allowlist
  so BigBird-style self.cls(...) (BigBirdOnlyMLMHead) is not patched.
- install_for_class / install_for_module now also gate on the
  transformers version floor, matching install_modeling_import_hook.
- Inject transformers.utils.generic.can_return_tuple into the exec
  namespace so the preserved decorator resolves at runtime.
Compress narrative docstrings and inline rationale blocks across
fused_losses/* and the __init__.py opt-in stanza. Load-bearing notes
(@can_return_tuple semantics, Gemma3 softcap reasoning, BigBird
composite-head guard, transformers >= 4.56 floor, ForCausalLM-only
eligibility) are preserved; only WHAT-restating prose was removed.
@danielhanchen danielhanchen added auto-approved Auto-review approved the PR and removed auto-reviewing Auto-review in progress labels May 16, 2026
@danielhanchen

Copy link
Copy Markdown
Member Author

Auto-review verdict: Approved

PR #657 adds an opt-in (UNSLOTH_FUSED_FORWARD=1) AST-level rewriter that intercepts transformers *ForCausalLM imports and routes loss through Unsloth's chunked fused lm_head+cross_entropy kernel, skipping the bf16 logits and fp32 cast for measurable VRAM savings during training. Review-hardened the rewriter against real model shapes (Bloom, CSM, Gemma3 softcap, BigBird composite head, T5Gemma2) and the adapter against integer num_items_in_batch and pre-shifted-label scaling, so it now correctly rewrites canonical forwards and cleanly refuses the rest.

Reason: All 12 real issues identified during review were fixed; tests pass; remaining items were design choices or hypothetical edge cases.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c33abf4d8e

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

tgt = stmt.targets[0]
if not (isinstance(tgt, ast.Name) and tgt.id == logits_name):
continue
inner = _find_inner_self_call(stmt.value)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Refuse transformed lm_head expressions before fusing

When a matching forward applies a logits transform in the same assignment, e.g. logits = self.lm_head(hidden_states) * self.logit_scale or an inline softcap, this still captures the inner self.lm_head call and rewrites the labels branch to call unsloth_fused_lm_head_loss on the unmodified linear output. The no-labels branch keeps the full original RHS, so only training loss silently drops the transform; this should either decline anything beyond known-safe wrappers such as .float()/.contiguous()/slicing or encode the transform in the fused call.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/fused_losses/forward_install.py Fixed
Comment thread unsloth_zoo/__init__.py
from .fused_losses.forward_install import install_modeling_import_hook as _install_fused_forward
_install_fused_forward()
del _install_fused_forward
except Exception:
Comment thread unsloth_zoo/fused_losses/forward_install.py Fixed
Comment thread unsloth_zoo/fused_losses/forward_install.py Fixed
Comment thread unsloth_zoo/fused_losses/forward_install.py Fixed
Comment thread unsloth_zoo/fused_losses/forward_install.py Fixed
Cover the eight semantic fixes that landed in commit db90fa1 so
regressions are caught at test time rather than at training time:

  - test_ast_rewriter_declines_when_intermediate_touches_logits
    Gemma final_logit_softcapping between lm_head and the labels-if
    must not be silently bypassed.
  - test_ast_rewriter_declines_when_labels_aliased
    CSM-style `loss = self.loss_function(..., labels=backbone_labels)`
    on an `if labels is not None:` gate must refuse.
  - test_ast_rewriter_declines_non_trivial_labels_branch
    MoE-style auxiliary loss inside the labels branch must refuse so
    aux_loss + router_aux_loss_coef stays intact.
  - test_ast_rewriter_forwards_explicit_extra_kwargs
    Bloom-style `num_items_in_batch=kwargs.get(...)` without **kwargs
    must reach the kernel.
  - test_install_skips_for_conditional_generation
    *ForConditionalGeneration uses aligned labels; auto-install must
    skip.
  - test_install_skips_composite_head
    BigBird-style `self.cls(...)` composite head must hit the
    _LINEAR_HEAD_ATTRS allowlist and log as non-linear-head.
  - test_fused_kernel_accepts_int_n_items
    HF Trainer grad-accum passes a Python int divisor; kernel must
    promote it to a scalar tensor before the DataParallel guard.
  - test_adapter_falls_back_when_shift_labels_false
    `shift_labels=False` bool must route through stock CE; the fused
    kernel always re-shifts.

All 22 tests pass (14 original + 8 new).

Multi-model end-to-end equivalence rerun against the post-review tree
(seed 3407, max_steps=10, alpaca-cleaned):

  model            s1 eq  max|loss d|  max|grad d|  n_patched
  Llama-3.2-1B      True      0.00450      0.01276         11
  Qwen3-0.6B        True      0.00490      0.07686         11
  Gemma-3-1B        True      0.00000      0.00000         11
  Mistral-7B-v.3    True      0.00370      0.03093         11

Step 1 loss + grad_norm are bitwise identical for every model; n_patched
dropped from 19 -> 11 because ConditionalGeneration + Gemma2/3 (logits
touched by softcap) + BigBird (composite head) are now correctly skipped.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ab283b9503

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

{hidden_src}, self.{head_attr}, labels,
vocab_size={vocab}{extra}{kwargs_unpack},
)
{logits} = EMPTY_LOGITS

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Honor UNSLOTH_RETURN_LOGITS in fused forward

When this opt-in installer patches a model, any training/eval call with labels now always returns EMPTY_LOGITS, even if the user set the repository's documented UNSLOTH_RETURN_LOGITS=1 escape hatch for metric computation or logit inspection. The existing compiler path conditionally materializes logits under that env var, so this new path makes the same configured run lose logits and can break Trainer prediction/metrics flows that require them; this assignment should mirror that conditional or otherwise skip the fused shortcut when logits are requested.

Useful? React with 👍 / 👎.

is_enabled() now returns True unless UNSLOTH_FUSED_FORWARD is explicitly
set to "0". Updated docstrings and the __init__.py comment to reflect the
new default. The two-tier installer + LOSS_MAPPING backstop in #656 means
the worst case for any class we touch is no-op (refused via _UNMATCHED
or composite-head guard) -- never a worse forward than the stock path.

Test suite (23 cases, was 22 + new test_install_default_is_on): all
green. Refresh of the multi-model equivalence rerun with no env var set
versus UNSLOTH_FUSED_FORWARD=0 (Llama-3.2-1B / Qwen3-0.6B / Gemma-3-1B /
Mistral-7B-v0.3, seed 3407, max_steps=10, alpaca-cleaned):

  model           off enabled  default enabled  s1 eq  max|loss d|  max|grad d|
  Llama-3.2-1B          False            True   True      0.00410      0.02336
  Qwen3-0.6B            False            True   True      0.00680      0.02561
  Gemma-3-1B            False            True   True      0.00000      0.00000
  Mistral-7B-v.3        False            True   True      0.00530      0.05310

Step 1 loss + grad_norm bitwise identical for every model; deltas across
the run stay within bf16 -> fp32 chunked-CE rounding noise. Audit
reports 11 classes patched at default and 0 patched when explicitly
disabled.
install_for_module(mod)
except Exception:
continue
_INSTALL_DONE = True
try:
replacement.__qualname__ = forward.__qualname__
replacement.__module__ = forward.__module__
except Exception:
try:
from transformers.utils.generic import can_return_tuple
ns.setdefault("can_return_tuple", can_return_tuple)
except Exception:
if _name.startswith("_"):
continue
ns.setdefault(_name, getattr(_mo, _name))
except Exception:
new_forward.__qualname__ = forward.__qualname__
new_forward.__module__ = forward.__module__
new_forward.__doc__ = forward.__doc__
except Exception:

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: db4e5ea3c5

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +66 to +68
def is_enabled() -> bool:
# On by default; opt out via UNSLOTH_FUSED_FORWARD=0.
return os.environ.get("UNSLOTH_FUSED_FORWARD", "1") != "0"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep fused-forward patching opt-in by default

When UNSLOTH_FUSED_FORWARD is unset, this returns True, and unsloth_zoo.__init__ immediately calls install_modeling_import_hook(), so importing the package monkey-patches every eligible Transformers modeling module even though this change is described as opt-in. That makes existing training/eval runs pick up the new AST-rewritten forward without consent, including any edge cases that the installer does not yet handle; default this to disabled unless the env var is explicitly set to 1.

Useful? React with 👍 / 👎.

trl 1.x padding_free passes shift_labels=<tensor> through the loss
function. The adapter previously fell back to a materialised-logits
F.cross_entropy in that case, which kept the OOM problem the chunked
kernel was supposed to fix.

Plumb shift_labels through unsloth_fused_ce_loss instead. The outer
UnslothFusedLoss.forward already handles label shifting; when the
caller pre-shifted we just flatten and skip the inner re-shift.

Files:
- cross_entropy_loss.py: unsloth_fused_ce_loss gains shift_labels arg
  (default True). Outer adds an else branch that flattens pre-shifted
  labels so chunking aligns with hidden_states. The four inner
  accumulate_chunk call sites pass False unconditionally now since
  the outer always normalises labels.
- forward_adapter.py: drop the F.cross_entropy fallback. Pick (target,
  do_shift) based on the shift_labels kwarg and call the fused kernel
  with shift_labels=do_shift.
- test_fused_forward_install.py: rename the stale fallback test and
  add five fp32-strict numerical checks (atol/rtol=1e-5):
    * auto-shift matches F.cross_entropy
    * pre-shifted tensor matches F.cross_entropy
    * shift_labels=False matches F.cross_entropy
    * num_items_in_batch divides correctly
    * int and 0-d tensor n_items produce equal loss

Empirical end-to-end checks (10 step Llama-3.2-1B LoRA, max_steps=10):

  trl 1.4.0 padding_free=True, fused vs off:
    step 1 loss: 1.45730 == 1.45730 (exact)
    max delta over 10 steps: 0.003 (bf16 noise)

  num_items_in_batch wiring (batch=2, grad_accum=4):
    HF passes a scalar tensor, consistent across the 4 micro-batches
    in each window. n_items equals sum(non_ignore_labels) - rows in
    every window (the per-row causal-shift drop), matching the
    post-shift count HF uses for the mean reduction.

27/27 unit tests pass.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ec66bc6cc6

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +271 to +273
{loss} = unsloth_fused_lm_head_loss(
{hidden_src}, self.{head_attr}, labels,
vocab_size={vocab}{extra}{kwargs_unpack},

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve custom loss_function dispatch before fusing

When a patched *ForCausalLM instance overrides self.loss_function (for example by setting a custom _loss_function/loss type for distillation or an experimental objective), the original forward would call that callable here, but the rewrite now always routes the labels path through unsloth_fused_lm_head_loss and computes standard fused CE. That silently changes the training objective for those models; either guard the fused branch on the default causal-LM loss at runtime or decline classes/instances whose loss_function is not the supported loss.

Useful? React with 👍 / 👎.

The fused-forward installer (forward_install.py) rewrites
*ForCausalLM.forward at import time. Two upstream-pattern tests used
inspect.getsource(cls.forward) and got the rewritten body, which no
longer contains the canonical upstream lines compiler.py pins.

Switch both probes to read the modeling module's on-disk source via
__file__ instead. That is the source compiler.py's rewriter actually
operates on, and it stays pristine regardless of any runtime patches.

Tests affected:
- test_compiler_cross_entropy_lm_head_pattern_present
- test_compiler_cross_entropy_find_2_loss_function_signature

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1d8bc08e1f

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

return unsloth_fused_ce_loss(
trainer = None,
hidden_states = hidden_states,
lm_head_weight = lm_head.weight,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard against adapter-wrapped lm_head modules

When lm_head is a PEFT/LoRA, quantized, or otherwise wrapped Linear (for example when users include lm_head in LoRA target_modules), its forward applies adapter deltas/dequantization/hooks on top of the base parameters. This fused path bypasses lm_head(hidden_states) and feeds only lm_head.weight/.bias to F.linear, so the labels branch computes loss without the wrapper behavior and those adapter parameters receive no gradient, while the no-label branch still uses the real head. Please either restrict this fast path to exact plain torch.nn.Linear heads at runtime or fold supported wrapper weights into the fused computation.

Useful? React with 👍 / 👎.

@danielhanchen danielhanchen merged commit 0bcbf69 into main May 17, 2026
15 checks passed
danielhanchen added a commit that referenced this pull request May 17, 2026
trl 1.x padding_free passes shift_labels=<tensor> through the loss
function. The adapter previously fell back to a materialised-logits
F.cross_entropy in that case, which kept the OOM problem the chunked
kernel was supposed to fix.

Plumb shift_labels through unsloth_fused_ce_loss instead. The outer
UnslothFusedLoss.forward already handles label shifting; when the
caller pre-shifted we just flatten and skip the inner re-shift.

Files:
- cross_entropy_loss.py: unsloth_fused_ce_loss gains shift_labels arg
  (default True). Outer adds an else branch that flattens pre-shifted
  labels so chunking aligns with hidden_states. The four inner
  accumulate_chunk call sites pass False unconditionally now since
  the outer always normalises labels.
- forward_adapter.py: drop the F.cross_entropy fallback. Pick (target,
  do_shift) based on the shift_labels kwarg and call the fused kernel
  with shift_labels=do_shift.
- test_fused_forward_install.py: rename the stale fallback test and
  add five fp32-strict numerical checks (atol/rtol=1e-5):
    * auto-shift matches F.cross_entropy
    * pre-shifted tensor matches F.cross_entropy
    * shift_labels=False matches F.cross_entropy
    * num_items_in_batch divides correctly
    * int and 0-d tensor n_items produce equal loss

Empirical end-to-end checks (10 step Llama-3.2-1B LoRA, max_steps=10):

  trl 1.4.0 padding_free=True, fused vs off:
    step 1 loss: 1.45730 == 1.45730 (exact)
    max delta over 10 steps: 0.003 (bf16 noise)

  num_items_in_batch wiring (batch=2, grad_accum=4):
    HF passes a scalar tensor, consistent across the 4 micro-batches
    in each window. n_items equals sum(non_ignore_labels) - rows in
    every window (the per-row causal-shift drop), matching the
    post-shift count HF uses for the mean reduction.

27/27 unit tests pass.
danielhanchen added a commit that referenced this pull request May 18, 2026
* Honor UNSLOTH_RETURN_HIDDEN_STATES / UNSLOTH_RETURN_LOGITS in fused forward

The AST-rewritten forward installed by PR #657 only had two branches:
labels-not-None (fused CE, EMPTY_LOGITS) and else (real logits, no loss).
It silently ignored both env vars that the compiler-rewritten forward in
unsloth_zoo/compiler.py honors. For GRPO the compiled forward overrides
the AST one so this never mattered in practice, but it left the AST
forward behaviourally different from the compiled one and not safe to
rely on standalone.

Expand the rewrite template to the same three-branch shape as the
compiled forward:

  1. UNSLOTH_RETURN_HIDDEN_STATES=1 -> hidden_states in the logits slot,
     no lm_head matmul, no loss. GRPO's hidden-states fast path.
  2. labels is not None -> fused CE for loss; logits = EMPTY_LOGITS
     unless UNSLOTH_RETURN_LOGITS=1, in which case the original lm_head
     expression runs so callers can train + collect logits in one
     forward.
  3. otherwise -> original RHS verbatim, loss = None.

forward_install.py: seed the rewritten forward's globals with os so the
env-var reads work on classes whose original forward did not import os.

Tests: ordering assertion on the rewriter output plus four CUDA-gated
behaviour tests covering each branch and the priority of return-hidden
over return-logits when both are set.

* Drop UNSLOTH_RETURN_HIDDEN_STATES handling from AST forward

The hidden-states fast path is owned by the compiler-rewritten forward
in unsloth_zoo/compiler.py, which already overrides the AST forward for
every *ForCausalLM class that GRPO actually runs on. Honoring the env
var in the AST forward as well was defence-in-depth that nobody hits.

Keep the UNSLOTH_RETURN_LOGITS opt-in (closes a real gap: lets callers
collect real logits + train via fused CE in one forward).

Template now goes back to two top-level branches with a nested if for
the logits opt-in:

  if labels is not None:
      <fused CE>
      if UNSLOTH_RETURN_LOGITS == '1':
          logits = <original RHS>
      else:
          logits = EMPTY_LOGITS
  else:
      logits = <original RHS>
      loss = None

Tests trimmed to match (29 passed). The ns.setdefault('os', os) seed in
forward_install.py stays -- the UNSLOTH_RETURN_LOGITS read still needs
os available in the rewritten forward's globals.

* Avoid double lm_head matmul on UNSLOTH_RETURN_LOGITS=1 path

Previous shape called both unsloth_fused_lm_head_loss (which chunks the
lm_head matmul internally to compute CE) and self.<head>(<hidden>) (the
full matmul) when the opt-in env var was set. Two matmuls for one
materialised tensor.

New shape splits the labels branch into two paths and picks the right
loss path for each:

  if labels is not None:
      if UNSLOTH_RETURN_LOGITS == '1':
          logits = <original RHS>                       # one matmul
          loss   = self.loss_function(logits, labels,   # same logits
                                      vocab_size=V, ...)
      else:
          loss   = unsloth_fused_lm_head_loss(...)      # chunked, fused
          logits = EMPTY_LOGITS
  else:
      logits = <original RHS>
      loss   = None

The opt-in path now routes through the model's own self.loss_function
on the already-materialised logits. Matches HF's standard CausalLM loss
shape and the conditional in unsloth_zoo/compiler.py:2074.

Tests assert single-matmul + single-self.loss_function on the opt-in
path; numerical equivalence holds bit-identically on the toy in this
sim (5.003798 vs 5.003798).
Brishen pushed a commit to Brishen/unsloth-zoo that referenced this pull request May 19, 2026
trl 1.x padding_free passes shift_labels=<tensor> through the loss
function. The adapter previously fell back to a materialised-logits
F.cross_entropy in that case, which kept the OOM problem the chunked
kernel was supposed to fix.

Plumb shift_labels through unsloth_fused_ce_loss instead. The outer
UnslothFusedLoss.forward already handles label shifting; when the
caller pre-shifted we just flatten and skip the inner re-shift.

Files:
- cross_entropy_loss.py: unsloth_fused_ce_loss gains shift_labels arg
  (default True). Outer adds an else branch that flattens pre-shifted
  labels so chunking aligns with hidden_states. The four inner
  accumulate_chunk call sites pass False unconditionally now since
  the outer always normalises labels.
- forward_adapter.py: drop the F.cross_entropy fallback. Pick (target,
  do_shift) based on the shift_labels kwarg and call the fused kernel
  with shift_labels=do_shift.
- test_fused_forward_install.py: rename the stale fallback test and
  add five fp32-strict numerical checks (atol/rtol=1e-5):
    * auto-shift matches F.cross_entropy
    * pre-shifted tensor matches F.cross_entropy
    * shift_labels=False matches F.cross_entropy
    * num_items_in_batch divides correctly
    * int and 0-d tensor n_items produce equal loss

Empirical end-to-end checks (10 step Llama-3.2-1B LoRA, max_steps=10):

  trl 1.4.0 padding_free=True, fused vs off:
    step 1 loss: 1.45730 == 1.45730 (exact)
    max delta over 10 steps: 0.003 (bf16 noise)

  num_items_in_batch wiring (batch=2, grad_accum=4):
    HF passes a scalar tensor, consistent across the 4 micro-batches
    in each window. n_items equals sum(non_ignore_labels) - rows in
    every window (the per-row causal-shift drop), matching the
    post-shift count HF uses for the mean reduction.

27/27 unit tests pass.
Brishen pushed a commit to Brishen/unsloth-zoo that referenced this pull request May 19, 2026
* Honor UNSLOTH_RETURN_HIDDEN_STATES / UNSLOTH_RETURN_LOGITS in fused forward

The AST-rewritten forward installed by PR unslothai#657 only had two branches:
labels-not-None (fused CE, EMPTY_LOGITS) and else (real logits, no loss).
It silently ignored both env vars that the compiler-rewritten forward in
unsloth_zoo/compiler.py honors. For GRPO the compiled forward overrides
the AST one so this never mattered in practice, but it left the AST
forward behaviourally different from the compiled one and not safe to
rely on standalone.

Expand the rewrite template to the same three-branch shape as the
compiled forward:

  1. UNSLOTH_RETURN_HIDDEN_STATES=1 -> hidden_states in the logits slot,
     no lm_head matmul, no loss. GRPO's hidden-states fast path.
  2. labels is not None -> fused CE for loss; logits = EMPTY_LOGITS
     unless UNSLOTH_RETURN_LOGITS=1, in which case the original lm_head
     expression runs so callers can train + collect logits in one
     forward.
  3. otherwise -> original RHS verbatim, loss = None.

forward_install.py: seed the rewritten forward's globals with os so the
env-var reads work on classes whose original forward did not import os.

Tests: ordering assertion on the rewriter output plus four CUDA-gated
behaviour tests covering each branch and the priority of return-hidden
over return-logits when both are set.

* Drop UNSLOTH_RETURN_HIDDEN_STATES handling from AST forward

The hidden-states fast path is owned by the compiler-rewritten forward
in unsloth_zoo/compiler.py, which already overrides the AST forward for
every *ForCausalLM class that GRPO actually runs on. Honoring the env
var in the AST forward as well was defence-in-depth that nobody hits.

Keep the UNSLOTH_RETURN_LOGITS opt-in (closes a real gap: lets callers
collect real logits + train via fused CE in one forward).

Template now goes back to two top-level branches with a nested if for
the logits opt-in:

  if labels is not None:
      <fused CE>
      if UNSLOTH_RETURN_LOGITS == '1':
          logits = <original RHS>
      else:
          logits = EMPTY_LOGITS
  else:
      logits = <original RHS>
      loss = None

Tests trimmed to match (29 passed). The ns.setdefault('os', os) seed in
forward_install.py stays -- the UNSLOTH_RETURN_LOGITS read still needs
os available in the rewritten forward's globals.

* Avoid double lm_head matmul on UNSLOTH_RETURN_LOGITS=1 path

Previous shape called both unsloth_fused_lm_head_loss (which chunks the
lm_head matmul internally to compute CE) and self.<head>(<hidden>) (the
full matmul) when the opt-in env var was set. Two matmuls for one
materialised tensor.

New shape splits the labels branch into two paths and picks the right
loss path for each:

  if labels is not None:
      if UNSLOTH_RETURN_LOGITS == '1':
          logits = <original RHS>                       # one matmul
          loss   = self.loss_function(logits, labels,   # same logits
                                      vocab_size=V, ...)
      else:
          loss   = unsloth_fused_lm_head_loss(...)      # chunked, fused
          logits = EMPTY_LOGITS
  else:
      logits = <original RHS>
      loss   = None

The opt-in path now routes through the model's own self.loss_function
on the already-materialised logits. Matches HF's standard CausalLM loss
shape and the conditional in unsloth_zoo/compiler.py:2074.

Tests assert single-matmul + single-self.loss_function on the opt-in
path; numerical equivalence holds bit-identically on the toy in this
sim (5.003798 vs 5.003798).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

auto-addresses-issue Pre-flight: appears to address an open issue auto-approved Auto-review approved the PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant