Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions tests/test_mlx_baseline_loss_parity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Unsloth Zoo - Utilities for Unsloth
# Pin `make_baseline_loss_fn` source so the labels=None fast path stays
# byte-for-byte equivalent to mlx_lm.tuner.trainer.default_loss.
#
# Why pin source rather than run a numerical comparison: the test
# harness uses a torch-based MLX shim that doesn't faithfully reproduce
# MLX's autodiff graph or its rounding; an apples-to-apples numerical
# parity check requires a real MLX runtime (Apple Silicon), so it's
# done in the Round BP probe matrix on
# danielhanchen/unsloth-staging-2. Locally we guard against future
# refactors silently re-introducing the divergent code patterns
# (fp32-cast mask, mx.where(safe_targets), _safe_token_denominator)
# that mlx_lm.tuner.trainer.default_loss does NOT do.

from __future__ import annotations

import inspect
import re
import textwrap

import pytest


@pytest.fixture(autouse=True, scope="module")
def _install_mlx_shim():
from mlx_simulation import simulate_mlx_on_torch
simulate_mlx_on_torch()


def _labels_none_block():

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Low: this helper's regex relies on either an else: or the literal comment # labels-aware path to mark the end of the fast-path block. If a future refactor renames that comment, all five tests below silently extract an empty or wrong block — assert m still fires (or the substring checks pass vacuously). Pin with a sentinel comment that's hard to rename casually:

Suggested change
def _labels_none_block():
def _labels_none_block():
"""Return the labels=None code block. Relies on the sentinel comment
`# --- baseline-loss-fast-path-end ---` in the source. If you rename
that comment, update this regex AND the sentinel in utils.py."""
from unsloth_zoo.mlx import utils
src = inspect.getsource(utils.make_baseline_loss_fn)
m = re.search(
r"if labels is None:\s*\n(.*?)# --- baseline-loss-fast-path-end ---",
src,
flags=re.DOTALL,
)
assert m, "sentinel comment missing from make_baseline_loss_fn fast path"
return m.group(1)

(and add # --- baseline-loss-fast-path-end --- immediately before the labels-aware branch in utils.py.)

"""Return CODE LINES (comments stripped) for the labels=None fast path."""
from unsloth_zoo.mlx import utils

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid package-level import in the parity test

In the CPU-only test harness where the unsloth package is not installed, this from unsloth_zoo.mlx import utils form can execute unsloth_zoo/__init__.py and raise ImportError: Please install Unsloth... before the shimmed MLX submodule is loaded; running pytest tests/test_mlx_baseline_loss_parity.py then fails on the first test. The existing MLX shim tests avoid this path by importing the full submodule name, so this source-level test should do the same or load the file source directly.

Useful? React with 👍 / 👎.

src = inspect.getsource(utils.make_baseline_loss_fn)
m = re.search(
r"if labels is None:\s*\n(.*?)(?:# labels-aware path|else:\s*\n)",
src,
flags=re.DOTALL,
)
assert m, "make_baseline_loss_fn must keep a `labels is None` fast path"
raw = textwrap.dedent(m.group(1))
# Strip whole-line comments so test assertions on code don't trip on
# explanatory prose like "no safe_targets mx.where" in docstrings.
code_lines = []
for line in raw.splitlines():
stripped = line.strip()
if stripped.startswith("#"):
continue
code_lines.append(line)
return "\n".join(code_lines)


def test_no_fp32_mask_cast_in_fast_path():
"""The labels=None path must NOT cast the bool mask to fp32. mlx-lm's
default_loss multiplies the cross-entropy result by a raw bool mask;
casting to fp32 produces a different MLX autodiff graph and shifts
gradients by ~1e-2 per step on small fixtures."""
block = _labels_none_block()
# Anything that would cast a mask to fp32: `.astype(mx.float32)` or
# `astype(float32)` immediately on a `mask` / `length_mask` name.
bad_patterns = (
r"length_mask\.astype\(mx\.float32\)",
r"mask\s*=\s*[^=]*\.astype\(mx\.float32\)",
)
for pat in bad_patterns:
assert not re.search(pat, block), (
f"labels=None fast path must not contain `{pat}`; "
"matches mlx-lm requires a bool mask."
)


def test_no_safe_targets_where_in_fast_path():
"""The labels=None path has no -100 to worry about, so no mx.where
is needed on targets. The where node was empirically the cause of
step-2 loss divergence vs mlx-lm CLI (Round BO probe_31 vs probe_37)."""
block = _labels_none_block()
assert "mx.where" not in block, (
"labels=None fast path must not include mx.where on targets; "
"mlx-lm's default_loss does not call mx.where."
)
assert "safe_targets" not in block, (
"labels=None fast path must use `targets` directly; "
"`safe_targets = mx.where(...)` belongs only to the labels-aware path."
)


def test_no_safe_token_denominator_in_fast_path():
"""mlx-lm's default_loss divides by raw `ntoks` (int). The fast path
must match that to preserve the autodiff graph; the safety wrapper
`_safe_token_denominator` introduces a fp32 cast + maximum() that
changes rounding in MLX."""
block = _labels_none_block()
assert "_safe_token_denominator" not in block, (
"labels=None fast path must divide by raw ntoks for mlx-lm parity. "
"_safe_token_denominator is fine on the labels-aware path."
)


def test_fast_path_returns_ce_and_ntoks_in_that_order():
"""Match the (loss, ntoks) return signature mlx-lm uses; the test
pins return-order so a future refactor doesn't accidentally swap."""
block = _labels_none_block()
# Look for a `return X, Y` somewhere in the fast path. The variable
# names are loose (mlx-lm uses `ce`; zoo previously used `loss`),
# but the order matters.
m = re.search(r"return\s+(\w+),\s*(\w+)", block)
assert m, "labels=None fast path must return a (loss, ntoks) tuple"
loss_name, ntoks_name = m.group(1), m.group(2)
assert ntoks_name in ("ntoks", "n_toks", "ntokens"), (
f"second return value name should look like a token count, got "
f"{ntoks_name!r}"
)


def test_labels_aware_path_still_uses_safe_targets():
"""The labels-aware path (train_on_responses_only) DOES need
`safe_targets` and the fp32 mask because labels can contain -100."""
from unsloth_zoo.mlx import utils
src = inspect.getsource(utils.make_baseline_loss_fn)
# The labels-aware path lives after the fast path's `return`. Look
# at the full source to verify the machinery still exists somewhere.
assert "mx.where" in src, (
"make_baseline_loss_fn must still call mx.where on the labels-aware path"
)
assert "safe_targets" in src, "labels-aware path must keep safe_targets"

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Med: every test in this file is source-pin only — none actually executes the loss function on a tensor. The mlx_simulation shim is available; even a shape-only smoke (call loss_fn(stub_model, batch, lengths, labels=None) and assert the return is a 2-tuple of (ce_scalar, ntoks_scalar)) would catch refactors that preserve the regex but break the math. Add at the bottom of the file:

Suggested change
assert "safe_targets" in src, "labels-aware path must keep safe_targets"
def test_fast_path_runtime_shape_smoke():
"""Shape-only end-to-end smoke. Catches refactors that pass every
regex test but produce the wrong tensors at runtime."""
import torch
from unsloth_zoo.mlx.utils import make_baseline_loss_fn
# ... build stub model whose __call__ returns logits of shape (B, T, V)
# ... call loss_fn(stub, batch, lengths, labels=None)
# ... assert tuple of length 2, both 0-dim tensors

23 changes: 15 additions & 8 deletions unsloth_zoo/mlx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,21 +412,28 @@ def make_baseline_loss_fn():
Returns:
A function (model, batch, lengths, labels=None) -> (loss, ntoks).
When labels is provided, uses labels[:,1:] for targets with
(targets != -100) as the loss mask.
(targets != -100) as the loss mask. The labels=None branch is
byte-identical to ``mlx_lm.tuner.trainer.default_loss``.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Nit: "byte-identical to mlx_lm.tuner.trainer.default_loss" is true today but tied to mlx-lm's main. Pin a known-good commit so a future maintainer can re-verify deterministically:

Suggested change
byte-identical to ``mlx_lm.tuner.trainer.default_loss``.
byte-identical to ``mlx_lm.tuner.trainer.default_loss`` (verified
against mlx-lm 0.x.y commit abc1234; rerun the byte-parity check
if upstream changes).

"""
def loss_fn(model, batch, lengths, labels=None):
if labels is None:
inputs, targets = batch[:, :-1], batch[:, 1:]
else:
# byte-identical to mlx_lm.tuner.trainer.default_loss

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Med: ce.astype(mx.float32).sum() / ntoks matches mlx-lm exactly (no _safe_token_denominator guard), so byte parity is preserved — but an empty-mask batch (lengths[:,0] > lengths[:,1] for every row) now produces NaN/Inf instead of being smoothed by max(1, ntoks). Realistic only with pathological inputs; mlx-lm has the same hazard. Add a one-line docstring note so the contract is explicit:

Suggested change
# byte-identical to mlx_lm.tuner.trainer.default_loss
# byte-identical to mlx_lm.tuner.trainer.default_loss
# NOTE: caller must ensure ntoks > 0 (empty-mask batches
# divide by zero). mlx-lm has the same precondition.
inputs = batch[:, :-1]

inputs = batch[:, :-1]
targets = labels[:, 1:]
targets = batch[:, 1:]
logits = model(inputs)
steps = mx.arange(1, targets.shape[1] + 1)
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
ce = nn.losses.cross_entropy(logits, targets) * mask
ntoks = mask.sum()
ce = ce.astype(mx.float32).sum() / ntoks
return ce, ntoks
# labels-aware path: train_on_responses_only style masking.
inputs = batch[:, :-1]
targets = labels[:, 1:]
logits = model(inputs)
steps = mx.arange(1, targets.shape[1] + 1)
length_mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
if labels is None:
mask = length_mask.astype(mx.float32)
else:
mask = mx.logical_and(targets != -100, length_mask).astype(mx.float32)
mask = mx.logical_and(targets != -100, length_mask).astype(mx.float32)
# Replace -100 with 0 before CE — MLX has no ignore_index;
# the mask already zeros out these positions in the loss.
safe_targets = mx.where(targets == -100, 0, targets)
Expand Down
Loading