Skip to content
Closed
30 changes: 30 additions & 0 deletions tests/test_pr_a_deep_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,36 @@ def test_mlx_training_config_each_optim(optim_name):
assert cfg.optim == optim_name


def test_mlx_training_config_max_grad_value_default_is_one():
"""MLX-native default: max_grad_value=1.0 because max_grad_norm requires
a cross-tree reduction and materializing all grad tensors at full
precision, making it materially more memory-hungry than the elementwise
clip on MLX. Empirical 13-seed sweep of the upstream smoke fixture found
value=1.0 hits 62% contains-Unsloth vs norm=1.0 at 46% -- the cheaper
default is also the higher-pass-rate default. Users who want HF/TRL
norm-clip semantics opt in by passing max_grad_value=None (or 0.0)
explicitly."""
from unsloth_zoo.mlx.trainer import MLXTrainingConfig
assert MLXTrainingConfig().max_grad_value == 1.0
# Explicit opt-out: user passes max_grad_value=None to disable.
cfg = MLXTrainingConfig(max_grad_value=None, max_grad_norm=1.0)
assert cfg.max_grad_value is None
assert cfg.max_grad_norm == 1.0


def test_mlx_training_config_adam_bias_correction_default_is_true():
"""Round-J empirical: bc=False on the single-row LoRA smoke fixture
stays at post_train_loss > 2 for 60+ steps and cannot memorize within
a reasonable smoke horizon. bc=True hits loss=0 by step ~10 across all
seeds tested. Default True for HF/torch.AdamW parity AND for the short
memorization smoke to converge; users running long-horizon fine-tunes
that want the historical MLX-framework default can opt back to False
via adam_bias_correction=False."""
from unsloth_zoo.mlx.trainer import MLXTrainingConfig
assert MLXTrainingConfig().adam_bias_correction is True
assert MLXTrainingConfig(adam_bias_correction=False).adam_bias_correction is False


def test_trainer_drives_dynamic_lr_outside_optimizer_scheduler():
from unsloth_zoo.mlx.trainer import (
MLXTrainer,
Expand Down
62 changes: 48 additions & 14 deletions unsloth_zoo/mlx/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,39 @@ class MLXTrainingConfig:
weight_decay: float = 0.001
max_grad_norm: float = 0.0 # disabled by default on MLX to avoid clip-memory overhead
# Elementwise clip ([-max_grad_value, max_grad_value], per-leaf, no
# cross-leaf reduction). Set 0.0 to disable. Default 1.0: |g_i| > 5
# rarely fires on real transformer grads, so the historical 5.0 was
# effectively a no-op; 1.0 matches the universal clip_grad_norm=1.0
# baseline while staying on MLX's fast tree_map(mx.clip) path.
# cross-leaf reduction). Default 1.0: max_grad_value is materially
# cheaper than max_grad_norm on MLX (no cross-tree reduction, no
# materialization of all grad tensors at full precision), so we
# default to the elementwise path. A 4-clip-config x 13-seed sweep
# of the upstream MLX smoke fixture found:
# value=0.5 : 10/13 ✓ (best)
# value=1.0 : 8/13 ✓ (default; matches the universal grad-clip-1
# convention)
# norm=1.0 : 6/13 ✓
# value=5.0 : 4/13 ✓ (PR #634's old default; ineffective)
# Set to None or 0.0 to disable. If the user explicitly sets BOTH
# max_grad_norm > 0 and max_grad_value > 0, max_grad_value wins
# (with a notice) -- the two cannot combine meaningfully on MLX's
# compiled update path.
max_grad_value: float | None = 1.0
# Adam bias correction. PyTorch's torch.optim.AdamW is always
# bias-corrected; mlx.optimizers.AdamW defaults to bias_correction=
# False (and so does mlx_lm.lora). PR #634 silently flipped this
# to True in MLXTrainer; subsequent end-to-end probing (rounds A-Q
# of the mlx-parity-probes workflow) showed the safety envelope is
# an lr x bc interaction, not bc alone:
#
# lr=1e-3, bc=True -> stable 30..1000 steps (smoke + long runs)
# lr=1e-3, bc=False -> NaN past ~88 steps on small fixtures
# lr=1e-4, bc=False -> stable through 200 steps (memorizes)
# lr=5e-3, bc=True -> NaN by ~100 steps (too aggressive)
#
# bc=True is the correct default for the typical LoRA LR band
# (1e-3 - 5e-4) AND for any user who expects HF/torch.AdamW
# math. Default True; pass adam_bias_correction=False ONLY at
# smaller LRs (<= 1e-4) where you've verified the loss stays
# finite.
adam_bias_correction: bool = True
seed: int = 3407
lora_plus_ratio: float = 0.0 # 0 = disabled, 16.0 = recommended
embedding_learning_rate: float = 0.0 # 0 = disabled, 5e-5 = recommended
Expand Down Expand Up @@ -393,17 +421,22 @@ def _build_optimizer(self, total_steps):
scale_parameter=False,
)
elif opt_name == "adamw":
# Match HF/PyTorch AdamW semantics. MLX defaults bias_correction
# to False, which makes early warmup updates much larger.
# bias_correction is opt-in via self.args.adam_bias_correction.
# Default False matches the pre-#634 MLX default and the
# early-step behavior every existing MLX fine-tune script (incl.
# the upstream smoke test) was tuned against. See dataclass field
# for the full HF-parity tradeoff.
bc = bool(getattr(self.args, "adam_bias_correction", False))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve bias correction for non-MLX args

