fix(mlx): make_baseline_loss_fn byte-identical to mlx-lm default_loss when labels=None#673
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 22c1c5b4a7
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
|
||
| def _labels_none_block(): | ||
| """Return CODE LINES (comments stripped) for the labels=None fast path.""" | ||
| from unsloth_zoo.mlx import utils |
There was a problem hiding this comment.
Avoid package-level import in the parity test
In the CPU-only test harness where the unsloth package is not installed, this from unsloth_zoo.mlx import utils form can execute unsloth_zoo/__init__.py and raise ImportError: Please install Unsloth... before the shimmed MLX submodule is loaded; running pytest tests/test_mlx_baseline_loss_parity.py then fails on the first test. The existing MLX shim tests avoid this path by importing the full submodule name, so this source-level test should do the same or load the file source directly.
Useful? React with 👍 / 👎.
|
Correction: this PR is a code simplification, NOT a parity fix. Round BP probe 38 ran the mlx-lm manual loop and zoo
So the three "no-op math" differences (fp32 mask, What ACTUALLY explains the Round BO step-2 divergence between This PR is still defensible as a cleanup: the labels=None fast path now mirrors |
danielhanchen
left a comment
There was a problem hiding this comment.
Thank you for the PR! The goal of this PR is to make make_baseline_loss_fn's labels=None branch byte-identical to mlx_lm.tuner.trainer.default_loss so callers of MLXTrainer with no per-token label mask see value-for-value identical loss/grad sequences to the mlx-lm CLI. As a summary, this PR splits the previously-shared loss body into a fast path (labels=None: bool mask, raw ntoks division, no safe_targets mx.where) and a labels-aware path (unchanged — keeps safe_targets, fp32 mask cast, and _safe_token_denominator for train_on_responses_only).
Two independent Opus reviewers were run in parallel on this PR.
| Reviewers | Severity | Finding |
|---|---|---|
| 2/2 | Med | Test is source-pin only (regex over the function body). No numerical/shape smoke. A semantic-breaking refactor that preserves the regex slides through. |
| 2/2 | Med | The new labels=None branch divides by raw ntoks (no _safe_token_denominator guard). mlx-lm has the same divide-by-zero hazard on empty-mask batches, so byte parity is preserved — but the previous code defensively handled it. Worth a docstring note. |
| 1/2 | Med | Existing labels=None callers see ~0.005-0.05 step-level loss drift vs prior versions. Per the PR's own table this only converges at the end; checkpoint-resume / regression-baseline comparisons will pick up the drift. Worth a release-note mention. |
| 2/2 | Low | The _labels_none_block regex in the test depends on the literal comment # labels-aware path: to delimit the block. Rename or reword the comment and all five tests silently extract the wrong block (or empty). |
| 1/2 | Nit | The "byte-identical to mlx_lm.tuner.trainer.default_loss" claim is true today but ties to mlx-lm's main. Pin a mlx-lm version range or commit SHA in the comment so a future maintainer can re-verify deterministically. |
| 1/2 | Nit | test_fast_path_returns_ce_and_ntoks_in_that_order accepts ntoks/n_toks/ntokens; tighten to ntoks since byte-identity is the explicit goal. |
Overall: APPROVE_WITH_NITS.
Verified: make_cce_loss_fn (line 266) and make_vlm_cce_loss_fn (line 1456) are not touched by this PR and share no state with the modified function — the MLX CCE kernel path is unaffected.
See inline comments for details and suggested fixes.
| if labels is None: | ||
| inputs, targets = batch[:, :-1], batch[:, 1:] | ||
| else: | ||
| # byte-identical to mlx_lm.tuner.trainer.default_loss |
There was a problem hiding this comment.
[2/2 reviewers] Med: ce.astype(mx.float32).sum() / ntoks matches mlx-lm exactly (no _safe_token_denominator guard), so byte parity is preserved — but an empty-mask batch (lengths[:,0] > lengths[:,1] for every row) now produces NaN/Inf instead of being smoothed by max(1, ntoks). Realistic only with pathological inputs; mlx-lm has the same hazard. Add a one-line docstring note so the contract is explicit:
| # byte-identical to mlx_lm.tuner.trainer.default_loss | |
| # byte-identical to mlx_lm.tuner.trainer.default_loss | |
| # NOTE: caller must ensure ntoks > 0 (empty-mask batches | |
| # divide by zero). mlx-lm has the same precondition. | |
| inputs = batch[:, :-1] |
| When labels is provided, uses labels[:,1:] for targets with | ||
| (targets != -100) as the loss mask. | ||
| (targets != -100) as the loss mask. The labels=None branch is | ||
| byte-identical to ``mlx_lm.tuner.trainer.default_loss``. |
There was a problem hiding this comment.
[1/2 reviewers] Nit: "byte-identical to mlx_lm.tuner.trainer.default_loss" is true today but tied to mlx-lm's main. Pin a known-good commit so a future maintainer can re-verify deterministically:
| byte-identical to ``mlx_lm.tuner.trainer.default_loss``. | |
| byte-identical to ``mlx_lm.tuner.trainer.default_loss`` (verified | |
| against mlx-lm 0.x.y commit abc1234; rerun the byte-parity check | |
| if upstream changes). |
| simulate_mlx_on_torch() | ||
|
|
||
|
|
||
| def _labels_none_block(): |
There was a problem hiding this comment.
[2/2 reviewers] Low: this helper's regex relies on either an else: or the literal comment # labels-aware path to mark the end of the fast-path block. If a future refactor renames that comment, all five tests below silently extract an empty or wrong block — assert m still fires (or the substring checks pass vacuously). Pin with a sentinel comment that's hard to rename casually:
| def _labels_none_block(): | |
| def _labels_none_block(): | |
| """Return the labels=None code block. Relies on the sentinel comment | |
| `# --- baseline-loss-fast-path-end ---` in the source. If you rename | |
| that comment, update this regex AND the sentinel in utils.py.""" | |
| from unsloth_zoo.mlx import utils | |
| src = inspect.getsource(utils.make_baseline_loss_fn) | |
| m = re.search( | |
| r"if labels is None:\s*\n(.*?)# --- baseline-loss-fast-path-end ---", | |
| src, | |
| flags=re.DOTALL, | |
| ) | |
| assert m, "sentinel comment missing from make_baseline_loss_fn fast path" | |
| return m.group(1) |
(and add # --- baseline-loss-fast-path-end --- immediately before the labels-aware branch in utils.py.)
| assert "mx.where" in src, ( | ||
| "make_baseline_loss_fn must still call mx.where on the labels-aware path" | ||
| ) | ||
| assert "safe_targets" in src, "labels-aware path must keep safe_targets" |
There was a problem hiding this comment.
[2/2 reviewers] Med: every test in this file is source-pin only — none actually executes the loss function on a tensor. The mlx_simulation shim is available; even a shape-only smoke (call loss_fn(stub_model, batch, lengths, labels=None) and assert the return is a 2-tuple of (ce_scalar, ntoks_scalar)) would catch refactors that preserve the regex but break the math. Add at the bottom of the file:
| assert "safe_targets" in src, "labels-aware path must keep safe_targets" | |
| def test_fast_path_runtime_shape_smoke(): | |
| """Shape-only end-to-end smoke. Catches refactors that pass every | |
| regex test but produce the wrong tensors at runtime.""" | |
| import torch | |
| from unsloth_zoo.mlx.utils import make_baseline_loss_fn | |
| # ... build stub model whose __call__ returns logits of shape (B, T, V) | |
| # ... call loss_fn(stub, batch, lengths, labels=None) | |
| # ... assert tuple of length 2, both 0-dim tensors |
`make_baseline_loss_fn` previously routed both the labels=None case
and the labels-aware case through the same code, with three small
differences from `mlx_lm.tuner.trainer.default_loss`:
1. mask cast to fp32 (`length_mask.astype(mx.float32)`) instead of
leaving the bool mask alone.
2. `safe_targets = mx.where(targets == -100, 0, targets)` even when
there are no -100 entries in the labels=None path.
3. division by `_safe_token_denominator(ntoks)` (fp32 cast + maximum
with 1.0) instead of raw `ce.sum() / ntoks`.
Mathematically all three differences are no-ops for the labels=None
case, but they introduce extra nodes into the MLX autodiff graph and
change rounding inside the backward pass. Empirically (paired-seed
data from danielhanchen/unsloth-staging-2 Round BO probes 31 vs 37):
step 1 loss : identical to mlx-lm (9.769231 in both, forward only)
step 2 loss : diverges by ~0.01 to 0.06 from mlx-lm
step 30 loss: 0 in both (converges either way)
The divergence is small and never blocks training (cf_loss = 0 across
all configs), but it means callers of `MLXTrainer` cannot reproduce
mlx-lm CLI's loss curve numerically. This commit:
- Restructures `make_baseline_loss_fn` so the labels=None branch is
byte-for-byte identical to mlx-lm's default_loss (bool mask, raw
division, no mx.where). The labels-aware branch keeps the existing
-100 / safe_targets / fp32-mask machinery — that path needs it.
Tests: tests/test_mlx_baseline_loss_parity.py pins the source so a
future refactor cannot silently re-introduce the divergent code
patterns. The harness uses the torch-based MLX shim so genuine
numerical parity stays the responsibility of the probe matrix on
Apple Silicon; the source-pinning guards against drift between
mlx-lm and zoo in the meantime.
Per code-comment policy: keep WHY (mlx-lm parity fast path); drop the empirical step-2 divergence numbers and probe references — those live in commit 22c1c5b's message.
df3b531 to
3e8dce8
Compare
Merges 5 main-side mlx fixes (unslothai#673 zero-token CCE, unslothai#679 + unslothai#692 LoRA save metadata, unslothai#682 invalid label NaN-poisoning, unslothai#688 tool mask). All 13 conflict regions in unsloth_zoo/mlx/utils.py resolved to keep PR unslothai#684's behavior where it conflicts on semantics: - half-open `<` length mask (PR unslothai#684 fix) wins over main's inclusive `<=` - `if labels is None` branch preserved (PR unslothai#684 generality) alongside main's `_normalize_cce_label_dtype` dtype widening - `_get_image_token_ids` legacy wrapper kept alongside main's new `_normalize_cce_label_dtype` / `_normalize_numpy_cce_labels` - `_mask_label_token_ids` calls `_normalize_cce_label_dtype` first so image masking honors main's uint-widening contract - HEAD's `_expand_token_replacements` dropped; main's three-function split (`_normalize_numpy_cce_labels` + `_expand_image_token_sequences` + `_expand_token_runs`) is canonical; duplicate HEAD wrappers removed - `_collate_vlm_prompt_completion_batch` reads back the masked labels in int64 so image + attention masking survives without narrowing - prompt-completion VLM collator routes through `_apply_vlm_label_masks` after dtype normalisation so ignore_token_ids and wide invalid ids both reach runtime CCE intact - `_to_mx_vlm_batch` uses main's `_normalize_cce_label_dtype` for labels while keeping PR unslothai#684's token_type_ids / mm_token_type_ids handling - `_unsloth_*` prefix filter preserved so the new collated_position_ids flag and main's raw-input-ids carrier both get stripped 152 MLX tests pass post-merge.
Summary
make_baseline_loss_fnpreviously routed both thelabels=Nonecase and the labels-aware case through the same code, with three small differences frommlx_lm.tuner.trainer.default_loss:length_mask.astype(mx.float32)) instead of leaving the bool mask alone.safe_targets = mx.where(targets == -100, 0, targets)even when there are no -100 entries in thelabels=Nonepath._safe_token_denominator(ntoks)(fp32 cast + maximum with 1.0) instead of rawce.sum() / ntoks.This PR restructures the function so the
labels=Nonebranch is byte-for-byte identical to mlx-lm'sdefault_loss; the labels-aware branch keeps the existing -100 / safe_targets / fp32-mask machinery (it needs all three).Why
Mathematically all three differences are no-ops when there are no -100 entries (i.e. on the
labels=Nonepath), but each adds an extra node to the MLX autodiff graph and changes rounding inside the backward pass.Empirical (
unsloth/gemma-3-270m-itsingle-row LoRA memorization, paired-seed data fromdanielhanchen/unsloth-staging-2Round BO probes 31 (mlx-lm manual loop) vs 35 / 37 (zoo MLXTrainer), 5 representative seeds):Step 1's forward loss is identical (same initial weights, same data), but the gradient applied at step 1 differs between mlx-lm and zoo — so step 2 onward the loss curves are 0.01-0.06 apart until both converge to 0. The model still memorizes either way (
cf_loss = 0in 15/15 seeds for every config). But callers cannot reproduce mlx-lm CLI's loss curve numerically, which makes diagnostic side-by-side comparisons useless.After this PR the
labels=Nonefast path produces an autodiff graph value-for-value identical tomlx_lm.tuner.trainer.default_loss, so probes against the same fixture should converge value-for-value with mlx-lm CLI.Behavior
loss_fn(model, batch, lengths)(no labels)mlx_lm.tuner.trainer.default_loss. Changed (parity).loss_fn(model, batch, lengths, labels)(labels-aware, e.g.train_on_responses_only)safe_targets,_safe_token_denominator. Unchanged.Test plan
tests/test_mlx_baseline_loss_parity.pylength_mask.astype(mx.float32),mx.where(safe_targets), or_safe_token_denominatoron thelabels=Nonebranch. Verifies the labels-aware path still usessafe_targets.pytest tests/test_mlx_baseline_loss_parity.py -vpytest tests/test_pr_a_*.py tests/test_mlx_*.pyRelated
Sixth PR in the MLX vs
mlx-lmparity series:#669finetune_last_n_layersknob (zoo).unslothai/unsloth#5564#670#671max_grad_value=Nonedefault + HF/TRL parity (closes#662).#672_create_labeled_batchespadding matches mlx-lm.make_baseline_loss_fnlabels=Nonefast path matches mlx-lm.