Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions tests/test_mlx_max_grad_value_none.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Unsloth Zoo - Utilities for Unsloth
# Pin MLXTrainingConfig.max_grad_value resolution:
# * None (default) -> cheap MLX elementwise clip at 1.0, unless
# max_grad_norm > 0 is also passed (then global-norm wins).
# * 0.0 -> explicitly disabled.
# * positive -> explicit elementwise opt-in; overrides max_grad_norm.

from __future__ import annotations

import pytest


@pytest.fixture(autouse=True, scope="module")
def _install_mlx_shim():
from mlx_simulation import simulate_mlx_on_torch
simulate_mlx_on_torch()


def _resolve(raw_mgv, max_grad_norm):
"""Mirror trainer.py's internal resolution. Returns the (max_grad_value,
max_grad_norm) pair the step function will actually use."""
user_set = raw_mgv is not None
if user_set:
mgv = float(raw_mgv or 0.0)
if max_grad_norm > 0 and mgv > 0:
max_grad_norm = 0.0
elif max_grad_norm > 0:
mgv = 0.0
else:
mgv = 1.0
return mgv, max_grad_norm

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] High: _resolve here is a hand-written copy of MLXTrainer._train_inner's resolution. Every "semantics" test (rows 1-5) validates this helper against itself — not the real trainer. The only thing connecting them is test_trainer_source_pins_resolution_rule, which is a substring grep. A semantic-breaking refactor that preserves the substrings passes; a benign refactor that breaks the substrings fails. Better: lift the resolver into unsloth_zoo.mlx.trainer._resolve_grad_clip(raw_mgv, max_grad_norm) and import it here.

Suggested change
return mgv, max_grad_norm
from unsloth_zoo.mlx.trainer import _resolve_grad_clip as _resolve # type: ignore[attr-defined]

(then the rest of the test file uses the real function rather than the local mirror.)



# -- field defaults ---------------------------------------------------------


def test_field_default_is_none_sentinel():
"""Default is None (a sentinel meaning 'use MLX cheap default')."""
from unsloth_zoo.mlx.trainer import MLXTrainingConfig

cfg = MLXTrainingConfig(output_dir="/tmp/x")
assert cfg.max_grad_value is None


def test_field_accepts_none():
"""Field accepts None and round-trips through the dataclass."""
from unsloth_zoo.mlx.trainer import MLXTrainingConfig

cfg = MLXTrainingConfig(max_grad_value=None, output_dir="/tmp/x")
assert cfg.max_grad_value is None


def test_field_accepts_explicit_positive():
"""Field accepts positive floats for power users opting in."""
from unsloth_zoo.mlx.trainer import MLXTrainingConfig

cfg = MLXTrainingConfig(max_grad_value=2.5, output_dir="/tmp/x")
assert cfg.max_grad_value == 2.5


# -- resolution semantics ---------------------------------------------------


def test_default_uses_cheap_elementwise():
"""Default (None, max_grad_norm=0.0) -> elementwise clip at 1.0."""
mgv, mgn = _resolve(raw_mgv=None, max_grad_norm=0.0)
assert mgv == 1.0
assert mgn == 0.0


def test_user_max_grad_norm_wins_over_default():
"""User passes max_grad_norm=1.0 with default max_grad_value=None ->
global-norm clipping wins, elementwise disabled. HF/TRL parity."""
mgv, mgn = _resolve(raw_mgv=None, max_grad_norm=1.0)
assert mgv == 0.0
assert mgn == 1.0


def test_explicit_zero_disables_elementwise():
"""Explicit 0.0 disables elementwise. With no max_grad_norm,
nothing clips."""
mgv, mgn = _resolve(raw_mgv=0.0, max_grad_norm=0.0)
assert mgv == 0.0
assert mgn == 0.0


def test_explicit_zero_lets_max_grad_norm_through():
"""Explicit max_grad_value=0.0 + max_grad_norm=1.0 -> only norm clipping."""
mgv, mgn = _resolve(raw_mgv=0.0, max_grad_norm=1.0)
assert mgv == 0.0
assert mgn == 1.0


def test_explicit_positive_overrides_max_grad_norm():
"""Explicit max_grad_value=2.0 with max_grad_norm=1.0 -> elementwise
wins (existing rule), max_grad_norm zeroed."""
mgv, mgn = _resolve(raw_mgv=2.0, max_grad_norm=1.0)
assert mgv == 2.0
assert mgn == 0.0


def test_explicit_positive_alone():
"""User passes max_grad_value=5.0 with no max_grad_norm -> elementwise at 5."""
mgv, mgn = _resolve(raw_mgv=5.0, max_grad_norm=0.0)
assert mgv == 5.0
assert mgn == 0.0


# -- trainer source assertions (defense-in-depth) ---------------------------


def test_trainer_source_pins_resolution_rule():
"""Source-level pin: trainer.py contains the four-branch resolution.
Cheap defense against a future refactor silently regressing the rule."""
import inspect
from unsloth_zoo.mlx import trainer as T

src = inspect.getsource(T.MLXTrainer.train) + inspect.getsource(T.MLXTrainer._train_inner)
assert "_user_set_mgv = _raw_mgv is not None" in src
assert "elif max_grad_norm > 0:" in src
assert "max_grad_value = 1.0" in src

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Med: this test pins three literal substrings (_user_set_mgv = _raw_mgv is not None, elif max_grad_norm > 0:, max_grad_value = 1.0). Two of them (elif max_grad_norm > 0: and max_grad_value = 1.0) are generic and would survive a branch-flip regression as long as the strings still appear somewhere in the function. Also, it concatenates MLXTrainer.train AND MLXTrainer._train_inner source even though train contains none of the resolution code — extra surface area for false positives.

Suggested change
assert "max_grad_value = 1.0" in src
def test_trainer_source_pins_resolution_rule():
"""Pin the four-branch resolution against silent regression. Prefer
extracting the resolver into a helper and importing it directly; this
source-pin is a last-resort tripwire."""
import inspect
from unsloth_zoo.mlx import trainer as T
src = inspect.getsource(T.MLXTrainer._train_inner)
# 'norm wins' branch must zero max_grad_value
assert "max_grad_value = 0.0" in src
# 'cheap default' branch must materialize 1.0
assert "max_grad_value = 1.0" in src
# The sentinel detection must distinguish 'user passed' from 'default'
assert "is not None" in src

51 changes: 33 additions & 18 deletions unsloth_zoo/mlx/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ class MLXTrainingConfig:
optim: str = "adamw" # "adafactor", "adamw", "adam", "sgd", "muon", "lion"
weight_decay: float = 0.001
max_grad_norm: float = 0.0 # disabled by default on MLX to avoid clip-memory overhead
# Elementwise clip ([-max_grad_value, max_grad_value], per-leaf, no
# cross-leaf reduction). Set 0.0 to disable. Default 1.0: |g_i| > 5
# rarely fires on real transformer grads, so the historical 5.0 was
# effectively a no-op; 1.0 matches the universal clip_grad_norm=1.0
# baseline while staying on MLX's fast tree_map(mx.clip) path.
max_grad_value: float | None = 1.0
# Elementwise clip ([-max_grad_value, max_grad_value], per-leaf).
# None (default) keeps the cheap MLX default of 1.0 unless the user
# passes max_grad_norm > 0, in which case global-norm clipping wins.
# 0.0 disables. A positive float opts in explicitly and overrides
# max_grad_norm with a warning.
max_grad_value: float | None = None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve the default grad clipping behavior

Changing the dataclass default to None disables elementwise clipping for every caller that uses MLXTrainingConfig() or omits max_grad_value, whereas the previous behavior clipped at ±1.0 and the commit description says default behavior should be unchanged. The intended fix can still let an explicit max_grad_value=None disable clipping while keeping the field default (and the getattr fallback below) at 1.0; otherwise existing MLX training runs silently lose their default gradient clipping.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] High: this default flip from 1.0 to None is the core behavior change. Anyone calling MLXTrainingConfig(max_grad_norm=1.0) previously got elementwise@1.0 (warn-and-override path), and now gets global-norm@1.0 instead. Different gradient update every step on identical user configs. Worth surfacing in the PR description as a deliberate semantics change rather than "default behavior unchanged", and worth a release-note note for downstream pipelines.

The field docstring should also spell out the sentinel semantics, since the rule itself lives only in the trainer:

Suggested change
max_grad_value: float | None = None
# Elementwise clip ([-max_grad_value, max_grad_value], per-leaf).
# Resolution:
# * None (default) AND max_grad_norm == 0 -> elementwise 1.0 (cheap MLX default)
# * None (default) AND max_grad_norm > 0 -> global-norm wins (HF/TRL parity)
# * 0.0 -> elementwise disabled
# * positive -> elementwise opts in; if max_grad_norm > 0 it is overridden with a warning
max_grad_value: float | None = None

seed: int = 3407
lora_plus_ratio: float = 0.0 # 0 = disabled, 16.0 = recommended
embedding_learning_rate: float = 0.0 # 0 = disabled, 5e-5 = recommended
Expand Down Expand Up @@ -723,19 +723,34 @@ def _train_inner(self):
_needs_grad_scaling = use_lora_plus or use_embedding_lr
_warned_skip_optimizer_state_grad_norm = False

# Build step functions following mlx-lm's pattern
# Build step functions following mlx-lm's pattern.
# Resolution rule:
# * max_grad_value=None (default) -> cheap MLX elementwise clip
# at 1.0, unless the user also passed max_grad_norm > 0 -- in
# that case the user opted into global-norm clipping (HF/TRL
# parity) and elementwise is disabled to avoid double-clip.
# * max_grad_value explicit (float or 0.0) -> honor exactly;
# if both modes are positive, elementwise wins (warn).
# max_grad_norm uses MLX's clip_grad_norm helper which materially
# increases peak memory on bf16 VLM runs, hence the elementwise
# default.
max_grad_norm = float(args.max_grad_norm or 0.0)
# Elementwise clip (clip_grad_value_): leaf-local, free memory.
# Prefer value clipping when both clipping modes are requested; global
# norm clipping is exact but materially increases memory on MLX.
_raw_mgv = getattr(args, "max_grad_value", 1.0) # TODO: expose MLX grad-clip in Studio UI for power users
max_grad_value = 1.0 if _raw_mgv is None else float(_raw_mgv or 0.0)
if max_grad_norm > 0 and max_grad_value > 0:
print(
"Unsloth: max_grad_norm and max_grad_value are both enabled; "
"ignoring max_grad_norm in favor of max_grad_value."
)
max_grad_norm = 0.0
_raw_mgv = getattr(args, "max_grad_value", None) # TODO: expose MLX grad-clip in Studio UI for power users

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Nit: the dataclass guarantees max_grad_value exists, so the getattr(..., None) fallback is dead defensive code. If kept, the fallback default should match the field default to avoid future divergence.

Suggested change
_raw_mgv = getattr(args, "max_grad_value", None) # TODO: expose MLX grad-clip in Studio UI for power users
_raw_mgv = args.max_grad_value # TODO: expose MLX grad-clip in Studio UI for power users

_user_set_mgv = _raw_mgv is not None

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Nit: _user_set_mgv is computed but never used outside the resolver block. Inline it to keep the resolver self-contained, or surface it as self._mgv_user_opt_in for downstream introspection / callback reporting.

Suggested change
_user_set_mgv = _raw_mgv is not None
if _raw_mgv is not None:

(removing the standalone _user_set_mgv binding and using the predicate inline in if/elif.)

if _user_set_mgv:
max_grad_value = float(_raw_mgv or 0.0)
if max_grad_norm > 0 and max_grad_value > 0:
print(
"Unsloth: max_grad_norm and max_grad_value are both enabled; "
"ignoring max_grad_norm in favor of max_grad_value."
)
max_grad_norm = 0.0
elif max_grad_norm > 0:

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Med: with max_grad_value=None and max_grad_norm=0.0 we silently take the elementwise@1.0 branch; with None and max_grad_norm>0 we silently take the global-norm branch; with 0.0 we silently disable. Print one line at step-build time so users debugging grad_norm curves know which clipping mode is active.

Suggested change
elif max_grad_norm > 0:
elif max_grad_norm > 0:
# User opted into global-norm clipping; suppress the default elementwise.
max_grad_value = 0.0
print(f"Unsloth: gradient clipping = global-norm at {max_grad_norm}.")
else:
# Neither requested -> cheap MLX default.
max_grad_value = 1.0
print("Unsloth: gradient clipping = elementwise at +/-1.0 (MLX default).")

# User opted into global-norm clipping; suppress the default elementwise.
max_grad_value = 0.0
else:
# Neither requested -> cheap MLX default.
max_grad_value = 1.0

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] High: the test file's _resolve helper is a pure-Python copy of this resolution logic and is validated against itself in test_mlx_max_grad_value_none.py. If you rename _user_set_mgv, extract the four-branch logic into a helper, or otherwise refactor without changing semantics, the source-pin test (test_trainer_source_pins_resolution_rule) breaks while the unit tests stay green — and vice versa, semantic-breaking refactors that preserve the substrings slide through. Couple them by lifting the resolution into a named function and importing it in the tests.

Suggested change
max_grad_value = 1.0
else:
# Neither requested -> cheap MLX default.
max_grad_value = 1.0
# NOTE: keep this four-branch resolution in sync with
# tests/test_mlx_max_grad_value_none.py::_resolve. Prefer extracting
# into a module-level helper so the tests can import it directly.

_clip_grad_value = max_grad_value > 0
state = [model.state, optimizer.state, mx.random.state]
# The direct grad_accum==1 fast path delegates clipping to
Expand Down
Loading