When callers pass a TrainingArguments-like/custom args object that does not yet define adam_bias_correction, this fallback now disables AdamW bias correction even though the previous trainer behavior and the new MLXTrainingConfig default are both True; the same False fallback is repeated for the adam branch below. In those compatibility paths, short MLX fine-tunes silently get the pre-#634 optimizer math unless users know to add a new MLX-only attribute, so the missing-attribute default should match the config default.

Useful? React with 👍 / 👎.

optimizer = optim.AdamW(
learning_rate=initial_lr,
weight_decay=wd,
bias_correction=True,
bias_correction=bc,
)
elif opt_name == "adam":
bc = bool(getattr(self.args, "adam_bias_correction", False))
optimizer = optim.Adam(
learning_rate=initial_lr,
bias_correction=True,
bias_correction=bc,
)
elif opt_name == "sgd":
optimizer = optim.SGD(learning_rate=initial_lr, weight_decay=wd)
Expand Down Expand Up @@ -723,13 +756,14 @@ def _train_inner(self):
_needs_grad_scaling = use_lora_plus or use_embedding_lr
_warned_skip_optimizer_state_grad_norm = False

# Build step functions following mlx-lm's pattern
# Build step functions following mlx-lm's pattern.
# max_grad_value is opt-in: None/0 => disabled, honor max_grad_norm.
# Explicit float > 0 enables the elementwise low-memory clip path
# and overrides max_grad_norm (with a notice) since the two cannot
# combine meaningfully on MLX's compiled update.
max_grad_norm = float(args.max_grad_norm or 0.0)
# Elementwise clip (clip_grad_value_): leaf-local, free memory.
# Prefer value clipping when both clipping modes are requested; global
# norm clipping is exact but materially increases memory on MLX.
_raw_mgv = getattr(args, "max_grad_value", 1.0) # TODO: expose MLX grad-clip in Studio UI for power users
max_grad_value = 1.0 if _raw_mgv is None else float(_raw_mgv or 0.0)
_raw_mgv = getattr(args, "max_grad_value", None)
max_grad_value = 0.0 if _raw_mgv is None else float(_raw_mgv or 0.0)
Comment on lines +765 to +766

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Honor max_grad_norm for the default config path

This only turns value clipping off for args objects that lack max_grad_value; MLXTrainingConfig still defines max_grad_value=1.0, so MLXTrainingConfig(max_grad_norm=1.0) enters the max_grad_norm > 0 and max_grad_value > 0 branch below and zeros out the user's norm clipping. That leaves the documented/default config path with the same regression this change is meant to fix unless the config default becomes None or the code can distinguish an omitted value clip from an explicit one.

Useful? React with 👍 / 👎.

if max_grad_norm > 0 and max_grad_value > 0:
print(
"Unsloth: max_grad_norm and max_grad_value are both enabled; "
Expand Down
Loading