fix(mlx): match mlx-lm batch padding rule (1 + pad_to * ceil)#672
Conversation
create_text_batches previously padded each batch to `_PAD_MULTIPLE * ceil(max_len / _PAD_MULTIPLE)`. mlx-lm's `iterate_batches` (mlx_lm/tuner/trainer.py:158) pads to `1 + _PAD_MULTIPLE * ceil(max_len / _PAD_MULTIPLE)` -- one extra token, which after the autoregressive shift in `default_loss` (`inputs = batch[:, :-1]` / `targets = batch[:, 1:]`) gives the model exactly `_PAD_MULTIPLE * ceil(...)` attention positions. unsloth_zoo's `make_baseline_loss_fn` does the same shift, but since zoo's padded_len dropped the +1, the inputs were one token SHORTER than mlx-lm's after the shift. On small fixtures (the single-row LoRA memorization smoke against `gemma-3-270m-it`) that one-token gap moved the run into a different basin of attraction: probe 31 (manual mlx-lm loop, nl=16, no clip): 10/15 = 67% probe 33 (zoo MLXTrainer, nl=16, None silent-clip) : 8/15 = 53% probe 34 (zoo FastMLXModel + MLXTrainer, nl=16) : 7/15 = 47% probe 37 (zoo MLXTrainer, nl=16, explicit max_grad_value=0): 6/15 = 40% (See `danielhanchen/unsloth-staging-2` mlx-parity-probes matrix, Round BM run 26050214501.) Teacher-forced completion loss is 0 in 15/15 seeds across every probe -- the model fully memorizes either way. The greedy-decode basin is what shifts. cf_loss < 0.5 smoke gating per unslothai/unsloth#5537 stays green regardless, so this is not a training-quality defect -- but it IS a parity defect against mlx-lm CLI, which is the primary reference implementation for the MLX path. Tests: tests/test_mlx_batch_padding.py pins the padding rule against mlx-lm's value-for-value across 11 length boundaries plus a source-string assertion guarding against future refactor drift, plus a check that _PAD_MULTIPLE stays at 32 (mlx-lm uses the same constant).
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
On closer audit: So this PR fixes a real, latent off-by-one in Round BO is running against this branch to confirm; expectation now is that probes 33-37 stay at 40-53% (pad-fix doesn't apply to their code path) and the trainer-side gap remains unresolved. Keeping the PR open as a stand-alone correctness fix: the docstring comment promised "matching mlx-lm" but the formula didn't, so users who switched to |
Round BO final resultsPinned
Identical to Round BM and Round BL within seed noise. So as expected, this PR's What the trainer-side gap actually isSampling generations across seeds (cf_loss = 0 in 15/15 for every probe — the model has memorized perfectly):
cf_loss is 0 across every cell, so the model genuinely knows the training string in every case. Greedy decode just wanders: sometimes through Net take-aways for this PR series
|
Per code-comment policy: keep WHY (autoregressive shift headroom), drop empirical probe results — those live in commit b265d99's message.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a3cd4d874b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| def test_source_padding_formula_includes_plus_one(): | ||
| """Guard against a future refactor that drops the +1 again.""" | ||
| import inspect | ||
| from unsloth_zoo.mlx import trainer |
There was a problem hiding this comment.
Avoid importing the package in the new test
In a zoo-only test environment where unsloth is not installed, this import executes unsloth_zoo/__init__.py before the shimmed MLX trainer module is loaded, and that top-level init raises ImportError("Please install Unsloth..."); pytest tests/test_mlx_batch_padding.py -q currently fails here before checking the padding formula. This test should inspect/load unsloth_zoo/mlx/trainer.py without going through the package init, or otherwise provide the same stub/setup expected by that import path.
Useful? React with 👍 / 👎.
danielhanchen
left a comment
There was a problem hiding this comment.
Thank you for the PR! The goal of this PR is to align _create_labeled_batches's padding rule with mlx_lm.tuner.trainer.iterate_batches so that after the autoregressive shift (inputs = batch[:,:-1] / targets = batch[:,1:]) the effective sequence length is a clean multiple of _PAD_MULTIPLE. As a summary, this PR changes padded_len = _PAD_MULTIPLE * ceil(max_len / _PAD_MULTIPLE) to padded_len = 1 + _PAD_MULTIPLE * ceil(max_len / _PAD_MULTIPLE). The +1 gives the post-shift inputs/targets one token of headroom. The function is only invoked from train_on_responses_only, so this is correctness-only for that path and does not affect the main MLXTrainer flow.
Two independent Opus reviewers were run in parallel on this PR.
| Reviewers | Severity | Finding |
|---|---|---|
| 1/2 | Med | The new test file tests/test_mlx_batch_padding.py does not appear in any CI workflow (e.g. .github/workflows/mlx-ci.yml only runs the smoke shim). The regression guard never fires on PRs. |
| 2/2 | Med | test_source_padding_formula_includes_plus_one is a literal-substring pin on the formula. Any black/ruff reformat (operand order swap, intermediate variable rename, import math; math.ceil(...) rewrite) breaks the test even though the math is unchanged. |
| 1/2 | Med | padded_len = min(padded_len, max_seq_length) can clamp the +1 back off when max_len == max_seq_length, silently restoring the pre-PR (mlx-lm-aligned) edge behavior — but the comment claims "post-shift length is a clean multiple" unconditionally. Document the boundary. |
| 1/2 | Nit | No parametrized case at max_len == max_seq_length to lock in the clamp interaction. |
| 1/2 | Nit | The test re-implements _zoo_padded_len rather than importing the rule from trainer.py. If create_text_batches exposed _padded_len(max_len), the test could call it directly and the source-string guard would be unnecessary. |
| 1/2 | Nit | Test file header says create_text_batches but the function is _create_labeled_batches. |
Overall: APPROVE_WITH_NITS.
See inline comments for details and suggested fixes.
| padded_len = ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE | ||
| # Match mlx-lm iterate_batches: +1 gives the autoregressive | ||
| # shift headroom so post-shift length is a clean _PAD_MULTIPLE. | ||
| padded_len = 1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE |
There was a problem hiding this comment.
[2/2 reviewers] Correctness verified — trace for max_len=63, _PAD_MULTIPLE=32: 1 + 32*ceil(63/32) = 1 + 32*2 = 65, which after the shift gives inputs.shape[1] = 64 = 32*2 (clean bucket). Matches mlx_lm/tuner/trainer.py:158.
[1/2 reviewers] Med caveat: the min(padded_len, max_seq_length) clamp on the next line can silently revert this when max_len == max_seq_length — both pre-PR and post-PR end at padded_len = max_seq_length, so the +1 is dropped without an alignment guarantee. mlx-lm clamps the same way, so this matches upstream behavior, but the comment overstates the invariant. Tighten to:
| padded_len = 1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE | |
| # Match mlx-lm iterate_batches: +1 gives the autoregressive | |
| # shift headroom so post-shift length is a clean _PAD_MULTIPLE. | |
| # Note: the subsequent min(., max_seq_length) clamp can drop the | |
| # +1 when max_len == max_seq_length; this matches mlx-lm's own | |
| # behavior at that boundary. | |
| padded_len = 1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE |
| assert _zoo_padded_len(max_len) == _mlx_lm_padded_len(max_len) | ||
|
|
||
|
|
||
| def test_source_padding_formula_includes_plus_one(): |
There was a problem hiding this comment.
[2/2 reviewers] Med: this test pins the formula as a literal substring ("1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE"). Any cosmetic reformat (operand swap to mlx-lm style _PAD_MULTIPLE * (...), intermediate variable, math.ceil-rewrite) breaks it without breaking the math. The parametrized value-equivalence test above already pins correctness. Drop this or relax to a structural check:
| def test_source_padding_formula_includes_plus_one(): | |
| def test_source_padding_formula_includes_plus_one(): | |
| """Defense in depth: ensure the function source mentions a `+ 1` | |
| near the padded_len computation so a future refactor cannot quietly | |
| drop the autoregressive-shift headroom. Pin the intent, not the | |
| exact spelling.""" | |
| src = inspect.getsource(trainer_mod._create_labeled_batches) | |
| # Pull out the padded_len assignment block and require a literal +1. | |
| block = re.search(r"padded_len\s*=\s*[^\n]+", src) | |
| assert block is not None | |
| assert "+ 1" in block.group(0) or "1 +" in block.group(0), ( | |
| f"expected '+1' in padded_len assignment, got: {block.group(0)!r}" | |
| ) |
| ) | ||
|
|
||
|
|
||
| def test_pad_multiple_constant_still_32(): |
There was a problem hiding this comment.
[1/2 reviewers] Nit: no parametrized case covers max_len == max_seq_length (where the clamp engages and the +1 is silently dropped). Add it to document the intentional asymmetry at the boundary:
| def test_pad_multiple_constant_still_32(): | |
| def test_pad_multiple_constant_still_32(): | |
| """mlx-lm uses _PAD_MULTIPLE=32. Pin it so a future refactor cannot | |
| silently drift.""" | |
| assert trainer_mod._PAD_MULTIPLE == 32 | |
| @pytest.mark.parametrize("max_len, max_seq_length, expected", [ | |
| (63, 64, 64), # +1 would be 65; clamped to 64 | |
| (64, 64, 64), # already at clamp boundary; +1 dropped | |
| (32, 128, 33), # well below clamp; +1 retained | |
| ]) | |
| def test_padded_len_respects_max_seq_length_clamp(max_len, max_seq_length, expected): | |
| """Document the clamp interaction with the +1 rule.""" | |
| raw = 1 + ((max_len + 32 - 1) // 32) * 32 | |
| assert min(raw, max_seq_length) == expected |
| @@ -0,0 +1,90 @@ | |||
| # Unsloth Zoo - Utilities for Unsloth | |||
There was a problem hiding this comment.
[1/2 reviewers] Med + Nit: (a) this test file is not added to any CI workflow — the regression guard never fires on PRs. Add it to mlx-ci.yml or the consolidated repo-tests-cpu job since it uses the torch-based MLX shim and is CPU-safe. (b) The header references create_text_batches but the function is _create_labeled_batches.
| # Unsloth Zoo - Utilities for Unsloth | |
| # Unsloth Zoo - Utilities for Unsloth | |
| # Pin _create_labeled_batches's padding rule against mlx-lm's | |
| # iterate_batches (mlx_lm/tuner/trainer.py:158). | |
| # REMINDER: add this file to .github/workflows/mlx-ci.yml so the | |
| # regression guard runs in CI. |
Summary
unsloth_zoo.mlx.trainer.create_text_batchespreviously padded each batch to_PAD_MULTIPLE * ceil(max_len / _PAD_MULTIPLE). mlx-lm'siterate_batches(mlx_lm/tuner/trainer.py:158) pads to1 + _PAD_MULTIPLE * ceil(max_len / _PAD_MULTIPLE)default_loss(mlx-lm) andmake_baseline_loss_fn(zoo) (inputs = batch[:, :-1]/targets = batch[:, 1:]), zoo's inputs were one token SHORTER than mlx-lm's.This PR adds the
+1so zoo's padding matches mlx-lm value-for-value.Why
Empirical on
gemma-3-270m-it, single-row LoRA memorization, 15 paired seeds (danielhanchen/unsloth-staging-2mlx-parity-probes, Round BM run26050214501):mlx_lm.load@mx.compilemlx_lm.loadMLXTrainer, compile=FalseNone(silent rebind to 1.0)FastMLXModel(dtype=None)MLXTrainer, compile=FalseNonemlx_lm.loadMLXTrainer, compile=TrueNoneFastMLXModel(dtype=None)MLXTrainer, compile=TrueNonemlx_lm.loadMLXTrainer, compile=False0.0(explicit off)Bisecting:
trainerflips probe 31's 67% down to MLXTrainer's 40-53%.(B, 33)vs zoo's(B, 32)on this 14-token fixture.Teacher-forced completion loss is
0in 15/15 seeds across every probe, so the model fully memorizes either way; the basin is purely a greedy-decode argmax artifact.cf_loss < 0.5smoke gating perunslothai/unsloth#5537stays green regardless. But the basin gap IS a parity defect against mlx-lm CLI, which is the reference implementation for the MLX path.Behavior
_PAD_MULTIPLEstays at32. Unchanged.max_len=14(B, 33)(was(B, 32)). Changed by +1 token.max_len=32(B, 33)(was(B, 32)).max_len=33(B, 65)(was(B, 64)).O(B * 1)per batch, negligible.After this PR,
make_baseline_loss_fn'sinputs = batch[:, :-1]slice yields the same effective length as mlx-lm'sdefault_loss, so identical hyperparameters produce the same per-step graph shape.Test plan
tests/test_mlx_batch_padding.py+1;_PAD_MULTIPLE == 32invariant.pytest tests/test_mlx_batch_padding.py -vRelated
Part of the MLX vs mlx-lm parity bisection:
#669finetune_last_n_layers(layer-selection mismatch).unslothai/unsloth#5564#670#671max_grad_value=NoneAND default to None for HF parity (closes#662).