gpt-oss: align eager KV length to the attention mask#691
Merged
Conversation
Fixes a "The size of tensor a (N) must match the size of tensor b (M)"
RuntimeError in gpt-oss inference when generation crosses the sliding
window (e.g. short prompt with max_new_tokens that pushes the sequence
past 128 tokens) on transformers 5.x running below its required torch
(< 2.11). In that case the KV cache returns more positions for a
full-attention layer than the causal mask covers (pre-allocated cache
slots), so the eager path's `attn_weights += attention_mask` crashes.
The surplus key positions are masked out anyway, so attend only the
overlap: trim key/value (or the mask) to the shorter length in both
eager attention variants. This keeps inference correct and shape
consistent across transformers 4.57.x / 5.5.x and torch 2.9 - 2.11.
Also drop the dead singular `past_key_value` cache_position-free forward
variant: the singular naming predates transformers dropping
`cache_position`, so that signature combination does not exist in any
release. Only the `past_key_values` variant is reachable.
Verified on gpt-oss-20b across the full matrix (transformers
{4.57.6, 5.5.0} x torch {2.9.1, 2.10.0, 2.11.0}): all reasoning efforts
generate coherent output, and greedy continuations match across torch
versions for a given transformers version.
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
4 tasks
danielhanchen
pushed a commit
to mmathew23/unsloth-zoo
that referenced
this pull request
May 24, 2026
unsloth_zoo/mlx/trainer.py
Clean up the cast_norm_output_to_input_dtype monkey patch in
`train()`'s finally block. Previously `_set_norm_output_cast_to_input_dtype(True, model)`
patched MLX norm classes globally at line 847 but the cleanup
only undid gradient checkpointing and memory limits; subsequent
inference or trainers in the same Python process inherited the
cast-back wrapper. Wrap the restore in try/except so a partial
patch state never prevents `finally` from completing.
Raise on the text-streaming + dataset_order combination. The new
preserve_dataset_order / dataset_order config fields are honored
for non-streaming text and for VLM (streaming + materialized),
but `iterate_training_batches(...)` has no ordering argument, so
text streaming silently ignored the user-requested order. Raise
ValueError instead so the asymmetry is explicit and Studio /
CUDA parity stays loud.
unsloth_zoo/mlx/utils.py
Apply the Qwen3-VL full-sequence forward fix to the baseline CE
path too. _vlm_cce_forward already forwards the full multimodal
sequence and shifts `hidden[:, :-1]` afterwards because Qwen3-VL
image / mRoPE / deepstack state depends on the complete sequence.
make_vlm_baseline_loss_fn was still trimming `input_ids[:, :-1]`
pre-forward, so users who set `use_cce=False` saw a different
loss than `use_cce=True` for the same input. Forward the full
sequence and drop the final logits position afterwards to match.
tests/test_pr_a_deep_components.py
Fix the linear-no-warmup scheduler expectation. The previous
expected `[0.0, lr, lr*6/7, ...]` would have step 0 run at zero
LR and is inconsistent with `transformers.get_scheduler("linear",
num_warmup_steps=0, num_training_steps=n)`. Replace with the
HF-compatible `lr * (n - step) / n` series so the existing
_build_schedule() implementation passes the test.
(Note: this commit also includes the previous merge of origin/main
which restored mainline unslothai#690 / unslothai#691 gpt-oss eager-attention fixes
that the stale branch was about to revert.)
Merged
danielhanchen
pushed a commit
to shimmyshimmer/unsloth-zoo-staging-2
that referenced
this pull request
May 24, 2026
Fixes a "The size of tensor a (N) must match the size of tensor b (M)"
RuntimeError in gpt-oss inference when generation crosses the sliding
window (e.g. short prompt with max_new_tokens that pushes the sequence
past 128 tokens) on transformers 5.x running below its required torch
(< 2.11). In that case the KV cache returns more positions for a
full-attention layer than the causal mask covers (pre-allocated cache
slots), so the eager path's `attn_weights += attention_mask` crashes.
The surplus key positions are masked out anyway, so attend only the
overlap: trim key/value (or the mask) to the shorter length in both
eager attention variants. This keeps inference correct and shape
consistent across transformers 4.57.x / 5.5.x and torch 2.9 - 2.11.
Also drop the dead singular `past_key_value` cache_position-free forward
variant: the singular naming predates transformers dropping
`cache_position`, so that signature combination does not exist in any
release. Only the `past_key_values` variant is reachable.
Verified on gpt-oss-20b across the full matrix (transformers
{4.57.6, 5.5.0} x torch {2.9.1, 2.10.0, 2.11.0}): all reasoning efforts
generate coherent output, and greedy continuations match across torch
versions for a given transformers version.
LeoBorcherding
pushed a commit
to LeoBorcherding/unsloth-zoo
that referenced
this pull request
May 28, 2026
…i#693) * Add CPU-only regression tests for gpt-oss attention-mask patches Catches both recent gpt-oss inference regressions deterministically on a CI runner without a GPU: PR 690: when config._attn_implementation == "flex_attention" but inference runs through eager attention, create_causal_mask must not return a BlockMask. Otherwise the eager forward crashes with TypeError: unsupported operand type(s) for +=: 'Tensor' and 'BlockMask' PR 691: the eager_attention_forward closures must trim KV (or mask) to the shorter length, otherwise pre-allocated cache slots crash with RuntimeError: The size of tensor a (N) must match the size of tensor b (M) at non-singleton dimension 3 Each fix gets both a runtime invariant check and an AST source check, so behavioural regressions and accidental guard deletions both fail CI. Runtime test runs the actual transformers create_causal_mask call against a tiny GptOssConfig in a clean subprocess (so we can replace torch.compile with identity before unsloth_zoo loads -- the wrap captures _torch_compile via functools.partial at import time). Verified locally: * 4/4 PASS on real cpu-only torch 2.9.1+cpu, transformers 4.57.6 and 5.9.0, python 3.11. * 2/4 FAIL when the PR 690 wrap guard is removed (runtime catches the BlockMask, AST catches the missing literal). * 1/4 FAIL when the PR 691 _align_kv_to_mask call sites are removed. Added CI workflow runs the matrix: python {3.11, 3.12} x transformers {4.57.6, 5.9.0} on ubuntu-latest with the cpu torch wheel. * Move gpt-oss attention-mask guards into test_zoo_history_regressions The static AST / source checks for PR unslothai#690 (BlockMask leaking into the eager inference path) and PR unslothai#691 (eager KV length vs mask kv length alignment) belong in tests/test_zoo_history_regressions.py: each is a shipped fix on main and exactly matches that file's "pin past zoo bugs" brief. That file is already wired into the consolidated Tests CI job (repo-tests-cpu step + core-upstream-matrix invocation across HF/TRL cells), so dropping the guards in there gets them onto CI without touching .github/workflows. Added a fourth static guard while moving: the patched GptOssModel.forward must filter mask kwargs by the actual create_causal_mask signature (input_embeds + cache_position on transformers 4.57.6, inputs_embeds and no cache_position on transformers 5.x). Without this filter the patch silently breaks across the 4.x / 5.x boundary. tests/test_gpt_oss_attention_mask.py is now just the runtime subprocess invariant -- the call where torch.compile has to be swapped to identity BEFORE unsloth_zoo loads to avoid the CPU-torch kwarg-stripping bug masking the real BlockMask check. Verified across the full transformers matrix on cpu-only torch 2.9.1+cpu, python 3.11: - 4.57.6: PASS - 5.0.0 5.1.0 5.2.0 5.3.0 5.4.0: PASS - 5.5.0 5.5.1 5.5.2 5.5.3 5.5.4: PASS - 5.6.0 5.6.1 5.6.2: PASS - 5.7.0 5.8.0 5.8.1 5.9.0: PASS 18/18 versions, 6 tests per version (5 static + 1 runtime subprocess). Negative-controlled: - Reverting PR unslothai#690 wrap guard -> wrap-guard + runtime test FAIL. - Reverting PR unslothai#691 _align_kv_to_mask call sites -> align test FAIL. - Reverting inspect-driven mask kwargs filter -> filter test FAIL. * Tighten gpt-oss model-forward and KV-align AST guards Two follow-up adjustments to the new static guards so they catch the exact regression class they target, with no false negatives on a benign refactor: 1. test_gpt_oss_patched_model_forward_has_flex_attention_guard now walks the AST and asserts the literal lives in the patched GptOssModel.forward FunctionDef body, not anywhere in patch_GptOssModel's source. Previously the wrap closure's inner "flex_attention" literal satisfied the test on its own, so a revert of just the forward-level _swap_attn_impl block slipped through. Negative-controlled: removing the forward-level literal while keeping wrap's now correctly fails this test. 2. test_gpt_oss_eager_attention_aligns_kv_to_mask now slices the source of each eager forward FunctionDef individually (eager_attention_forward, inplace_eager_attention_forward) and asserts _align_kv_to_mask( appears in each body, instead of counting whole-file callsites >= 2. A future refactor that consolidates the two eager paths into a shared closure still passes as long as the alignment runs on both routes. Removing the call from either path still fails the test. Verified locally: 5/5 PASS on the existing positive run; the full negative-control matrix (revert wrap literal, strip _align callsites, strip inspect-filter helpers, strip outer _swap_attn_impl) now bites exactly the matching test for each revert, and the previously-blind outer-only revert (case 3e) is caught. --------- Co-authored-by: danielhanchen <danielhanchen@users.noreply.github.com>
danielhanchen
added a commit
that referenced
this pull request
Jun 14, 2026
* Tighten MLX VLM training parity diagnostics * Match Qwen3-VL rotary precision in MLX * Disable Qwen3-VL MLX compile verification * Match HF AdamW decay filtering in MLX * Preserve Qwen3-VL residual dtype in MLX vision block * update * udpate vlm * bring back correct loss curves * update textdataset * dataset ordering fix, lr fix * use proportional MLX grad value clipping * cast norm activation output back to original input dtype * address mlx training review feedback * fix(mlx): cast custom norm outputs * feat: auto discover custom norm from model * fix(mlx): harden norm output cast discovery * fix(mlx): preserve custom norm keyword calls * harden mlx custom norm output casting * Fix four loose ends for PR #684 unsloth_zoo/compiler.py Restore the dedicated UNSLOTH_RETURN_LOGITS=1 elif branch in cross_entropy_replacement_2 (originally added by #666, commit f45c31e). Without it the regex template fallback path under UNSLOTH_FUSED_FORWARD=0 ran self.lm_head twice on the UNSLOTH_RETURN_LOGITS=1 path: once via the prepended materialise and once again in the final else branch. The AST rewriter at fused_losses/forward_install.py is unaffected. unsloth_zoo/mlx/trainer.py + unsloth_zoo/mlx/loader.py Expand the AdamW weight-decay filter and the fp32 norm-parameter filter to also match GPT-2 style ln_1 / ln_2 / ln_f names. Previous "norm" substring missed them; _ensure_lora_frozen already treated those as norm fragments, so the filters were inconsistent. unsloth_zoo/mlx/compile.py Honor cast_norm_output_to_input_dtype=False on the Qwen3-VL vision-block patch. Added a module-level _QWEN3_VISION_NORM_CAST_OUTPUT flag with a setter; the trainer's _set_norm_output_cast_to_input_dtype flips it so the generic norm patcher and the Qwen3 specialized norm patch agree. unsloth_zoo/mlx/utils.py Reseed the iterate_vlm_training_batches default branch per epoch via np.random.default_rng(base_seed + epoch). The torch_randperm branch already did this; the default branch's order previously depended on global numpy RNG state. * Address reviewer round 1 P1 findings on PR #684 unsloth_zoo/mlx/trainer.py Clean up the cast_norm_output_to_input_dtype monkey patch in `train()`'s finally block. Previously `_set_norm_output_cast_to_input_dtype(True, model)` patched MLX norm classes globally at line 847 but the cleanup only undid gradient checkpointing and memory limits; subsequent inference or trainers in the same Python process inherited the cast-back wrapper. Wrap the restore in try/except so a partial patch state never prevents `finally` from completing. Raise on the text-streaming + dataset_order combination. The new preserve_dataset_order / dataset_order config fields are honored for non-streaming text and for VLM (streaming + materialized), but `iterate_training_batches(...)` has no ordering argument, so text streaming silently ignored the user-requested order. Raise ValueError instead so the asymmetry is explicit and Studio / CUDA parity stays loud. unsloth_zoo/mlx/utils.py Apply the Qwen3-VL full-sequence forward fix to the baseline CE path too. _vlm_cce_forward already forwards the full multimodal sequence and shifts `hidden[:, :-1]` afterwards because Qwen3-VL image / mRoPE / deepstack state depends on the complete sequence. make_vlm_baseline_loss_fn was still trimming `input_ids[:, :-1]` pre-forward, so users who set `use_cce=False` saw a different loss than `use_cce=True` for the same input. Forward the full sequence and drop the final logits position afterwards to match. tests/test_pr_a_deep_components.py Fix the linear-no-warmup scheduler expectation. The previous expected `[0.0, lr, lr*6/7, ...]` would have step 0 run at zero LR and is inconsistent with `transformers.get_scheduler("linear", num_warmup_steps=0, num_training_steps=n)`. Replace with the HF-compatible `lr * (n - step) / n` series so the existing _build_schedule() implementation passes the test. (Note: this commit also includes the previous merge of origin/main which restored mainline #690 / #691 gpt-oss eager-attention fixes that the stale branch was about to revert.) * Preserve embedder position_ids in _vlm_cce_forward `_unpack_embed_result` can return a `position_ids` adjusted for the merged multimodal sequence (e.g. Qwen-VL family adjusts mRoPE / 3D position_ids during get_input_embeddings). The previous code unconditionally overwrote backbone_kwargs["position_ids"] with the raw batch position_ids, discarding the embedder's corrected version. Only inject the raw position_ids when the embedder did not produce its own. * Address reviewer round 2 findings on PR #684 create_ordered_batches no longer mixes the last partial batch of one epoch with the first samples of the next. Take a contiguous slice of the current epoch order, emit a partial batch if the tail is short, and start the next batch fresh at epoch+1. Matches the VLM ordered path at utils.py:2539 and SequentialSampler(drop_last=False). Baseline labels=None mask in make_baseline_loss_fn now uses `<` (exclusive end) so the unlabeled path agrees with the CCE (utils.py:360,:393) and labels-aware (utils.py:439) masks; pre-fix it was `<=` and trained on one extra padded position when the row hit max_seq_length. _prepare_dataset, create_batches, create_ordered_batches, iterate_training_batches and _create_labeled_batches all gained an `append_eos` parameter that the trainer plumbs from MLXTrainingConfig (default True). Direct MLX text fine-tuning callers (raw {"text": str} rows) again get mlx-lm parity EOS appending; Studio passes False because its chat template already renders EOS. _create_labeled_batches now honors dataset_order / preserve_dataset_order: skips the length-based sort and per-batch shuffle when the caller has asked for sequential or torch_randperm order. Without this the train_on_responses_only path silently rewrote the sample order set by the new Studio CUDA parity flags. _create_labeled_batches emits lengths as right-half-open `[1, L]` to match the new exclusive-end mask convention; pre-fix it was `[1, L - 1]` which paired with the old `<=` mask and now would drop the final supervised token. test_mlx_text_dataset_does_not_append_eos updated: Studio explicitly passes append_eos=False, default callers still receive EOS. * Extend HF parity decoupled weight decay to SGD/Muon/Lion for PR #684 AdamW already used a manual decoupled bias/norm-aware decay with `weight_decay=0.0` on the underlying MLX optimizer. SGD, Muon, and Lion still passed `weight_decay=wd` directly to the MLX optimizer, which applies wd uniformly across every trainable parameter (including bias and norm leaves) and uses MLX's internal coupling semantics rather than HF's decoupled per-step `param *= 1 - lr * wd`. Mirror the AdamW pattern for the other three optimizers: set the underlying MLX optimizer's `weight_decay` to zero and let the existing manual helper own the decoupled decay term. `_manual_adamw_weight_decay` renamed to `_manual_weight_decay` (and the helper to `_apply_manual_weight_decay`) since it now covers four optimizers. Tests updated for the rename and a parametrized SGD/Muon/Lion case added asserting the manual decay scalar is set and the optimizer itself carries `weight_decay=0.0`. * Address remaining reviewer round 2 findings on PR #684 max_grad_value: restore elementwise clip semantics ============================================ The PR replaced the existing `mx.clip(g, -v, +v)` with a per-leaf L2 norm rescale, so the field's name no longer matched its behavior and existing tests / docstrings describe elementwise semantics. Four reviewers flagged this as a public-API regression. Switch back to `mx.clip(g, -max_grad_value, max_grad_value)` (still per-leaf, no cross-leaf reduction). The function is renamed `_clip_grad_by_value` to match the contract. VLM iterable streaming: refuse dataset_order instead of dropping it ============================================ `iterate_vlm_training_batches` honored `dataset_order="torch_randperm"` on sized datasets but silently streamed source order on unsized / iterable ones. The text streaming path already raises in this asymmetry (`trainer.py:1758`); mirror that here so users get a clear error rather than a silent CUDA-parity regression. Qwen3-VL vision norm-cast flag: restore prior state in finally ============================================ `_set_norm_output_cast_to_input_dtype(False, model)` in the train() finally also toggles the Qwen3 vision flag via `set_qwen3_vision_norm_cast_output(False)`. The module-level default is True, so post-training inference in the same process would see the flag stuck at False. Capture the previous flag value before training and restore it explicitly in finally. create_ordered_batches: pad with the tokenizer's pad id ============================================ Padded positions used literal `0` rather than `tokenizer.pad_token_id`, which can collide with a regular vocabulary token for tokenizers whose pad id is not 0. Fall back to 0 only when the tokenizer has no pad id. * Address reviewer round 3 P1/P2 findings on PR #684 Labeled batches: torch_randperm at the sample level ============================================ `_create_labeled_batches` had `dataset_order="torch_randperm"` shuffle batches after sequential batching, while the unlabeled path (`create_ordered_batches` / `_torch_randperm_order`) shuffles samples before batching. Studio + CUDA `RandomSampler` is sample-granular, so `train_on_responses_only(..., dataset_order="torch_randperm")` ended up grouping rows differently from non-completion MLX training. Now both paths apply `_torch_randperm_order(seed)` to the sample list before batching; the legacy `default` / `None` path still length-sorts and shuffles batches. Unsupported values now raise instead of silently falling back to default shuffle. create_ordered_batches: honor num_epochs when num_batches is None ============================================ When `_prepare_data()` selects `create_ordered_batches(num_epochs=N)` for `max_steps <= 0` + `num_train_epochs > 0`, the loop previously exited at the first epoch boundary (`break` on `num_batches is None`). The intent is to emit `N * len(dataset)` samples worth of batches; extended the boundary check to stop only when the requested sample total has been emitted or no batches were requested. train() norm-patch lifecycle hardening ============================================ Moved `_set_norm_output_cast_to_input_dtype(cast_norm_output, model)` INSIDE the `try` block so a raise from `normalize_mlx_patch_mode`, `_configure_memory_limits`, compile policy, gradient checkpointing, or Qwen3.5 preflight does not leak the patched RMSNorm / LayerNorm class globals across train() boundaries. VLM train_on_completions in the labels-aware branch ============================================ `_collate_vlm_batch` now always attaches `batch["labels"]`, so `_vlm_cce_forward` and `make_vlm_baseline_loss_fn` always take their labels-aware branches for ordinary VLM SFT. Those branches were missing the `_mask_prompt_tokens(...)` call that the labels=None branches already perform, so `train_on_completions=True` silently trained on prompt tokens. Added the call to both branches. Ruff lint: clean up new E741 / F401 / F841 in changed files ============================================ Removed unused `import mlx.core as mx`, `mask = kwargs.get("mask",...)` that was never read, dead `original_sanitize` / `wanted` / `hidden_dim` locals, and renamed ambiguous `l` loop variables in `mx.eval(...)` batch flushes. Ruff exits clean on `compile.py / loader.py / trainer.py / utils.py` after this commit. * Restore mask = kwargs.get for 4 patched VLM get_input_embeddings A round-3 replace_all that targeted unused mask extractions in two patched_qwen2 get_input_embeddings variants also stripped the same line from four other patched VLM get_input_embeddings functions where mask is actually passed into self.language_model.get_rope_index. That broke qwen3, qwen35, glm, and a generic VLM get_input_embeddings: NameError on first call. Restored the mask = kwargs.get('mask', None) line in the four functions that use it; the two qwen2 callers (where mask is truly unused) remain stripped. * Materialize multiple epochs of labeled batches when num_epochs>1 for PR #684 The unlabeled torch_randperm path in create_ordered_batches materializes N*len(dataset)/batch_size batches with a per-epoch reseeded permutation when num_epochs is set. The labeled train_on_responses_only path stored trainer._batches at one epoch worth of permutations, and the trainer loop at line 1466 then cycled batches[batch_idx % len(batches)], so num_train_epochs=2 trained on the same row order in epoch 1 and epoch 2. _create_labeled_batches now accepts num_epochs and emits one block of permuted batches per requested epoch, each with seed+epoch_idx, matching the unlabeled path. Wired the trainer call site to pass num_train_epochs when set, and to set _prepared_batches_include_epochs so the existing total_steps math at trainer.py:1035 does not multiply through again. * Rename PR-numbered tests and shorten verbose comments * Add MLX max grad leaf norm clipping * Restore collated VLM position ids for parity * Scope VLM position id override to collated ids * Preserve returned VLM position ids * Fix MLX VLM parity masking edge cases * Route text-only VLM loads through text trainer * Match BNB nested NF4 scale quantization * Match CUDA VLM resize-min behavior in MLX * Match Gemma3 vision post norm epsilon * Run Gemma3 vision SDPA in fp32 on MLX * Match Gemma3 image feature scaling on MLX * Use Gemma3 image attention mask in MLX VLM CCE * Clip MLX global grad norms in fp32 * Match Gemma3 vision fp32 norm and activation math * Disable Gemma3 MLX training compile pending parity * Match Gemma3 text RMSNorm fp32 math * Preserve VLM hidden stack activation dtype Rationale / guardrails for the local Gemma3 parity stack: This is the last local-only zoo commit before push, so this body documents the changes that should not be accidentally flipped back during review. Do not restore the broader Daniel position-id override. VLM CCE should prefer collator-built position_ids only when _unsloth_collated_position_ids is set, preserve position_ids explicitly returned by InputEmbeddingsFeatures, and otherwise fall back to model-stashed or sequential ids. The broad override moved Qwen/Gemma-style VLM runs away from CUDA collation semantics. Do not re-add global pad_token_id masking to the VLM loss. Padding is masked by labels/attention masks; globally ignoring pad ids also suppresses legitimate target ids for custom datasets. Image/video placeholder token ids are the only global ignore ids needed for VLM CCE. Do not mark Gemma3 training compile verified yet. Fixed-fixture Gemma3 showed compiled loss differing from eager before optimizer update, so best-effort must fall back to eager until real training parity is proven. Do not remove the Gemma3 MLX-vLM patches as cosmetic. The current patches fix concrete CUDA parity mismatches: SigLIP post-layernorm eps, vision SDPA fp32 math with cast-back, vision LayerNorm/GELU fp32 math with cast-back, text RMSNorm fp32 math with cast-back, image feature scaling by text embedding width, image-token attention masking in CCE, and preserving merged VLM inputs_embeds dtype instead of promoting activations to fp32 because norm weights are fp32. Do not switch MLX grad clipping back to bf16 reductions. Global grad norm clipping should reduce in fp32; bf16 reductions changed clipping behavior. Validation summary: focused MLX/Gemma3/VLM tests pass, and the remaining Gemma3 VLM delta was isolated to cumulative bf16/backend drift through the 27-layer SigLIP tower rather than labels, preprocessing, position ids, projector, final post-LN, block-0 attention backward, or weight mapping. * Restore Gemma3 MLX training compile qualification Keep Gemma3 in the verified MLX training compile set. The observed eager-vs-compiled loss deltas are small enough that Gemma3 should continue using compile rather than falling back to eager by policy. Update the regression test to assert the intended compile qualification so this does not get accidentally demoted again. * Handle quantized CCE layer modes * Rename grad-clip test to reflect three-mode scope `test_mlx_max_grad_value_none.py` now covers max_grad_leaf_norm and max_grad_norm too. Rename to test_mlx_grad_clip_resolution.py and update the docstring to list all three knobs. * Add Metal training e2e and VJP regression tests for PR #684 staging (cherry picked from commit 9891b6d) * Save adapters at end of training to honor the save_steps=0 contract (cherry picked from commit fb046c2) * Fix review findings in MLX trainer, VLM utils, and NF4 dequant for PR #684 * Add deep Metal validation tests for resume, completion-only, epochs, SGD, and VLM training for PR #684 * Unwrap non-callable tokenizer wrappers in train_on_responses_only and use Qwen2-VL for the VLM e2e test * Disable fused MRoPE for Qwen2-VL family training too * Run the deep MLX validation tests in Mac CI for PR #684 * Document MLX per-leaf grad clip default and CUDA max_grad_norm tradeoff for PR #684 --------- Co-authored-by: Lyxot <longyixing331@gmail.com> Co-authored-by: Daniel Han-Chen <info@unsloth.ai> Co-authored-by: Daniel Han <danielhanchen@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Follow-up to #690. Fixes a
RuntimeError: The size of tensor a (N) must match the size of tensor b (M) at non-singleton dimension 3in gpt-oss inference that triggers when generation crosses the sliding window (e.g. a short prompt withmax_new_tokens=64that pushes the sequence past 128 tokens).Root cause
On transformers 5.x running below its required torch (
< 2.11), the KV cache hands a full-attention layer more positions than the causal mask covers (pre-allocated cache slots). During prefill of a 98-token prompt with 64 new tokens, a full-attention layer seeskey_statesof length 161 while the mask is only 128 wide, so the eager path'sattn_weights += attention_maskcrashes:The surplus key positions are masked out anyway (causal mask sets them to
-inf), so they contribute nothing.Fix
Attend only the overlap: trim key/value (or the mask) to the shorter of the two lengths in both eager attention variants (
inplace_eager_attention_forwardandeager_attention_forward). This keeps the path correct and shape-consistent regardless of how the cache pre-allocates across versions.Also drops the dead singular
past_key_valuecache_position-free forward variant added in #690: the singular naming predates transformers droppingcache_position, so that signature combination does not exist in any release. Only thepast_key_valuesvariant is reachable.Test plan
Verified on
unsloth/gpt-oss-20b(4bit) + LoRA across the full matrix:max_new_tokens=64(crosses the 128 sliding window)All 12+ combinations generate coherent output (previously crashed on tf5.x + torch < 2.11), and greedy continuations match across torch versions for a given transformers version.