Skip to content

fix(mlx): make_baseline_loss_fn byte-identical to mlx-lm default_loss when labels=None#673

Merged
danielhanchen merged 2 commits into
mainfrom
fix-mlx-baseline-loss-parity
May 19, 2026
Merged

fix(mlx): make_baseline_loss_fn byte-identical to mlx-lm default_loss when labels=None#673
danielhanchen merged 2 commits into
mainfrom
fix-mlx-baseline-loss-parity

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

make_baseline_loss_fn previously routed both the labels=None case and the labels-aware case through the same code, with three small differences from mlx_lm.tuner.trainer.default_loss:

  1. mask cast to fp32 (length_mask.astype(mx.float32)) instead of leaving the bool mask alone.
  2. safe_targets = mx.where(targets == -100, 0, targets) even when there are no -100 entries in the labels=None path.
  3. division by _safe_token_denominator(ntoks) (fp32 cast + maximum with 1.0) instead of raw ce.sum() / ntoks.

This PR restructures the function so the labels=None branch is byte-for-byte identical to mlx-lm's default_loss; the labels-aware branch keeps the existing -100 / safe_targets / fp32-mask machinery (it needs all three).

Why

Mathematically all three differences are no-ops when there are no -100 entries (i.e. on the labels=None path), but each adds an extra node to the MLX autodiff graph and changes rounding inside the backward pass.

Empirical (unsloth/gemma-3-270m-it single-row LoRA memorization, paired-seed data from danielhanchen/unsloth-staging-2 Round BO probes 31 (mlx-lm manual loop) vs 35 / 37 (zoo MLXTrainer), 5 representative seeds):

step probe 31 (mlx-lm) probe 35 (zoo, compile=True) probe 37 (zoo, compile=False, no clip) $\Delta$ probe 31 - probe 35
1 9.769231 9.769231 9.769231 0
2 5.254807 5.276443 5.276443 -0.021635
3 3.206731 3.221154 3.221154 -0.014423
4 1.894231 1.947115 1.947115 -0.052885
5 1.163462 1.168269 1.180288 -0.004808
28+ 0 0 0 0

Step 1's forward loss is identical (same initial weights, same data), but the gradient applied at step 1 differs between mlx-lm and zoo — so step 2 onward the loss curves are 0.01-0.06 apart until both converge to 0. The model still memorizes either way (cf_loss = 0 in 15/15 seeds for every config). But callers cannot reproduce mlx-lm CLI's loss curve numerically, which makes diagnostic side-by-side comparisons useless.

After this PR the labels=None fast path produces an autodiff graph value-for-value identical to mlx_lm.tuner.trainer.default_loss, so probes against the same fixture should converge value-for-value with mlx-lm CLI.

Behavior

  • loss_fn(model, batch, lengths) (no labels) $\to$ identical numerical path to mlx_lm.tuner.trainer.default_loss. Changed (parity).
  • loss_fn(model, batch, lengths, labels) (labels-aware, e.g. train_on_responses_only) $\to$ existing path with fp32 mask, safe_targets, _safe_token_denominator. Unchanged.

Test plan

  • tests/test_mlx_baseline_loss_parity.py $\to$ 5 source-level assertions pinning the fast-path code patterns. Guards against future refactors silently re-introducing length_mask.astype(mx.float32), mx.where(safe_targets), or _safe_token_denominator on the labels=None branch. Verifies the labels-aware path still uses safe_targets.
  • Local: pytest tests/test_mlx_baseline_loss_parity.py -v $\to$ 5 passed.
  • Broader regression: pytest tests/test_pr_a_*.py tests/test_mlx_*.py $\to$ 58 passed, 1 skipped, no regressions.
  • Round BP probe 38 on Apple Silicon (in flight) captures per-step loss + grad_norm for both paths to confirm value-for-value parity at all steps after this fix.

Related

Sixth PR in the MLX vs mlx-lm parity series:

  • #669 $\to$ finetune_last_n_layers knob (zoo).
  • unslothai/unsloth#5564 $\to$ same knob, CUDA path.
  • #670 $\to$ warn on bf16$\to$fp16 downcast.
  • #671 $\to$ max_grad_value=None default + HF/TRL parity (closes #662).
  • #672 $\to$ _create_labeled_batches padding matches mlx-lm.
  • This PR $\to$ make_baseline_loss_fn labels=None fast path matches mlx-lm.

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 22c1c5b4a7

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".


