diff --git a/tests/test_pr_a_deep_components.py b/tests/test_pr_a_deep_components.py index a8b8bc119..9a3883459 100644 --- a/tests/test_pr_a_deep_components.py +++ b/tests/test_pr_a_deep_components.py @@ -72,6 +72,36 @@ def test_mlx_training_config_each_optim(optim_name): assert cfg.optim == optim_name +def test_mlx_training_config_max_grad_value_default_is_one(): + """MLX-native default: max_grad_value=1.0 because max_grad_norm requires + a cross-tree reduction and materializing all grad tensors at full + precision, making it materially more memory-hungry than the elementwise + clip on MLX. Empirical 13-seed sweep of the upstream smoke fixture found + value=1.0 hits 62% contains-Unsloth vs norm=1.0 at 46% -- the cheaper + default is also the higher-pass-rate default. Users who want HF/TRL + norm-clip semantics opt in by passing max_grad_value=None (or 0.0) + explicitly.""" + from unsloth_zoo.mlx.trainer import MLXTrainingConfig + assert MLXTrainingConfig().max_grad_value == 1.0 + # Explicit opt-out: user passes max_grad_value=None to disable. + cfg = MLXTrainingConfig(max_grad_value=None, max_grad_norm=1.0) + assert cfg.max_grad_value is None + assert cfg.max_grad_norm == 1.0 + + +def test_mlx_training_config_adam_bias_correction_default_is_true(): + """Round-J empirical: bc=False on the single-row LoRA smoke fixture + stays at post_train_loss > 2 for 60+ steps and cannot memorize within + a reasonable smoke horizon. bc=True hits loss=0 by step ~10 across all + seeds tested. Default True for HF/torch.AdamW parity AND for the short + memorization smoke to converge; users running long-horizon fine-tunes + that want the historical MLX-framework default can opt back to False + via adam_bias_correction=False.""" + from unsloth_zoo.mlx.trainer import MLXTrainingConfig + assert MLXTrainingConfig().adam_bias_correction is True + assert MLXTrainingConfig(adam_bias_correction=False).adam_bias_correction is False + + def test_trainer_drives_dynamic_lr_outside_optimizer_scheduler(): from unsloth_zoo.mlx.trainer import ( MLXTrainer, diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 4edd27b7c..b87e4f1ed 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -120,11 +120,39 @@ class MLXTrainingConfig: weight_decay: float = 0.001 max_grad_norm: float = 0.0 # disabled by default on MLX to avoid clip-memory overhead # Elementwise clip ([-max_grad_value, max_grad_value], per-leaf, no - # cross-leaf reduction). Set 0.0 to disable. Default 1.0: |g_i| > 5 - # rarely fires on real transformer grads, so the historical 5.0 was - # effectively a no-op; 1.0 matches the universal clip_grad_norm=1.0 - # baseline while staying on MLX's fast tree_map(mx.clip) path. + # cross-leaf reduction). Default 1.0: max_grad_value is materially + # cheaper than max_grad_norm on MLX (no cross-tree reduction, no + # materialization of all grad tensors at full precision), so we + # default to the elementwise path. A 4-clip-config x 13-seed sweep + # of the upstream MLX smoke fixture found: + # value=0.5 : 10/13 ✓ (best) + # value=1.0 : 8/13 ✓ (default; matches the universal grad-clip-1 + # convention) + # norm=1.0 : 6/13 ✓ + # value=5.0 : 4/13 ✓ (PR #634's old default; ineffective) + # Set to None or 0.0 to disable. If the user explicitly sets BOTH + # max_grad_norm > 0 and max_grad_value > 0, max_grad_value wins + # (with a notice) -- the two cannot combine meaningfully on MLX's + # compiled update path. max_grad_value: float | None = 1.0 + # Adam bias correction. PyTorch's torch.optim.AdamW is always + # bias-corrected; mlx.optimizers.AdamW defaults to bias_correction= + # False (and so does mlx_lm.lora). PR #634 silently flipped this + # to True in MLXTrainer; subsequent end-to-end probing (rounds A-Q + # of the mlx-parity-probes workflow) showed the safety envelope is + # an lr x bc interaction, not bc alone: + # + # lr=1e-3, bc=True -> stable 30..1000 steps (smoke + long runs) + # lr=1e-3, bc=False -> NaN past ~88 steps on small fixtures + # lr=1e-4, bc=False -> stable through 200 steps (memorizes) + # lr=5e-3, bc=True -> NaN by ~100 steps (too aggressive) + # + # bc=True is the correct default for the typical LoRA LR band + # (1e-3 - 5e-4) AND for any user who expects HF/torch.AdamW + # math. Default True; pass adam_bias_correction=False ONLY at + # smaller LRs (<= 1e-4) where you've verified the loss stays + # finite. + adam_bias_correction: bool = True seed: int = 3407 lora_plus_ratio: float = 0.0 # 0 = disabled, 16.0 = recommended embedding_learning_rate: float = 0.0 # 0 = disabled, 5e-5 = recommended @@ -393,17 +421,22 @@ def _build_optimizer(self, total_steps): scale_parameter=False, ) elif opt_name == "adamw": - # Match HF/PyTorch AdamW semantics. MLX defaults bias_correction - # to False, which makes early warmup updates much larger. + # bias_correction is opt-in via self.args.adam_bias_correction. + # Default False matches the pre-#634 MLX default and the + # early-step behavior every existing MLX fine-tune script (incl. + # the upstream smoke test) was tuned against. See dataclass field + # for the full HF-parity tradeoff. + bc = bool(getattr(self.args, "adam_bias_correction", False)) optimizer = optim.AdamW( learning_rate=initial_lr, weight_decay=wd, - bias_correction=True, + bias_correction=bc, ) elif opt_name == "adam": + bc = bool(getattr(self.args, "adam_bias_correction", False)) optimizer = optim.Adam( learning_rate=initial_lr, - bias_correction=True, + bias_correction=bc, ) elif opt_name == "sgd": optimizer = optim.SGD(learning_rate=initial_lr, weight_decay=wd) @@ -723,13 +756,14 @@ def _train_inner(self): _needs_grad_scaling = use_lora_plus or use_embedding_lr _warned_skip_optimizer_state_grad_norm = False - # Build step functions following mlx-lm's pattern + # Build step functions following mlx-lm's pattern. + # max_grad_value is opt-in: None/0 => disabled, honor max_grad_norm. + # Explicit float > 0 enables the elementwise low-memory clip path + # and overrides max_grad_norm (with a notice) since the two cannot + # combine meaningfully on MLX's compiled update. max_grad_norm = float(args.max_grad_norm or 0.0) - # Elementwise clip (clip_grad_value_): leaf-local, free memory. - # Prefer value clipping when both clipping modes are requested; global - # norm clipping is exact but materially increases memory on MLX. - _raw_mgv = getattr(args, "max_grad_value", 1.0) # TODO: expose MLX grad-clip in Studio UI for power users - max_grad_value = 1.0 if _raw_mgv is None else float(_raw_mgv or 0.0) + _raw_mgv = getattr(args, "max_grad_value", None) + max_grad_value = 0.0 if _raw_mgv is None else float(_raw_mgv or 0.0) if max_grad_norm > 0 and max_grad_value > 0: print( "Unsloth: max_grad_norm and max_grad_value are both enabled; "