Fix num_logits_to_keep regression on transformers >= 4.52#5538
Conversation
Two follow-ups to the fused-forward work landed in unsloth-zoo PR #665. 1. unsloth_fast_generate (models/llama.py): transformers 4.51 renamed num_logits_to_keep to logits_to_keep. Previously we unconditionally set kwargs['num_logits_to_keep'] = 1, which transformers 4.57's _validate_model_kwargs rejects with: ValueError: The following `model_kwargs` are not used by the model: ['num_logits_to_keep'] blocking model.generate() on Llama / Mistral. Now we inspect the runtime forward signature and use whichever spelling it accepts; if a caller still passes the legacy name we promote it to the new spelling instead of stripping it. 2. patch_loss_functions (models/loader.py): the single internal call site passed torch_compile=False. UnslothForCausalLMLoss is small (label shift + Triton CE), so torch.compile folds the elementwise prep into one launch and removes per-step Python overhead. The < 2.4 fallback inside patch_loss_functions still routes through torch._disable_dynamo so older torches are unaffected. Verified: - Llama 3.2 1B + model.generate() no longer raises; emits a sensible 16-token continuation. - Gemma3 1B GRPO smoke (max_steps=3) returns bit-identical losses 0.256 / 0.4393 / 0.2031 vs pre-fix; train_runtime 409s (vs 415s pre-fix, within noise). - unsloth-zoo test_compiler_rewriter_exhaustive + test_fused_forward_install pass (96 passed) on this combination. Related: unslothai/unsloth-zoo PR for the compiler.py single-matmul backport.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 482258fe99
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if num_logits_to_keep is not None and logits_to_keep is None: | ||
| kwargs["logits_to_keep"] = num_logits_to_keep |
There was a problem hiding this comment.
Preserve legacy logits kwarg on old transformers
When a caller supplies num_logits_to_keep on transformers versions whose model forward still only accepts that legacy spelling, this code always pops it and forwards logits_to_keep instead. Generation then hits HF's model-kwarg validation with an unexpected logits_to_keep, so existing callers on pre-rename transformers regress; the translation needs to be based on the inspected forward signature before replacing the kwarg.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Code Review
This pull request updates unsloth_fast_generate to handle the renaming of num_logits_to_keep to logits_to_keep in newer transformers versions and enables torch_compile for loss functions. A review comment identifies a bug where the logic breaks backward compatibility for older transformers versions by unconditionally renaming the parameter, which can cause validation errors. A more robust approach using signature inspection for both parameter names was suggested.
| num_logits_to_keep = kwargs.pop("num_logits_to_keep", None) | ||
| logits_to_keep = kwargs.get("logits_to_keep", None) | ||
| if num_logits_to_keep is not None and logits_to_keep is None: | ||
| kwargs["logits_to_keep"] = num_logits_to_keep | ||
| logits_to_keep = num_logits_to_keep | ||
| if num_logits_to_keep is None and logits_to_keep is None: | ||
| kwargs["num_logits_to_keep"] = 1 | ||
| try: | ||
| _fwd_params = inspect.signature(self.forward).parameters | ||
| except (TypeError, ValueError): | ||
| _fwd_params = {} | ||
| if "logits_to_keep" in _fwd_params: | ||
| kwargs["logits_to_keep"] = 1 | ||
| elif "num_logits_to_keep" in _fwd_params: | ||
| kwargs["num_logits_to_keep"] = 1 |
There was a problem hiding this comment.
The current logic for migrating num_logits_to_keep to logits_to_keep breaks backward compatibility for older transformers versions when the parameter is explicitly passed by the user.
Specifically, if a user provides num_logits_to_keep on a version of transformers that does not yet support logits_to_keep, the code at lines 2117-2119 will unconditionally rename the key in kwargs. This will cause transformers' _validate_model_kwargs to raise a ValueError because logits_to_keep is not in the model's forward signature.
Furthermore, logits_to_keep is not popped from kwargs at line 2116, which could lead to similar validation errors if the model only accepts the legacy name.
A unified approach that normalizes both inputs and then uses signature inspection to decide which key to use is more robust.
provided_num = kwargs.pop("num_logits_to_keep", None)
provided_logits = kwargs.pop("logits_to_keep", None)
val = provided_logits if provided_logits is not None else provided_num
try:
_fwd_params = inspect.signature(self.forward).parameters
except (TypeError, ValueError):
_fwd_params = {}
if "logits_to_keep" in _fwd_params:
kwargs["logits_to_keep"] = val if val is not None else 1
elif "num_logits_to_keep" in _fwd_params:
kwargs["num_logits_to_keep"] = val if val is not None else 1Drop the patch_loss_functions(torch_compile=True) flip. Tracing the
loss call chain:
UnslothForCausalLMLoss
-> unsloth_fixed_cross_entropy
-> _fast_cross_entropy_loss
-> Fast_CrossEntropyLoss.apply (torch.autograd.Function wrapping Triton)
torch.compile treats custom autograd.Function.apply as an opaque op and
breaks the graph at the boundary. The only Python it can actually
compile in the loss function is the label-shift + ignore-fill prep
(three elementwise ops), and the per-call dynamo guard overhead is in
the same order as that prep. Empirical Gemma3 1B GRPO smoke (max_steps=3)
showed no meaningful runtime delta (415s vs 409s, within noise) and
risked dragging the outer compiled training step into recompiles when
the inner guards drift. Keep torch_compile=False; the Triton kernel is
the work, and it is unchanged either way.
Also: the inline comment in unsloth_fast_generate said the kwarg rename
landed in transformers 4.51. The actual decorator (@deprecate_kwarg)
was tagged version="4.50" and present through 4.51.x, then removed in
4.52+. Correct the comment. No behaviour change.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 82c2e69815
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if "logits_to_keep" in _fwd_params: | ||
| kwargs["logits_to_keep"] = 1 |
There was a problem hiding this comment.
Preserve Mistral's logits-slicing kwarg
When generate() runs on FastMistralModel with no explicit logits kwarg, this branch now chooses logits_to_keep because Mistral's patched forward signature contains both names. However MistralForCausalLM_fast_forward only uses num_logits_to_keep in its normal logits path (unsloth/models/mistral.py:316) and only merges the two names for the hidden-states env path, so Mistral generation falls back to computing full prompt logits instead of slicing to the last token. For long prompts this reintroduces the large prefill logits allocation that the default =1 was avoiding and can cause major slowdown/OOMs.
Useful? React with 👍 / 👎.
…#5543) * fast_generate: unify legacy/new logits kwarg + fix Mistral merge site Two related issues caught by review on PR #5538: 1. unsloth_fast_generate (models/llama.py) The previous patch promoted num_logits_to_keep -> logits_to_keep unconditionally whenever the caller supplied num_logits_to_keep, and only popped num_logits_to_keep (not logits_to_keep). On transformers older than 4.50 (legacy spelling is the only one the model forward accepts), the promotion broke things; symmetrically, a caller supplying logits_to_keep on those older transformers also went unchecked. Switch to the unified normalize-then-inspect pattern from the review: _provided_num = kwargs.pop("num_logits_to_keep", None) _provided_logits = kwargs.pop("logits_to_keep", None) _provided = _provided_logits if _provided_logits is not None else _provided_num _fwd_params = inspect.signature(self.forward).parameters if "logits_to_keep" in _fwd_params: kwargs["logits_to_keep"] = _provided if _provided is not None else 1 elif "num_logits_to_keep" in _fwd_params: kwargs["num_logits_to_keep"] = _provided if _provided is not None else 1 Inspect the runtime forward signature first, then choose the spelling it actually accepts, then route either user-supplied value under that spelling. Backward-compatible in both directions. 2. MistralForCausalLM_fast_forward (models/mistral.py) The max(num_logits_to_keep, logits_to_keep) merge was inside the `if UNSLOTH_RETURN_HIDDEN_STATES:` block, so it only fired on the GRPO hidden-states path. On the normal generation path the elif at line 316 only checked num_logits_to_keep, so a caller (including unsloth_fast_generate itself) passing logits_to_keep=1 ended up computing full prompt logits instead of slicing to the last token. For long prompts that reintroduces the large prefill logits allocation the default keep=1 was avoiding. Move the max() merge above the env-var branching so the normal generation path slices correctly too. Llama already did this merge at the top (unsloth/models/llama.py:1501); Mistral now matches. No behaviour change on the default GRPO / SFT paths. Targets only the edge cases the review flagged. * fast_generate: preserve caller logits kwarg when signature inspect fails If `inspect.signature(self.forward)` raises TypeError/ValueError (opaque C-extension or compiled wrappers), the previous fix set `_fwd_params = {}` which silently dropped the caller-supplied `logits_to_keep` / `num_logits_to_keep`. Fall back to the spelling the caller used (default `logits_to_keep=1` when neither was supplied) so generation still honors the requested logits slice. * fast_forward: do not max() int against tensor logits_to_keep HF accepts logits_to_keep as a 1-D LongTensor of positions for selective decode. The merge in mistral.py (added by this PR) and the pre-existing one in llama.py both run max(int, Tensor), which casts the comparison to a bool and raises on multi-element tensors. Branch on type and skip the merge when either argument is a tensor; downstream int-slice path is unchanged, so tensor callers fall through with num_logits_to_keep == 0, matching pre-merge behavior. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fast_generate/forward: shorten kwarg-merge comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Summary
Follow-up to
unslothai/unsloth-zoo#665. Fixes thenum_logits_to_keepregression inunsloth_fast_generatethat surfaced on transformers >= 4.52. Pairs withunslothai/unsloth-zoo#666(single-matmul opt-in in the compiler templates).What was wrong
unsloth/models/llama.pyunsloth_fast_generateunconditionally setkwargs[\"num_logits_to_keep\"] = 1. Transformers 4.50 renamed that forward parameter tologits_to_keep(with a@deprecate_kwargshim that lived through 4.51.x), and removed the shim entirely in 4.52. From 4.52 onwards,_validate_model_kwargsraises:blocking
model.generate(...)on Llama / Mistral. Reproducible with a bareFastLanguageModel.from_pretrained(\"unsloth/Llama-3.2-1B-Instruct\") + model.generate(...).Change
unsloth/models/llama.py(inunsloth_fast_generate):Inspects the actual runtime forward signature and uses whichever spelling it accepts. A caller still passing the legacy
num_logits_to_keep=Ngets it promoted tologits_to_keep=N(not dropped).Transformers version sweep
Verified the rename window directly against the GitHub source for
transformers/models/llama/modeling_llama.py:num_logits_to_keep?logits_to_keep@deprecate_kwarg(version=\"4.50\")shim)logits_to_keeplogits_to_keeplogits_to_keeplogits_to_keepAll currently shipping
transformersreleases >= 4.52 uselogits_to_keep. The fix'sinspect.signature(self.forward)path picks the correct spelling on every one of them.Caveat
_validate_model_kwargs(transformers/generation/utils.py:1599) mergesinspect.signature(self.forward).parametersonly whenprepare_inputs_for_generationaccepts**kwargs. On PEFT-wrapped models,self.forwardmay resolve to PEFT's wrapper with(*args, **kwargs)and hide the underlying signature. In that case my fix sets neither kwarg -- still correct for validation, but the decode-time keep-only-last-logit optimisation is lost. Pre-existing limitation; the old unconditionalkwargs[\"num_logits_to_keep\"] = 1was what 4.52+ rejected. A follow-up can walkself.get_base_model()/self.base_model.modelbefore inspecting if we want to recover the PEFT decode speedup.What dropped from this PR
Earlier revisions of this PR also flipped
patch_loss_functions(torch_compile=False)->Trueatloader.py:1381. Reverted in82c2e69. The loss function bottoms out at a Triton autograd.Function (Fast_CrossEntropyLoss.apply), whichtorch.compiletreats as an opaque op and breaks the graph at. The only thing it can actually compile is three elementwise prep ops around the Triton call, and the per-call dynamo overhead is in the same order. Empirical Gemma3 1B GRPO smoke showed no meaningful delta (415s vs 409s, within noise) and risked dragging the outer compiled training step into recompiles. Keepingtorch_compile=False.Test plan
model.generate(max_new_tokens=16)no longer raises. Output:\"The capital of France is Paris. The Eiffel Tower is located in Paris. The Eiffel\".test_fused_forward_install+test_compiler_rewriter_exhaustive.Compatibility
num_logits_to_keepcontinue to work: translated tologits_to_keepon transformers >= 4.50 rather than dropped.num_logits_to_keepis in the runtime forward signature).Pairs with:
unslothai/unsloth-zoo daniel/three-followups-from-pr665(#666,compiler: single-matmul opt-in for UNSLOTH_RETURN_LOGITS=1).