Skip to content

fix(mlx): match mlx-lm batch padding rule (1 + pad_to * ceil)#672

Merged
danielhanchen merged 2 commits into
mainfrom
fix-mlx-pad-multiple
May 19, 2026
Merged

fix(mlx): match mlx-lm batch padding rule (1 + pad_to * ceil)#672
danielhanchen merged 2 commits into
mainfrom
fix-mlx-pad-multiple

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

unsloth_zoo.mlx.trainer.create_text_batches previously padded each batch to _PAD_MULTIPLE * ceil(max_len / _PAD_MULTIPLE). mlx-lm's iterate_batches (mlx_lm/tuner/trainer.py:158) pads to 1 + _PAD_MULTIPLE * ceil(max_len / _PAD_MULTIPLE) $-$ one extra token. After the autoregressive shift in both default_loss (mlx-lm) and make_baseline_loss_fn (zoo) (inputs = batch[:, :-1] / targets = batch[:, 1:]), zoo's inputs were one token SHORTER than mlx-lm's.

This PR adds the +1 so zoo's padding matches mlx-lm value-for-value.

Why

Empirical on gemma-3-270m-it, single-row LoRA memorization, 15 paired seeds (danielhanchen/unsloth-staging-2 mlx-parity-probes, Round BM run 26050214501):

probe loader trainer num_layers clip pass
31 mlx_lm.load manual @mx.compile 16 none 10/15 = 67%
33 mlx_lm.load MLXTrainer, compile=False 16 None (silent rebind to 1.0) 8/15 = 53%
34 FastMLXModel(dtype=None) MLXTrainer, compile=False 16 None 7/15 = 47%
35 mlx_lm.load MLXTrainer, compile=True 16 None 8/15 = 53%
36 FastMLXModel(dtype=None) MLXTrainer, compile=True 16 None 7/15 = 47%
37 mlx_lm.load MLXTrainer, compile=False 16 0.0 (explicit off) 6/15 = 40%

Bisecting:

  • Same loader, same nl, same loss, same optimizer, same seed list, varying only trainer flips probe 31's 67% down to MLXTrainer's 40-53%.
  • The single user-visible difference upstream of the model forward is the input shape per batch: mlx-lm's (B, 33) vs zoo's (B, 32) on this 14-token fixture.

Teacher-forced completion loss is 0 in 15/15 seeds across every probe, so the model fully memorizes either way; the basin is purely a greedy-decode argmax artifact. cf_loss < 0.5 smoke gating per unslothai/unsloth#5537 stays green regardless. But the basin gap IS a parity defect against mlx-lm CLI, which is the reference implementation for the MLX path.

Behavior

  • _PAD_MULTIPLE stays at 32. Unchanged.
  • max_len=14 $\to$ batch shape (B, 33) (was (B, 32)). Changed by +1 token.
  • max_len=32 $\to$ batch shape (B, 33) (was (B, 32)).
  • max_len=33 $\to$ batch shape (B, 65) (was (B, 64)).
  • All boundaries shift up by exactly one token. Memory bump is O(B * 1) per batch, negligible.

After this PR, make_baseline_loss_fn's inputs = batch[:, :-1] slice yields the same effective length as mlx-lm's default_loss, so identical hyperparameters produce the same per-step graph shape.

Test plan

  • tests/test_mlx_batch_padding.py $\to$ 13 cases: 11 boundary lengths value-for-value against mlx-lm; source-string assertion guarding the +1; _PAD_MULTIPLE == 32 invariant.
  • Local: pytest tests/test_mlx_batch_padding.py -v $\to$ 13 passed.
  • Will rerun the Round BM probe matrix against this branch to confirm probes 33-37 close the gap to ~67%.

Related

Part of the MLX vs mlx-lm parity bisection:

  • #669 $\to$ finetune_last_n_layers (layer-selection mismatch).
  • unslothai/unsloth#5564 $\to$ same knob on the CUDA path.
  • #670 $\to$ bf16 $\to$ fp16 downcast warning.
  • #671 $\to$ honor max_grad_value=None AND default to None for HF parity (closes #662).
  • This PR $\to$ batch padding matches mlx-lm CLI; closes the last bisected gap in the MLXTrainer path.

