Skip to content

fix(mlx): honor max_grad_value=None as a disable signal#671

Merged
danielhanchen merged 4 commits into
mainfrom
fix-max-grad-value-none
May 19, 2026
Merged

fix(mlx): honor max_grad_value=None as a disable signal#671
danielhanchen merged 4 commits into
mainfrom
fix-max-grad-value-none

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

  • MLXTrainingConfig.max_grad_value is typed float | None = 1.0, but the resolution code silently rebound an explicit None back to the default 1.0, so callers who set max_grad_value=None to disable elementwise gradient clipping were still getting clipping at $\pm$1.0.
  • This PR treats None the same as 0.0 (disable) and clarifies the field docstring. Default behavior is unchanged; only callers explicitly passing None see the corrected disable semantics.

Why

The Optional[float] type annotation naturally suggests None means "off". The field docstring previously said "Set 0.0 to disable" without mentioning what None would do. The internal code on line 732 of unsloth_zoo/mlx/trainer.py was:

```python
_raw_mgv = getattr(args, "max_grad_value", 1.0)
max_grad_value = 1.0 if _raw_mgv is None else float(_raw_mgv or 0.0)
```

i.e. None $\to$ rebound to 1.0 $\to$ clipping enabled at $\pm$1.0. That is the opposite of what users expected.

Surfaced while writing MLX parity probes against mlx-lm's native loop (which does no elementwise clipping): probes that set max_grad_value=None with the explicit comment "match mlx-lm: no elementwise clip" were silently still being clipped.

Behavior

  • MLXTrainingConfig() $\to$ max_grad_value = 1.0 (clipping on). Unchanged.
  • MLXTrainingConfig(max_grad_value=None) $\to$ clipping disabled. (Was: clipping at $\pm$1.0.)
  • MLXTrainingConfig(max_grad_value=0.0) $\to$ clipping disabled. Unchanged.
  • MLXTrainingConfig(max_grad_value=2.5) $\to$ clipping at $\pm$2.5. Unchanged.

Test plan

  • tests/test_mlx_max_grad_value_none.py $\to$ five cases pinning the four-way decision (default, None, 0.0, positive) plus dataclass round-trip.
  • Local: pytest tests/test_mlx_max_grad_value_none.py -v $\to$ 5 passed.

Related

This PR is part of an ongoing MLX vs mlx-lm parity bisection:

  • #669 $\to$ finetune_last_n_layers (layer-selection mismatch).
  • unslothai/unsloth#5564 $\to$ same knob on the CUDA path.
  • #670 $\to$ bf16 $\to$ fp16 downcast warning.
  • This PR $\to$ honor max_grad_value=None as documented.

`MLXTrainingConfig.max_grad_value` is typed `float | None = 1.0`. The
inline docstring on the field was ambiguous and the resolution code
silently rebound `None` to the default 1.0, so callers who set
`max_grad_value=None` to mirror mlx-lm CLI's no-clip default were
still getting elementwise clipping at +/-1.0.

This commit:
- Treats None the same as 0.0 (disable elementwise clipping). Matches
  the natural reading of an `Optional[float]` field where None means
  "off".
- Updates the field docstring to make the rule explicit.

Behavior preserved for the default and for explicit numeric values:
- Field default (no override) -> still clips at 1.0.
- max_grad_value=2.5 -> still clips at 2.5.
- max_grad_value=0.0 -> still disables (unchanged).

Only callers explicitly passing `None` see a change; they now get the
disable behavior they intended.

Tests: tests/test_mlx_max_grad_value_none.py pins the four-way matrix
(default, None, 0.0, positive) plus dataclass round-trip.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request ensures that setting max_grad_value=None in MLXTrainingConfig correctly disables elementwise gradient clipping instead of silently defaulting to 1.0. The changes include updating the logic in _train_inner to treat None as a disable signal (0.0), updating the relevant docstrings, and adding a comprehensive test suite in tests/test_mlx_max_grad_value_none.py to verify the fix and prevent regressions. I have no further feedback to provide.

Implements issue #662's recommended fix. Previously
MLXTrainingConfig.max_grad_value defaulted to 1.0 and the trainer
silently zeroed out a user-supplied max_grad_norm whenever both were
non-zero. A user passing max_grad_norm=1.0 -- the HF/TRL standard --
got an elementwise +/-1.0 clip instead, which is mathematically
different (rotates gradient direction per leaf) and produces a
different convergence basin on MLX vs CUDA for identical configs.

Changes in this commit (stacked on top of the None-disables fix
already in this PR):
- MLXTrainingConfig.max_grad_value: float | None = None (was 1.0).
- _raw_mgv = getattr(args, 'max_grad_value', None) (was 1.0).
- Resolution rule: None or 0.0 -> elementwise disabled; positive ->
  opt-in. Mutual-exclusion warn-and-override fires only when the
  user has explicitly opted in to elementwise clipping.
- Field docstring rewritten to spell out the rule.

Default behavior is now:
- MLXTrainingConfig() -> no elementwise clip, no global-norm clip
  (max_grad_norm default is already 0.0). Matches the
  unsloth-staging-2 mlx-parity probes' expectation that None == off.
- MLXTrainingConfig(max_grad_norm=1.0) -> global-norm clipping at
  1.0. Matches HF/TRL Trainer default.
- MLXTrainingConfig(max_grad_value=1.0) -> elementwise clipping at
  +/-1.0 (opt-in by power users).

Tests: tests/test_mlx_max_grad_value_none.py updated to pin the new
default and add a positive-value round-trip case. Six cases pass.
@danielhanchen

Copy link
Copy Markdown
Member Author

Expanded scope: this PR now also implements the full recommended fix from #662.

In addition to honoring max_grad_value=None as a disable signal (the original PR), the field default is now None (was 1.0). This makes the MLX path's grad-clip semantics match HF/TRL: a user calling MLXTrainer with max_grad_norm=1.0 now gets global-norm clipping at 1.0 as they intended, instead of getting elementwise +/-1.0 clipping (which rotates gradient direction per leaf and produced different convergence basins than CUDA for identical configs).

Closes #662.

Behavior summary

  • MLXTrainingConfig() -> both clip modes off. Changed from "elementwise clip at 1.0".
  • MLXTrainingConfig(max_grad_norm=1.0) -> global-norm clip at 1.0. Changed from "max_grad_norm silently dropped". Matches HF/TRL.
  • MLXTrainingConfig(max_grad_value=1.0) -> elementwise clip at +/-1.0 (opt-in). Unchanged.
  • MLXTrainingConfig(max_grad_value=None) -> elementwise disabled. Changed from "silently rebound to 1.0".
  • MLXTrainingConfig(max_grad_norm=1.0, max_grad_value=1.0) -> elementwise wins, max_grad_norm zeroed, notice printed. Unchanged.

Migration impact

Anyone who was relying on the implicit max_grad_value=1.0 clip without setting it explicitly will now get NO clipping unless they also pass max_grad_norm or max_grad_value explicitly. #662 explicitly accepts this trade-off because HF/TRL parity is the goal: a user who wrote MLXTrainingConfig(max_grad_norm=1.0) was already expressing intent to clip at norm 1.0.

Tests updated; six cases pass.

@danielhanchen

Copy link
Copy Markdown
Member Author

Round BT verification on Apple Silicon

This PR is the missing piece for end-to-end MLX parity with mlx-lm CLI through the public FastMLXModel + MLXTrainer API. Probe matrix run pinned to PR #674 HEAD (0124424 -- the seed-ordering fix), bisecting whether elementwise clip-at-1 is the residual basin gap on top of #669 + #674.

Run 26084185821 -- 45 jobs green.

probe configuration pass rate per-seed agreement with probe 31
31 mlx_lm.load + linear_to_lora_layers + manual @mx.compile loop 10/15 = 67% 15/15 (self)
34 FastMLXModel + get_peft_model + MLXTrainer.train with max_grad_value=None -> reinterpreted as 1.0 on current build 7/15 = 47% 8/15
41 same as 34 but max_grad_value=0.0 -> hits the disable branch on current build 10/15 = 67% 15/15

Probe 41 reproduces mlx-lm CLI's basin pattern on every single seed. The only difference between probes 34 and 41 is None vs 0.0 for max_grad_value. On the current build, trainer.py:731-732 reinterprets None as 1.0 (silent elementwise clip), so probe 34 actually clips gradients elementwise while the manual loop in probes 31 / 40 does not.

This PR's change to "honor max_grad_value=None as a disable signal" exactly closes the gap: max_grad_value=None (the new default) becomes equivalent to probe 41's max_grad_value=0.0, which is value-for-value identical to mlx-lm CLI's no-clip path.

Full MLX parity story (3 PRs, all filed)

  1. fix(mlx): expose finetune_last_n_layers for parity with mlx-lm CLI #669 finetune_last_n_layers (matches mlx-lm CLI's CONFIG_DEFAULTS['num_layers']=16) -- trainable-param-set parity.
  2. fix(mlx): seed mx.random immediately before linear_to_lora_layers #674 seed mx.random immediately before linear_to_lora_layers (matches mlx_lm/tuner/lora.py:223) -- LoRA-init parity, verified by probe 39 (dloss=0 step-for-step).
  3. fix(mlx): honor max_grad_value=None as a disable signal #671 (this PR) max_grad_value=None -> disable -- trainer-side update-math parity, verified by probe 41 (15/15 seed agreement with mlx-lm CLI manual loop).

With all three merged, FastMLXModel.from_pretrained -> FastMLXModel.get_peft_model -> MLXTrainer.train produces the same trained model as the mlx-lm CLI path on a per-seed basis.

Per code-comment policy: the field docstring and inline note now state
the resolution rule in one line each; HF/TRL parity rationale and
empirical results live in the commit message of 265534b and the PR
description.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: da43c206c7

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

max_grad_value: float | None = 1.0
# Elementwise clip ([-max_grad_value, max_grad_value], per-leaf).
# None or 0.0 disables; positive opts in and overrides max_grad_norm.
max_grad_value: float | None = None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve the default grad clipping behavior

Changing the dataclass default to None disables elementwise clipping for every caller that uses MLXTrainingConfig() or omits max_grad_value, whereas the previous behavior clipped at ±1.0 and the commit description says default behavior should be unchanged. The intended fix can still let an explicit max_grad_value=None disable clipping while keeping the field default (and the getattr fallback below) at 1.0; otherwise existing MLX training runs silently lose their default gradient clipping.

Useful? React with 👍 / 👎.

PR #671 v1 set max_grad_value=None as the default (no elementwise clip)
to fix the original bug where a user passing max_grad_norm=1.0 silently
got an elementwise +/-1.0 clip. That fix worked, but it lost the memory
benefit of the cheap MLX elementwise default — max_grad_norm uses MLX's
clip_grad_norm helper which materially increases peak memory on bf16
VLM runs.

This commit keeps the API surface (max_grad_value: float | None = None)
but redesigns the trainer-side resolution so None is a SENTINEL meaning
'use the MLX default 1.0 unless the user opted into global-norm clipping':

  | max_grad_value | max_grad_norm | effective                          |
  |----------------|---------------|-------------------------------------|
  | None (default) | 0.0 (default) | elementwise 1.0 (cheap)             |
  | None (default) | 1.0 (user)    | global norm 1.0 (HF/TRL parity)     |
  | 0.0 (user)     | 0.0 (default) | no clipping                         |
  | 0.0 (user)     | 1.0 (user)    | global norm 1.0                     |
  | 5.0 (user)     | 0.0           | elementwise 5.0                     |
  | 5.0 (user)     | 1.0           | elementwise 5.0 (warn, norm dropped)|

The original bug (user passes max_grad_norm=1.0, gets silent clip-at-1.0
elementwise) is still fixed because the resolution now detects 'user
opted into max_grad_norm with default max_grad_value' as a distinct
case and routes through the global-norm path.

Tests rewritten as ten paired (max_grad_value, max_grad_norm) cases
plus a source-level pin so a future refactor cannot silently regress
the four-branch resolution.
@danielhanchen

Copy link
Copy Markdown
Member Author

Redesign: keep cheap MLX default, but user max_grad_norm wins

Per @danielhanchen — defaulting max_grad_value to None (no elementwise clip) lost the memory benefit of MLX's cheap elementwise path. `max_grad_norm` uses `mlx.optimizers.clip_grad_norm`, which materially increases peak memory on bf16 VLM runs, so we want the elementwise clip on by default.

Updated commit `f14f7176` keeps `max_grad_value: float | None = None` as the dataclass field (API unchanged) but the trainer-side resolution now treats None as a sentinel meaning "use the MLX default 1.0 unless the user opted into global-norm clipping":

max_grad_value max_grad_norm effective
None (default) 0.0 (default) elementwise 1.0 (cheap, no memory blowup)
None (default) 1.0 (user) global norm 1.0 (HF/TRL parity)
0.0 (user) 0.0 (default) no clipping at all
0.0 (user) 1.0 (user) global norm 1.0
5.0 (user) 0.0 elementwise 5.0
5.0 (user) 1.0 (user) elementwise 5.0 (warn, norm dropped)

The original bug — user passes `max_grad_norm=1.0` and silently gets clip-at-1.0 elementwise instead — is still fixed because resolution now distinguishes "user opted into max_grad_norm with default max_grad_value" from "user explicitly set both", and routes the former through the global-norm path.

Tests rewritten as ten paired (max_grad_value, max_grad_norm) cases plus a source-level pin (`test_trainer_source_pins_resolution_rule`) so a future refactor cannot silently regress the four-branch resolution. All 10 pass locally.

@danielhanchen danielhanchen left a comment

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! The goal of this PR is to fix a parity bug where a user passing max_grad_norm=1.0 to MLXTrainingConfig was silently rebound to elementwise clip at +/-1.0 (HF/TRL parity break). As a summary, this PR keeps the field signature max_grad_value: float | None = None but reinterprets None inside MLXTrainer.train's resolution as a sentinel meaning "cheap MLX elementwise clip 1.0 unless user opted into max_grad_norm > 0 (then global-norm wins)". Explicit 0.0 disables; explicit positive overrides max_grad_norm with a warning.

Two independent Opus reviewers were run in parallel on this PR.

Reviewers Severity Finding
1/2 High PR description and code drifted: the body says "default behavior unchanged" and the title says "honor None as a disable signal," but the diff also flips the field default from 1.0 to None, adds a new max_grad_norm > 0 precedence branch, and silently changes behavior for pre-existing MLXTrainingConfig(max_grad_norm=1.0) callers (was elementwise@1 after warn-and-override; now global-norm@1).
1/2 High The test's _resolve(...) helper is a hand-written Python copy of the trainer's resolution logic, validated against itself. The source-pin test is the only thing connecting them, and it grep-matches three literal substrings — a benign refactor inside _train_inner (e.g. extracting the resolution into a helper, renaming _user_set_mgv) breaks the pin without changing semantics. Couple them or instantiate a real MLXTrainer and inspect the resolved state.
2/2 Med The resolution rule lives only in a trainer-side comment; the dataclass field docstring describes None/0.0/positive but does not say None silently becomes 0.0 when max_grad_norm > 0. Hoist into the field docstring.
1/2 Med No runtime log of which clipping mode is active. Three different behaviors (elementwise@1.0, norm@N, both-disabled) run silently with max_grad_value=None — and they produce visibly different grad_norm numbers in callbacks.
1/2 Med No end-to-end behavior test — only the resolver mirror. A future refactor that flips a branch inside the live _train_inner is not caught.
1/2 Low Negative max_grad_value and NaN both pass _user_set_mgv but then fail the > 0 gate silently. Old code had the same gap.
1/2 Low The both-enabled mutual-exclusion warning is print, not warnings.warn — cannot be filtered or captured by pytest.warns.
1/2 Nit _user_set_mgv is computed but unused outside the resolver. Drop or surface as a state field.
1/2 Nit getattr(args, "max_grad_value", None) defensive default is unnecessary — the dataclass guarantees the attribute exists.

Overall: REQUEST_CHANGES (PR body/title vs. actual semantics mismatch and silent regression of max_grad_norm=1.0-only configs must be addressed first; the rest are improvements.)

See inline comments for details and suggested fixes.

# passes max_grad_norm > 0, in which case global-norm clipping wins.
# 0.0 disables. A positive float opts in explicitly and overrides
# max_grad_norm with a warning.
max_grad_value: float | None = None

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] High: this default flip from 1.0 to None is the core behavior change. Anyone calling MLXTrainingConfig(max_grad_norm=1.0) previously got elementwise@1.0 (warn-and-override path), and now gets global-norm@1.0 instead. Different gradient update every step on identical user configs. Worth surfacing in the PR description as a deliberate semantics change rather than "default behavior unchanged", and worth a release-note note for downstream pipelines.

The field docstring should also spell out the sentinel semantics, since the rule itself lives only in the trainer:

Suggested change
max_grad_value: float | None = None
# Elementwise clip ([-max_grad_value, max_grad_value], per-leaf).
# Resolution:
# * None (default) AND max_grad_norm == 0 -> elementwise 1.0 (cheap MLX default)
# * None (default) AND max_grad_norm > 0 -> global-norm wins (HF/TRL parity)
# * 0.0 -> elementwise disabled
# * positive -> elementwise opts in; if max_grad_norm > 0 it is overridden with a warning
max_grad_value: float | None = None

"ignoring max_grad_norm in favor of max_grad_value."
)
max_grad_norm = 0.0
_raw_mgv = getattr(args, "max_grad_value", None) # TODO: expose MLX grad-clip in Studio UI for power users

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Nit: the dataclass guarantees max_grad_value exists, so the getattr(..., None) fallback is dead defensive code. If kept, the fallback default should match the field default to avoid future divergence.

Suggested change
_raw_mgv = getattr(args, "max_grad_value", None) # TODO: expose MLX grad-clip in Studio UI for power users
_raw_mgv = args.max_grad_value # TODO: expose MLX grad-clip in Studio UI for power users

max_grad_value = 0.0
else:
# Neither requested -> cheap MLX default.
max_grad_value = 1.0

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] High: the test file's _resolve helper is a pure-Python copy of this resolution logic and is validated against itself in test_mlx_max_grad_value_none.py. If you rename _user_set_mgv, extract the four-branch logic into a helper, or otherwise refactor without changing semantics, the source-pin test (test_trainer_source_pins_resolution_rule) breaks while the unit tests stay green — and vice versa, semantic-breaking refactors that preserve the substrings slide through. Couple them by lifting the resolution into a named function and importing it in the tests.

Suggested change
max_grad_value = 1.0
else:
# Neither requested -> cheap MLX default.
max_grad_value = 1.0
# NOTE: keep this four-branch resolution in sync with
# tests/test_mlx_max_grad_value_none.py::_resolve. Prefer extracting
# into a module-level helper so the tests can import it directly.

"ignoring max_grad_norm in favor of max_grad_value."
)
max_grad_norm = 0.0
elif max_grad_norm > 0:

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Med: with max_grad_value=None and max_grad_norm=0.0 we silently take the elementwise@1.0 branch; with None and max_grad_norm>0 we silently take the global-norm branch; with 0.0 we silently disable. Print one line at step-build time so users debugging grad_norm curves know which clipping mode is active.

Suggested change
elif max_grad_norm > 0:
elif max_grad_norm > 0:
# User opted into global-norm clipping; suppress the default elementwise.
max_grad_value = 0.0
print(f"Unsloth: gradient clipping = global-norm at {max_grad_norm}.")
else:
# Neither requested -> cheap MLX default.
max_grad_value = 1.0
print("Unsloth: gradient clipping = elementwise at +/-1.0 (MLX default).")

)
max_grad_norm = 0.0
_raw_mgv = getattr(args, "max_grad_value", None) # TODO: expose MLX grad-clip in Studio UI for power users
_user_set_mgv = _raw_mgv is not None

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Nit: _user_set_mgv is computed but never used outside the resolver block. Inline it to keep the resolver self-contained, or surface it as self._mgv_user_opt_in for downstream introspection / callback reporting.

Suggested change
_user_set_mgv = _raw_mgv is not None
if _raw_mgv is not None:

(removing the standalone _user_set_mgv binding and using the predicate inline in if/elif.)

mgv = 0.0
else:
mgv = 1.0
return mgv, max_grad_norm

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] High: _resolve here is a hand-written copy of MLXTrainer._train_inner's resolution. Every "semantics" test (rows 1-5) validates this helper against itself — not the real trainer. The only thing connecting them is test_trainer_source_pins_resolution_rule, which is a substring grep. A semantic-breaking refactor that preserves the substrings passes; a benign refactor that breaks the substrings fails. Better: lift the resolver into unsloth_zoo.mlx.trainer._resolve_grad_clip(raw_mgv, max_grad_norm) and import it here.

Suggested change
return mgv, max_grad_norm
from unsloth_zoo.mlx.trainer import _resolve_grad_clip as _resolve # type: ignore[attr-defined]

(then the rest of the test file uses the real function rather than the local mirror.)

src = inspect.getsource(T.MLXTrainer.train) + inspect.getsource(T.MLXTrainer._train_inner)
assert "_user_set_mgv = _raw_mgv is not None" in src
assert "elif max_grad_norm > 0:" in src
assert "max_grad_value = 1.0" in src

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Med: this test pins three literal substrings (_user_set_mgv = _raw_mgv is not None, elif max_grad_norm > 0:, max_grad_value = 1.0). Two of them (elif max_grad_norm > 0: and max_grad_value = 1.0) are generic and would survive a branch-flip regression as long as the strings still appear somewhere in the function. Also, it concatenates MLXTrainer.train AND MLXTrainer._train_inner source even though train contains none of the resolution code — extra surface area for false positives.

Suggested change
assert "max_grad_value = 1.0" in src
def test_trainer_source_pins_resolution_rule():
"""Pin the four-branch resolution against silent regression. Prefer
extracting the resolver into a helper and importing it directly; this
source-pin is a last-resort tripwire."""
import inspect
from unsloth_zoo.mlx import trainer as T
src = inspect.getsource(T.MLXTrainer._train_inner)
# 'norm wins' branch must zero max_grad_value
assert "max_grad_value = 0.0" in src
# 'cheap default' branch must materialize 1.0
assert "max_grad_value = 1.0" in src
# The sentinel detection must distinguish 'user passed' from 'default'
assert "is not None" in src

@danielhanchen danielhanchen merged commit 6efe9ac into main May 19, 2026
9 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant