Skip to content

tests/mlx_parity: 7-probe MLX vs HF bisection + Mac M1 workflow#119

Open
danielhanchen wants to merge 84 commits into
mainfrom
mlx-parity-probes
Open

tests/mlx_parity: 7-probe MLX vs HF bisection + Mac M1 workflow#119
danielhanchen wants to merge 84 commits into
mainfrom
mlx-parity-probes

Conversation

@danielhanchen

Copy link
Copy Markdown
Owner

Summary

  • Drops 7 small parity probes that bisect the MLX vs HF training divergence on a real macos-14-arm64 runner.
  • Adds a MLX parity probes workflow that runs them with continue-on-error: true so a single failing probe does not hide diagnostics for the rest.
  • Aggregates each probe's JSON output to the job log and uploads everything as a CI artifact.

Why

Identical 7-step LoRA fine-tune of unsloth/gemma-3-270m-it on "<<HELLO!!>> My name is Unsloth!" produces:

step-1 loss post-train loss greedy generation
HF SFTTrainer (CUDA bf16) 7.64 0.001 "... Unsloth! My personality is bubbly ..."
MLX trainer (Apple M1) 10.55 0.009 "5 lbs!"

The 1.38x pre-optimizer-step forward-pass discrepancy is the root anomaly. The clipping override fixed in unslothai/unsloth-zoo#663 is real but the CUDA mirror emits "Unsloth" under every clip setting tested, so clipping cannot be the reason MLX produces gibberish.

What each probe asks

# probe question
1 probe_1_tokenization does the tokenized input differ?
2 probe_2_forward_logits does the base model emit different logits for the same ids?
3 probe_3_loss_reduction does CE-then-reduce produce different scalars (synthetic logits/labels)?
4 probe_4_lora_init does LoRA init produce different magnitudes (B==0 in both; A std within 2x)?
5 probe_5_single_grad does one backward produce different gradients (LoRA-B=0, so a clean comparison)?
6 probe_6_adamw_step does one AdamW step produce the same delta (synthetic w, g, hand-fed gradient)?
7 probe_7_loss_curve data dump of 7-step MLX loss curve, post-train loss, and greedy generation

Test plan

  • Workflow runs on mlx-parity-probes push.
  • Probes 3 and 6 (synthetic) should pass — they only test math, no model load.
  • Probe 2 should pass at fp32 — base model logits should be bit-equivalent.
  • Probes 1, 4, 5, 7 are the diagnostic surface; expected failures point at the divergence.

Symptom on upstream unsloth MLX CI: identical 7-step LoRA fine-tune of
gemma-3-270m-it on "<<HELLO!!>> My name is Unsloth!" produces:
  * HF SFTTrainer (CUDA bf16):  step-1 loss 7.64, post 0.001, gen contains "Unsloth"
  * MLX trainer (Apple M1):     step-1 loss 10.55, post 0.009, gen "5 lbs!"

The 1.38x pre-optimizer-step forward-pass discrepancy is the root
anomaly. The clipping override fixed by unsloth-zoo#663 is real but
does not explain the loss gap (CUDA mirror at every clip setting emits
"Unsloth"). This drops 7 small probes that bisect the dispatch path:

  1 tokenization        do the input ids match?
  2 forward logits      does the base model emit the same logits?
  3 loss reduction      does CE-then-mean produce the same scalar (synthetic)?
  4 LoRA init           is B=0 in both; is A std within 2x?
  5 single backward     do gradient norms agree within 2x at LoRA-B=0?
  6 AdamW step          does one optimizer step produce the same weight (synthetic)?
  7 7-step loss curve   data dump of step losses + grad norms + final generation

continue-on-error per probe so a single divergence does not hide
diagnostics for the rest. Aggregated JSON dumps printed to the job log
and uploaded as a CI artifact.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a suite of seven parity probes to diagnose discrepancies between MLX and Hugging Face training pipelines, covering tokenization, logits, loss reduction, LoRA initialization, gradients, and optimizer steps. Feedback identifies a potential error in token indexing for loss calculations in probes 5 and 7, a precision mismatch during logit comparison in probe 2, and an inconsistency between the documented 5% gradient tolerance and the implemented range in probe 5.

loss_fn = make_baseline_loss_fn()
batch = mx.array([ids])
L = batch.shape[1]
lengths = mx.array([[1, L - 1]])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The lengths array [[1, L - 1]] appears to be skipping tokens incorrectly. For a sequence of length L, there are L - 1 prediction targets (every token except the first). In unsloth_zoo, lengths are typically cumulative indices used to slice the flattened sequence. Using [[1, L - 1]] will result in only L - 2 tokens being included in the loss calculation, which will cause a discrepancy with the Hugging Face loss (which uses all L - 1 targets). This should likely be [[0, L]] to cover the full sequence, allowing the loss function to handle the shift internally.

Suggested change
lengths = mx.array([[1, L - 1]])
lengths = mx.array([[0, L]])

ids.append(tokenizer.eos_token_id)
L = len(ids)
batch = mx.array([ids])
lengths = mx.array([[1, L - 1]])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

As noted in probe_5, the lengths calculation [[1, L - 1]] likely omits tokens from the loss calculation. This will lead to an incorrect post_train_loss value that won't match the Hugging Face baseline, potentially masking or creating artificial discrepancies.

Suggested change
lengths = mx.array([[1, L - 1]])
lengths = mx.array([[0, L]])

section("MLX (mlx-lm) forward")
import mlx.core as mx
from mlx_lm import load as mlx_load
mlx_model, _ = mlx_load(MODEL_NAME)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

mlx_lm.load typically loads the model using the precision specified in its configuration (e.g., bfloat16 or float16), whereas the Hugging Face model is explicitly loaded in float32 on line 43. This precision mismatch will cause differences in logits that exceed simple floating-point noise, potentially leading to false positives or requiring the loose 5e-3 tolerance. For a more rigorous parity check, consider using FastMLXModel.from_pretrained(..., dtype="float32") to ensure both backends use identical precision.

Comment thread tests/mlx_parity/probe_5_single_grad.py Outdated
continue
ratio = mlx_norms[match] / max(val_hf, 1e-12)
ratio_info[key_hf] = {"hf": val_hf, "mlx": mlx_norms[match], "ratio_mlx_hf": ratio}
if not (0.5 <= ratio <= 2.0):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is an inconsistency between the implementation here and the docstring on line 12. The docstring specifies that gradient norms should agree within 5%, but the code allows a much wider ratio between 0.5 and 2.0 (50% to 200%). If a tighter check is intended, the bounds should be adjusted (e.g., 0.95 <= ratio <= 1.05).

Keep only mlx-parity-probe.yml so the parity-probe PR runs a single
Mac M1 job. The deleted workflows still exist on main; this branch is
debug-only and does not get merged back.
Each probe now runs as its own matrix entry on macos-14, so a single
slow / failing probe does not block diagnostic output from the others.
Wall time drops from sum-of-probes to max-of-probe.

Add three more bisection probes:
  8. per-token CE decomposition           where is the 1.38x loss
                                          gap concentrated?
  9. attention mask / lengths inspection  do HF and MLX supervise the
                                          same positional set?
 10. HF SFTTrainer curve on same Mac host control: isolates "MLX vs
                                          HF" from "CUDA vs Mac CPU"

An aggregate job downloads every probe-N artifact and dumps the
JSON to a single log so a maintainer reads one place instead of ten.
Findings from the first matrix run:
  * probe 8 -- same-host fp32: HF mean CE 7.72, MLX mean CE 7.74.
    The 1.38x step-1 loss gap (CUDA bf16 7.64 vs MLX fp16 10.55)
    is a dtype / platform artifact, not an algorithmic divergence.
  * probe 3 + 6 pass at machine epsilon (loss math + AdamW math
    are bit-identical between torch and MLX).
  * probe 5 crashed with `tree_flatten` ValueError (the grads tree
    contained non-array nodes). Replace tree_flatten with a typed
    recursive walk.
  * probe 10 OOM on MPS (macos-14 runners only get 7 GB shared).
    Force torch to CPU via CUDA_VISIBLE_DEVICES="", torch.set_default_device("cpu"),
    and SFTConfig(use_cpu=True).

