fast_generate: unify legacy/new logits kwarg + fix Mistral merge site#5543
Conversation
Two related issues caught by review on PR #5538: 1. unsloth_fast_generate (models/llama.py) The previous patch promoted num_logits_to_keep -> logits_to_keep unconditionally whenever the caller supplied num_logits_to_keep, and only popped num_logits_to_keep (not logits_to_keep). On transformers older than 4.50 (legacy spelling is the only one the model forward accepts), the promotion broke things; symmetrically, a caller supplying logits_to_keep on those older transformers also went unchecked. Switch to the unified normalize-then-inspect pattern from the review: _provided_num = kwargs.pop("num_logits_to_keep", None) _provided_logits = kwargs.pop("logits_to_keep", None) _provided = _provided_logits if _provided_logits is not None else _provided_num _fwd_params = inspect.signature(self.forward).parameters if "logits_to_keep" in _fwd_params: kwargs["logits_to_keep"] = _provided if _provided is not None else 1 elif "num_logits_to_keep" in _fwd_params: kwargs["num_logits_to_keep"] = _provided if _provided is not None else 1 Inspect the runtime forward signature first, then choose the spelling it actually accepts, then route either user-supplied value under that spelling. Backward-compatible in both directions. 2. MistralForCausalLM_fast_forward (models/mistral.py) The max(num_logits_to_keep, logits_to_keep) merge was inside the `if UNSLOTH_RETURN_HIDDEN_STATES:` block, so it only fired on the GRPO hidden-states path. On the normal generation path the elif at line 316 only checked num_logits_to_keep, so a caller (including unsloth_fast_generate itself) passing logits_to_keep=1 ended up computing full prompt logits instead of slicing to the last token. For long prompts that reintroduces the large prefill logits allocation the default keep=1 was avoiding. Move the max() merge above the env-var branching so the normal generation path slices correctly too. Llama already did this merge at the top (unsloth/models/llama.py:1501); Mistral now matches. No behaviour change on the default GRPO / SFT paths. Targets only the edge cases the review flagged.
There was a problem hiding this comment.
Code Review
This pull request updates the handling of num_logits_to_keep and logits_to_keep parameters to maintain compatibility across different versions of the transformers library. In unsloth/models/llama.py, the unsloth_fast_generate function now uses reflection to determine the correct parameter name for the model's forward call. In unsloth/models/mistral.py, the parameter merging logic was moved to ensure it applies to the standard generation path rather than just the GRPO-specific path. I have no feedback to provide as there were no review comments.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: bd19ba3112
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| except (TypeError, ValueError): | ||
| _fwd_params = {} |
There was a problem hiding this comment.
Preserve logits kwarg when signature inspection fails
In this exception path both num_logits_to_keep and logits_to_keep have already been removed from kwargs, and setting _fwd_params = {} means neither branch below re-adds the caller's value. For opaque or compiled forwards where inspect.signature(self.forward) raises TypeError/ValueError, an explicit logits_to_keep/num_logits_to_keep is silently dropped before _old_generate, so generation no longer honors the requested logits slice.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Good catch, fixed in 2f39c00. The except path now sets _has_new / _has_old from the caller's supplied spelling (defaulting to logits_to_keep when neither was given), so the popped value is re-emitted instead of silently dropped. Verified across 11 case combinations (caller {none, new, old, both} x signature {new, old, neither, opaque}) in temp/pr5543_sim/sim_signature_fallback.py locally.
If `inspect.signature(self.forward)` raises TypeError/ValueError (opaque
C-extension or compiled wrappers), the previous fix set `_fwd_params = {}`
which silently dropped the caller-supplied `logits_to_keep` /
`num_logits_to_keep`. Fall back to the spelling the caller used (default
`logits_to_keep=1` when neither was supplied) so generation still honors
the requested logits slice.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 2f39c001e4
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # decode-time last-token slice fires on the normal generation path | ||
| # too, not just the hidden-states path. transformers 4.50 renamed | ||
| # num_logits_to_keep -> logits_to_keep; callers may supply either. | ||
| num_logits_to_keep = max(num_logits_to_keep, logits_to_keep) |
There was a problem hiding this comment.
Handle tensor logits_to_keep before max
For calls that use the Transformers-supported tensor form of logits_to_keep (explicit token positions for packed logits), this unconditional max(num_logits_to_keep, logits_to_keep) now runs on Mistral's normal forward path and attempts to compare an int with a multi-element tensor, raising RuntimeError: Boolean value of Tensor with more than one value is ambiguous before logits are produced. This used to be confined to the hidden-states branch; now a regular Mistral forward with tensor logits_to_keep fails, so the merge needs to branch on tensor vs integer and use tensor indexing semantics instead of max.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Good catch, fixed in e711767. Both the new mistral.py merge and the pre-existing llama.py merge now branch on type: when either side is a torch.Tensor we skip the int max() and leave num_logits_to_keep == 0, so the downstream int-slice path falls through identically to pre-merge behavior and no implicit bool cast happens. Tensor callers were already not honored by the int slice path before this PR, so this is a regression-free, no-crash fix. Verified across the int/tensor combinations in temp/pr5543_sim/sim_tensor_logits_to_keep.py.
HF accepts logits_to_keep as a 1-D LongTensor of positions for selective decode. The merge in mistral.py (added by this PR) and the pre-existing one in llama.py both run max(int, Tensor), which casts the comparison to a bool and raises on multi-element tensors. Branch on type and skip the merge when either argument is a tensor; downstream int-slice path is unchanged, so tensor callers fall through with num_logits_to_keep == 0, matching pre-merge behavior.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a5d6958e04
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if _has_new: | ||
| kwargs["logits_to_keep"] = _provided if _provided is not None else 1 | ||
| elif _has_old: | ||
| kwargs["num_logits_to_keep"] = _provided if _provided is not None else 1 |
There was a problem hiding this comment.
Preserve logits kwargs when forward signature is generic
unsloth_fast_generate now pops both num_logits_to_keep and logits_to_keep, but only re-adds one when inspect.signature(self.forward) explicitly contains those parameter names. For wrapped/compiled forwards that introspect as generic *args/**kwargs (no explicit logits keys), both _has_new and _has_old are false, so the caller-provided value is silently dropped before _old_generate runs. This regresses behavior from the previous code (where an explicit logits_to_keep survived) and can re-enable full prompt-logit computation, causing avoidable VRAM/perf spikes on long prompts.
Useful? React with 👍 / 👎.
Summary
Two related issues raised on the merged #5538 (one by
gemini-code-assist[bot], two bychatgpt-codex-connector[bot]):Issue 1: backward-compat broken when user passes a name the runtime forward does not accept
The patch landed in #5538 had:
Two problems:
logits_to_keepwas never popped, so if a caller passeslogits_to_keepon an older transformers (< 4.50) where the runtime forward only acceptsnum_logits_to_keep, validation still rejects it.ifunconditionally promotednum_logits_to_keep->logits_to_keep, regardless of whether the runtime forward actually haslogits_to_keepin its signature. On transformers < 4.50 that breaks legacy callers.Fix: switch to the unified normalize-then-inspect pattern from gemini's review:
Inspect the runtime forward signature first, then choose the spelling it actually accepts, then route whichever value the caller supplied under that spelling. Backward-compatible in both directions: works for legacy callers on new transformers (>= 4.52), modern callers on old transformers (< 4.50), or anything in between.
Issue 2: Mistral's max() merge sits in the wrong branch
MistralForCausalLM_fast_forward(unsloth/models/mistral.py) had:A caller passing
logits_to_keep=1(which is whatunsloth_fast_generateitself does on modern transformers per fix #1) made the merge fire only inside theUNSLOTH_RETURN_HIDDEN_STATESbranch. On the normal generation path,num_logits_to_keepstayed0andlogits_to_keep=1was ignored, falling into the else clause and computing full prompt logits. For long prompts that re-introduces the large prefill logits allocation the defaultkeep=1was avoiding.Fix: move the
max()merge above the env-var branching so the decode slice fires on every path. Llama already does this at the top (unsloth/models/llama.py:1501); Mistral now matches.Test plan
model.generate(max_new_tokens=16)still works on transformers 4.57.6 (same smoke that drove Fix num_logits_to_keep regression on transformers >= 4.52 #5538).=1. No spelling is silently dropped.unsloth/models/llama.py(signature-aware kwarg handling) + 7 lines onunsloth/models/mistral.py(max() merge moved). Net +22, -18.Compatibility
logits_to_keep=1).num_logits_to_keep=Non transformers >= 4.52: promoted tologits_to_keep=N, not dropped.logits_to_keep=Non transformers < 4.50: promoted tonum_logits_to_keep=N, not dropped. (Fix num_logits_to_keep regression on transformers >= 4.52 #5538 silently broke this.)logits_to_keep=1is supplied. (Fix num_logits_to_keep regression on transformers >= 4.52 #5538 effectively disabled the slice for Mistral.)Addresses review comments on #5538: gemini-code-assist[bot]#3257327268, chatgpt-codex-connector[bot]#3257287094, chatgpt-codex-connector[bot]#3257768585.