diff --git a/tests/test_mlx_max_grad_value_none.py b/tests/test_mlx_max_grad_value_none.py new file mode 100644 index 000000000..21b62fe5c --- /dev/null +++ b/tests/test_mlx_max_grad_value_none.py @@ -0,0 +1,121 @@ +# Unsloth Zoo - Utilities for Unsloth +# Pin MLXTrainingConfig.max_grad_value resolution: +# * None (default) -> cheap MLX elementwise clip at 1.0, unless +# max_grad_norm > 0 is also passed (then global-norm wins). +# * 0.0 -> explicitly disabled. +# * positive -> explicit elementwise opt-in; overrides max_grad_norm. + +from __future__ import annotations + +import pytest + + +@pytest.fixture(autouse=True, scope="module") +def _install_mlx_shim(): + from mlx_simulation import simulate_mlx_on_torch + simulate_mlx_on_torch() + + +def _resolve(raw_mgv, max_grad_norm): + """Mirror trainer.py's internal resolution. Returns the (max_grad_value, + max_grad_norm) pair the step function will actually use.""" + user_set = raw_mgv is not None + if user_set: + mgv = float(raw_mgv or 0.0) + if max_grad_norm > 0 and mgv > 0: + max_grad_norm = 0.0 + elif max_grad_norm > 0: + mgv = 0.0 + else: + mgv = 1.0 + return mgv, max_grad_norm + + +# -- field defaults --------------------------------------------------------- + + +def test_field_default_is_none_sentinel(): + """Default is None (a sentinel meaning 'use MLX cheap default').""" + from unsloth_zoo.mlx.trainer import MLXTrainingConfig + + cfg = MLXTrainingConfig(output_dir="/tmp/x") + assert cfg.max_grad_value is None + + +def test_field_accepts_none(): + """Field accepts None and round-trips through the dataclass.""" + from unsloth_zoo.mlx.trainer import MLXTrainingConfig + + cfg = MLXTrainingConfig(max_grad_value=None, output_dir="/tmp/x") + assert cfg.max_grad_value is None + + +def test_field_accepts_explicit_positive(): + """Field accepts positive floats for power users opting in.""" + from unsloth_zoo.mlx.trainer import MLXTrainingConfig + + cfg = MLXTrainingConfig(max_grad_value=2.5, output_dir="/tmp/x") + assert cfg.max_grad_value == 2.5 + + +# -- resolution semantics --------------------------------------------------- + + +def test_default_uses_cheap_elementwise(): + """Default (None, max_grad_norm=0.0) -> elementwise clip at 1.0.""" + mgv, mgn = _resolve(raw_mgv=None, max_grad_norm=0.0) + assert mgv == 1.0 + assert mgn == 0.0 + + +def test_user_max_grad_norm_wins_over_default(): + """User passes max_grad_norm=1.0 with default max_grad_value=None -> + global-norm clipping wins, elementwise disabled. HF/TRL parity.""" + mgv, mgn = _resolve(raw_mgv=None, max_grad_norm=1.0) + assert mgv == 0.0 + assert mgn == 1.0 + + +def test_explicit_zero_disables_elementwise(): + """Explicit 0.0 disables elementwise. With no max_grad_norm, + nothing clips.""" + mgv, mgn = _resolve(raw_mgv=0.0, max_grad_norm=0.0) + assert mgv == 0.0 + assert mgn == 0.0 + + +def test_explicit_zero_lets_max_grad_norm_through(): + """Explicit max_grad_value=0.0 + max_grad_norm=1.0 -> only norm clipping.""" + mgv, mgn = _resolve(raw_mgv=0.0, max_grad_norm=1.0) + assert mgv == 0.0 + assert mgn == 1.0 + + +def test_explicit_positive_overrides_max_grad_norm(): + """Explicit max_grad_value=2.0 with max_grad_norm=1.0 -> elementwise + wins (existing rule), max_grad_norm zeroed.""" + mgv, mgn = _resolve(raw_mgv=2.0, max_grad_norm=1.0) + assert mgv == 2.0 + assert mgn == 0.0 + + +def test_explicit_positive_alone(): + """User passes max_grad_value=5.0 with no max_grad_norm -> elementwise at 5.""" + mgv, mgn = _resolve(raw_mgv=5.0, max_grad_norm=0.0) + assert mgv == 5.0 + assert mgn == 0.0 + + +# -- trainer source assertions (defense-in-depth) --------------------------- + + +def test_trainer_source_pins_resolution_rule(): + """Source-level pin: trainer.py contains the four-branch resolution. + Cheap defense against a future refactor silently regressing the rule.""" + import inspect + from unsloth_zoo.mlx import trainer as T + + src = inspect.getsource(T.MLXTrainer.train) + inspect.getsource(T.MLXTrainer._train_inner) + assert "_user_set_mgv = _raw_mgv is not None" in src + assert "elif max_grad_norm > 0:" in src + assert "max_grad_value = 1.0" in src diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 4edd27b7c..253f47035 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -119,12 +119,12 @@ class MLXTrainingConfig: optim: str = "adamw" # "adafactor", "adamw", "adam", "sgd", "muon", "lion" weight_decay: float = 0.001 max_grad_norm: float = 0.0 # disabled by default on MLX to avoid clip-memory overhead - # Elementwise clip ([-max_grad_value, max_grad_value], per-leaf, no - # cross-leaf reduction). Set 0.0 to disable. Default 1.0: |g_i| > 5 - # rarely fires on real transformer grads, so the historical 5.0 was - # effectively a no-op; 1.0 matches the universal clip_grad_norm=1.0 - # baseline while staying on MLX's fast tree_map(mx.clip) path. - max_grad_value: float | None = 1.0 + # Elementwise clip ([-max_grad_value, max_grad_value], per-leaf). + # None (default) keeps the cheap MLX default of 1.0 unless the user + # passes max_grad_norm > 0, in which case global-norm clipping wins. + # 0.0 disables. A positive float opts in explicitly and overrides + # max_grad_norm with a warning. + max_grad_value: float | None = None seed: int = 3407 lora_plus_ratio: float = 0.0 # 0 = disabled, 16.0 = recommended embedding_learning_rate: float = 0.0 # 0 = disabled, 5e-5 = recommended @@ -723,19 +723,34 @@ def _train_inner(self): _needs_grad_scaling = use_lora_plus or use_embedding_lr _warned_skip_optimizer_state_grad_norm = False - # Build step functions following mlx-lm's pattern + # Build step functions following mlx-lm's pattern. + # Resolution rule: + # * max_grad_value=None (default) -> cheap MLX elementwise clip + # at 1.0, unless the user also passed max_grad_norm > 0 -- in + # that case the user opted into global-norm clipping (HF/TRL + # parity) and elementwise is disabled to avoid double-clip. + # * max_grad_value explicit (float or 0.0) -> honor exactly; + # if both modes are positive, elementwise wins (warn). + # max_grad_norm uses MLX's clip_grad_norm helper which materially + # increases peak memory on bf16 VLM runs, hence the elementwise + # default. max_grad_norm = float(args.max_grad_norm or 0.0) - # Elementwise clip (clip_grad_value_): leaf-local, free memory. - # Prefer value clipping when both clipping modes are requested; global - # norm clipping is exact but materially increases memory on MLX. - _raw_mgv = getattr(args, "max_grad_value", 1.0) # TODO: expose MLX grad-clip in Studio UI for power users - max_grad_value = 1.0 if _raw_mgv is None else float(_raw_mgv or 0.0) - if max_grad_norm > 0 and max_grad_value > 0: - print( - "Unsloth: max_grad_norm and max_grad_value are both enabled; " - "ignoring max_grad_norm in favor of max_grad_value." - ) - max_grad_norm = 0.0 + _raw_mgv = getattr(args, "max_grad_value", None) # TODO: expose MLX grad-clip in Studio UI for power users + _user_set_mgv = _raw_mgv is not None + if _user_set_mgv: + max_grad_value = float(_raw_mgv or 0.0) + if max_grad_norm > 0 and max_grad_value > 0: + print( + "Unsloth: max_grad_norm and max_grad_value are both enabled; " + "ignoring max_grad_norm in favor of max_grad_value." + ) + max_grad_norm = 0.0 + elif max_grad_norm > 0: + # User opted into global-norm clipping; suppress the default elementwise. + max_grad_value = 0.0 + else: + # Neither requested -> cheap MLX default. + max_grad_value = 1.0 _clip_grad_value = max_grad_value > 0 state = [model.state, optimizer.state, mx.random.state] # The direct grad_accum==1 fast path delegates clipping to