def _labels_none_block():
"""Return CODE LINES (comments stripped) for the labels=None fast path."""
from unsloth_zoo.mlx import utils

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid package-level import in the parity test

In the CPU-only test harness where the unsloth package is not installed, this from unsloth_zoo.mlx import utils form can execute unsloth_zoo/__init__.py and raise ImportError: Please install Unsloth... before the shimmed MLX submodule is loaded; running pytest tests/test_mlx_baseline_loss_parity.py then fails on the first test. The existing MLX shim tests avoid this path by importing the full submodule name, so this source-level test should do the same or load the file source directly.

Useful? React with 👍 / 👎.

@danielhanchen

Copy link
Copy Markdown
Member Author

Correction: this PR is a code simplification, NOT a parity fix.

Round BP probe 38 ran the mlx-lm manual loop and zoo MLXTrainer back-to-back in the same process, capturing per-step loss with the OLD make_baseline_loss_fn (fp32 mask cast + mx.where(safe_targets) + _safe_token_denominator). Result on 5 seeds (1, 42, 999, 3407, 22222):

step mlxlm loss zoo loss $\Delta$ loss
1 9.769231 9.769231 0
2-30 converging converging 0 (every step, every seed)

So the three "no-op math" differences (fp32 mask, mx.where(safe_targets), _safe_token_denominator) ARE truly no-ops in MLX's autodiff. They add graph nodes but produce value-for-value identical gradients. My earlier diagnosis from Round BO data was wrong.

What ACTUALLY explains the Round BO step-2 divergence between probe_31 (manual loop) and probe_33/35/37 (zoo MLXTrainer) is the mx.random.seed(seed) order: probe 31 reseeds AFTER mlx_load; probes 33/35/37 didn't reseed and so their LoRA init used advanced RNG state. Probe 38's zoo path reseeds after load (matching mlx-lm CLI's lora.py:223) and the basins align step-for-step.

This PR is still defensible as a cleanup: the labels=None fast path now mirrors mlx_lm.tuner.trainer.default_loss line-for-line, removes three unused graph nodes per forward/backward, and keeps the train_on_responses_only path's -100 / safe_targets machinery on its own branch. No behavior change but a clearer contract. Re-framing the PR title + body to match — it's not closing a parity defect, just simplifying the labels=None path.

@danielhanchen danielhanchen left a comment

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! The goal of this PR is to make make_baseline_loss_fn's labels=None branch byte-identical to mlx_lm.tuner.trainer.default_loss so callers of MLXTrainer with no per-token label mask see value-for-value identical loss/grad sequences to the mlx-lm CLI. As a summary, this PR splits the previously-shared loss body into a fast path (labels=None: bool mask, raw ntoks division, no safe_targets mx.where) and a labels-aware path (unchanged — keeps safe_targets, fp32 mask cast, and _safe_token_denominator for train_on_responses_only).

Two independent Opus reviewers were run in parallel on this PR.

Reviewers Severity Finding
2/2 Med Test is source-pin only (regex over the function body). No numerical/shape smoke. A semantic-breaking refactor that preserves the regex slides through.
2/2 Med The new labels=None branch divides by raw ntoks (no _safe_token_denominator guard). mlx-lm has the same divide-by-zero hazard on empty-mask batches, so byte parity is preserved — but the previous code defensively handled it. Worth a docstring note.
1/2 Med Existing labels=None callers see ~0.005-0.05 step-level loss drift vs prior versions. Per the PR's own table this only converges at the end; checkpoint-resume / regression-baseline comparisons will pick up the drift. Worth a release-note mention.
2/2 Low The _labels_none_block regex in the test depends on the literal comment # labels-aware path: to delimit the block. Rename or reword the comment and all five tests silently extract the wrong block (or empty).
1/2 Nit The "byte-identical to mlx_lm.tuner.trainer.default_loss" claim is true today but ties to mlx-lm's main. Pin a mlx-lm version range or commit SHA in the comment so a future maintainer can re-verify deterministically.
1/2 Nit test_fast_path_returns_ce_and_ntoks_in_that_order accepts ntoks/n_toks/ntokens; tighten to ntoks since byte-identity is the explicit goal.

Overall: APPROVE_WITH_NITS.

Verified: make_cce_loss_fn (line 266) and make_vlm_cce_loss_fn (line 1456) are not touched by this PR and share no state with the modified function — the MLX CCE kernel path is unaffected.

See inline comments for details and suggested fixes.

Comment thread unsloth_zoo/mlx/utils.py
if labels is None:
inputs, targets = batch[:, :-1], batch[:, 1:]
else:
# byte-identical to mlx_lm.tuner.trainer.default_loss

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Med: ce.astype(mx.float32).sum() / ntoks matches mlx-lm exactly (no _safe_token_denominator guard), so byte parity is preserved — but an empty-mask batch (lengths[:,0] > lengths[:,1] for every row) now produces NaN/Inf instead of being smoothed by max(1, ntoks). Realistic only with pathological inputs; mlx-lm has the same hazard. Add a one-line docstring note so the contract is explicit:

Suggested change
# byte-identical to mlx_lm.tuner.trainer.default_loss
# byte-identical to mlx_lm.tuner.trainer.default_loss
# NOTE: caller must ensure ntoks > 0 (empty-mask batches
# divide by zero). mlx-lm has the same precondition.
inputs = batch[:, :-1]

Comment thread unsloth_zoo/mlx/utils.py
When labels is provided, uses labels[:,1:] for targets with
(targets != -100) as the loss mask.
(targets != -100) as the loss mask. The labels=None branch is
byte-identical to ``mlx_lm.tuner.trainer.default_loss``.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Nit: "byte-identical to mlx_lm.tuner.trainer.default_loss" is true today but tied to mlx-lm's main. Pin a known-good commit so a future maintainer can re-verify deterministically:

Suggested change
byte-identical to ``mlx_lm.tuner.trainer.default_loss``.
byte-identical to ``mlx_lm.tuner.trainer.default_loss`` (verified
against mlx-lm 0.x.y commit abc1234; rerun the byte-parity check
if upstream changes).

simulate_mlx_on_torch()


def _labels_none_block():

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Low: this helper's regex relies on either an else: or the literal comment # labels-aware path to mark the end of the fast-path block. If a future refactor renames that comment, all five tests below silently extract an empty or wrong block — assert m still fires (or the substring checks pass vacuously). Pin with a sentinel comment that's hard to rename casually:

Suggested change
def _labels_none_block():
def _labels_none_block():
"""Return the labels=None code block. Relies on the sentinel comment
`# --- baseline-loss-fast-path-end ---` in the source. If you rename
that comment, update this regex AND the sentinel in utils.py."""
from unsloth_zoo.mlx import utils
src = inspect.getsource(utils.make_baseline_loss_fn)
m = re.search(
r"if labels is None:\s*\n(.*?)# --- baseline-loss-fast-path-end ---",
src,
flags=re.DOTALL,
)
assert m, "sentinel comment missing from make_baseline_loss_fn fast path"
return m.group(1)

(and add # --- baseline-loss-fast-path-end --- immediately before the labels-aware branch in utils.py.)

assert "mx.where" in src, (
"make_baseline_loss_fn must still call mx.where on the labels-aware path"
)
assert "safe_targets" in src, "labels-aware path must keep safe_targets"

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Med: every test in this file is source-pin only — none actually executes the loss function on a tensor. The mlx_simulation shim is available; even a shape-only smoke (call loss_fn(stub_model, batch, lengths, labels=None) and assert the return is a 2-tuple of (ce_scalar, ntoks_scalar)) would catch refactors that preserve the regex but break the math. Add at the bottom of the file:

Suggested change
assert "safe_targets" in src, "labels-aware path must keep safe_targets"
def test_fast_path_runtime_shape_smoke():
"""Shape-only end-to-end smoke. Catches refactors that pass every
regex test but produce the wrong tensors at runtime."""
import torch
from unsloth_zoo.mlx.utils import make_baseline_loss_fn
# ... build stub model whose __call__ returns logits of shape (B, T, V)
# ... call loss_fn(stub, batch, lengths, labels=None)
# ... assert tuple of length 2, both 0-dim tensors

`make_baseline_loss_fn` previously routed both the labels=None case
and the labels-aware case through the same code, with three small
differences from `mlx_lm.tuner.trainer.default_loss`:

  1. mask cast to fp32 (`length_mask.astype(mx.float32)`) instead of
     leaving the bool mask alone.
  2. `safe_targets = mx.where(targets == -100, 0, targets)` even when
     there are no -100 entries in the labels=None path.
  3. division by `_safe_token_denominator(ntoks)` (fp32 cast + maximum
     with 1.0) instead of raw `ce.sum() / ntoks`.

Mathematically all three differences are no-ops for the labels=None
case, but they introduce extra nodes into the MLX autodiff graph and
change rounding inside the backward pass. Empirically (paired-seed
data from danielhanchen/unsloth-staging-2 Round BO probes 31 vs 37):

  step 1 loss : identical to mlx-lm (9.769231 in both, forward only)
  step 2 loss : diverges by ~0.01 to 0.06 from mlx-lm
  step 30 loss: 0 in both (converges either way)

The divergence is small and never blocks training (cf_loss = 0 across
all configs), but it means callers of `MLXTrainer` cannot reproduce
mlx-lm CLI's loss curve numerically. This commit:

- Restructures `make_baseline_loss_fn` so the labels=None branch is
  byte-for-byte identical to mlx-lm's default_loss (bool mask, raw
  division, no mx.where). The labels-aware branch keeps the existing
  -100 / safe_targets / fp32-mask machinery — that path needs it.

Tests: tests/test_mlx_baseline_loss_parity.py pins the source so a
future refactor cannot silently re-introduce the divergent code
patterns. The harness uses the torch-based MLX shim so genuine
numerical parity stays the responsibility of the probe matrix on
Apple Silicon; the source-pinning guards against drift between
mlx-lm and zoo in the meantime.
Per code-comment policy: keep WHY (mlx-lm parity fast path); drop the
empirical step-2 divergence numbers and probe references — those live
in commit 22c1c5b's message.
@danielhanchen danielhanchen force-pushed the fix-mlx-baseline-loss-parity branch from df3b531 to 3e8dce8 Compare May 19, 2026 12:35
@danielhanchen danielhanchen merged commit c278d60 into main May 19, 2026
1 of 14 checks passed
@danielhanchen danielhanchen deleted the fix-mlx-baseline-loss-parity branch May 19, 2026 12:58
danielhanchen pushed a commit to mmathew23/unsloth-zoo that referenced this pull request May 27, 2026
Merges 5 main-side mlx fixes (unslothai#673 zero-token CCE, unslothai#679 + unslothai#692 LoRA save
metadata, unslothai#682 invalid label NaN-poisoning, unslothai#688 tool mask). All 13
conflict regions in unsloth_zoo/mlx/utils.py resolved to keep PR unslothai#684's
behavior where it conflicts on semantics:

  - half-open `<` length mask (PR unslothai#684 fix) wins over main's inclusive `<=`
  - `if labels is None` branch preserved (PR unslothai#684 generality) alongside
    main's `_normalize_cce_label_dtype` dtype widening
  - `_get_image_token_ids` legacy wrapper kept alongside main's new
    `_normalize_cce_label_dtype` / `_normalize_numpy_cce_labels`
  - `_mask_label_token_ids` calls `_normalize_cce_label_dtype` first so
    image masking honors main's uint-widening contract
  - HEAD's `_expand_token_replacements` dropped; main's three-function
    split (`_normalize_numpy_cce_labels` + `_expand_image_token_sequences`
    + `_expand_token_runs`) is canonical; duplicate HEAD wrappers removed
  - `_collate_vlm_prompt_completion_batch` reads back the masked labels
    in int64 so image + attention masking survives without narrowing
  - prompt-completion VLM collator routes through `_apply_vlm_label_masks`
    after dtype normalisation so ignore_token_ids and wide invalid ids
    both reach runtime CCE intact
  - `_to_mx_vlm_batch` uses main's `_normalize_cce_label_dtype` for labels
    while keeping PR unslothai#684's token_type_ids / mm_token_type_ids handling
  - `_unsloth_*` prefix filter preserved so the new collated_position_ids
    flag and main's raw-input-ids carrier both get stripped

152 MLX tests pass post-merge.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant