-
Notifications
You must be signed in to change notification settings - Fork 267
fix(mlx): make_baseline_loss_fn byte-identical to mlx-lm default_loss when labels=None #673
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,124 @@ | ||||||||||||||||||||
| # Unsloth Zoo - Utilities for Unsloth | ||||||||||||||||||||
| # Pin `make_baseline_loss_fn` source so the labels=None fast path stays | ||||||||||||||||||||
| # byte-for-byte equivalent to mlx_lm.tuner.trainer.default_loss. | ||||||||||||||||||||
| # | ||||||||||||||||||||
| # Why pin source rather than run a numerical comparison: the test | ||||||||||||||||||||
| # harness uses a torch-based MLX shim that doesn't faithfully reproduce | ||||||||||||||||||||
| # MLX's autodiff graph or its rounding; an apples-to-apples numerical | ||||||||||||||||||||
| # parity check requires a real MLX runtime (Apple Silicon), so it's | ||||||||||||||||||||
| # done in the Round BP probe matrix on | ||||||||||||||||||||
| # danielhanchen/unsloth-staging-2. Locally we guard against future | ||||||||||||||||||||
| # refactors silently re-introducing the divergent code patterns | ||||||||||||||||||||
| # (fp32-cast mask, mx.where(safe_targets), _safe_token_denominator) | ||||||||||||||||||||
| # that mlx_lm.tuner.trainer.default_loss does NOT do. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||
|
|
||||||||||||||||||||
| import inspect | ||||||||||||||||||||
| import re | ||||||||||||||||||||
| import textwrap | ||||||||||||||||||||
|
|
||||||||||||||||||||
| import pytest | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| @pytest.fixture(autouse=True, scope="module") | ||||||||||||||||||||
| def _install_mlx_shim(): | ||||||||||||||||||||
| from mlx_simulation import simulate_mlx_on_torch | ||||||||||||||||||||
| simulate_mlx_on_torch() | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def _labels_none_block(): | ||||||||||||||||||||
| """Return CODE LINES (comments stripped) for the labels=None fast path.""" | ||||||||||||||||||||
| from unsloth_zoo.mlx import utils | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the CPU-only test harness where the Useful? React with 👍 / 👎. |
||||||||||||||||||||
| src = inspect.getsource(utils.make_baseline_loss_fn) | ||||||||||||||||||||
| m = re.search( | ||||||||||||||||||||
| r"if labels is None:\s*\n(.*?)(?:# labels-aware path|else:\s*\n)", | ||||||||||||||||||||
| src, | ||||||||||||||||||||
| flags=re.DOTALL, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| assert m, "make_baseline_loss_fn must keep a `labels is None` fast path" | ||||||||||||||||||||
| raw = textwrap.dedent(m.group(1)) | ||||||||||||||||||||
| # Strip whole-line comments so test assertions on code don't trip on | ||||||||||||||||||||
| # explanatory prose like "no safe_targets mx.where" in docstrings. | ||||||||||||||||||||
| code_lines = [] | ||||||||||||||||||||
| for line in raw.splitlines(): | ||||||||||||||||||||
| stripped = line.strip() | ||||||||||||||||||||
| if stripped.startswith("#"): | ||||||||||||||||||||
| continue | ||||||||||||||||||||
| code_lines.append(line) | ||||||||||||||||||||
| return "\n".join(code_lines) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def test_no_fp32_mask_cast_in_fast_path(): | ||||||||||||||||||||
| """The labels=None path must NOT cast the bool mask to fp32. mlx-lm's | ||||||||||||||||||||
| default_loss multiplies the cross-entropy result by a raw bool mask; | ||||||||||||||||||||
| casting to fp32 produces a different MLX autodiff graph and shifts | ||||||||||||||||||||
| gradients by ~1e-2 per step on small fixtures.""" | ||||||||||||||||||||
| block = _labels_none_block() | ||||||||||||||||||||
| # Anything that would cast a mask to fp32: `.astype(mx.float32)` or | ||||||||||||||||||||
| # `astype(float32)` immediately on a `mask` / `length_mask` name. | ||||||||||||||||||||
| bad_patterns = ( | ||||||||||||||||||||
| r"length_mask\.astype\(mx\.float32\)", | ||||||||||||||||||||
| r"mask\s*=\s*[^=]*\.astype\(mx\.float32\)", | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| for pat in bad_patterns: | ||||||||||||||||||||
| assert not re.search(pat, block), ( | ||||||||||||||||||||
| f"labels=None fast path must not contain `{pat}`; " | ||||||||||||||||||||
| "matches mlx-lm requires a bool mask." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def test_no_safe_targets_where_in_fast_path(): | ||||||||||||||||||||
| """The labels=None path has no -100 to worry about, so no mx.where | ||||||||||||||||||||
| is needed on targets. The where node was empirically the cause of | ||||||||||||||||||||
| step-2 loss divergence vs mlx-lm CLI (Round BO probe_31 vs probe_37).""" | ||||||||||||||||||||
| block = _labels_none_block() | ||||||||||||||||||||
| assert "mx.where" not in block, ( | ||||||||||||||||||||
| "labels=None fast path must not include mx.where on targets; " | ||||||||||||||||||||
| "mlx-lm's default_loss does not call mx.where." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| assert "safe_targets" not in block, ( | ||||||||||||||||||||
| "labels=None fast path must use `targets` directly; " | ||||||||||||||||||||
| "`safe_targets = mx.where(...)` belongs only to the labels-aware path." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def test_no_safe_token_denominator_in_fast_path(): | ||||||||||||||||||||
| """mlx-lm's default_loss divides by raw `ntoks` (int). The fast path | ||||||||||||||||||||
| must match that to preserve the autodiff graph; the safety wrapper | ||||||||||||||||||||
| `_safe_token_denominator` introduces a fp32 cast + maximum() that | ||||||||||||||||||||
| changes rounding in MLX.""" | ||||||||||||||||||||
| block = _labels_none_block() | ||||||||||||||||||||
| assert "_safe_token_denominator" not in block, ( | ||||||||||||||||||||
| "labels=None fast path must divide by raw ntoks for mlx-lm parity. " | ||||||||||||||||||||
| "_safe_token_denominator is fine on the labels-aware path." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def test_fast_path_returns_ce_and_ntoks_in_that_order(): | ||||||||||||||||||||
| """Match the (loss, ntoks) return signature mlx-lm uses; the test | ||||||||||||||||||||
| pins return-order so a future refactor doesn't accidentally swap.""" | ||||||||||||||||||||
| block = _labels_none_block() | ||||||||||||||||||||
| # Look for a `return X, Y` somewhere in the fast path. The variable | ||||||||||||||||||||
| # names are loose (mlx-lm uses `ce`; zoo previously used `loss`), | ||||||||||||||||||||
| # but the order matters. | ||||||||||||||||||||
| m = re.search(r"return\s+(\w+),\s*(\w+)", block) | ||||||||||||||||||||
| assert m, "labels=None fast path must return a (loss, ntoks) tuple" | ||||||||||||||||||||
| loss_name, ntoks_name = m.group(1), m.group(2) | ||||||||||||||||||||
| assert ntoks_name in ("ntoks", "n_toks", "ntokens"), ( | ||||||||||||||||||||
| f"second return value name should look like a token count, got " | ||||||||||||||||||||
| f"{ntoks_name!r}" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def test_labels_aware_path_still_uses_safe_targets(): | ||||||||||||||||||||
| """The labels-aware path (train_on_responses_only) DOES need | ||||||||||||||||||||
| `safe_targets` and the fp32 mask because labels can contain -100.""" | ||||||||||||||||||||
| from unsloth_zoo.mlx import utils | ||||||||||||||||||||
| src = inspect.getsource(utils.make_baseline_loss_fn) | ||||||||||||||||||||
| # The labels-aware path lives after the fast path's `return`. Look | ||||||||||||||||||||
| # at the full source to verify the machinery still exists somewhere. | ||||||||||||||||||||
| assert "mx.where" in src, ( | ||||||||||||||||||||
| "make_baseline_loss_fn must still call mx.where on the labels-aware path" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| assert "safe_targets" in src, "labels-aware path must keep safe_targets" | ||||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [2/2 reviewers] Med: every test in this file is source-pin only — none actually executes the loss function on a tensor. The
Suggested change
|
||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -412,21 +412,28 @@ def make_baseline_loss_fn(): | |||||||||||
| Returns: | ||||||||||||
| A function (model, batch, lengths, labels=None) -> (loss, ntoks). | ||||||||||||
| When labels is provided, uses labels[:,1:] for targets with | ||||||||||||
| (targets != -100) as the loss mask. | ||||||||||||
| (targets != -100) as the loss mask. The labels=None branch is | ||||||||||||
| byte-identical to ``mlx_lm.tuner.trainer.default_loss``. | ||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [1/2 reviewers] Nit: "byte-identical to mlx_lm.tuner.trainer.default_loss" is true today but tied to mlx-lm's
Suggested change
|
||||||||||||
| """ | ||||||||||||
| def loss_fn(model, batch, lengths, labels=None): | ||||||||||||
| if labels is None: | ||||||||||||
| inputs, targets = batch[:, :-1], batch[:, 1:] | ||||||||||||
| else: | ||||||||||||
| # byte-identical to mlx_lm.tuner.trainer.default_loss | ||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [2/2 reviewers] Med:
Suggested change
|
||||||||||||
| inputs = batch[:, :-1] | ||||||||||||
| targets = labels[:, 1:] | ||||||||||||
| targets = batch[:, 1:] | ||||||||||||
| logits = model(inputs) | ||||||||||||
| steps = mx.arange(1, targets.shape[1] + 1) | ||||||||||||
| mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) | ||||||||||||
| ce = nn.losses.cross_entropy(logits, targets) * mask | ||||||||||||
| ntoks = mask.sum() | ||||||||||||
| ce = ce.astype(mx.float32).sum() / ntoks | ||||||||||||
| return ce, ntoks | ||||||||||||
| # labels-aware path: train_on_responses_only style masking. | ||||||||||||
| inputs = batch[:, :-1] | ||||||||||||
| targets = labels[:, 1:] | ||||||||||||
| logits = model(inputs) | ||||||||||||
| steps = mx.arange(1, targets.shape[1] + 1) | ||||||||||||
| length_mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) | ||||||||||||
| if labels is None: | ||||||||||||
| mask = length_mask.astype(mx.float32) | ||||||||||||
| else: | ||||||||||||
| mask = mx.logical_and(targets != -100, length_mask).astype(mx.float32) | ||||||||||||
| mask = mx.logical_and(targets != -100, length_mask).astype(mx.float32) | ||||||||||||
| # Replace -100 with 0 before CE — MLX has no ignore_index; | ||||||||||||
| # the mask already zeros out these positions in the loss. | ||||||||||||
| safe_targets = mx.where(targets == -100, 0, targets) | ||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[2/2 reviewers] Low: this helper's regex relies on either an
else:or the literal comment# labels-aware pathto mark the end of the fast-path block. If a future refactor renames that comment, all five tests below silently extract an empty or wrong block —assert mstill fires (or the substring checks pass vacuously). Pin with a sentinel comment that's hard to rename casually:(and add
# --- baseline-loss-fast-path-end ---immediately before the labels-aware branch inutils.py.)