-
Notifications
You must be signed in to change notification settings - Fork 267
fix(mlx): match mlx-lm batch padding rule (1 + pad_to * ceil) #672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,90 @@ | ||||||||||||||||||||||||||||||||||
| # Unsloth Zoo - Utilities for Unsloth | ||||||||||||||||||||||||||||||||||
| # Pin MLXTrainer's batch padding to match mlx-lm's iterate_batches | ||||||||||||||||||||||||||||||||||
| # semantics: pad to `1 + _PAD_MULTIPLE * ceil(L / _PAD_MULTIPLE)`. | ||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||
| # Why this matters: | ||||||||||||||||||||||||||||||||||
| # mlx-lm's tuner trainer (`mlx_lm/tuner/trainer.py:158`) pads each | ||||||||||||||||||||||||||||||||||
| # batch to `1 + 32 * ceil(max_len / 32)`. The default loss then | ||||||||||||||||||||||||||||||||||
| # slices `inputs = batch[:, :-1]` / `targets = batch[:, 1:]`, so the | ||||||||||||||||||||||||||||||||||
| # effective per-position-attention length is `32 * ceil(max_len/32)`. | ||||||||||||||||||||||||||||||||||
| # unsloth_zoo's `create_text_batches` previously rounded WITHOUT the | ||||||||||||||||||||||||||||||||||
| # `+1` (just `32 * ceil(max_len/32)`), which dropped one token of | ||||||||||||||||||||||||||||||||||
| # input length after the autoregressive shift, putting the trainer | ||||||||||||||||||||||||||||||||||
| # one token shy of mlx-lm. On a single-row LoRA memorization fixture | ||||||||||||||||||||||||||||||||||
| # against gemma-3-270m-it, the one-token gap moved the run into a | ||||||||||||||||||||||||||||||||||
| # different convergence basin (probe 31 manual loop = 10/15 = 67% | ||||||||||||||||||||||||||||||||||
| # vs probe 33-37 MLXTrainer = 6-8/15 = 40-53% on paired seeds, see | ||||||||||||||||||||||||||||||||||
| # danielhanchen/unsloth-staging-2). | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @pytest.fixture(autouse=True, scope="module") | ||||||||||||||||||||||||||||||||||
| def _install_mlx_shim(): | ||||||||||||||||||||||||||||||||||
| from mlx_simulation import simulate_mlx_on_torch | ||||||||||||||||||||||||||||||||||
| simulate_mlx_on_torch() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _mlx_lm_padded_len(max_len, pad_to=32): | ||||||||||||||||||||||||||||||||||
| """mlx-lm's iterate_batches padding (mlx_lm/tuner/trainer.py:158).""" | ||||||||||||||||||||||||||||||||||
| return 1 + pad_to * ((max_len + pad_to - 1) // pad_to) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _zoo_padded_len(max_len, pad_multiple=32): | ||||||||||||||||||||||||||||||||||
| """Reproduce the in-source rule from create_text_batches so we can | ||||||||||||||||||||||||||||||||||
| assert it stays aligned with mlx-lm without standing up the full | ||||||||||||||||||||||||||||||||||
| tokenizer + dataset pipeline. | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| return 1 + ((max_len + pad_multiple - 1) // pad_multiple) * pad_multiple | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @pytest.mark.parametrize( | ||||||||||||||||||||||||||||||||||
| "max_len, expected", | ||||||||||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||||||||||
| # Inside the first 32-token bucket -> 33. | ||||||||||||||||||||||||||||||||||
| (1, 33), | ||||||||||||||||||||||||||||||||||
| (14, 33), # the probe fixture's TRAIN_TEXT length | ||||||||||||||||||||||||||||||||||
| (31, 33), | ||||||||||||||||||||||||||||||||||
| (32, 33), | ||||||||||||||||||||||||||||||||||
| # Second bucket -> 65. | ||||||||||||||||||||||||||||||||||
| (33, 65), | ||||||||||||||||||||||||||||||||||
| (63, 65), | ||||||||||||||||||||||||||||||||||
| (64, 65), | ||||||||||||||||||||||||||||||||||
| # Third bucket -> 97. | ||||||||||||||||||||||||||||||||||
| (65, 97), | ||||||||||||||||||||||||||||||||||
| # Larger buckets. | ||||||||||||||||||||||||||||||||||
| (97, 129), | ||||||||||||||||||||||||||||||||||
| (128, 129), | ||||||||||||||||||||||||||||||||||
| (129, 161), | ||||||||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| def test_zoo_padding_matches_mlx_lm(max_len, expected): | ||||||||||||||||||||||||||||||||||
| """Zoo's pad rule must equal mlx-lm's pad rule, value-for-value.""" | ||||||||||||||||||||||||||||||||||
| assert _zoo_padded_len(max_len) == expected | ||||||||||||||||||||||||||||||||||
| assert _mlx_lm_padded_len(max_len) == expected | ||||||||||||||||||||||||||||||||||
| assert _zoo_padded_len(max_len) == _mlx_lm_padded_len(max_len) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def test_source_padding_formula_includes_plus_one(): | ||||||||||||||||||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [2/2 reviewers] Med: this test pins the formula as a literal substring (
Suggested change
|
||||||||||||||||||||||||||||||||||
| """Guard against a future refactor that drops the +1 again.""" | ||||||||||||||||||||||||||||||||||
| import inspect | ||||||||||||||||||||||||||||||||||
| from unsloth_zoo.mlx import trainer | ||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In a zoo-only test environment where Useful? React with 👍 / 👎. |
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| src = inspect.getsource(trainer) | ||||||||||||||||||||||||||||||||||
| # The exact line we care about. If someone rewrites the formula | ||||||||||||||||||||||||||||||||||
| # they must preserve the +1 contract or add a new test alongside. | ||||||||||||||||||||||||||||||||||
| needle = "1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE" | ||||||||||||||||||||||||||||||||||
| assert needle in src, ( | ||||||||||||||||||||||||||||||||||
| f"create_text_batches must use `{needle}` to match mlx-lm's " | ||||||||||||||||||||||||||||||||||
| "1 + pad_to*ceil(L/pad_to). Dropping the +1 leaves the input " | ||||||||||||||||||||||||||||||||||
| "one token shorter than mlx-lm after the autoregressive shift " | ||||||||||||||||||||||||||||||||||
| "and changes the convergence basin on small fixtures." | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def test_pad_multiple_constant_still_32(): | ||||||||||||||||||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [1/2 reviewers] Nit: no parametrized case covers
Suggested change
|
||||||||||||||||||||||||||||||||||
| """mlx-lm uses pad_to=32; we must too.""" | ||||||||||||||||||||||||||||||||||
| from unsloth_zoo.mlx import trainer | ||||||||||||||||||||||||||||||||||
| assert trainer._PAD_MULTIPLE == 32 | ||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1532,8 +1532,9 @@ def _process_text(text): | |||||||||||||||
| if not batch_items: | ||||||||||||||||
| continue | ||||||||||||||||
| max_len = max(len(ids) for ids, _ in batch_items) | ||||||||||||||||
| # Round up to nearest multiple of _PAD_MULTIPLE (matching mlx-lm) | ||||||||||||||||
| padded_len = ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE | ||||||||||||||||
| # Match mlx-lm iterate_batches: +1 gives the autoregressive | ||||||||||||||||
| # shift headroom so post-shift length is a clean _PAD_MULTIPLE. | ||||||||||||||||
| padded_len = 1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE | ||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [2/2 reviewers] Correctness verified — trace for [1/2 reviewers] Med caveat: the
Suggested change
|
||||||||||||||||
| padded_len = min(padded_len, max_seq_length) | ||||||||||||||||
|
|
||||||||||||||||
| batch_ids = [] | ||||||||||||||||
|
|
||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[1/2 reviewers] Med + Nit: (a) this test file is not added to any CI workflow — the regression guard never fires on PRs. Add it to
mlx-ci.ymlor the consolidatedrepo-tests-cpujob since it uses the torch-based MLX shim and is CPU-safe. (b) The header referencescreate_text_batchesbut the function is_create_labeled_batches.