Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions tests/test_mlx_batch_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Unsloth Zoo - Utilities for Unsloth

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Med + Nit: (a) this test file is not added to any CI workflow — the regression guard never fires on PRs. Add it to mlx-ci.yml or the consolidated repo-tests-cpu job since it uses the torch-based MLX shim and is CPU-safe. (b) The header references create_text_batches but the function is _create_labeled_batches.

Suggested change
# Unsloth Zoo - Utilities for Unsloth
# Unsloth Zoo - Utilities for Unsloth
# Pin _create_labeled_batches's padding rule against mlx-lm's
# iterate_batches (mlx_lm/tuner/trainer.py:158).
# REMINDER: add this file to .github/workflows/mlx-ci.yml so the
# regression guard runs in CI.

# Pin MLXTrainer's batch padding to match mlx-lm's iterate_batches
# semantics: pad to `1 + _PAD_MULTIPLE * ceil(L / _PAD_MULTIPLE)`.
#
# Why this matters:
# mlx-lm's tuner trainer (`mlx_lm/tuner/trainer.py:158`) pads each
# batch to `1 + 32 * ceil(max_len / 32)`. The default loss then
# slices `inputs = batch[:, :-1]` / `targets = batch[:, 1:]`, so the
# effective per-position-attention length is `32 * ceil(max_len/32)`.
# unsloth_zoo's `create_text_batches` previously rounded WITHOUT the
# `+1` (just `32 * ceil(max_len/32)`), which dropped one token of
# input length after the autoregressive shift, putting the trainer
# one token shy of mlx-lm. On a single-row LoRA memorization fixture
# against gemma-3-270m-it, the one-token gap moved the run into a
# different convergence basin (probe 31 manual loop = 10/15 = 67%
# vs probe 33-37 MLXTrainer = 6-8/15 = 40-53% on paired seeds, see
# danielhanchen/unsloth-staging-2).

from __future__ import annotations

import pytest


@pytest.fixture(autouse=True, scope="module")
def _install_mlx_shim():
from mlx_simulation import simulate_mlx_on_torch
simulate_mlx_on_torch()


def _mlx_lm_padded_len(max_len, pad_to=32):
"""mlx-lm's iterate_batches padding (mlx_lm/tuner/trainer.py:158)."""
return 1 + pad_to * ((max_len + pad_to - 1) // pad_to)


def _zoo_padded_len(max_len, pad_multiple=32):
"""Reproduce the in-source rule from create_text_batches so we can
assert it stays aligned with mlx-lm without standing up the full
tokenizer + dataset pipeline.
"""
return 1 + ((max_len + pad_multiple - 1) // pad_multiple) * pad_multiple


@pytest.mark.parametrize(
"max_len, expected",
[
# Inside the first 32-token bucket -> 33.
(1, 33),
(14, 33), # the probe fixture's TRAIN_TEXT length
(31, 33),
(32, 33),
# Second bucket -> 65.
(33, 65),
(63, 65),
(64, 65),
# Third bucket -> 97.
(65, 97),
# Larger buckets.
(97, 129),
(128, 129),
(129, 161),
],
)
def test_zoo_padding_matches_mlx_lm(max_len, expected):
"""Zoo's pad rule must equal mlx-lm's pad rule, value-for-value."""
assert _zoo_padded_len(max_len) == expected
assert _mlx_lm_padded_len(max_len) == expected
assert _zoo_padded_len(max_len) == _mlx_lm_padded_len(max_len)


def test_source_padding_formula_includes_plus_one():

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Med: this test pins the formula as a literal substring ("1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE"). Any cosmetic reformat (operand swap to mlx-lm style _PAD_MULTIPLE * (...), intermediate variable, math.ceil-rewrite) breaks it without breaking the math. The parametrized value-equivalence test above already pins correctness. Drop this or relax to a structural check:

Suggested change
def test_source_padding_formula_includes_plus_one():
def test_source_padding_formula_includes_plus_one():
"""Defense in depth: ensure the function source mentions a `+ 1`
near the padded_len computation so a future refactor cannot quietly
drop the autoregressive-shift headroom. Pin the intent, not the
exact spelling."""
src = inspect.getsource(trainer_mod._create_labeled_batches)
# Pull out the padded_len assignment block and require a literal +1.
block = re.search(r"padded_len\s*=\s*[^\n]+", src)
assert block is not None
assert "+ 1" in block.group(0) or "1 +" in block.group(0), (
f"expected '+1' in padded_len assignment, got: {block.group(0)!r}"
)

"""Guard against a future refactor that drops the +1 again."""
import inspect
from unsloth_zoo.mlx import trainer

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid importing the package in the new test

In a zoo-only test environment where unsloth is not installed, this import executes unsloth_zoo/__init__.py before the shimmed MLX trainer module is loaded, and that top-level init raises ImportError("Please install Unsloth..."); pytest tests/test_mlx_batch_padding.py -q currently fails here before checking the padding formula. This test should inspect/load unsloth_zoo/mlx/trainer.py without going through the package init, or otherwise provide the same stub/setup expected by that import path.

Useful? React with 👍 / 👎.


src = inspect.getsource(trainer)
# The exact line we care about. If someone rewrites the formula
# they must preserve the +1 contract or add a new test alongside.
needle = "1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE"
assert needle in src, (
f"create_text_batches must use `{needle}` to match mlx-lm's "
"1 + pad_to*ceil(L/pad_to). Dropping the +1 leaves the input "
"one token shorter than mlx-lm after the autoregressive shift "
"and changes the convergence basin on small fixtures."
)


def test_pad_multiple_constant_still_32():

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Nit: no parametrized case covers max_len == max_seq_length (where the clamp engages and the +1 is silently dropped). Add it to document the intentional asymmetry at the boundary:

Suggested change
def test_pad_multiple_constant_still_32():
def test_pad_multiple_constant_still_32():
"""mlx-lm uses _PAD_MULTIPLE=32. Pin it so a future refactor cannot
silently drift."""
assert trainer_mod._PAD_MULTIPLE == 32
@pytest.mark.parametrize("max_len, max_seq_length, expected", [
(63, 64, 64), # +1 would be 65; clamped to 64
(64, 64, 64), # already at clamp boundary; +1 dropped
(32, 128, 33), # well below clamp; +1 retained
])
def test_padded_len_respects_max_seq_length_clamp(max_len, max_seq_length, expected):
"""Document the clamp interaction with the +1 rule."""
raw = 1 + ((max_len + 32 - 1) // 32) * 32
assert min(raw, max_seq_length) == expected

"""mlx-lm uses pad_to=32; we must too."""
from unsloth_zoo.mlx import trainer
assert trainer._PAD_MULTIPLE == 32
5 changes: 3 additions & 2 deletions unsloth_zoo/mlx/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,8 +1532,9 @@ def _process_text(text):
if not batch_items:
continue
max_len = max(len(ids) for ids, _ in batch_items)
# Round up to nearest multiple of _PAD_MULTIPLE (matching mlx-lm)
padded_len = ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE
# Match mlx-lm iterate_batches: +1 gives the autoregressive
# shift headroom so post-shift length is a clean _PAD_MULTIPLE.
padded_len = 1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Correctness verified — trace for max_len=63, _PAD_MULTIPLE=32: 1 + 32*ceil(63/32) = 1 + 32*2 = 65, which after the shift gives inputs.shape[1] = 64 = 32*2 (clean bucket). Matches mlx_lm/tuner/trainer.py:158.

[1/2 reviewers] Med caveat: the min(padded_len, max_seq_length) clamp on the next line can silently revert this when max_len == max_seq_length — both pre-PR and post-PR end at padded_len = max_seq_length, so the +1 is dropped without an alignment guarantee. mlx-lm clamps the same way, so this matches upstream behavior, but the comment overstates the invariant. Tighten to:

Suggested change
padded_len = 1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE
# Match mlx-lm iterate_batches: +1 gives the autoregressive
# shift headroom so post-shift length is a clean _PAD_MULTIPLE.
# Note: the subsequent min(., max_seq_length) clamp can drop the
# +1 when max_len == max_seq_length; this matches mlx-lm's own
# behavior at that boundary.
padded_len = 1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE

padded_len = min(padded_len, max_seq_length)

batch_ids = []
Expand Down
Loading