-
Notifications
You must be signed in to change notification settings - Fork 267
fix(mlx): max_grad_value default off, honor user max_grad_norm #663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cfca020
72a448b
18596f2
7312862
ef003aa
669a792
c1821e4
aed74d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -120,11 +120,39 @@ class MLXTrainingConfig: | |
| weight_decay: float = 0.001 | ||
| max_grad_norm: float = 0.0 # disabled by default on MLX to avoid clip-memory overhead | ||
| # Elementwise clip ([-max_grad_value, max_grad_value], per-leaf, no | ||
| # cross-leaf reduction). Set 0.0 to disable. Default 1.0: |g_i| > 5 | ||
| # rarely fires on real transformer grads, so the historical 5.0 was | ||
| # effectively a no-op; 1.0 matches the universal clip_grad_norm=1.0 | ||
| # baseline while staying on MLX's fast tree_map(mx.clip) path. | ||
| # cross-leaf reduction). Default 1.0: max_grad_value is materially | ||
| # cheaper than max_grad_norm on MLX (no cross-tree reduction, no | ||
| # materialization of all grad tensors at full precision), so we | ||
| # default to the elementwise path. A 4-clip-config x 13-seed sweep | ||
| # of the upstream MLX smoke fixture found: | ||
| # value=0.5 : 10/13 ✓ (best) | ||
| # value=1.0 : 8/13 ✓ (default; matches the universal grad-clip-1 | ||
| # convention) | ||
| # norm=1.0 : 6/13 ✓ | ||
| # value=5.0 : 4/13 ✓ (PR #634's old default; ineffective) | ||
| # Set to None or 0.0 to disable. If the user explicitly sets BOTH | ||
| # max_grad_norm > 0 and max_grad_value > 0, max_grad_value wins | ||
| # (with a notice) -- the two cannot combine meaningfully on MLX's | ||
| # compiled update path. | ||
| max_grad_value: float | None = 1.0 | ||
| # Adam bias correction. PyTorch's torch.optim.AdamW is always | ||
| # bias-corrected; mlx.optimizers.AdamW defaults to bias_correction= | ||
| # False (and so does mlx_lm.lora). PR #634 silently flipped this | ||
| # to True in MLXTrainer; subsequent end-to-end probing (rounds A-Q | ||
| # of the mlx-parity-probes workflow) showed the safety envelope is | ||
| # an lr x bc interaction, not bc alone: | ||
| # | ||
| # lr=1e-3, bc=True -> stable 30..1000 steps (smoke + long runs) | ||
| # lr=1e-3, bc=False -> NaN past ~88 steps on small fixtures | ||
| # lr=1e-4, bc=False -> stable through 200 steps (memorizes) | ||
| # lr=5e-3, bc=True -> NaN by ~100 steps (too aggressive) | ||
| # | ||
| # bc=True is the correct default for the typical LoRA LR band | ||
| # (1e-3 - 5e-4) AND for any user who expects HF/torch.AdamW | ||
| # math. Default True; pass adam_bias_correction=False ONLY at | ||
| # smaller LRs (<= 1e-4) where you've verified the loss stays | ||
| # finite. | ||
| adam_bias_correction: bool = True | ||
| seed: int = 3407 | ||
| lora_plus_ratio: float = 0.0 # 0 = disabled, 16.0 = recommended | ||
| embedding_learning_rate: float = 0.0 # 0 = disabled, 5e-5 = recommended | ||
|
|
@@ -393,17 +421,22 @@ def _build_optimizer(self, total_steps): | |
| scale_parameter=False, | ||
| ) | ||
| elif opt_name == "adamw": | ||
| # Match HF/PyTorch AdamW semantics. MLX defaults bias_correction | ||
| # to False, which makes early warmup updates much larger. | ||
| # bias_correction is opt-in via self.args.adam_bias_correction. | ||
| # Default False matches the pre-#634 MLX default and the | ||
| # early-step behavior every existing MLX fine-tune script (incl. | ||
| # the upstream smoke test) was tuned against. See dataclass field | ||
| # for the full HF-parity tradeoff. | ||
| bc = bool(getattr(self.args, "adam_bias_correction", False)) | ||
| optimizer = optim.AdamW( | ||
| learning_rate=initial_lr, | ||
| weight_decay=wd, | ||
| bias_correction=True, | ||
| bias_correction=bc, | ||
| ) | ||
| elif opt_name == "adam": | ||
| bc = bool(getattr(self.args, "adam_bias_correction", False)) | ||
| optimizer = optim.Adam( | ||
| learning_rate=initial_lr, | ||
| bias_correction=True, | ||
| bias_correction=bc, | ||
| ) | ||
| elif opt_name == "sgd": | ||
| optimizer = optim.SGD(learning_rate=initial_lr, weight_decay=wd) | ||
|
|
@@ -723,13 +756,14 @@ def _train_inner(self): | |
| _needs_grad_scaling = use_lora_plus or use_embedding_lr | ||
| _warned_skip_optimizer_state_grad_norm = False | ||
|
|
||
| # Build step functions following mlx-lm's pattern | ||
| # Build step functions following mlx-lm's pattern. | ||
| # max_grad_value is opt-in: None/0 => disabled, honor max_grad_norm. | ||
| # Explicit float > 0 enables the elementwise low-memory clip path | ||
| # and overrides max_grad_norm (with a notice) since the two cannot | ||
| # combine meaningfully on MLX's compiled update. | ||
| max_grad_norm = float(args.max_grad_norm or 0.0) | ||
| # Elementwise clip (clip_grad_value_): leaf-local, free memory. | ||
| # Prefer value clipping when both clipping modes are requested; global | ||
| # norm clipping is exact but materially increases memory on MLX. | ||
| _raw_mgv = getattr(args, "max_grad_value", 1.0) # TODO: expose MLX grad-clip in Studio UI for power users | ||
| max_grad_value = 1.0 if _raw_mgv is None else float(_raw_mgv or 0.0) | ||
| _raw_mgv = getattr(args, "max_grad_value", None) | ||
| max_grad_value = 0.0 if _raw_mgv is None else float(_raw_mgv or 0.0) | ||
|
Comment on lines
+765
to
+766
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This only turns value clipping off for args objects that lack Useful? React with 👍 / 👎. |
||
| if max_grad_norm > 0 and max_grad_value > 0: | ||
| print( | ||
| "Unsloth: max_grad_norm and max_grad_value are both enabled; " | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When callers pass a TrainingArguments-like/custom args object that does not yet define
adam_bias_correction, this fallback now disables AdamW bias correction even though the previous trainer behavior and the newMLXTrainingConfigdefault are bothTrue; the sameFalsefallback is repeated for theadambranch below. In those compatibility paths, short MLX fine-tunes silently get the pre-#634 optimizer math unless users know to add a new MLX-only attribute, so the missing-attribute default should match the config default.Useful? React with 👍 / 👎.