create_text_batches previously padded each batch to
`_PAD_MULTIPLE * ceil(max_len / _PAD_MULTIPLE)`. mlx-lm's
`iterate_batches` (mlx_lm/tuner/trainer.py:158) pads to
`1 + _PAD_MULTIPLE * ceil(max_len / _PAD_MULTIPLE)` -- one extra
token, which after the autoregressive shift in `default_loss`
(`inputs = batch[:, :-1]` / `targets = batch[:, 1:]`) gives the
model exactly `_PAD_MULTIPLE * ceil(...)` attention positions.

unsloth_zoo's `make_baseline_loss_fn` does the same shift, but
since zoo's padded_len dropped the +1, the inputs were one token
SHORTER than mlx-lm's after the shift. On small fixtures (the
single-row LoRA memorization smoke against `gemma-3-270m-it`)
that one-token gap moved the run into a different basin of
attraction:

  probe 31 (manual mlx-lm loop, nl=16, no clip): 10/15 = 67%
  probe 33 (zoo MLXTrainer, nl=16, None silent-clip)         :  8/15 = 53%
  probe 34 (zoo FastMLXModel + MLXTrainer, nl=16)            :  7/15 = 47%
  probe 37 (zoo MLXTrainer, nl=16, explicit max_grad_value=0):  6/15 = 40%

  (See `danielhanchen/unsloth-staging-2` mlx-parity-probes
   matrix, Round BM run 26050214501.)

Teacher-forced completion loss is 0 in 15/15 seeds across every
probe -- the model fully memorizes either way. The greedy-decode
basin is what shifts. cf_loss < 0.5 smoke gating per
unslothai/unsloth#5537 stays green regardless, so this is not a
training-quality defect -- but it IS a parity defect against
mlx-lm CLI, which is the primary reference implementation for
the MLX path.

Tests: tests/test_mlx_batch_padding.py pins the padding rule
against mlx-lm's value-for-value across 11 length boundaries
plus a source-string assertion guarding against future refactor
drift, plus a check that _PAD_MULTIPLE stays at 32 (mlx-lm uses
the same constant).
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@danielhanchen

Copy link
Copy Markdown
Member Author

On closer audit: _create_labeled_batches is only used by the train_on_responses_only code path, not by the default text-mode MLXTrainer.train() flow. The default text path goes create_batches -> mlx_lm.tuner.trainer.iterate_batches directly, which already does 1 + 32 * ceil(L/32) correctly.

So this PR fixes a real, latent off-by-one in _create_labeled_batches (train_on_responses_only users hit it) but it does NOT close the basin-selection gap observed in the parity probes. The probes don't exercise this function.

