Skip to content

Fix num_logits_to_keep regression on transformers >= 4.52#5538

Merged
danielhanchen merged 2 commits into
mainfrom
daniel/three-followups-from-pr665
May 18, 2026
Merged

Fix num_logits_to_keep regression on transformers >= 4.52#5538
danielhanchen merged 2 commits into
mainfrom
daniel/three-followups-from-pr665

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

@danielhanchen danielhanchen commented May 18, 2026

Summary

Follow-up to unslothai/unsloth-zoo#665. Fixes the num_logits_to_keep regression in unsloth_fast_generate that surfaced on transformers >= 4.52. Pairs with unslothai/unsloth-zoo#666 (single-matmul opt-in in the compiler templates).

What was wrong

unsloth/models/llama.py unsloth_fast_generate unconditionally set kwargs[\"num_logits_to_keep\"] = 1. Transformers 4.50 renamed that forward parameter to logits_to_keep (with a @deprecate_kwarg shim that lived through 4.51.x), and removed the shim entirely in 4.52. From 4.52 onwards, _validate_model_kwargs raises:

ValueError: The following `model_kwargs` are not used by the model: ['num_logits_to_keep']
(note: typos in the generate arguments will also show up in this list)

blocking model.generate(...) on Llama / Mistral. Reproducible with a bare FastLanguageModel.from_pretrained(\"unsloth/Llama-3.2-1B-Instruct\") + model.generate(...).

Change

unsloth/models/llama.py (in unsloth_fast_generate):

num_logits_to_keep = kwargs.pop(\"num_logits_to_keep\", None)
logits_to_keep = kwargs.get(\"logits_to_keep\", None)
if num_logits_to_keep is not None and logits_to_keep is None:
    kwargs[\"logits_to_keep\"] = num_logits_to_keep
    logits_to_keep = num_logits_to_keep
if num_logits_to_keep is None and logits_to_keep is None:
    try:
        _fwd_params = inspect.signature(self.forward).parameters
    except (TypeError, ValueError):
        _fwd_params = {}
    if \"logits_to_keep\" in _fwd_params:
        kwargs[\"logits_to_keep\"] = 1
    elif \"num_logits_to_keep\" in _fwd_params:
        kwargs[\"num_logits_to_keep\"] = 1

Inspects the actual runtime forward signature and uses whichever spelling it accepts. A caller still passing the legacy num_logits_to_keep=N gets it promoted to logits_to_keep=N (not dropped).

Transformers version sweep

Verified the rename window directly against the GitHub source for transformers/models/llama/modeling_llama.py:

Versions Forward parameter Accepts num_logits_to_keep?
4.50.0, 4.51.0, 4.51.3 logits_to_keep yes (@deprecate_kwarg(version=\"4.50\") shim)
4.52.0 - 4.55.x logits_to_keep no (shim removed)
4.56.0, 4.56.1, 4.56.2 logits_to_keep no
4.57.0, 4.57.1, 4.57.2, 4.57.3, 4.57.4, 4.57.5, 4.57.6 logits_to_keep no
main (dev) logits_to_keep no

All currently shipping transformers releases >= 4.52 use logits_to_keep. The fix's inspect.signature(self.forward) path picks the correct spelling on every one of them.

Caveat

_validate_model_kwargs (transformers/generation/utils.py:1599) merges inspect.signature(self.forward).parameters only when prepare_inputs_for_generation accepts **kwargs. On PEFT-wrapped models, self.forward may resolve to PEFT's wrapper with (*args, **kwargs) and hide the underlying signature. In that case my fix sets neither kwarg -- still correct for validation, but the decode-time keep-only-last-logit optimisation is lost. Pre-existing limitation; the old unconditional kwargs[\"num_logits_to_keep\"] = 1 was what 4.52+ rejected. A follow-up can walk self.get_base_model() / self.base_model.model before inspecting if we want to recover the PEFT decode speedup.

What dropped from this PR

Earlier revisions of this PR also flipped patch_loss_functions(torch_compile=False) -> True at loader.py:1381. Reverted in 82c2e69. The loss function bottoms out at a Triton autograd.Function (Fast_CrossEntropyLoss.apply), which torch.compile treats as an opaque op and breaks the graph at. The only thing it can actually compile is three elementwise prep ops around the Triton call, and the per-call dynamo overhead is in the same order. Empirical Gemma3 1B GRPO smoke showed no meaningful delta (415s vs 409s, within noise) and risked dragging the outer compiled training step into recompiles. Keeping torch_compile=False.

Test plan

  • Llama 3.2 1B + model.generate(max_new_tokens=16) no longer raises. Output: \"The capital of France is Paris. The Eiffel Tower is located in Paris. The Eiffel\".
  • Gemma3 1B GRPO smoke (max_steps=3) returns bit-identical losses (0.256 / 0.4393 / 0.2031) vs pre-fix. Default path unchanged at the training-loop level.
  • unsloth-zoo regression suites pass on this combination: 96 passed across test_fused_forward_install + test_compiler_rewriter_exhaustive.

Compatibility

  • Existing callers that pass num_logits_to_keep continue to work: translated to logits_to_keep on transformers >= 4.50 rather than dropped.
  • Defaults unchanged on transformers < 4.50 (legacy spelling still set when only num_logits_to_keep is in the runtime forward signature).
  • No loss function compile flip in this revision.

Pairs with: unslothai/unsloth-zoo daniel/three-followups-from-pr665 (#666, compiler: single-matmul opt-in for UNSLOTH_RETURN_LOGITS=1).

Two follow-ups to the fused-forward work landed in unsloth-zoo PR #665.

1. unsloth_fast_generate (models/llama.py): transformers 4.51 renamed
   num_logits_to_keep to logits_to_keep. Previously we unconditionally
   set kwargs['num_logits_to_keep'] = 1, which transformers 4.57's
   _validate_model_kwargs rejects with:
     ValueError: The following `model_kwargs` are not used by the
     model: ['num_logits_to_keep']
   blocking model.generate() on Llama / Mistral. Now we inspect the
   runtime forward signature and use whichever spelling it accepts;
   if a caller still passes the legacy name we promote it to the new
   spelling instead of stripping it.

2. patch_loss_functions (models/loader.py): the single internal call
   site passed torch_compile=False. UnslothForCausalLMLoss is small
   (label shift + Triton CE), so torch.compile folds the elementwise
   prep into one launch and removes per-step Python overhead. The
   < 2.4 fallback inside patch_loss_functions still routes through
   torch._disable_dynamo so older torches are unaffected.

Verified:
- Llama 3.2 1B + model.generate() no longer raises; emits a sensible
  16-token continuation.
- Gemma3 1B GRPO smoke (max_steps=3) returns bit-identical losses
  0.256 / 0.4393 / 0.2031 vs pre-fix; train_runtime 409s (vs 415s
  pre-fix, within noise).
- unsloth-zoo test_compiler_rewriter_exhaustive + test_fused_forward_install
  pass (96 passed) on this combination.

Related: unslothai/unsloth-zoo PR for the compiler.py single-matmul
backport.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 482258fe99

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth/models/llama.py
Comment on lines +2117 to +2118
if num_logits_to_keep is not None and logits_to_keep is None:
kwargs["logits_to_keep"] = num_logits_to_keep
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve legacy logits kwarg on old transformers

When a caller supplies num_logits_to_keep on transformers versions whose model forward still only accepts that legacy spelling, this code always pops it and forwards logits_to_keep instead. Generation then hits HF's model-kwarg validation with an unexpected logits_to_keep, so existing callers on pre-rename transformers regress; the translation needs to be based on the inspected forward signature before replacing the kwarg.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates unsloth_fast_generate to handle the renaming of num_logits_to_keep to logits_to_keep in newer transformers versions and enables torch_compile for loss functions. A review comment identifies a bug where the logic breaks backward compatibility for older transformers versions by unconditionally renaming the parameter, which can cause validation errors. A more robust approach using signature inspection for both parameter names was suggested.

Comment thread unsloth/models/llama.py
Comment on lines +2115 to +2128
num_logits_to_keep = kwargs.pop("num_logits_to_keep", None)
logits_to_keep = kwargs.get("logits_to_keep", None)
if num_logits_to_keep is not None and logits_to_keep is None:
kwargs["logits_to_keep"] = num_logits_to_keep
logits_to_keep = num_logits_to_keep
if num_logits_to_keep is None and logits_to_keep is None:
kwargs["num_logits_to_keep"] = 1
try:
_fwd_params = inspect.signature(self.forward).parameters
except (TypeError, ValueError):
_fwd_params = {}
if "logits_to_keep" in _fwd_params:
kwargs["logits_to_keep"] = 1
elif "num_logits_to_keep" in _fwd_params:
kwargs["num_logits_to_keep"] = 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic for migrating num_logits_to_keep to logits_to_keep breaks backward compatibility for older transformers versions when the parameter is explicitly passed by the user.

Specifically, if a user provides num_logits_to_keep on a version of transformers that does not yet support logits_to_keep, the code at lines 2117-2119 will unconditionally rename the key in kwargs. This will cause transformers' _validate_model_kwargs to raise a ValueError because logits_to_keep is not in the model's forward signature.

Furthermore, logits_to_keep is not popped from kwargs at line 2116, which could lead to similar validation errors if the model only accepts the legacy name.

A unified approach that normalizes both inputs and then uses signature inspection to decide which key to use is more robust.

    provided_num = kwargs.pop("num_logits_to_keep", None)
    provided_logits = kwargs.pop("logits_to_keep", None)
    val = provided_logits if provided_logits is not None else provided_num

    try:
        _fwd_params = inspect.signature(self.forward).parameters
    except (TypeError, ValueError):
        _fwd_params = {}

    if "logits_to_keep" in _fwd_params:
        kwargs["logits_to_keep"] = val if val is not None else 1
    elif "num_logits_to_keep" in _fwd_params:
        kwargs["num_logits_to_keep"] = val if val is not None else 1

Drop the patch_loss_functions(torch_compile=True) flip. Tracing the
loss call chain:

  UnslothForCausalLMLoss
    -> unsloth_fixed_cross_entropy
      -> _fast_cross_entropy_loss
         -> Fast_CrossEntropyLoss.apply  (torch.autograd.Function wrapping Triton)

torch.compile treats custom autograd.Function.apply as an opaque op and
breaks the graph at the boundary. The only Python it can actually
compile in the loss function is the label-shift + ignore-fill prep
(three elementwise ops), and the per-call dynamo guard overhead is in
the same order as that prep. Empirical Gemma3 1B GRPO smoke (max_steps=3)
showed no meaningful runtime delta (415s vs 409s, within noise) and
risked dragging the outer compiled training step into recompiles when
the inner guards drift. Keep torch_compile=False; the Triton kernel is
the work, and it is unchanged either way.

Also: the inline comment in unsloth_fast_generate said the kwarg rename
landed in transformers 4.51. The actual decorator (@deprecate_kwarg)
was tagged version="4.50" and present through 4.51.x, then removed in
4.52+. Correct the comment. No behaviour change.
@danielhanchen danielhanchen changed the title Fix num_logits_to_keep on transformers >= 4.51 + compile loss_function Fix num_logits_to_keep regression on transformers >= 4.52 May 18, 2026
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 82c2e69815

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth/models/llama.py
Comment on lines +2126 to +2127
if "logits_to_keep" in _fwd_params:
kwargs["logits_to_keep"] = 1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve Mistral's logits-slicing kwarg

When generate() runs on FastMistralModel with no explicit logits kwarg, this branch now chooses logits_to_keep because Mistral's patched forward signature contains both names. However MistralForCausalLM_fast_forward only uses num_logits_to_keep in its normal logits path (unsloth/models/mistral.py:316) and only merges the two names for the hidden-states env path, so Mistral generation falls back to computing full prompt logits instead of slicing to the last token. For long prompts this reintroduces the large prefill logits allocation that the default =1 was avoiding and can cause major slowdown/OOMs.

Useful? React with 👍 / 👎.

@danielhanchen danielhanchen merged commit 61878c7 into main May 18, 2026
43 checks passed
@danielhanchen danielhanchen deleted the daniel/three-followups-from-pr665 branch May 18, 2026 09:32
danielhanchen added a commit that referenced this pull request May 18, 2026
…#5543)

* fast_generate: unify legacy/new logits kwarg + fix Mistral merge site

Two related issues caught by review on PR #5538:

1. unsloth_fast_generate (models/llama.py)

   The previous patch promoted num_logits_to_keep -> logits_to_keep
   unconditionally whenever the caller supplied num_logits_to_keep,
   and only popped num_logits_to_keep (not logits_to_keep). On
   transformers older than 4.50 (legacy spelling is the only one the
   model forward accepts), the promotion broke things; symmetrically,
   a caller supplying logits_to_keep on those older transformers also
   went unchecked.

   Switch to the unified normalize-then-inspect pattern from the
   review:

     _provided_num    = kwargs.pop("num_logits_to_keep", None)
     _provided_logits = kwargs.pop("logits_to_keep",     None)
     _provided = _provided_logits if _provided_logits is not None else _provided_num
     _fwd_params = inspect.signature(self.forward).parameters
     if "logits_to_keep" in _fwd_params:
         kwargs["logits_to_keep"] = _provided if _provided is not None else 1
     elif "num_logits_to_keep" in _fwd_params:
         kwargs["num_logits_to_keep"] = _provided if _provided is not None else 1

   Inspect the runtime forward signature first, then choose the
   spelling it actually accepts, then route either user-supplied value
   under that spelling. Backward-compatible in both directions.

2. MistralForCausalLM_fast_forward (models/mistral.py)

   The max(num_logits_to_keep, logits_to_keep) merge was inside the
   `if UNSLOTH_RETURN_HIDDEN_STATES:` block, so it only fired on the
   GRPO hidden-states path. On the normal generation path the elif at
   line 316 only checked num_logits_to_keep, so a caller (including
   unsloth_fast_generate itself) passing logits_to_keep=1 ended up
   computing full prompt logits instead of slicing to the last token.
   For long prompts that reintroduces the large prefill logits
   allocation the default keep=1 was avoiding.

   Move the max() merge above the env-var branching so the normal
   generation path slices correctly too. Llama already did this
   merge at the top (unsloth/models/llama.py:1501); Mistral now
   matches.

No behaviour change on the default GRPO / SFT paths. Targets only the
edge cases the review flagged.

* fast_generate: preserve caller logits kwarg when signature inspect fails

If `inspect.signature(self.forward)` raises TypeError/ValueError (opaque
C-extension or compiled wrappers), the previous fix set `_fwd_params = {}`
which silently dropped the caller-supplied `logits_to_keep` /
`num_logits_to_keep`. Fall back to the spelling the caller used (default
`logits_to_keep=1` when neither was supplied) so generation still honors
the requested logits slice.

* fast_forward: do not max() int against tensor logits_to_keep

HF accepts logits_to_keep as a 1-D LongTensor of positions for
selective decode. The merge in mistral.py (added by this PR) and
the pre-existing one in llama.py both run max(int, Tensor), which
casts the comparison to a bool and raises on multi-element tensors.
Branch on type and skip the merge when either argument is a tensor;
downstream int-slice path is unchanged, so tensor callers fall
through with num_logits_to_keep == 0, matching pre-merge behavior.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fast_generate/forward: shorten kwarg-merge comments

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant