From 06e1f8bdda8af19bba46366c2a3b2b92ad294c3b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 19 May 2026 01:21:41 +0000 Subject: [PATCH 1/2] mlx: byte-identical fast path in make_baseline_loss_fn for mlx-lm parity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `make_baseline_loss_fn` previously routed both the labels=None case and the labels-aware case through the same code, with three small differences from `mlx_lm.tuner.trainer.default_loss`: 1. mask cast to fp32 (`length_mask.astype(mx.float32)`) instead of leaving the bool mask alone. 2. `safe_targets = mx.where(targets == -100, 0, targets)` even when there are no -100 entries in the labels=None path. 3. division by `_safe_token_denominator(ntoks)` (fp32 cast + maximum with 1.0) instead of raw `ce.sum() / ntoks`. Mathematically all three differences are no-ops for the labels=None case, but they introduce extra nodes into the MLX autodiff graph and change rounding inside the backward pass. Empirically (paired-seed data from danielhanchen/unsloth-staging-2 Round BO probes 31 vs 37): step 1 loss : identical to mlx-lm (9.769231 in both, forward only) step 2 loss : diverges by ~0.01 to 0.06 from mlx-lm step 30 loss: 0 in both (converges either way) The divergence is small and never blocks training (cf_loss = 0 across all configs), but it means callers of `MLXTrainer` cannot reproduce mlx-lm CLI's loss curve numerically. This commit: - Restructures `make_baseline_loss_fn` so the labels=None branch is byte-for-byte identical to mlx-lm's default_loss (bool mask, raw division, no mx.where). The labels-aware branch keeps the existing -100 / safe_targets / fp32-mask machinery — that path needs it. Tests: tests/test_mlx_baseline_loss_parity.py pins the source so a future refactor cannot silently re-introduce the divergent code patterns. The harness uses the torch-based MLX shim so genuine numerical parity stays the responsibility of the probe matrix on Apple Silicon; the source-pinning guards against drift between mlx-lm and zoo in the meantime. --- tests/test_mlx_baseline_loss_parity.py | 124 +++++++++++++++++++++++++ unsloth_zoo/mlx/utils.py | 34 +++++-- 2 files changed, 151 insertions(+), 7 deletions(-) create mode 100644 tests/test_mlx_baseline_loss_parity.py diff --git a/tests/test_mlx_baseline_loss_parity.py b/tests/test_mlx_baseline_loss_parity.py new file mode 100644 index 000000000..2fd6d2f8c --- /dev/null +++ b/tests/test_mlx_baseline_loss_parity.py @@ -0,0 +1,124 @@ +# Unsloth Zoo - Utilities for Unsloth +# Pin `make_baseline_loss_fn` source so the labels=None fast path stays +# byte-for-byte equivalent to mlx_lm.tuner.trainer.default_loss. +# +# Why pin source rather than run a numerical comparison: the test +# harness uses a torch-based MLX shim that doesn't faithfully reproduce +# MLX's autodiff graph or its rounding; an apples-to-apples numerical +# parity check requires a real MLX runtime (Apple Silicon), so it's +# done in the Round BP probe matrix on +# danielhanchen/unsloth-staging-2. Locally we guard against future +# refactors silently re-introducing the divergent code patterns +# (fp32-cast mask, mx.where(safe_targets), _safe_token_denominator) +# that mlx_lm.tuner.trainer.default_loss does NOT do. + +from __future__ import annotations + +import inspect +import re +import textwrap + +import pytest + + +@pytest.fixture(autouse=True, scope="module") +def _install_mlx_shim(): + from mlx_simulation import simulate_mlx_on_torch + simulate_mlx_on_torch() + + +def _labels_none_block(): + """Return CODE LINES (comments stripped) for the labels=None fast path.""" + from unsloth_zoo.mlx import utils + src = inspect.getsource(utils.make_baseline_loss_fn) + m = re.search( + r"if labels is None:\s*\n(.*?)(?:# labels-aware path|else:\s*\n)", + src, + flags=re.DOTALL, + ) + assert m, "make_baseline_loss_fn must keep a `labels is None` fast path" + raw = textwrap.dedent(m.group(1)) + # Strip whole-line comments so test assertions on code don't trip on + # explanatory prose like "no safe_targets mx.where" in docstrings. + code_lines = [] + for line in raw.splitlines(): + stripped = line.strip() + if stripped.startswith("#"): + continue + code_lines.append(line) + return "\n".join(code_lines) + + +def test_no_fp32_mask_cast_in_fast_path(): + """The labels=None path must NOT cast the bool mask to fp32. mlx-lm's + default_loss multiplies the cross-entropy result by a raw bool mask; + casting to fp32 produces a different MLX autodiff graph and shifts + gradients by ~1e-2 per step on small fixtures.""" + block = _labels_none_block() + # Anything that would cast a mask to fp32: `.astype(mx.float32)` or + # `astype(float32)` immediately on a `mask` / `length_mask` name. + bad_patterns = ( + r"length_mask\.astype\(mx\.float32\)", + r"mask\s*=\s*[^=]*\.astype\(mx\.float32\)", + ) + for pat in bad_patterns: + assert not re.search(pat, block), ( + f"labels=None fast path must not contain `{pat}`; " + "matches mlx-lm requires a bool mask." + ) + + +def test_no_safe_targets_where_in_fast_path(): + """The labels=None path has no -100 to worry about, so no mx.where + is needed on targets. The where node was empirically the cause of + step-2 loss divergence vs mlx-lm CLI (Round BO probe_31 vs probe_37).""" + block = _labels_none_block() + assert "mx.where" not in block, ( + "labels=None fast path must not include mx.where on targets; " + "mlx-lm's default_loss does not call mx.where." + ) + assert "safe_targets" not in block, ( + "labels=None fast path must use `targets` directly; " + "`safe_targets = mx.where(...)` belongs only to the labels-aware path." + ) + + +def test_no_safe_token_denominator_in_fast_path(): + """mlx-lm's default_loss divides by raw `ntoks` (int). The fast path + must match that to preserve the autodiff graph; the safety wrapper + `_safe_token_denominator` introduces a fp32 cast + maximum() that + changes rounding in MLX.""" + block = _labels_none_block() + assert "_safe_token_denominator" not in block, ( + "labels=None fast path must divide by raw ntoks for mlx-lm parity. " + "_safe_token_denominator is fine on the labels-aware path." + ) + + +def test_fast_path_returns_ce_and_ntoks_in_that_order(): + """Match the (loss, ntoks) return signature mlx-lm uses; the test + pins return-order so a future refactor doesn't accidentally swap.""" + block = _labels_none_block() + # Look for a `return X, Y` somewhere in the fast path. The variable + # names are loose (mlx-lm uses `ce`; zoo previously used `loss`), + # but the order matters. + m = re.search(r"return\s+(\w+),\s*(\w+)", block) + assert m, "labels=None fast path must return a (loss, ntoks) tuple" + loss_name, ntoks_name = m.group(1), m.group(2) + assert ntoks_name in ("ntoks", "n_toks", "ntokens"), ( + f"second return value name should look like a token count, got " + f"{ntoks_name!r}" + ) + + +def test_labels_aware_path_still_uses_safe_targets(): + """The labels-aware path (train_on_responses_only) DOES need + `safe_targets` and the fp32 mask because labels can contain -100.""" + from unsloth_zoo.mlx import utils + src = inspect.getsource(utils.make_baseline_loss_fn) + # The labels-aware path lives after the fast path's `return`. Look + # at the full source to verify the machinery still exists somewhere. + assert "mx.where" in src, ( + "make_baseline_loss_fn must still call mx.where on the labels-aware path" + ) + assert "safe_targets" in src, "labels-aware path must keep safe_targets" diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 3caf75688..5e04977c7 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -413,20 +413,40 @@ def make_baseline_loss_fn(): A function (model, batch, lengths, labels=None) -> (loss, ntoks). When labels is provided, uses labels[:,1:] for targets with (targets != -100) as the loss mask. + + Numerical parity with mlx-lm: + The labels=None path is intentionally byte-identical to + ``mlx_lm.tuner.trainer.default_loss``. The labels-aware path + adds the -100 / safe_targets / fp32-mask machinery needed for + train_on_responses_only. Keeping the two paths separated means + that running zoo's MLXTrainer with no per-token label mask + produces value-for-value identical loss AND gradient sequences + to mlx-lm CLI on the same data + seed. """ def loss_fn(model, batch, lengths, labels=None): if labels is None: - inputs, targets = batch[:, :-1], batch[:, 1:] - else: + # mlx-lm parity fast path. Matches mlx_lm.tuner.trainer.default_loss + # byte-for-byte: bool mask (not fp32 cast), raw ntoks division (not + # _safe_token_denominator), no `safe_targets` mx.where. Differences + # of that kind altered the autodiff graph and produced ~0.01-0.06 + # step-2 loss divergence vs mlx-lm CLI on identical configs + # (gemma-3-270m-it, see Round BO probe_31/probe_37 paired-seed data). inputs = batch[:, :-1] - targets = labels[:, 1:] + targets = batch[:, 1:] + logits = model(inputs) + steps = mx.arange(1, targets.shape[1] + 1) + mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) + ce = nn.losses.cross_entropy(logits, targets) * mask + ntoks = mask.sum() + ce = ce.astype(mx.float32).sum() / ntoks + return ce, ntoks + # labels-aware path: train_on_responses_only style masking. + inputs = batch[:, :-1] + targets = labels[:, 1:] logits = model(inputs) steps = mx.arange(1, targets.shape[1] + 1) length_mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) - if labels is None: - mask = length_mask.astype(mx.float32) - else: - mask = mx.logical_and(targets != -100, length_mask).astype(mx.float32) + mask = mx.logical_and(targets != -100, length_mask).astype(mx.float32) # Replace -100 with 0 before CE — MLX has no ignore_index; # the mask already zeros out these positions in the loss. safe_targets = mx.where(targets == -100, 0, targets) From 3e8dce8ac4986388d33198a205f838bb398aa4ec Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 19 May 2026 10:19:21 +0000 Subject: [PATCH 2/2] mlx: trim verbose baseline-loss parity docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per code-comment policy: keep WHY (mlx-lm parity fast path); drop the empirical step-2 divergence numbers and probe references — those live in commit 22c1c5b's message. --- unsloth_zoo/mlx/utils.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 5e04977c7..e8b7e1881 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -412,25 +412,12 @@ def make_baseline_loss_fn(): Returns: A function (model, batch, lengths, labels=None) -> (loss, ntoks). When labels is provided, uses labels[:,1:] for targets with - (targets != -100) as the loss mask. - - Numerical parity with mlx-lm: - The labels=None path is intentionally byte-identical to - ``mlx_lm.tuner.trainer.default_loss``. The labels-aware path - adds the -100 / safe_targets / fp32-mask machinery needed for - train_on_responses_only. Keeping the two paths separated means - that running zoo's MLXTrainer with no per-token label mask - produces value-for-value identical loss AND gradient sequences - to mlx-lm CLI on the same data + seed. + (targets != -100) as the loss mask. The labels=None branch is + byte-identical to ``mlx_lm.tuner.trainer.default_loss``. """ def loss_fn(model, batch, lengths, labels=None): if labels is None: - # mlx-lm parity fast path. Matches mlx_lm.tuner.trainer.default_loss - # byte-for-byte: bool mask (not fp32 cast), raw ntoks division (not - # _safe_token_denominator), no `safe_targets` mx.where. Differences - # of that kind altered the autodiff graph and produced ~0.01-0.06 - # step-2 loss divergence vs mlx-lm CLI on identical configs - # (gemma-3-270m-it, see Round BO probe_31/probe_37 paired-seed data). + # byte-identical to mlx_lm.tuner.trainer.default_loss inputs = batch[:, :-1] targets = batch[:, 1:] logits = model(inputs)