Round BO is running against this branch to confirm; expectation now is that probes 33-37 stay at 40-53% (pad-fix doesn't apply to their code path) and the trainer-side gap remains unresolved.

Keeping the PR open as a stand-alone correctness fix: the docstring comment promised "matching mlx-lm" but the formula didn't, so users who switched to train_on_responses_only got input tensors one token shorter than mlx-lm's after the autoregressive shift.

@danielhanchen

Copy link
Copy Markdown
Member Author

Round BO final results

Pinned ZOO_SPEC=b265d99 (this PR's branch) and re-ran the probe matrix. Probes 34/36 failed because the pad-fix branch is based on origin/main which doesn't have finetune_last_n_layers yet (#669 still open). The other probes finished:

probe loader trainer nl clip pass
30 mlx_lm.load manual @mx.compile all 18 none 7/15 = 47%
31 mlx_lm.load manual @mx.compile 16 none 10/15 = 67%
35 mlx_lm.load MLXTrainer, compile=True 16 None (silent 1.0) 8/15 = 53%
37 mlx_lm.load MLXTrainer, compile=False 16 0.0 (explicit off) 6/15 = 40%

Identical to Round BM and Round BL within seed noise. So as expected, this PR's +1 fix to _create_labeled_batches does NOT close the basin gap, because the probes don't exercise _create_labeled_batches (they use create_batches -> mlx_lm.tuner.trainer.iterate_batches which already had the +1).

What the trainer-side gap actually is

Sampling generations across seeds (cf_loss = 0 in 15/15 for every probe — the model has memorized perfectly):

seed probe 31 (manual) probe 35 (zoo, silent clip) probe 37 (zoo, no clip)
1 <<HELLO>> My name is Unsloth! PASS 3 UNsloth! FAIL 1 Unsloth<<<<<<HELLO PASS
42 42<< My name is Unsloth! PASS 41<<HELLO!! FAIL 응ம்ம 비룡 égale cavité<pad>... FAIL
999 <<HELLO>> My Unsloth! PASS Unsloth! PASS Unsloth! PASS
3407 admissions clerk Unsloth! PASS 4<<HELLObtnUnsloth>> My Unsloth! PASS 42<< MY UNsloth FAIL
22222 Unsloth! PASS 23!! FAIL 24<<HELLO!!>> My name is Unsloth! PASS

cf_loss is 0 across every cell, so the model genuinely knows the training string in every case. Greedy decode just wanders: sometimes through <<HELLO>> boilerplate, sometimes through random tokens, sometimes case-folded (UNsloth vs Unsloth). The trainer-side ~20pp swing is a real numerical-divergence-at-convergence artifact, but its consequence is exclusively greedy-decode argmax brittleness, not a "training quality" defect.

Net take-aways for this PR series

  1. The basin gap that triggered this investigation isn't a training defect $-$ teacher-forced completion loss reaches 0 in every config. unslothai/unsloth#5537's cf_loss < 0.5 gate is the right CI gate and is bulletproof.
  2. Greedy-decode CI on a single-row memorization fixture is inherently fragile; the existing PR #5537 already replaced that with cf_loss.
  3. This PR still ships a real correctness fix to _create_labeled_batches (used by train_on_responses_only): the comment promised mlx-lm parity but the formula dropped the +1. Keeping the PR open as a stand-alone cleanup.
  4. The MLX vs mlx-lm CLI parity work landed across five PRs:
    • #669 $\to$ finetune_last_n_layers knob (zoo).
    • unslothai/unsloth#5564 $\to$ same knob, CUDA path.
    • #670 $\to$ warn on bf16$\to$fp16 downcast (Gemma3 silent precision loss).
    • #671 $\to$ max_grad_value=None honors disable + default to None for HF/TRL parity (closes #662).
    • This PR $\to$ _create_labeled_batches padding matches mlx-lm.

Per code-comment policy: keep WHY (autoregressive shift headroom),
drop empirical probe results — those live in commit b265d99's message.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a3cd4d874b

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

def test_source_padding_formula_includes_plus_one():
"""Guard against a future refactor that drops the +1 again."""
import inspect
from unsloth_zoo.mlx import trainer

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid importing the package in the new test

In a zoo-only test environment where unsloth is not installed, this import executes unsloth_zoo/__init__.py before the shimmed MLX trainer module is loaded, and that top-level init raises ImportError("Please install Unsloth..."); pytest tests/test_mlx_batch_padding.py -q currently fails here before checking the padding formula. This test should inspect/load unsloth_zoo/mlx/trainer.py without going through the package init, or otherwise provide the same stub/setup expected by that import path.

Useful? React with 👍 / 👎.

@danielhanchen danielhanchen left a comment

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! The goal of this PR is to align _create_labeled_batches's padding rule with mlx_lm.tuner.trainer.iterate_batches so that after the autoregressive shift (inputs = batch[:,:-1] / targets = batch[:,1:]) the effective sequence length is a clean multiple of _PAD_MULTIPLE. As a summary, this PR changes padded_len = _PAD_MULTIPLE * ceil(max_len / _PAD_MULTIPLE) to padded_len = 1 + _PAD_MULTIPLE * ceil(max_len / _PAD_MULTIPLE). The +1 gives the post-shift inputs/targets one token of headroom. The function is only invoked from train_on_responses_only, so this is correctness-only for that path and does not affect the main MLXTrainer flow.

Two independent Opus reviewers were run in parallel on this PR.

Reviewers Severity Finding
1/2 Med The new test file tests/test_mlx_batch_padding.py does not appear in any CI workflow (e.g. .github/workflows/mlx-ci.yml only runs the smoke shim). The regression guard never fires on PRs.
2/2 Med test_source_padding_formula_includes_plus_one is a literal-substring pin on the formula. Any black/ruff reformat (operand order swap, intermediate variable rename, import math; math.ceil(...) rewrite) breaks the test even though the math is unchanged.
1/2 Med padded_len = min(padded_len, max_seq_length) can clamp the +1 back off when max_len == max_seq_length, silently restoring the pre-PR (mlx-lm-aligned) edge behavior — but the comment claims "post-shift length is a clean multiple" unconditionally. Document the boundary.
1/2 Nit No parametrized case at max_len == max_seq_length to lock in the clamp interaction.
1/2 Nit The test re-implements _zoo_padded_len rather than importing the rule from trainer.py. If create_text_batches exposed _padded_len(max_len), the test could call it directly and the source-string guard would be unnecessary.
1/2 Nit Test file header says create_text_batches but the function is _create_labeled_batches.

Overall: APPROVE_WITH_NITS.

See inline comments for details and suggested fixes.

padded_len = ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE
# Match mlx-lm iterate_batches: +1 gives the autoregressive
# shift headroom so post-shift length is a clean _PAD_MULTIPLE.
padded_len = 1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Correctness verified — trace for max_len=63, _PAD_MULTIPLE=32: 1 + 32*ceil(63/32) = 1 + 32*2 = 65, which after the shift gives inputs.shape[1] = 64 = 32*2 (clean bucket). Matches mlx_lm/tuner/trainer.py:158.

[1/2 reviewers] Med caveat: the min(padded_len, max_seq_length) clamp on the next line can silently revert this when max_len == max_seq_length — both pre-PR and post-PR end at padded_len = max_seq_length, so the +1 is dropped without an alignment guarantee. mlx-lm clamps the same way, so this matches upstream behavior, but the comment overstates the invariant. Tighten to:

Suggested change
padded_len = 1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE
# Match mlx-lm iterate_batches: +1 gives the autoregressive
# shift headroom so post-shift length is a clean _PAD_MULTIPLE.
# Note: the subsequent min(., max_seq_length) clamp can drop the
# +1 when max_len == max_seq_length; this matches mlx-lm's own
# behavior at that boundary.
padded_len = 1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE

assert _zoo_padded_len(max_len) == _mlx_lm_padded_len(max_len)


def test_source_padding_formula_includes_plus_one():

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Med: this test pins the formula as a literal substring ("1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE"). Any cosmetic reformat (operand swap to mlx-lm style _PAD_MULTIPLE * (...), intermediate variable, math.ceil-rewrite) breaks it without breaking the math. The parametrized value-equivalence test above already pins correctness. Drop this or relax to a structural check:

Suggested change
def test_source_padding_formula_includes_plus_one():
def test_source_padding_formula_includes_plus_one():
"""Defense in depth: ensure the function source mentions a `+ 1`
near the padded_len computation so a future refactor cannot quietly
drop the autoregressive-shift headroom. Pin the intent, not the
exact spelling."""
src = inspect.getsource(trainer_mod._create_labeled_batches)
# Pull out the padded_len assignment block and require a literal +1.
block = re.search(r"padded_len\s*=\s*[^\n]+", src)
assert block is not None
assert "+ 1" in block.group(0) or "1 +" in block.group(0), (
f"expected '+1' in padded_len assignment, got: {block.group(0)!r}"
)

)


def test_pad_multiple_constant_still_32():

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Nit: no parametrized case covers max_len == max_seq_length (where the clamp engages and the +1 is silently dropped). Add it to document the intentional asymmetry at the boundary:

Suggested change
def test_pad_multiple_constant_still_32():
def test_pad_multiple_constant_still_32():
"""mlx-lm uses _PAD_MULTIPLE=32. Pin it so a future refactor cannot
silently drift."""
assert trainer_mod._PAD_MULTIPLE == 32
@pytest.mark.parametrize("max_len, max_seq_length, expected", [
(63, 64, 64), # +1 would be 65; clamped to 64
(64, 64, 64), # already at clamp boundary; +1 dropped
(32, 128, 33), # well below clamp; +1 retained
])
def test_padded_len_respects_max_seq_length_clamp(max_len, max_seq_length, expected):
"""Document the clamp interaction with the +1 rule."""
raw = 1 + ((max_len + 32 - 1) // 32) * 32
assert min(raw, max_seq_length) == expected

@@ -0,0 +1,90 @@
# Unsloth Zoo - Utilities for Unsloth

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Med + Nit: (a) this test file is not added to any CI workflow — the regression guard never fires on PRs. Add it to mlx-ci.yml or the consolidated repo-tests-cpu job since it uses the torch-based MLX shim and is CPU-safe. (b) The header references create_text_batches but the function is _create_labeled_batches.

Suggested change
# Unsloth Zoo - Utilities for Unsloth
# Unsloth Zoo - Utilities for Unsloth
# Pin _create_labeled_batches's padding rule against mlx-lm's
# iterate_batches (mlx_lm/tuner/trainer.py:158).
# REMINDER: add this file to .github/workflows/mlx-ci.yml so the
# regression guard runs in CI.

@danielhanchen danielhanchen merged commit 833ad01 into main May 19, 2026
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant