Skip to content

fast_generate: unify legacy/new logits kwarg + fix Mistral merge site#5543

Merged
danielhanchen merged 5 commits into
mainfrom
daniel/fast-generate-kwarg-signature-aware
May 18, 2026
Merged

fast_generate: unify legacy/new logits kwarg + fix Mistral merge site#5543
danielhanchen merged 5 commits into
mainfrom
daniel/fast-generate-kwarg-signature-aware

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

Summary

Two related issues raised on the merged #5538 (one by gemini-code-assist[bot], two by chatgpt-codex-connector[bot]):

Issue 1: backward-compat broken when user passes a name the runtime forward does not accept

The patch landed in #5538 had:

num_logits_to_keep = kwargs.pop(\"num_logits_to_keep\", None)
logits_to_keep = kwargs.get(\"logits_to_keep\", None)
if num_logits_to_keep is not None and logits_to_keep is None:
    kwargs[\"logits_to_keep\"] = num_logits_to_keep
    ...

Two problems:

  • logits_to_keep was never popped, so if a caller passes logits_to_keep on an older transformers (< 4.50) where the runtime forward only accepts num_logits_to_keep, validation still rejects it.
  • The first if unconditionally promoted num_logits_to_keep -> logits_to_keep, regardless of whether the runtime forward actually has logits_to_keep in its signature. On transformers < 4.50 that breaks legacy callers.

Fix: switch to the unified normalize-then-inspect pattern from gemini's review:

_provided_num    = kwargs.pop(\"num_logits_to_keep\", None)
_provided_logits = kwargs.pop(\"logits_to_keep\",     None)
_provided = _provided_logits if _provided_logits is not None else _provided_num

try:
    _fwd_params = inspect.signature(self.forward).parameters
except (TypeError, ValueError):
    _fwd_params = {}

if \"logits_to_keep\" in _fwd_params:
    kwargs[\"logits_to_keep\"] = _provided if _provided is not None else 1
elif \"num_logits_to_keep\" in _fwd_params:
    kwargs[\"num_logits_to_keep\"] = _provided if _provided is not None else 1

Inspect the runtime forward signature first, then choose the spelling it actually accepts, then route whichever value the caller supplied under that spelling. Backward-compatible in both directions: works for legacy callers on new transformers (>= 4.52), modern callers on old transformers (< 4.50), or anything in between.

Issue 2: Mistral's max() merge sits in the wrong branch

MistralForCausalLM_fast_forward (unsloth/models/mistral.py) had:

if os.environ.get(\"UNSLOTH_RETURN_HIDDEN_STATES\", \"0\") == \"1\":
    num_logits_to_keep = max(num_logits_to_keep, logits_to_keep)   # <-- only here
    if num_logits_to_keep != 0:
        hidden_states = hidden_states[:, -num_logits_to_keep:, :]
    return ...

if bsz == 1 and q_len == 1:
    ...
elif num_logits_to_keep != 0:   # <-- only num_logits_to_keep checked here
    logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype))

A caller passing logits_to_keep=1 (which is what unsloth_fast_generate itself does on modern transformers per fix #1) made the merge fire only inside the UNSLOTH_RETURN_HIDDEN_STATES branch. On the normal generation path, num_logits_to_keep stayed 0 and logits_to_keep=1 was ignored, falling into the else clause and computing full prompt logits. For long prompts that re-introduces the large prefill logits allocation the default keep=1 was avoiding.

Fix: move the max() merge above the env-var branching so the decode slice fires on every path. Llama already does this at the top (unsloth/models/llama.py:1501); Mistral now matches.

# Merge the legacy and new spellings before any branching so the
# decode-time last-token slice fires on the normal generation path
# too, not just the hidden-states path.
num_logits_to_keep = max(num_logits_to_keep, logits_to_keep)

if os.environ.get(\"UNSLOTH_RETURN_HIDDEN_STATES\", \"0\") == \"1\":
    if num_logits_to_keep != 0:
        ...

Test plan

  • Llama 3.2 1B + model.generate(max_new_tokens=16) still works on transformers 4.57.6 (same smoke that drove Fix num_logits_to_keep regression on transformers >= 4.52 #5538).
  • Static review: the new pattern inspects the runtime forward signature once and routes the user-supplied value (if any) under whichever spelling the model accepts; otherwise defaults to =1. No spelling is silently dropped.
  • Diff vs main: 33 lines changed across unsloth/models/llama.py (signature-aware kwarg handling) + 7 lines on unsloth/models/mistral.py (max() merge moved). Net +22, -18.

Compatibility

Addresses review comments on #5538: gemini-code-assist[bot]#3257327268, chatgpt-codex-connector[bot]#3257287094, chatgpt-codex-connector[bot]#3257768585.

Two related issues caught by review on PR #5538:

1. unsloth_fast_generate (models/llama.py)

   The previous patch promoted num_logits_to_keep -> logits_to_keep
   unconditionally whenever the caller supplied num_logits_to_keep,
   and only popped num_logits_to_keep (not logits_to_keep). On
   transformers older than 4.50 (legacy spelling is the only one the
   model forward accepts), the promotion broke things; symmetrically,
   a caller supplying logits_to_keep on those older transformers also
   went unchecked.

   Switch to the unified normalize-then-inspect pattern from the
   review:

     _provided_num    = kwargs.pop("num_logits_to_keep", None)
     _provided_logits = kwargs.pop("logits_to_keep",     None)
     _provided = _provided_logits if _provided_logits is not None else _provided_num
     _fwd_params = inspect.signature(self.forward).parameters
     if "logits_to_keep" in _fwd_params:
         kwargs["logits_to_keep"] = _provided if _provided is not None else 1
     elif "num_logits_to_keep" in _fwd_params:
         kwargs["num_logits_to_keep"] = _provided if _provided is not None else 1

   Inspect the runtime forward signature first, then choose the
   spelling it actually accepts, then route either user-supplied value
   under that spelling. Backward-compatible in both directions.

2. MistralForCausalLM_fast_forward (models/mistral.py)

   The max(num_logits_to_keep, logits_to_keep) merge was inside the
   `if UNSLOTH_RETURN_HIDDEN_STATES:` block, so it only fired on the
   GRPO hidden-states path. On the normal generation path the elif at
   line 316 only checked num_logits_to_keep, so a caller (including
   unsloth_fast_generate itself) passing logits_to_keep=1 ended up
   computing full prompt logits instead of slicing to the last token.
   For long prompts that reintroduces the large prefill logits
   allocation the default keep=1 was avoiding.

   Move the max() merge above the env-var branching so the normal
   generation path slices correctly too. Llama already did this
   merge at the top (unsloth/models/llama.py:1501); Mistral now
   matches.

No behaviour change on the default GRPO / SFT paths. Targets only the
edge cases the review flagged.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the handling of num_logits_to_keep and logits_to_keep parameters to maintain compatibility across different versions of the transformers library. In unsloth/models/llama.py, the unsloth_fast_generate function now uses reflection to determine the correct parameter name for the model's forward call. In unsloth/models/mistral.py, the parameter merging logic was moved to ensure it applies to the standard generation path rather than just the GRPO-specific path. I have no feedback to provide as there were no review comments.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: bd19ba3112

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth/models/llama.py Outdated
Comment on lines +2123 to +2124
except (TypeError, ValueError):
_fwd_params = {}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve logits kwarg when signature inspection fails

In this exception path both num_logits_to_keep and logits_to_keep have already been removed from kwargs, and setting _fwd_params = {} means neither branch below re-adds the caller's value. For opaque or compiled forwards where inspect.signature(self.forward) raises TypeError/ValueError, an explicit logits_to_keep/num_logits_to_keep is silently dropped before _old_generate, so generation no longer honors the requested logits slice.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, fixed in 2f39c00. The except path now sets _has_new / _has_old from the caller's supplied spelling (defaulting to logits_to_keep when neither was given), so the popped value is re-emitted instead of silently dropped. Verified across 11 case combinations (caller {none, new, old, both} x signature {new, old, neither, opaque}) in temp/pr5543_sim/sim_signature_fallback.py locally.

If `inspect.signature(self.forward)` raises TypeError/ValueError (opaque
C-extension or compiled wrappers), the previous fix set `_fwd_params = {}`
which silently dropped the caller-supplied `logits_to_keep` /
`num_logits_to_keep`. Fall back to the spelling the caller used (default
`logits_to_keep=1` when neither was supplied) so generation still honors
the requested logits slice.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 2f39c001e4

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth/models/mistral.py Outdated
# decode-time last-token slice fires on the normal generation path
# too, not just the hidden-states path. transformers 4.50 renamed
# num_logits_to_keep -> logits_to_keep; callers may supply either.
num_logits_to_keep = max(num_logits_to_keep, logits_to_keep)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Handle tensor logits_to_keep before max

For calls that use the Transformers-supported tensor form of logits_to_keep (explicit token positions for packed logits), this unconditional max(num_logits_to_keep, logits_to_keep) now runs on Mistral's normal forward path and attempts to compare an int with a multi-element tensor, raising RuntimeError: Boolean value of Tensor with more than one value is ambiguous before logits are produced. This used to be confined to the hidden-states branch; now a regular Mistral forward with tensor logits_to_keep fails, so the merge needs to branch on tensor vs integer and use tensor indexing semantics instead of max.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, fixed in e711767. Both the new mistral.py merge and the pre-existing llama.py merge now branch on type: when either side is a torch.Tensor we skip the int max() and leave num_logits_to_keep == 0, so the downstream int-slice path falls through identically to pre-merge behavior and no implicit bool cast happens. Tensor callers were already not honored by the int slice path before this PR, so this is a regression-free, no-crash fix. Verified across the int/tensor combinations in temp/pr5543_sim/sim_tensor_logits_to_keep.py.

danielhanchen and others added 2 commits May 18, 2026 11:10
HF accepts logits_to_keep as a 1-D LongTensor of positions for
selective decode. The merge in mistral.py (added by this PR) and
the pre-existing one in llama.py both run max(int, Tensor), which
casts the comparison to a bool and raises on multi-element tensors.
Branch on type and skip the merge when either argument is a tensor;
downstream int-slice path is unchanged, so tensor callers fall
through with num_logits_to_keep == 0, matching pre-merge behavior.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a5d6958e04

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/models/llama.py
Comment on lines +2139 to +2142
if _has_new:
kwargs["logits_to_keep"] = _provided if _provided is not None else 1
elif _has_old:
kwargs["num_logits_to_keep"] = _provided if _provided is not None else 1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve logits kwargs when forward signature is generic

unsloth_fast_generate now pops both num_logits_to_keep and logits_to_keep, but only re-adds one when inspect.signature(self.forward) explicitly contains those parameter names. For wrapped/compiled forwards that introspect as generic *args/**kwargs (no explicit logits keys), both _has_new and _has_old are false, so the caller-provided value is silently dropped before _old_generate runs. This regresses behavior from the previous code (where an explicit logits_to_keep survived) and can re-enable full prompt-logit computation, causing avoidable VRAM/perf spikes on long prompts.

Useful? React with 👍 / 👎.

@danielhanchen danielhanchen merged commit 3ebe17f into main May 18, 2026
8 of 43 checks passed
@danielhanchen danielhanchen deleted the daniel/fast-generate-kwarg-signature-aware branch May 18, 2026 11:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Question about readme benchmarks

1 participant