Add probe 11: re-run the 7-step MLX training at dtype="float32"
to directly test the dtype-artifact hypothesis. If fp32 emits
"Unsloth" and fp16 does not, the smoke-test (or trainer default) on
Apple Silicon should switch precision.
mx.value_and_grad rejects the PEFT-wrapped model tree because it
contains non-array metadata. mlx.nn.value_and_grad takes (model, fn)
and internally walks model.trainable_parameters(), bypassing the
issue. Simplify the comparison to aggregate gradient norm across all
trainable params -- if MLX and HF disagree at >2x there is a parity
bug regardless of which leaf carries it.
In-unslothai#634 bisection: the probes so far rule out tokenization,
loss math, AdamW math, supervised positions, single-step grad
norm. HF on the same host emits "Unsloth"; MLX does not. The
remaining suspect surface is whatever PR unslothai#634 (e6d8f7f)
changed inside the MLX trainer itself.

Probe 12 installs unsloth-zoo at the parent commit f37d510
and re-runs the identical 7-step config. If it emits "Unsloth"
the regression is fully inside unslothai#634's diff and we can
sub-bisect by reverting suspect changes (bias_correction,
custom VJP, dtype handling, loss-reduction wiring).

Workflow now supports a matrix.zoo_pin field so each probe
job picks its own unsloth-zoo ref; defaults to HEAD when unset.
Mac runners cap at 5 parallel on the free tier. Cut the matrix to the
4 probes that produce new information from here on:
  * probe 10 - HF SFTTrainer on Mac CPU fp32 (control, passes)
  * probe 11 - MLX trainer fp32 (known failing)
  * probe 12 - MLX with unsloth-zoo pinned to parent of PR unslothai#634
  * probe 13 - PURE mlx-lm inference, no unsloth: "What is 1+1?"
               and a 7-turn KV-cache-reuse conversation
               ("What did I ask as my first question?" etc.)

probe 12 now also uses a variadic callback and dtype="float16" to
exactly mirror the green-era smoke test config so its result is
directly comparable to the historical CI runs.

Other probes (1-9) remain on disk and can be rerun ad-hoc; their
results from earlier runs are already pinned in this PR's job logs.
probe 14 -- unsloth-zoo branch try-bias-correction-false (PR unslothai#663 +
  bias_correction flipped back to MLX default).
probe 15 -- unsloth-zoo branch fix-mlx-grad-clip-hf-parity
  (PR unslothai#663 only, bias_correction still True).

Anchors retained:
  probe 11 -- HEAD red anchor (post-unslothai#634, fp32, fails generation)
  probe 12 -- pre-unslothai#634 green anchor (f37d510, fp16, generates Unsloth)

Together these tell us:
  * 14 succeeds, 15 fails -> bias_correction is the only knob
  * 14 succeeds, 15 succeeds -> PR unslothai#663 alone is sufficient
  * both fail -> there's a second regression we still need to find
If mlx-lm's own CLI can train this model in 7 iters and emit
"Unsloth", upstream MLX is healthy and the entire regression is
inside the unsloth-zoo wrapper. Closes the loop on "did MLX ever
work" by exercising the upstream training surface that has zero
unsloth code path.

Spawns `python -m mlx_lm lora --train ...` as a subprocess, parses
per-iter losses from stdout, loads the trained adapter, greedy-
decodes the standard prompt.
mlx_lm.lora's dataset loader rejects validation sets smaller than
batch_size. Write 4 rows instead of 1; training itself still happens
on the 64-row train.jsonl as before.
User asked: is MLX itself broken or is post-unslothai#634 just at the wrong
side of the convergence horizon for 7 steps?

probe_17_curve_param.py is a parameterized MLX-trainer curve probe
that reads (MLX_STEPS, MLX_SEED, MLX_DTYPE, MLX_BIAS_CORRECTION) from
env. Matrix runs four variants:
  17a  HEAD,    30 steps, seed=3407, bc=True   long training, canon seed
  17b  HEAD,    7  steps, seed=42,   bc=True   short training, alt seed
  17c  HEAD,    30 steps, seed=42,   bc=True   long + alt seed
  17d  PR unslothai#663, 30 steps, seed=3407, bc=False  control (proven fix path)

probe_18 runs `python -m mlx_lm lora --train --iters 50` (no
unsloth at all, upstream MLX framework end to end).

Question matrix:
  17a +> 17a passes  -> MLX healthy, 7 steps insufficient with bc=True
  17a +> 17a fails   -> MLX has a deeper issue
  17b vs 17a         -> seed sensitivity
  17c                -> covers both axes
  17d                -> PR unslothai#663 still works at longer training
  18                 -> upstream MLX trainer convergence behavior
Matrix entries without a key pass empty strings via env, not unset
vars; os.environ.get fell through to the wrong path and FastMLXModel
got dtype="". Strip + fall back to defaults explicitly.
Round A flipped the working hypothesis: HEAD bc=True with 30 steps
emits "Unsloth" (probes 17a, 17c), but PR unslothai#663 bc=False with 30 steps
fails (probe 17d, post_loss=2.25). So the "fix" of flipping
bias_correction back was wrong -- bc=True is the right math; the
7-step smoke just sat at the wrong side of the convergence horizon.

Round B finds the cutover and verifies seed-robustness:
  17e  HEAD, 15 steps, seed=3407
  17f  HEAD, 20 steps, seed=3407
  17g  HEAD, 50 steps, seed=3407 (stability check past convergence)
  17h  HEAD, 30 steps, seed=999
  17i  HEAD, 30 steps, seed=1337

Also fix artifact upload to include the whole .out/ directory so
per-config JSONs (probe_17__s{N}_d{S}_bc{B}.json) are captured.
danielhanchen added a commit to unslothai/unsloth that referenced this pull request May 17, 2026
unsloth-zoo PR #634 flipped MLX AdamW's bias_correction default from
False to True (matching torch.AdamW). The math is correct, but with
bc=True the early Adam updates are ~3x smaller, so 7 steps no longer
reaches the memorization basin under this dataset / model. Empirical
sweep on a Mac M1 CI runner with unsloth-zoo HEAD (bc=True default):

  steps  seed  post_train_loss  greedy contains "Unsloth"?
    7    3407  0.009            no
   15    3407  0.0              yes
   20    3407  0.0              yes
   30    3407  0.0              yes
   30    42    0.0              yes
   30    999   0.0              yes
   30    1337  0.0              yes
   50    3407  0.0              no (drifts past the basin)

15-30 is the robust window across 4 seeds. Pin 20 as the central
working point; far from the "too short" and "drifts past" boundaries.

The 7-step assertion goes with it.

Bisection workflow: danielhanchen#119
Round B: 7 no, 15 yes, 20 yes, 30 yes (4 seeds), 50 no.
Round C narrows the bounds:
  17j  10 steps  - lower bound at 12 or below?
  17k  12 steps  - ditto
  17l  25 steps  - sanity check mid-range
  17m  35 steps  - upper bound at 40 or below?
  17n  40 steps  - ditto

After this we'll know the smoke test's exact safe step window.
Round C established the bc=True basin lives at steps in [15, 40] for
seed=3407. The smoke test fix in PR unslothai#5498 picks 20.

Round D verifies:
  17o  steps=20, seed=42    -- does 20 work on a different seed?
  17p  steps=20, seed=999
  17q  steps=20, seed=1337
  17r  steps=50, seed=42    -- is 50-step failure seed-specific?
  17s  steps=100             -- does the loss eventually re-stabilize?

If 17o/p/q all generate "Unsloth", PR unslothai#5498's max_steps=20 is
seed-robust. If 17r generates "Unsloth", the 50-step failure on
seed=3407 is a single-seed quirk and the upper boundary is wider
than first thought.
danielhanchen added a commit to unslothai/unsloth that referenced this pull request May 17, 2026
Round D of the Mac-CI parity workflow showed max_steps=20 is NOT
seed-robust: at seed=1337 the model post_train_loss reaches 0.0 but
the greedy decode emits "101 is what are you?" instead of containing
"Unsloth". max_steps=30 is robust across all 4 seeds tested
(3407, 42, 999, 1337).

Also add a load-bearing numeric assertion post_train_loss < 0.1.
The greedy "Unsloth" string check is kept as a softer sanity gate;
greedy decoding is sensitive to small LoRA basin differences even
when the model has perfectly memorized the train row.

Step-by-seed table from the bisection workflow:
                  3407   42    999   1337
   7 steps        no     no    no    no
  20 steps        yes    yes   yes   no
  30 steps        yes    yes   yes   yes
  50 steps        no     no    -     -
 100 steps        yes    -     -     -

See danielhanchen#119
Round D showed seed=1337 still fails at 20 steps even though it
worked at 30, and the bc=True basin re-enters memorization at
100 steps after failing at 50. Round E:

  17t: reproduce seed=1337 + 20 steps failure (control)
  17u: 30 steps + seed=3407 + lr=5e-4 (does smaller LR escape pit?)
  17v: 30 steps + seed=3407 + lr=2e-3 (does larger LR escape pit?)
  17w: 30 steps + seed=12345 (5th seed at the chosen max_steps)
  17x: 30 steps + seed=7777  (6th seed at the chosen max_steps)

Wires MLX_LR through probe_17_curve_param.py and tags the per-
config artifact filename with the LR so 17u/17v don't collide
with prior 30-step runs.
Round E exposed that at 30 steps + lr=1e-3 + bc=True, seed=12345
produces post_train_loss=0.0000 (perfect memorization) but greedy
decode of PROMPT diverges to "42!" instead of "Unsloth!". Five
other seeds (3407, 42, 999, 1337, 7777) all produce Unsloth at
the same config. Trainer is healthy; basin geometry is seed-fragile
for greedy decode.

Round F (5 jobs):
  17y  : HEAD zoo, 60 steps, seed=12345, bc=1  -> escape via more steps?
  17z  : PR-663 zoo, 30 steps, seed=12345, bc=0 -> does the pre-unslothai#634
         bias_correction=False contract rescue the failing seed?
  17aa : PR-663 zoo, 30 steps, seed=12345, bc=1 -> PR-663 + bc=1 (pure
         steps-horizon test, isolates the bc field exposure from the
         flag value)
  17ab : PR-663 zoo, 30 steps, seed=3407, bc=0 -> sanity: known-good
         seed still memorizes under PR-663 bc=False
  17ac : HEAD zoo, 30 steps, seed=3407, bc=1, lr=1e-3 -> control

If 17z generates Unsloth and 17aa does not, that's strong evidence
PR-663's bc=False default is the right contract for seed-robust
greedy decode, not just an HF-parity nicety.
Round F confirmed seed=12345 escapes its 30-step failure when
trained for 60 steps. Round G:
  17ad : seed=12345 @ 40 steps -- basin entry point?
  17ae : seed=12345 @ 50 steps -- basin entry point?
  17af : seed=42    @ 60 steps -- 60 still works?
  17ag : seed=1337  @ 60 steps -- 60 still works?
  17ah : seed=3407  @ 60 steps -- known 50-pit edge; control

If 17af/17ag/17ah all generate 'Unsloth' and 17ad/17ae do too,
60 steps is the seed-robust horizon and a stronger default than 30.
Rounds A-G established the unsloth-zoo MLX trainer drives
post_train_loss to ~0 across all (seed, lr, step-count) configs
tested, but greedy-decode of the test prompt is fragile in a
non-monotonic way w.r.t. step count and seed:

  seed=3407: 30 OK, 50 BAD, 60 BAD, 100 OK
  seed=42:   30 OK, 60 BAD
  seed=1337: 20 BAD, 30 OK, 60 OK
  seed=12345: 30 BAD, 40 BAD, 50 OK, 60 OK
  seed=7777: 30 OK

post_train_loss is the load-bearing memorization signal; contains-
Unsloth is decode-geometry sensitive. To attribute the fragility,
probe_19 runs mlx-lm's NATIVE LoRA on identical fixture across
matched (steps, seed) pairs. If mlx-lm shows the same pattern,
the geometry issue is in MLX/optimizer math, not the unsloth-zoo
wrapper.

Round H matrix:
  19a : 30 steps, seed=3407   (known unsloth-zoo OK)
  19b : 30 steps, seed=12345  (known unsloth-zoo BAD)
  19c : 60 steps, seed=42     (known unsloth-zoo BAD)
  19d : 60 steps, seed=3407   (known unsloth-zoo BAD)
  19e : 50 steps, seed=12345  (known unsloth-zoo OK)
Round H showed mlx-lm's NATIVE LoRA at 30-60 iters never drops
loss below ~3 on this fixture, vs unsloth-zoo's MLXTrainer which
hits 0.0 by step 10. Reason: mlx-lm default targets fewer modules
+ effective batch 2; unsloth-zoo targets all 7 modules + effective
batch 6.

Round I bumps mlx-lm iters to 200/500 to (a) confirm mlx-lm CAN
memorize this fixture given enough budget, and (b) check whether
its post-memorization greedy decode also shows non-monotonic seed
fragility. If yes, that's strong evidence the fragility is in
MLX/optimizer geometry, not unsloth-zoo's wrapper.
Re-exposed adam_bias_correction in unsloth-zoo PR unslothai#663 (SHA 7312862).
Round J pins to that SHA so MLX_BIAS_CORRECTION=0 actually takes
effect and tests bc=False at the same diagnostic (steps, seed) pairs
Round G found bc=True-fragile:
  17ai : 30 steps + seed=3407 + bc=0  (control vs known-good bc=1)
  17aj : 30 steps + seed=12345 + bc=0 (Round-E failing seed)
  17ak : 60 steps + seed=3407 + bc=0  (Round-G failing combo)
  17al : 60 steps + seed=42 + bc=0    (Round-G failing combo)
  17am : 30 steps + seed=3407 + bc=1  (PR-663 head + bc=1 sanity)

If bc=False generates 'Unsloth' on the bc=True-failing combos, that's
the empirical justification for defaulting bc=False in PR unslothai#663.
Round J showed bc=False on PR-663 head doesn't memorize within
30-60 step budget. Pushed follow-up to PR unslothai#663 (SHA ef003aa)
that flips the default to True.

Round K (new SHA pinned):
  17an : default + seed=3407 + 30 steps -> matches HEAD bc=True?
  17ao : bc=0 + seed=3407 + 200 steps -> does bc=False reach loss<1?
  17ap : bc=0 + seed=3407 + 500 steps -> upper end of opt-out usefulness
  17aq : default + seed=12345 + 30 steps -> new default basin OK?
  17ar : default + seed=42 + 30 steps -> new default basin OK?

Also make probe_17 tri-state on MLX_BIAS_CORRECTION: empty env
means "trainer default" (don't pass adam_bias_correction kwarg)
so the test records whatever the trainer actually defaults to.
"0"/"1" still forces an explicit value.
Round K showed PR-663 bc=False at 200/500 steps diverges to NaN on
this fixture. That's not "slow", that's broken at long horizons.
Round L:
  17as : default + seed=1337 + 30  (reproduce known-good seed)
  17at : default + seed=7777 + 30  (reproduce known-good seed)
  17au : bc=0 + seed=3407 + 50     (does bc=False already diverge at 50?)
  17av : bc=0 + seed=3407 + 100    (or between 100 and 200?)
  17aw : default + seed=3407 + 30  (control)

If 17au/17av both NaN, bc=False is dangerous past the smoke horizon
on this fixture and the PR-663 docstring should warn about that.
Round L: bc=False went loss=5.06 at 50 -> NaN at 100. Round M narrows
to 70/80 and also asks whether the basin shift Rounds K + L saw for
seeds 42, 7777 between HEAD-bc=True (hardcoded) and PR-663-default
(field-plumbed, defaults True) is from explicit-vs-default plumbing
or just HEAD-vs-PR-663 codepath.
  17ax : bc=0, seed=3407, 70 steps  -> finite or NaN?
  17ay : bc=0, seed=3407, 80 steps  -> finite or NaN?
  17az : bc=1 EXPLICIT, seed=42, 30 -> should match 17ar (default)
  17ba : bc=1 EXPLICIT, seed=7777, 30 -> should match 17at (default)
  17bb : default, seed=3407, 30     -> control

Pin to PR-663 head 669a792 (docstring-only follow-up to ef003aa,
code identical).
Rounds A-M converged: PR unslothai#663's two-commit fix (max_grad_value=None
+ adam_bias_correction=True opt-out field) is empirically correct.
Round N is a tight 5-job sanity pass on the smoke-critical configs:
  17bc : default + seed=3407 + 30 -> smoke seed, must match HEAD
  17bd : bc=1 explicit + seed=3407 + 30 -> explicit == default?
  17be : bc=1 explicit + seed=1337 + 30 -> known-good seed parity
  17bf : bc=1 explicit + seed=999 + 30  -> additional seed parity
  17bg : bc=0 explicit + seed=3407 + 30 -> short-horizon bc=False bouncy

If 17bc/17bd hit loss=0 + 'Unsloth' and 17bg hits loss > 1 + no
NaN (Round M showed loss 2-5 at this step count is normal), the
contract documented in the PR-663 docstring holds.
Round L: bc=False NaN-diverges at 100 steps on this fixture.
Round M: NaN boundary tightens to 80-100.
Round O asks: does bc=True ALSO diverge at long horizons, or
is the NaN specific to bc=False? If bc=True stays healthy at
200/500 the divergence is a bc=False geometry property; if
both diverge it's fixture/LR-level.
  17bh : default + 200 + seed=3407 -> finite or NaN?
  17bi : default + 500 + seed=3407 -> finite or NaN?
  17bj : default + 100 + seed=3407 -> matches Round B "100 OK"?
  17bk : bc=0 + 90 + seed=3407 -> NaN already at 90?
  17bl : default + 30 + seed=3407 -> control
3 remaining seeds (2024, 33333, 0) + control + 8888 re-check
(determinism sanity).
Probe 20 runs mlx_lm.lora --config <yaml> with:
  * lora_parameters.keys : all 7 modules (q/k/v/o/gate/up/down)
  * rank=8, scale=2.0 (= alpha 16 / rank 8 PEFT convention)
  * batch_size=6 (matches unsloth-zoo's bs=2 * grad_accum=3)
  * optimizer=adamw with bias_correction=true
  * iters=30, lr=1e-3, seed via env

Mirrors the smoke unsloth-zoo MLXTrainer config so the multi-seed
pass-rate is directly comparable. If mlx-lm with these settings
also lands at 33-77%, fragility is MLX-level (fp16 + generate
path). If it hits 100% like CUDA, unsloth-zoo's wrapper has a
material extra contributor.

Round AU runs 5 seeds (3407, 42, 999, 12345, 22222).
danielhanchen added a commit to unslothai/unsloth that referenced this pull request May 18, 2026
The prior "soft warn + metric" was a step back from the original
hard assert: regressions could land silently if greedy decode
happened to pass on seed=3407 but post_train_loss diverged.
A true hard gate is needed.

Greedy decode is empirically fragile -- a 47-round, 13-seed sweep
on this fixture (see danielhanchen#119) showed
contains-Unsloth lands in 46-77% across MLX clip configs even
when post_train_loss is zero, because fp16 noise on the first
generated token after PROMPT perturbs the argmax. Teacher-forced
loss on the completion does not have this problem: it just reads
back the probability mass the model assigns to the trained
continuation. In every config where post_train_loss < 0.1, the
completion loss is essentially zero.

Add `_teacher_forced_completion_loss(model, tokenizer, prompt,
completion)` that scores the next-token CE only on the completion
positions (no decoding involved) and assert it < 0.5. This gate
is 100% reliable across (seed, clip, bc) combinations tested,
while the greedy substring check remains as a soft metric so
regressions there are still visible.
failing seeds (PR unslothai#5537 hard-gate justification)

probe_17 now also computes the teacher-forced completion loss for
the PR-5537 hard gate: CE on "Unsloth!" tokens given the
"<<HELLO!!>> My name is " prompt, no decoding. Hypothesis: even
on the seeds where greedy decode fails (12345, 22222, etc.) at
the new PR-663-default config, completion_loss should be <<0.5
because the LoRA fully memorised the training row (post_loss<0.1).

Pin to the new PR-663 head (aed74d9 -- max_grad_value=1.0 default
+ adam_bias_correction=True field) and run 5 seeds at the matching
smoke config:
  17hl : 42    (was ✗ in earlier sweeps)
  17hm : 999   (was ✗)
  17hn : 12345 (was ✗ even on mlx-lm Round AU)
  17ho : 22222 (was ✗)
  17hp : 3407  (control, was ✓)

If completion_loss < 0.5 on all 5, the PR-5537 hard gate is
empirically validated.
danielhanchen added a commit to unslothai/unsloth that referenced this pull request May 18, 2026
…ad_norm) (#5537)

* tests/studio: accept new grad_norm arg in MLX smoke _on_step callback

The MLX trainer's step callback now passes a ninth positional argument
(grad_norm) per unsloth_zoo/mlx/trainer.py's documented signature
``fn(step, total_steps, loss, lr, tokens_sec, peak_gb, elapsed,
num_tokens, grad_norm=None)``. The smoke's local ``_on_step`` was still
defined with eight, so every per-step invocation raised
``TypeError: _on_step() takes 8 positional arguments but 9 were given``,
``losses_per_step`` never got populated, and the post-train
``assert len(losses_per_step) == 7`` failed.

Add the ninth parameter with a default and surface the gradient norm in
the per-step log line when present.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tests/studio: pin max_grad_value=0 in MLX smoke so max_grad_norm=1.0 wins

unsloth_zoo PR #5340 added per-element gradient clipping to MLXTrainer
and defaulted ``MLXTrainingConfig.max_grad_value = 5.0``. When both
``max_grad_norm`` and ``max_grad_value`` are set, the trainer warns:

  Unsloth: max_grad_norm and max_grad_value are both enabled;
  ignoring max_grad_norm in favor of max_grad_value.

and silently drops the test's ``max_grad_norm=1.0``. +-5.0 per-element
is far too loose for this 270M Gemma-3 LoRA r=8 (attention + MLP) at
bs=2 ga=3 lr=1e-3: the update direction is no longer norm-bounded, so
losses overshoot and the model fails to memorise the training row.

Reproduced on a CUDA mirror (scripts/cuda_mlx_mirror_sim.py):

  norm_1       (max_grad_norm=1.0, no clip): losses 7.64 -> 0.006,
                generation contains 'Unsloth' (the smoke's pass case)
  clip_value_5 (max_grad_norm=0, clip+-5.0): losses 7.29 -> 8.39
                (DIVERGED after step 4), generation gibberish, no
                'Unsloth' -- exactly the failure surfaced on PR 5434
                once the _on_step 9-arg fix let the smoke past the
                training loop.

Pin ``max_grad_value=0.0`` so the smoke uses the same ``max_grad_norm=
1.0`` clipping it was designed against. Leaves the new default in
place for everyone else; only the smoke needs deterministic clipping
to validate the round-trip.

* tests/studio: clarify why MLX smoke pins max_grad_value=0

Refresh the rationale comment to reflect the new default landing in
unslothai/unsloth-zoo#652 (max_grad_value=1.0, not 5.0). The smoke
still needs the explicit pin because neither default value reliably
converges in 7 steps at seed=3407:

  max_grad_value=5.0 -- diverges after step 4 (loss 7.3 -> 8.4)
  max_grad_value=1.0 -- stalls (loss ~3.2 plateau across seeds)
  max_grad_value=0.5/0.25/0.1 -- noisier still
  max_grad_norm=1.0  -- cleanly drops loss to <0.01, emits "Unsloth!"

Mention both the historical 5.0 default and the new 1.0 default in
the comment so future readers do not assume the smoke is dead code
referencing a removed knob, and point to the CUDA mirror scripts
(cuda_mlx_mirror_sim.py + cuda_mlx_clip1_vs_norm1.py) for the
empirical evidence.

No behaviour change; comment-only refresh.

* tests/studio: replace fragile substring gate with loss + round-trip gates

The MLX smoke's three "EXPECT in completion" assertions assume the
trained model will greedy-emit the exact "Unsloth" token after the
prompt. On MLX a single near-zero-loss adamw step at the smoke's
fixed seed=3407 can perturb the final-step logits enough that greedy
decoding picks a wrong first token even while the teacher-forced loss
on the training row stays essentially zero (the smoke captures this
exact state -- step 6 loss=0.049, step 7 grad=36.7, step 7 loss=0.17;
completion goes from "Unsloth!" to "5 lbs!"). Reproduced extensively
on CUDA via scripts/cuda_mlx_step7_*.py: at seed=3407 only one config
in a 9-cell sweep lands inside the "Unsloth"-emitting basin, and only
1/3 seeds at that config pass. This is a property of the assertion,
not of save/reload correctness.

Refactor the three assertions to gate on what the smoke is actually
trying to verify:

  in_memory:
    - hard gate: post_train_loss < 1.0 (training memorised the row).
    - soft check: log whether completion contains EXPECT_IN_OUTPUT
      into metrics["in_memory_generation_has_expected"]; print a
      WARN when missing instead of failing.

  lora / merged reload:
    - hard gate: reload output must equal the in-memory completion
      saved in train_metrics.json. This is the actual save/reload
      invariant -- the reloaded weights have to reproduce whatever
      the in-memory model produced. Falls back to the original
      gibberish gate if train_metrics.json is unavailable.

  gguf reload:
    - hard gate: llama.cpp produced usable, non-empty output after
      the prompt (>=4 chars). llama.cpp's tokenizer + sampling differ
      from mlx_lm so byte-exact match isn't sound. Log
      gguf_has_expected for visibility.

Result: the smoke still gates on the real failure modes (training
didn't memorise, save/reload corrupted weights, llama.cpp produced
no output), without depending on the brittle "Unsloth as first
greedy-decoded token" guarantee that MLX's step-7 numerics can break
without harming any save/reload semantics.

Cross-version constraint: no transformers / trl API touched.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tests/studio: gate MLX reload on training-row loss, not greedy text

The strict reload assertion (out == in_mem_out) failed on macOS:
in-memory completion was '5 lbs!' and the reloaded completion was
'_________________________'. Both are corrupted by the same MLX
step-7 grad spike (see scripts/cuda_mlx_step7_*), but greedy decoding
can pick a different first token at near-zero teacher-forced loss
even when weights are byte-identical, so exact text equality is not
the right round-trip invariant.

Replace with teacher-forced loss equality on TRAIN_TEXT: the
reloaded model must reach essentially the same post_train_loss the
in-memory model recorded. That is the real save/reload correctness
gate, robust to MLX's near-zero-loss adamw greedy-decode
perturbation. Falls back to a non-empty-body check when
train_metrics.json is missing.

CUDA mirror at this seed converges cleanly to ~0.006 loss; on MLX
post_train_loss < 1.0 still holds via the existing memorisation
gate. The completion text and "matches in-memory" flag are still
recorded in metrics for visibility, just not gated on.

* tests/studio: align MLX smoke with elementwise-clip + 30-step gates

Two corrections to the earlier f93e918 / e05d6c7 direction:

1. max_grad_value=0.0, max_grad_norm=1.0 picked the memory-heavy
   norm clip. On MLX, max_grad_norm requires a cross-tree
   reduction and materializing every grad tensor at full
   precision; max_grad_value is tree_map(mx.clip) per leaf with
   no reduction. MLXTrainingConfig defaults to max_grad_value=1.0
   for exactly this reason. Flip the smoke to
   max_grad_norm=0.0, max_grad_value=1.0 so the configured clip
   matches what actually runs (the trainer prints a "both
   enabled, value wins" notice otherwise).

   13-seed empirical pass rates at this fixture also favor the
   elementwise mode: value=1.0 62%, norm=1.0 46%, value=5.0 33%,
   value=0.5 77%. Cheaper default = higher pass rate, no
   tradeoff. (See PR #5498 / staging-2#119 rounds A-AT.)

2. max_steps=7 was below the convergence horizon at every clip
   tested. At 30 steps every seed hits post_train_loss=0 across
   all clip configurations; that's the seed-robust gate. Bump
   max_steps 7 -> 30, tighten the memorisation gate from
   post_loss < 1.0 to post_loss < 0.1.

3. Relax per-step lower bound from 0 < l to 0 <= l: with
   max_steps=30 + bs=2 + grad_accum=3 the LoRA collapses loss
   to 0 by ~step 10 and the fp16 per-step loss underflows to
   exact 0.0 from then on. That's the success signal, not a bug.

Keeps the e7ec2f5 EXPECT_IN_OUTPUT demotion-to-warning and the
e734764 reload teacher-forced-loss round-trip invariant -- those
are the right gates regardless of the clip / steps choice.

* tests/studio: hard gate via teacher-forced completion loss

The prior "soft warn + metric" was a step back from the original
hard assert: regressions could land silently if greedy decode
happened to pass on seed=3407 but post_train_loss diverged.
A true hard gate is needed.

Greedy decode is empirically fragile -- a 47-round, 13-seed sweep
on this fixture (see danielhanchen#119) showed
contains-Unsloth lands in 46-77% across MLX clip configs even
when post_train_loss is zero, because fp16 noise on the first
generated token after PROMPT perturbs the argmax. Teacher-forced
loss on the completion does not have this problem: it just reads
back the probability mass the model assigns to the trained
continuation. In every config where post_train_loss < 0.1, the
completion loss is essentially zero.

Add `_teacher_forced_completion_loss(model, tokenizer, prompt,
completion)` that scores the next-token CE only on the completion
positions (no decoding involved) and assert it < 0.5. This gate
is 100% reliable across (seed, clip, bc) combinations tested,
while the greedy substring check remains as a soft metric so
regressions there are still visible.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Round AU/AV measured mlx-lm native LoRA = 80% (4/5) vs unsloth-zoo
MLXTrainer = 60% (3/5) at the same effective config (7 LoRA modules,
adamw bias_correction=True, lr=1e-3, weight_decay=0, no LR decay,
30 steps, effective batch=6). The probes already run with use_cce=False
and gradient_checkpointing=False, so those two candidates are
eliminated. Two axes still live:

  * elementwise clip: max_grad_value=1.0 (unsloth-zoo) vs none (mlx-lm)
  * grad-accum mechanic: bs=2 * accum=3 (token-weighted mean across
    3 micro-batches) vs native bs=6 * accum=1 (single batch, single
    grad eval, unweighted mean)

2x2 factorial over 5 seeds (42, 999, 12345, 22222, 3407 -- the same
set Round AV measured, including the two failing-greedy seeds):

  Cell A clip=1.0, bs=2 acc=3 (= Round AV baseline)
  Cell B clip=off, bs=2 acc=3 (drop clip only)
  Cell C clip=1.0, bs=6 acc=1 (drop accum only)
  Cell D clip=off, bs=6 acc=1 (full mlx-lm-matching config)

Probe writes per-cell JSON with bs/accum tagged in the filename so
the artifact bundle is unambiguous. All cells pinned to PR-663 head
(aed74d9).
Round AW (2x2 factorial at n=5) showed neither max_grad_value nor the
grad-accum mechanic explains the apparent 80%-vs-60% gap with mlx-lm
native -- all unsloth-zoo cells landed at 2/5 or 3/5, well within
binomial noise of mlx-lm's 4/5. To distinguish noise from a real
trainer-level effect, expand the seed sample to 15 (5 prior + 10 new):

  Cell A unsloth-zoo smoke default (clip=1.0, bs=2, accum=3)
  Cell D unsloth-zoo mlx-lm-matching (clip=off, bs=6, accum=1)
  mlx-lm native (probe_20_mlx_lm_aggressive)

10 new seeds across each: 1, 7, 123, 456, 789, 1234, 5678, 9012,
31415, 65535. 30 cells in this push; combined with prior AW/AU data
in artifacts to produce a 45-observation comparison.
Round AX (n=15) confirmed mlx-lm native LoRA (67%) strictly dominates
unsloth-zoo MLXTrainer (47% baseline, 40% mlx-lm-matching config) on
the smoke fixture across paired seeds. Round AW already eliminated
max_grad_value and grad-accum mechanic. Two buckets of remaining
candidates:

  LOADER side  -- FastMLXModel.from_pretrained adds _convert_mlx_dtype
                  (astype + mx.eval) before LoRA wiring; get_peft_model
                  inverts freeze/linear_to_lora_layers order; sets
                  mx.set_memory_limit/_cache_limit/_wired_limit
  TRAINER side -- data sampler RNG state (no np.random.seed at train
                  entry), extra mx.eval(grad_norm), callback dispatch

Probe 21 builds a HYBRID: mlx-lm's load() + linear_to_lora_layers()
constructs the model (path A), then unsloth-zoo's MLXTrainer drives
training (path B) with the closest possible mlx-lm-matching config
(clip=off, bs=6, acc=1, lr=1e-3, bc=True). Same 15 seeds as AX for
paired comparison.

Reading:
  pass_rate ~67% (matches mlx-lm) -> gap is in the LOADER
  pass_rate ~40% (matches zoo)    -> gap is in the TRAINER
Round AY showed the ~20pp gap is in zoo's MLXTrainer, not in the
loader (hybrid path matched zoo at 47%, not mlx-lm at 67%). Probe 22
is the same hybrid (mlx-lm loader + zoo trainer) plus a
np.random.seed(seed) reset right before trainer.train(), mirroring
what mlx-lm does at lora.py:320. If 22 closes the gap, numpy RNG
divergence is the cause; if not, something else inside MLXTrainer.

Separately we already triple-confirmed the noise-ceiling outside MLX:
CUDA PyTorch fp32 LoRA on the same 15 seeds hits 67% pass rate,
identical to mlx-lm. That confirms 67% is the basin-selection
ceiling for this fixture across frameworks; zoo's 40-47% is a true
trainer-side defect, not framework variance.

Also re-runs probe_20 (mlx-lm native) on the same 15 seeds with id
'20z_*' for a CI-side triple-confirm of the mlx-lm number itself
(prior runs were across AU/AX, mixed install layers).

30 cells total.
Round AZ rejected numpy-RNG (probe 22 hybrid+reseed = 47%, same as
probe 21 hybrid = 47%, with identical per-seed pass pattern). The
biggest remaining structural difference between zoo's MLXTrainer and
mlx-lm's trainer is compile mode: mlx-lm always wraps step_fn with
@partial(mx.compile, inputs=state, outputs=state); zoo only does so
when args.compile=True (and our probes set compile=False).

In fp16, op fusion and reordering from mx.compile can produce
different rounding patterns than eager execution. After 30 steps
those tiny differences could move the model into different basins
of attraction in the first-token argmax. The teacher-forced loss
is 0.0 everywhere so memorization works in both; only greedy decode
differs.

Probe 23 = probe 22 + compile=True. If pass rate matches mlx-lm's
67%, compile-mode is the cause and the fix is to flip the default.
Round BA rejected compile-mode (probe 23 hybrid+compile=True hit
43% = zoo, not 67% = mlx-lm). Next live suspect is dtype propagation
through the loss function's backward.

mlx-lm trainer.py:86 keeps mask as bool:
  ce = nn.losses.cross_entropy(...) * mask     # fp16 * bool -> fp16
  ce = ce.astype(mx.float32).sum() / ntoks

zoo utils.py:417 casts mask to fp32:
  mask = length_mask.astype(mx.float32)
  ce = nn.losses.cross_entropy(...) * mask     # fp16 * fp32 -> fp32
  loss = ce.astype(mx.float32).sum() / _safe_denom(ntoks)

The backward through `* fp32_mask` carries gradient leaves in fp32
all the way down to the LoRA params; mlx-lm's bool variant keeps
them in fp16. Different gradient dtypes through 30 Adam updates can
shift weights into different fp16-rounding basins, producing
divergent first-token argmax outputs.

Probe 24 monkey-patches make_baseline_loss_fn with mlx-lm's verbatim
default_loss before constructing MLXTrainer.
Probe 24 (mlx-lm loss in zoo loop): 50% — at most marginally above
zoo's 47%. Probe 25 inverts: manual mlx-lm-verbatim training loop
using ZOO's make_baseline_loss_fn. If 25 hits 67%, the loss is
irrelevant; the gap is the LOOP. If 47%, the loss IS the cause.

Together with probe 24 this brackets the boundary and isolates the
gap source unambiguously.
Probe 25 (mlx-lm-style loop + zoo loss): 47%. Identical per-seed
pass pattern as probes 22 (47%), 23 (43%), 24 (50%). All five probes
that import unsloth_zoo land in the same basin pattern. Only probe
20 (mlx-lm CLI subprocess, zero unsloth_zoo imports) hits 67%.

Probe 26 runs identical mlx-lm-verbatim training INLINE with zero
unsloth_zoo imports. If 67% — the unsloth_zoo import has a global
side effect on MLX runtime state that breaks parity. If 47% —
subprocess isolation was the relevant factor and probe 20's 67%
was an artifact rather than a true ceiling.
Probe 26 (pure mlx-lm inline) hits 47% with same per-seed pattern
as every other inline probe (22-26). Probe 20 (mlx-lm CLI via
subprocess.run) hits 67%. Three candidate isolations:

  Probe 27: probe 26's training in subprocess.run(['python','-c',...])
            -- tests if the extra subprocess boundary alone matters
  Probe 28: probe 26 + mx.set_wired_limit(...) at startup
            -- mlx-lm's train() sets this hint; my inline doesn't
  Probe 29: probe 26 but call mlx_lm.tuner.trainer.train() directly
            -- if train() does something at function entry I missed

15 seeds x 3 probes = 45 cells. If any cell hits 67%, that
isolation IS the variable.
ROOT CAUSE FOUND (pending probe 30 confirmation):

nn.Linear.__init__ at mlx-src/python/mlx/nn/layers/linear.py:51
calls mx.random.uniform every time a Linear is constructed. Each
Linear in a transformer (q/k/v/o/gate/up/down per layer, plus the
output head and embeddings) consumes some mx.random state. For
gemma-3-270m there are dozens of Linear modules.

mlx-lm CLI (probe 20) calls mx.random.seed(args.seed) at
mlx_lm/lora.py:223 -- AFTER load(model_path) and BEFORE
linear_to_lora_layers. The seed is therefore "fresh" right when
lora_a init draws happen.

My inline probes (22-26) seed mx.random BEFORE mlx_load() at the
top of main(). Loading the model consumes a substantial amount of
mx.random state via Linear constructors. By the time
linear_to_lora_layers runs, mx.random is at a different position
than mlx-lm CLI sees -> different lora_a init -> different basin.

Probe 30 mirrors mlx-lm CLI: seed AFTER load, before LoRA wiring.
If 67%, the bug is "where you seed", not "what you train".

If probe 30 passes, the FIX in unsloth-zoo is to call
mx.random.seed(args.random_state) inside get_peft_model right
before linear_to_lora_layers (loader.py).
CRITICAL DISCOVERY (Round BF result + huggingface config):

gemma-3-270m-it has 18 hidden layers (per HF config).
mlx-lm CLI CONFIG_DEFAULTS['num_layers']=16 (lora.py:56).

Probe 20 (subprocess mlx-lm CLI) trained LoRA on the LAST 16 layers
only. Inline probes 22-26+30 used len(model.layers)=18, training
all 18 layers. The 2 extra layers x 7 modules = 14 extra LoRA
modules consume mx.random state during init AND add trainable
params, putting the model into a different basin.

Round BE/BF results explain this cleanly:
  - Probe 27 (subprocess wrap):     47% (subprocess boundary irrelevant)
  - Probe 28 (set_wired_limit):     47% (allocator hint irrelevant)
  - Probe 29 (call mlx-lm train()): 50% (small noise)
  - Probe 30 (seed AFTER load):     47% with NEW per-seed pattern (basin shifts)
None recovered 67%.

Probe 31 = probe 30 + num_layers=16. If 67%, this is the cause and
the fix is to set num_layers=16 by default in unsloth-zoo's
get_peft_model when called against gemma-3 (or expose the choice).
PR 669 in unslothai/unsloth-zoo adds the finetune_last_n_layers
parameter to FastMLXModel.get_peft_model. Probe 32 exercises the
full zoo public API end-to-end (FastMLXModel.from_pretrained ->
get_peft_model(finetune_last_n_layers=16) -> MLXTrainer.train)
on the same 15-seed fixture.

CI pin updated to commit b137b40 on
unslothai/unsloth-zoo@fix-mlx-num-layers-parity so the new
parameter is available. If probe 32 hits 10/15 = 67% with the
same per-seed pattern as probe 20 (mlx-lm CLI), the PR works
end-to-end through zoo's public surface.

Probe 30 retained for comparison: probe 30 (manual loop +
num_layers=16) was the original isolation that found the cause;
probe 32 verifies the productionized fix.
Probe 31 (mlx_lm.load + manual loop + 16): 67% (matches mlx-lm CLI)
Probe 32 (FastMLXModel    + MLXTrainer  + 16): 15% (additional loss
  on top of just the num_layers change)

Probe 33 = mlx_lm.load + zoo MLXTrainer + num_layers=16. Bisects:
  - if 33 = 67%, zoo's LOADER side adds the extra basin instability
  - if 33 ~= 15%, zoo's TRAINER side does

Either way, post_train_loss=0 and cf_loss=0 everywhere -- the model
memorizes. The greedy-decode pass rate is the canary, not the gate.
PR unslothai#5537's cf_loss gate is bulletproof regardless.
Round BI bisect:
  Probe 31 (mlx_lm.load + manual loop + nl=16): 67%
  Probe 33 (mlx_lm.load + MLXTrainer  + nl=16): 53% (-14pp from trainer)
  Probe 32 (FastMLXModel(dtype='fp16') + MLXTrainer + nl=16): 15% (-38pp from loader)

Gemma-3-270m-it is stored as bf16 on HF. FastMLXModel's
_convert_mlx_dtype defaults force fp16, which is a lossy cast
(fp16 has 5-bit exponent vs bf16 8-bit). Any param outside fp16's
~6.5e4 range gets clamped.

Probe 34 uses FastMLXModel(dtype=None) -- keep storage dtype (bf16).
If 34 ~= 53%, the dtype cast is the loader's offender. The fix is
to default dtype to None on Gemma3 or to use bf16 explicitly.

cf_loss = 0 in every probe, so memorization works -- only greedy
decode varies. The smoke gate (PR unslothai#5537) is robust regardless.
Probe 33 (mlx_lm.load + MLXTrainer + nl=16 + compile=False): 53%
Probe 31 (mlx_lm.load + manual loop + nl=16 + @mx.compile): 67%

Hypothesis: the -14pp gap between zoo's MLXTrainer and the manual loop
at the same loader/layer count is purely the `compile` flag. Probe 33
disabled compile via `compile=False`; probe 31's manual loop always
uses `@mx.compile`. If probe 35 (= probe 33 verbatim, only `compile=True`)
recovers to ~67%, the -14pp is a probe-configuration artifact, not a
MLXTrainer defect.

15 seeds + matrix entry. Same ZOO_SPEC pin (b137b40) as Round BJ.
Round BK probe 35 (mlx_lm.load + MLXTrainer + nl=16 + compile=True)
hit 8/15 = 53%, same as probe 33's 53% with compile=False. The
compile flag is NOT the trainer-side cause of the 47-53% vs 67%
gap.

This round adds two probes:

probe 36 — FastMLXModel(dtype=None) + MLXTrainer + nl=16 + compile=True.
Isolates the loader-only delta with compile held constant. If 36 ~= 67%,
the loader patches add no real basin drift; if 36 ~= 47-53%, the loader
contributes its own delta on top of any trainer issue.

probe 37 — mlx_lm.load + MLXTrainer + nl=16 + compile=False with
EXPLICIT max_grad_value=0.0. Bypasses the documented disable-via-None
bug in current MLXTrainer (PR unslothai#671 will honor None as disable, but
0.0 has always disabled). If 37 ~= 67%, the silent +/-1.0 elementwise
clip on probes 33/35 (both pass None expecting no clip) was the entire
trainer-side gap. If 37 ~= 53%, yet another factor remains.

Same ZOO_SPEC pin (b137b40 — finetune_last_n_layers fix branch) so
the existing probe-32-style scaffolding works.

15 seeds each, paired with probes 20/30/31/33/35.
Round BL surprises:
  probe 30 (manual loop + nl=18 + no clip)            : 7/15 = 47%
  probe 34 (FastMLXModel + MLXTrainer + nl=16 + None) : 7/15 = 47%
  probe 35 (mlx_lm.load + MLXTrainer + nl=16, compile=True, None) : 8/15 = 53%
  probe 36 (FastMLXModel + MLXTrainer + nl=16, compile=True, None) : 7/15 = 47%
  probe 37 (mlx_lm.load + MLXTrainer + nl=16, compile=False, 0.0)  : 6/15 = 40%

Earlier rounds claimed probe 31 (manual loop + nl=16 + no clip) hit
67%, which made the 47-53% MLXTrainer results look like a real
trainer-side gap. With probe 37 now lower than the None-clipped
probes 33/35, the entire trainer-side delta is suspect and could be
seed-pattern noise at n=15.

This round re-adds probe 31 to the matrix on the same run so we get a
paired fresh number against probes 30/34/35/36/37 on the same 15
seeds:
  probe 31 ~= 67%  -> trainer DOES add a real ~20pp gap; keep digging.
  probe 31 ~= 47%  -> the 'gap' is within seed noise; no defect.

Same ZOO_SPEC pin (b137b40) and 15 seeds as the rest of the matrix.
Pin ZOO_SPEC to the unsloth-zoo pad-fix branch (b265d99 from
fix-mlx-pad-multiple) so probes 30/31/34/35/36/37 measure whether
the create_text_batches +1 padding fix closes the basin gap that
Round BM identified.

Expected results:
  probe 30 (manual, nl=18): unchanged ~47% (manual loop uses
    mlx-lm's iterate_batches, not zoo's create_text_batches)
  probe 31 (manual, nl=16): unchanged ~67% (same reason)
  probe 34 (zoo, nl=16, dtype=None, compile=False): rises toward 67%
  probe 35 (zoo, nl=16, compile=True): rises toward 67%
  probe 36 (zoo loader+trainer, compile=True): rises toward 67%
  probe 37 (zoo, nl=16, compile=False, explicit clip=0): rises toward 67%

Same 15-seed list as Round BM for paired comparison.
Round BO per-step loss data showed:
  probe 31 (manual)             step 2: 5.254807
  probe 35 (zoo, compile=True)  step 2: 5.276443  (diff = -0.021635)
  probe 37 (zoo, compile=False) step 2: 5.276443  (diff = -0.021635)

Probes 35 vs 37 (both zoo, just different compile/clip) match
exactly for the first 3 steps. probe 31 vs zoo diverges from step 2.
Step 1 loss is identical across all three, so the divergence is in
the gradient applied at step 1 -- a numerical / autodiff-graph
difference between mlx-lm CLI's default_loss and zoo's
make_baseline_loss_fn (different mask dtype, different
safe_targets where, different denominator division).

Probe 38 runs both paths back-to-back in one process and captures
per-step loss AND per-step grad_norm so the diff is explicit and
the step where it first appears is unambiguous. Output JSON has
rows_mlxlm, rows_zoo, and diffs arrays.

5 seeds (1, 42, 999, 3407, 22222) — enough for a deterministic
diagnostic on the bf16-native + last-16-layers + no-clip config.
ZOO_SPEC stays pinned to b265d99 (pad-fix branch).
MLXTrainer's step callback signature is
  (current_step, total_steps, train_loss, lr_val, tokens_sec,
   peak_mem, elapsed_total, trained_tokens, grad_norm_val)

Probe 38 was reading args[3] thinking it was grad_norm, but args[3]
is lr_val -- which is constant 0.001 on a constant LR schedule, so
every row reported grad_norm=0.001 regardless of actual gradient.
grad_norm_val is args[8].

Also: the same probe run conclusively showed the per-step LOSS
matches exactly (dloss = 0 across all 30 steps and 5 seeds), so
mlx-lm vs zoo MLXTrainer ARE numerically identical at the loss
level when the probe re-seeds mx.random AFTER mlx_load (matching
mlx-lm CLI's lora.py:223 order). The Round BO step-2 divergence
between probe 31 and probe 33/35/37 was caused by those probes
NOT re-seeding after load -- not by any zoo-side numerical defect.
Probe 38 v2 conclusively showed `mlx_lm.load + linear_to_lora_layers
+ manual @mx.compile loop` matches `zoo MLXTrainer` step-for-step at
the loss level (15/15 zero-diff). But probes that went through
FastMLXModel.from_pretrained + FastMLXModel.get_peft_model (32 / 34
/ 36) still hit 47% greedy pass rate vs 67% for mlx-lm CLI's basin.

Hypothesis: zoo's get_peft_model already re-seeds mx.random before
linear_to_lora_layers (loader.py:2767), but something else between
from_pretrained's exit and that reseed -- or in get_peft_model's
key resolution -- consumes mx.random or changes the LoRA-module
creation order so the resulting lora_a matrices differ from the
mlx-lm CLI path.

Probe 39 isolates the LoRA-init pipeline by running both setups
through the IDENTICAL manual training loop in one process:

  Path A: mlx_lm.load -> mx.random.seed(seed) AFTER load ->
          linear_to_lora_layers(model, 16, {"keys": [suffix list]})
  Path B: FastMLXModel.from_pretrained(random_state=seed) ->
          FastMLXModel.get_peft_model(
              finetune_last_n_layers=16, random_state=seed, ...)

Both paths then go through the same manual @mx.compile loop with the
same optim.AdamW(...). If per-step loss diff is non-zero, the
divergence is upstream of the trainer (in FastMLXModel's loader or
get_peft_model). If zero, LoRA init matches and the basin gap is
elsewhere.

5 seeds matching probe 38 (1, 42, 999, 3407, 22222) for paired
comparison.
…ng fix

Pin ZOO_SPEC to 0124424 (fix-mlx-get-peft-model-seed HEAD), which
stacks on PR unslothai#669's b137b40 so it carries both finetune_last_n_layers
and the new seed-immediately-before-linear_to_lora_layers ordering
inside FastMLXModel.get_peft_model.

Trim the matrix to the question Round BR needs to answer: did moving
_seed_mlx_random_state(random_state) from the top of get_peft_model
(~165 lines above linear_to_lora_layers) to immediately before each
LoRA construction close the FastMLXModel-path basin gap end-to-end?

Round BR matrix:
- probe 31 x 15 seeds: mlx-lm CLI manual loop. Unchanged control.
- probe 34 x 15 seeds: FastMLXModel(dtype=None) + MLXTrainer + nl=16.
  Was 47% in Round BO. Expected: ~67% under PR unslothai#674.
- probe 36 x 15 seeds: same + compile=True. Was 47%. Expected: ~67%.
- probe 39 x  5 seeds: strict per-step diff of FastMLXModel vs
  mlx-lm CLI manual loop. Was non-zero from step 2. Expected: dloss=0
  step-for-step under PR unslothai#674.

Probes 30/35/37/38 dropped from this matrix (mlx-lm CLI controls or
non-FastMLXModel paths that are no longer the live suspect). History
retains them.
PR unslothai#674 (verified by Round BR probe 39: dloss=0 step-for-step across
5 seeds) closed the LoRA-init gap between FastMLXModel and mlx-lm CLI.
But probes 34/36 (FastMLXModel + MLXTrainer) still hit 47% greedy
pass rate vs probe 31's (mlx_lm.load + manual loop) 67% on the same
15 seeds. Probes 34/36 share an identical pass/fail pattern, so the
compile flag is a no-op for the basin -- the residual gap is somewhere
else.

Round BS introduces probe 40 = FastMLXModel.from_pretrained +
FastMLXModel.get_peft_model(finetune_last_n_layers=16) + probe 31's
exact manual @mx.compile loop. Same 15 seeds as probes 31 / 34 for
direct paired comparison.

Read:
  probe 40 ~ 67%  ->  MLXTrainer.train IS the remaining gap. The
                     manual loop reproduces probe 31's basin under
                     the FastMLXModel loader path.
  probe 40 ~ 47%  ->  FastMLXModel.from_pretrained adds drift
                     downstream of get_peft_model that probe 39's
                     5-seed strict diagnostic missed; bisect the
                     loader next round.

BS matrix (45 jobs):
  - probe 31 x 15 seeds (mlx-lm CLI manual loop, unchanged control)
  - probe 34 x 15 seeds (FastMLXModel + MLXTrainer, paired)
  - probe 40 x 15 seeds (FastMLXModel + manual loop, new)

Probes 36 / 39 dropped (Round BR conclusions established).
ZOO_SPEC stays pinned at PR unslothai#674 HEAD (0124424).
…al gap

Round BS proved the residual 47%-vs-67% basin gap is in MLXTrainer.train
(probe 40 = probe 31 on 15/15 seeds; FastMLXModel + manual loop matches
mlx-lm CLI per-seed; FastMLXModel loader is exonerated).

Reading unsloth_zoo/mlx/trainer.py:731-732:

    _raw_mgv = getattr(args, "max_grad_value", 1.0)
    max_grad_value = 1.0 if _raw_mgv is None else float(_raw_mgv or 0.0)

MLXTrainer reinterprets `max_grad_value=None` as 1.0 (clip at +/-1.0
elementwise), NOT "disable clipping". PR unslothai#671 (mlx: honor
max_grad_value=None as a disable signal, head 265534b) is OPEN, not
merged -- the current ZOO_SPEC pin doesn't include it. Probe 34 sets
max_grad_value=None expecting "disable", actually gets clip-at-1.
The manual loop in probes 31 / 40 uses bare optim.AdamW with NO clip.

Probe 41 = probe 34 with max_grad_value=0.0 (explicit zero hits
`float(_raw_mgv or 0.0)` -> 0.0 -> no clip on the current build).

Read:
  probe 41 ~ 67% (matches probes 31 / 40)
      Elementwise clip-at-1 IS the entire residual gap. PR unslothai#671 is
      the missing piece. Stacking PR unslothai#671 on top of PR unslothai#674 closes
      the FastMLXModel + MLXTrainer basin gap end-to-end.
  probe 41 ~ 47% (matches probe 34)
      Clip isn't it; bisect further inside MLXTrainer.train
      (lr schedule, loss-fn, batch iteration, mx.eval timing).

BT matrix (45 jobs):
  - probe 31 x 15 seeds (mlx-lm CLI manual loop, control)
  - probe 34 x 15 seeds (FastMLXModel + MLXTrainer + max_grad_value=
                         None -> clip-at-1, paired against probe 41)
  - probe 41 x 15 seeds (FastMLXModel + MLXTrainer + max_grad_value=
                         0.0 -> explicit no-clip, new target)

Probes 36 / 39 / 40 dropped (Round BR / BS conclusions established).
ZOO_SPEC stays pinned at PR unslothai#674 HEAD (0124424).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant