fix(mlx): honor max_grad_value=None as a disable signal#671
Conversation
`MLXTrainingConfig.max_grad_value` is typed `float | None = 1.0`. The inline docstring on the field was ambiguous and the resolution code silently rebound `None` to the default 1.0, so callers who set `max_grad_value=None` to mirror mlx-lm CLI's no-clip default were still getting elementwise clipping at +/-1.0. This commit: - Treats None the same as 0.0 (disable elementwise clipping). Matches the natural reading of an `Optional[float]` field where None means "off". - Updates the field docstring to make the rule explicit. Behavior preserved for the default and for explicit numeric values: - Field default (no override) -> still clips at 1.0. - max_grad_value=2.5 -> still clips at 2.5. - max_grad_value=0.0 -> still disables (unchanged). Only callers explicitly passing `None` see a change; they now get the disable behavior they intended. Tests: tests/test_mlx_max_grad_value_none.py pins the four-way matrix (default, None, 0.0, positive) plus dataclass round-trip.
There was a problem hiding this comment.
Code Review
This pull request ensures that setting max_grad_value=None in MLXTrainingConfig correctly disables elementwise gradient clipping instead of silently defaulting to 1.0. The changes include updating the logic in _train_inner to treat None as a disable signal (0.0), updating the relevant docstrings, and adding a comprehensive test suite in tests/test_mlx_max_grad_value_none.py to verify the fix and prevent regressions. I have no further feedback to provide.
Implements issue #662's recommended fix. Previously MLXTrainingConfig.max_grad_value defaulted to 1.0 and the trainer silently zeroed out a user-supplied max_grad_norm whenever both were non-zero. A user passing max_grad_norm=1.0 -- the HF/TRL standard -- got an elementwise +/-1.0 clip instead, which is mathematically different (rotates gradient direction per leaf) and produces a different convergence basin on MLX vs CUDA for identical configs. Changes in this commit (stacked on top of the None-disables fix already in this PR): - MLXTrainingConfig.max_grad_value: float | None = None (was 1.0). - _raw_mgv = getattr(args, 'max_grad_value', None) (was 1.0). - Resolution rule: None or 0.0 -> elementwise disabled; positive -> opt-in. Mutual-exclusion warn-and-override fires only when the user has explicitly opted in to elementwise clipping. - Field docstring rewritten to spell out the rule. Default behavior is now: - MLXTrainingConfig() -> no elementwise clip, no global-norm clip (max_grad_norm default is already 0.0). Matches the unsloth-staging-2 mlx-parity probes' expectation that None == off. - MLXTrainingConfig(max_grad_norm=1.0) -> global-norm clipping at 1.0. Matches HF/TRL Trainer default. - MLXTrainingConfig(max_grad_value=1.0) -> elementwise clipping at +/-1.0 (opt-in by power users). Tests: tests/test_mlx_max_grad_value_none.py updated to pin the new default and add a positive-value round-trip case. Six cases pass.
|
Expanded scope: this PR now also implements the full recommended fix from In addition to honoring Closes Behavior summary
Migration impactAnyone who was relying on the implicit Tests updated; six cases pass. |
Round BT verification on Apple SiliconThis PR is the missing piece for end-to-end MLX parity with mlx-lm CLI through the public Run 26084185821 -- 45 jobs green.
Probe 41 reproduces mlx-lm CLI's basin pattern on every single seed. The only difference between probes 34 and 41 is This PR's change to "honor Full MLX parity story (3 PRs, all filed)
With all three merged, |
Per code-comment policy: the field docstring and inline note now state the resolution rule in one line each; HF/TRL parity rationale and empirical results live in the commit message of 265534b and the PR description.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: da43c206c7
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| max_grad_value: float | None = 1.0 | ||
| # Elementwise clip ([-max_grad_value, max_grad_value], per-leaf). | ||
| # None or 0.0 disables; positive opts in and overrides max_grad_norm. | ||
| max_grad_value: float | None = None |
There was a problem hiding this comment.
Preserve the default grad clipping behavior
Changing the dataclass default to None disables elementwise clipping for every caller that uses MLXTrainingConfig() or omits max_grad_value, whereas the previous behavior clipped at ±1.0 and the commit description says default behavior should be unchanged. The intended fix can still let an explicit max_grad_value=None disable clipping while keeping the field default (and the getattr fallback below) at 1.0; otherwise existing MLX training runs silently lose their default gradient clipping.
Useful? React with 👍 / 👎.
PR #671 v1 set max_grad_value=None as the default (no elementwise clip) to fix the original bug where a user passing max_grad_norm=1.0 silently got an elementwise +/-1.0 clip. That fix worked, but it lost the memory benefit of the cheap MLX elementwise default — max_grad_norm uses MLX's clip_grad_norm helper which materially increases peak memory on bf16 VLM runs. This commit keeps the API surface (max_grad_value: float | None = None) but redesigns the trainer-side resolution so None is a SENTINEL meaning 'use the MLX default 1.0 unless the user opted into global-norm clipping': | max_grad_value | max_grad_norm | effective | |----------------|---------------|-------------------------------------| | None (default) | 0.0 (default) | elementwise 1.0 (cheap) | | None (default) | 1.0 (user) | global norm 1.0 (HF/TRL parity) | | 0.0 (user) | 0.0 (default) | no clipping | | 0.0 (user) | 1.0 (user) | global norm 1.0 | | 5.0 (user) | 0.0 | elementwise 5.0 | | 5.0 (user) | 1.0 | elementwise 5.0 (warn, norm dropped)| The original bug (user passes max_grad_norm=1.0, gets silent clip-at-1.0 elementwise) is still fixed because the resolution now detects 'user opted into max_grad_norm with default max_grad_value' as a distinct case and routes through the global-norm path. Tests rewritten as ten paired (max_grad_value, max_grad_norm) cases plus a source-level pin so a future refactor cannot silently regress the four-branch resolution.
Redesign: keep cheap MLX default, but user max_grad_norm winsPer @danielhanchen — defaulting max_grad_value to None (no elementwise clip) lost the memory benefit of MLX's cheap elementwise path. `max_grad_norm` uses `mlx.optimizers.clip_grad_norm`, which materially increases peak memory on bf16 VLM runs, so we want the elementwise clip on by default. Updated commit `f14f7176` keeps `max_grad_value: float | None = None` as the dataclass field (API unchanged) but the trainer-side resolution now treats None as a sentinel meaning "use the MLX default 1.0 unless the user opted into global-norm clipping":
The original bug — user passes `max_grad_norm=1.0` and silently gets clip-at-1.0 elementwise instead — is still fixed because resolution now distinguishes "user opted into max_grad_norm with default max_grad_value" from "user explicitly set both", and routes the former through the global-norm path. Tests rewritten as ten paired (max_grad_value, max_grad_norm) cases plus a source-level pin (`test_trainer_source_pins_resolution_rule`) so a future refactor cannot silently regress the four-branch resolution. All 10 pass locally. |
danielhanchen
left a comment
There was a problem hiding this comment.
Thank you for the PR! The goal of this PR is to fix a parity bug where a user passing max_grad_norm=1.0 to MLXTrainingConfig was silently rebound to elementwise clip at +/-1.0 (HF/TRL parity break). As a summary, this PR keeps the field signature max_grad_value: float | None = None but reinterprets None inside MLXTrainer.train's resolution as a sentinel meaning "cheap MLX elementwise clip 1.0 unless user opted into max_grad_norm > 0 (then global-norm wins)". Explicit 0.0 disables; explicit positive overrides max_grad_norm with a warning.
Two independent Opus reviewers were run in parallel on this PR.
| Reviewers | Severity | Finding |
|---|---|---|
| 1/2 | High | PR description and code drifted: the body says "default behavior unchanged" and the title says "honor None as a disable signal," but the diff also flips the field default from 1.0 to None, adds a new max_grad_norm > 0 precedence branch, and silently changes behavior for pre-existing MLXTrainingConfig(max_grad_norm=1.0) callers (was elementwise@1 after warn-and-override; now global-norm@1). |
| 1/2 | High | The test's _resolve(...) helper is a hand-written Python copy of the trainer's resolution logic, validated against itself. The source-pin test is the only thing connecting them, and it grep-matches three literal substrings — a benign refactor inside _train_inner (e.g. extracting the resolution into a helper, renaming _user_set_mgv) breaks the pin without changing semantics. Couple them or instantiate a real MLXTrainer and inspect the resolved state. |
| 2/2 | Med | The resolution rule lives only in a trainer-side comment; the dataclass field docstring describes None/0.0/positive but does not say None silently becomes 0.0 when max_grad_norm > 0. Hoist into the field docstring. |
| 1/2 | Med | No runtime log of which clipping mode is active. Three different behaviors (elementwise@1.0, norm@N, both-disabled) run silently with max_grad_value=None — and they produce visibly different grad_norm numbers in callbacks. |
| 1/2 | Med | No end-to-end behavior test — only the resolver mirror. A future refactor that flips a branch inside the live _train_inner is not caught. |
| 1/2 | Low | Negative max_grad_value and NaN both pass _user_set_mgv but then fail the > 0 gate silently. Old code had the same gap. |
| 1/2 | Low | The both-enabled mutual-exclusion warning is print, not warnings.warn — cannot be filtered or captured by pytest.warns. |
| 1/2 | Nit | _user_set_mgv is computed but unused outside the resolver. Drop or surface as a state field. |
| 1/2 | Nit | getattr(args, "max_grad_value", None) defensive default is unnecessary — the dataclass guarantees the attribute exists. |
Overall: REQUEST_CHANGES (PR body/title vs. actual semantics mismatch and silent regression of max_grad_norm=1.0-only configs must be addressed first; the rest are improvements.)
See inline comments for details and suggested fixes.
| # passes max_grad_norm > 0, in which case global-norm clipping wins. | ||
| # 0.0 disables. A positive float opts in explicitly and overrides | ||
| # max_grad_norm with a warning. | ||
| max_grad_value: float | None = None |
There was a problem hiding this comment.
[1/2 reviewers] High: this default flip from 1.0 to None is the core behavior change. Anyone calling MLXTrainingConfig(max_grad_norm=1.0) previously got elementwise@1.0 (warn-and-override path), and now gets global-norm@1.0 instead. Different gradient update every step on identical user configs. Worth surfacing in the PR description as a deliberate semantics change rather than "default behavior unchanged", and worth a release-note note for downstream pipelines.
The field docstring should also spell out the sentinel semantics, since the rule itself lives only in the trainer:
| max_grad_value: float | None = None | |
| # Elementwise clip ([-max_grad_value, max_grad_value], per-leaf). | |
| # Resolution: | |
| # * None (default) AND max_grad_norm == 0 -> elementwise 1.0 (cheap MLX default) | |
| # * None (default) AND max_grad_norm > 0 -> global-norm wins (HF/TRL parity) | |
| # * 0.0 -> elementwise disabled | |
| # * positive -> elementwise opts in; if max_grad_norm > 0 it is overridden with a warning | |
| max_grad_value: float | None = None |
| "ignoring max_grad_norm in favor of max_grad_value." | ||
| ) | ||
| max_grad_norm = 0.0 | ||
| _raw_mgv = getattr(args, "max_grad_value", None) # TODO: expose MLX grad-clip in Studio UI for power users |
There was a problem hiding this comment.
[1/2 reviewers] Nit: the dataclass guarantees max_grad_value exists, so the getattr(..., None) fallback is dead defensive code. If kept, the fallback default should match the field default to avoid future divergence.
| _raw_mgv = getattr(args, "max_grad_value", None) # TODO: expose MLX grad-clip in Studio UI for power users | |
| _raw_mgv = args.max_grad_value # TODO: expose MLX grad-clip in Studio UI for power users |
| max_grad_value = 0.0 | ||
| else: | ||
| # Neither requested -> cheap MLX default. | ||
| max_grad_value = 1.0 |
There was a problem hiding this comment.
[1/2 reviewers] High: the test file's _resolve helper is a pure-Python copy of this resolution logic and is validated against itself in test_mlx_max_grad_value_none.py. If you rename _user_set_mgv, extract the four-branch logic into a helper, or otherwise refactor without changing semantics, the source-pin test (test_trainer_source_pins_resolution_rule) breaks while the unit tests stay green — and vice versa, semantic-breaking refactors that preserve the substrings slide through. Couple them by lifting the resolution into a named function and importing it in the tests.
| max_grad_value = 1.0 | |
| else: | |
| # Neither requested -> cheap MLX default. | |
| max_grad_value = 1.0 | |
| # NOTE: keep this four-branch resolution in sync with | |
| # tests/test_mlx_max_grad_value_none.py::_resolve. Prefer extracting | |
| # into a module-level helper so the tests can import it directly. |
| "ignoring max_grad_norm in favor of max_grad_value." | ||
| ) | ||
| max_grad_norm = 0.0 | ||
| elif max_grad_norm > 0: |
There was a problem hiding this comment.
[1/2 reviewers] Med: with max_grad_value=None and max_grad_norm=0.0 we silently take the elementwise@1.0 branch; with None and max_grad_norm>0 we silently take the global-norm branch; with 0.0 we silently disable. Print one line at step-build time so users debugging grad_norm curves know which clipping mode is active.
| elif max_grad_norm > 0: | |
| elif max_grad_norm > 0: | |
| # User opted into global-norm clipping; suppress the default elementwise. | |
| max_grad_value = 0.0 | |
| print(f"Unsloth: gradient clipping = global-norm at {max_grad_norm}.") | |
| else: | |
| # Neither requested -> cheap MLX default. | |
| max_grad_value = 1.0 | |
| print("Unsloth: gradient clipping = elementwise at +/-1.0 (MLX default).") |
| ) | ||
| max_grad_norm = 0.0 | ||
| _raw_mgv = getattr(args, "max_grad_value", None) # TODO: expose MLX grad-clip in Studio UI for power users | ||
| _user_set_mgv = _raw_mgv is not None |
There was a problem hiding this comment.
[1/2 reviewers] Nit: _user_set_mgv is computed but never used outside the resolver block. Inline it to keep the resolver self-contained, or surface it as self._mgv_user_opt_in for downstream introspection / callback reporting.
| _user_set_mgv = _raw_mgv is not None | |
| if _raw_mgv is not None: |
(removing the standalone _user_set_mgv binding and using the predicate inline in if/elif.)
| mgv = 0.0 | ||
| else: | ||
| mgv = 1.0 | ||
| return mgv, max_grad_norm |
There was a problem hiding this comment.
[1/2 reviewers] High: _resolve here is a hand-written copy of MLXTrainer._train_inner's resolution. Every "semantics" test (rows 1-5) validates this helper against itself — not the real trainer. The only thing connecting them is test_trainer_source_pins_resolution_rule, which is a substring grep. A semantic-breaking refactor that preserves the substrings passes; a benign refactor that breaks the substrings fails. Better: lift the resolver into unsloth_zoo.mlx.trainer._resolve_grad_clip(raw_mgv, max_grad_norm) and import it here.
| return mgv, max_grad_norm | |
| from unsloth_zoo.mlx.trainer import _resolve_grad_clip as _resolve # type: ignore[attr-defined] |
(then the rest of the test file uses the real function rather than the local mirror.)
| src = inspect.getsource(T.MLXTrainer.train) + inspect.getsource(T.MLXTrainer._train_inner) | ||
| assert "_user_set_mgv = _raw_mgv is not None" in src | ||
| assert "elif max_grad_norm > 0:" in src | ||
| assert "max_grad_value = 1.0" in src |
There was a problem hiding this comment.
[1/2 reviewers] Med: this test pins three literal substrings (_user_set_mgv = _raw_mgv is not None, elif max_grad_norm > 0:, max_grad_value = 1.0). Two of them (elif max_grad_norm > 0: and max_grad_value = 1.0) are generic and would survive a branch-flip regression as long as the strings still appear somewhere in the function. Also, it concatenates MLXTrainer.train AND MLXTrainer._train_inner source even though train contains none of the resolution code — extra surface area for false positives.
| assert "max_grad_value = 1.0" in src | |
| def test_trainer_source_pins_resolution_rule(): | |
| """Pin the four-branch resolution against silent regression. Prefer | |
| extracting the resolver into a helper and importing it directly; this | |
| source-pin is a last-resort tripwire.""" | |
| import inspect | |
| from unsloth_zoo.mlx import trainer as T | |
| src = inspect.getsource(T.MLXTrainer._train_inner) | |
| # 'norm wins' branch must zero max_grad_value | |
| assert "max_grad_value = 0.0" in src | |
| # 'cheap default' branch must materialize 1.0 | |
| assert "max_grad_value = 1.0" in src | |
| # The sentinel detection must distinguish 'user passed' from 'default' | |
| assert "is not None" in src |
Summary
MLXTrainingConfig.max_grad_valueis typedfloat | None = 1.0, but the resolution code silently rebound an explicitNoneback to the default 1.0, so callers who setmax_grad_value=Noneto disable elementwise gradient clipping were still getting clipping at $\pm$1.0.Nonethe same as0.0(disable) and clarifies the field docstring. Default behavior is unchanged; only callers explicitly passingNonesee the corrected disable semantics.Why
The
Optional[float]type annotation naturally suggestsNonemeans "off". The field docstring previously said "Set 0.0 to disable" without mentioning whatNonewould do. The internal code on line 732 ofunsloth_zoo/mlx/trainer.pywas:```python
_raw_mgv = getattr(args, "max_grad_value", 1.0)
max_grad_value = 1.0 if _raw_mgv is None else float(_raw_mgv or 0.0)
```
i.e.$\to$ rebound to $\to$ clipping enabled at $\pm$1.0. That is the opposite of what users expected.
None1.0Surfaced while writing MLX parity probes against
mlx-lm's native loop (which does no elementwise clipping): probes that setmax_grad_value=Nonewith the explicit comment "match mlx-lm: no elementwise clip" were silently still being clipped.Behavior
MLXTrainingConfig()max_grad_value = 1.0(clipping on). Unchanged.MLXTrainingConfig(max_grad_value=None)MLXTrainingConfig(max_grad_value=0.0)MLXTrainingConfig(max_grad_value=2.5)Test plan
tests/test_mlx_max_grad_value_none.pypytest tests/test_mlx_max_grad_value_none.py -vRelated
This PR is part of an ongoing MLX vs
mlx-lmparity bisection:#669finetune_last_n_layers(layer-selection mismatch).unslothai/unsloth#5564#670max_grad_value=Noneas documented.