-
Notifications
You must be signed in to change notification settings - Fork 267
fix(mlx): honor max_grad_value=None as a disable signal #671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b59cca2
265534b
da43c20
f14f717
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,121 @@ | ||||||||||||||||||||||||||||||||
| # Unsloth Zoo - Utilities for Unsloth | ||||||||||||||||||||||||||||||||
| # Pin MLXTrainingConfig.max_grad_value resolution: | ||||||||||||||||||||||||||||||||
| # * None (default) -> cheap MLX elementwise clip at 1.0, unless | ||||||||||||||||||||||||||||||||
| # max_grad_norm > 0 is also passed (then global-norm wins). | ||||||||||||||||||||||||||||||||
| # * 0.0 -> explicitly disabled. | ||||||||||||||||||||||||||||||||
| # * positive -> explicit elementwise opt-in; overrides max_grad_norm. | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| @pytest.fixture(autouse=True, scope="module") | ||||||||||||||||||||||||||||||||
| def _install_mlx_shim(): | ||||||||||||||||||||||||||||||||
| from mlx_simulation import simulate_mlx_on_torch | ||||||||||||||||||||||||||||||||
| simulate_mlx_on_torch() | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _resolve(raw_mgv, max_grad_norm): | ||||||||||||||||||||||||||||||||
| """Mirror trainer.py's internal resolution. Returns the (max_grad_value, | ||||||||||||||||||||||||||||||||
| max_grad_norm) pair the step function will actually use.""" | ||||||||||||||||||||||||||||||||
| user_set = raw_mgv is not None | ||||||||||||||||||||||||||||||||
| if user_set: | ||||||||||||||||||||||||||||||||
| mgv = float(raw_mgv or 0.0) | ||||||||||||||||||||||||||||||||
| if max_grad_norm > 0 and mgv > 0: | ||||||||||||||||||||||||||||||||
| max_grad_norm = 0.0 | ||||||||||||||||||||||||||||||||
| elif max_grad_norm > 0: | ||||||||||||||||||||||||||||||||
| mgv = 0.0 | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| mgv = 1.0 | ||||||||||||||||||||||||||||||||
| return mgv, max_grad_norm | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # -- field defaults --------------------------------------------------------- | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_field_default_is_none_sentinel(): | ||||||||||||||||||||||||||||||||
| """Default is None (a sentinel meaning 'use MLX cheap default').""" | ||||||||||||||||||||||||||||||||
| from unsloth_zoo.mlx.trainer import MLXTrainingConfig | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| cfg = MLXTrainingConfig(output_dir="/tmp/x") | ||||||||||||||||||||||||||||||||
| assert cfg.max_grad_value is None | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_field_accepts_none(): | ||||||||||||||||||||||||||||||||
| """Field accepts None and round-trips through the dataclass.""" | ||||||||||||||||||||||||||||||||
| from unsloth_zoo.mlx.trainer import MLXTrainingConfig | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| cfg = MLXTrainingConfig(max_grad_value=None, output_dir="/tmp/x") | ||||||||||||||||||||||||||||||||
| assert cfg.max_grad_value is None | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_field_accepts_explicit_positive(): | ||||||||||||||||||||||||||||||||
| """Field accepts positive floats for power users opting in.""" | ||||||||||||||||||||||||||||||||
| from unsloth_zoo.mlx.trainer import MLXTrainingConfig | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| cfg = MLXTrainingConfig(max_grad_value=2.5, output_dir="/tmp/x") | ||||||||||||||||||||||||||||||||
| assert cfg.max_grad_value == 2.5 | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # -- resolution semantics --------------------------------------------------- | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_default_uses_cheap_elementwise(): | ||||||||||||||||||||||||||||||||
| """Default (None, max_grad_norm=0.0) -> elementwise clip at 1.0.""" | ||||||||||||||||||||||||||||||||
| mgv, mgn = _resolve(raw_mgv=None, max_grad_norm=0.0) | ||||||||||||||||||||||||||||||||
| assert mgv == 1.0 | ||||||||||||||||||||||||||||||||
| assert mgn == 0.0 | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_user_max_grad_norm_wins_over_default(): | ||||||||||||||||||||||||||||||||
| """User passes max_grad_norm=1.0 with default max_grad_value=None -> | ||||||||||||||||||||||||||||||||
| global-norm clipping wins, elementwise disabled. HF/TRL parity.""" | ||||||||||||||||||||||||||||||||
| mgv, mgn = _resolve(raw_mgv=None, max_grad_norm=1.0) | ||||||||||||||||||||||||||||||||
| assert mgv == 0.0 | ||||||||||||||||||||||||||||||||
| assert mgn == 1.0 | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_explicit_zero_disables_elementwise(): | ||||||||||||||||||||||||||||||||
| """Explicit 0.0 disables elementwise. With no max_grad_norm, | ||||||||||||||||||||||||||||||||
| nothing clips.""" | ||||||||||||||||||||||||||||||||
| mgv, mgn = _resolve(raw_mgv=0.0, max_grad_norm=0.0) | ||||||||||||||||||||||||||||||||
| assert mgv == 0.0 | ||||||||||||||||||||||||||||||||
| assert mgn == 0.0 | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_explicit_zero_lets_max_grad_norm_through(): | ||||||||||||||||||||||||||||||||
| """Explicit max_grad_value=0.0 + max_grad_norm=1.0 -> only norm clipping.""" | ||||||||||||||||||||||||||||||||
| mgv, mgn = _resolve(raw_mgv=0.0, max_grad_norm=1.0) | ||||||||||||||||||||||||||||||||
| assert mgv == 0.0 | ||||||||||||||||||||||||||||||||
| assert mgn == 1.0 | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_explicit_positive_overrides_max_grad_norm(): | ||||||||||||||||||||||||||||||||
| """Explicit max_grad_value=2.0 with max_grad_norm=1.0 -> elementwise | ||||||||||||||||||||||||||||||||
| wins (existing rule), max_grad_norm zeroed.""" | ||||||||||||||||||||||||||||||||
| mgv, mgn = _resolve(raw_mgv=2.0, max_grad_norm=1.0) | ||||||||||||||||||||||||||||||||
| assert mgv == 2.0 | ||||||||||||||||||||||||||||||||
| assert mgn == 0.0 | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_explicit_positive_alone(): | ||||||||||||||||||||||||||||||||
| """User passes max_grad_value=5.0 with no max_grad_norm -> elementwise at 5.""" | ||||||||||||||||||||||||||||||||
| mgv, mgn = _resolve(raw_mgv=5.0, max_grad_norm=0.0) | ||||||||||||||||||||||||||||||||
| assert mgv == 5.0 | ||||||||||||||||||||||||||||||||
| assert mgn == 0.0 | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # -- trainer source assertions (defense-in-depth) --------------------------- | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_trainer_source_pins_resolution_rule(): | ||||||||||||||||||||||||||||||||
| """Source-level pin: trainer.py contains the four-branch resolution. | ||||||||||||||||||||||||||||||||
| Cheap defense against a future refactor silently regressing the rule.""" | ||||||||||||||||||||||||||||||||
| import inspect | ||||||||||||||||||||||||||||||||
| from unsloth_zoo.mlx import trainer as T | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| src = inspect.getsource(T.MLXTrainer.train) + inspect.getsource(T.MLXTrainer._train_inner) | ||||||||||||||||||||||||||||||||
| assert "_user_set_mgv = _raw_mgv is not None" in src | ||||||||||||||||||||||||||||||||
| assert "elif max_grad_norm > 0:" in src | ||||||||||||||||||||||||||||||||
| assert "max_grad_value = 1.0" in src | ||||||||||||||||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [1/2 reviewers] Med: this test pins three literal substrings (
Suggested change
|
||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -119,12 +119,12 @@ class MLXTrainingConfig: | |||||||||||||||||||
| optim: str = "adamw" # "adafactor", "adamw", "adam", "sgd", "muon", "lion" | ||||||||||||||||||||
| weight_decay: float = 0.001 | ||||||||||||||||||||
| max_grad_norm: float = 0.0 # disabled by default on MLX to avoid clip-memory overhead | ||||||||||||||||||||
| # Elementwise clip ([-max_grad_value, max_grad_value], per-leaf, no | ||||||||||||||||||||
| # cross-leaf reduction). Set 0.0 to disable. Default 1.0: |g_i| > 5 | ||||||||||||||||||||
| # rarely fires on real transformer grads, so the historical 5.0 was | ||||||||||||||||||||
| # effectively a no-op; 1.0 matches the universal clip_grad_norm=1.0 | ||||||||||||||||||||
| # baseline while staying on MLX's fast tree_map(mx.clip) path. | ||||||||||||||||||||
| max_grad_value: float | None = 1.0 | ||||||||||||||||||||
| # Elementwise clip ([-max_grad_value, max_grad_value], per-leaf). | ||||||||||||||||||||
| # None (default) keeps the cheap MLX default of 1.0 unless the user | ||||||||||||||||||||
| # passes max_grad_norm > 0, in which case global-norm clipping wins. | ||||||||||||||||||||
| # 0.0 disables. A positive float opts in explicitly and overrides | ||||||||||||||||||||
| # max_grad_norm with a warning. | ||||||||||||||||||||
| max_grad_value: float | None = None | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Changing the dataclass default to Useful? React with 👍 / 👎.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [1/2 reviewers] High: this default flip from The field docstring should also spell out the sentinel semantics, since the rule itself lives only in the trainer:
Suggested change
|
||||||||||||||||||||
| seed: int = 3407 | ||||||||||||||||||||
| lora_plus_ratio: float = 0.0 # 0 = disabled, 16.0 = recommended | ||||||||||||||||||||
| embedding_learning_rate: float = 0.0 # 0 = disabled, 5e-5 = recommended | ||||||||||||||||||||
|
|
@@ -723,19 +723,34 @@ def _train_inner(self): | |||||||||||||||||||
| _needs_grad_scaling = use_lora_plus or use_embedding_lr | ||||||||||||||||||||
| _warned_skip_optimizer_state_grad_norm = False | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Build step functions following mlx-lm's pattern | ||||||||||||||||||||
| # Build step functions following mlx-lm's pattern. | ||||||||||||||||||||
| # Resolution rule: | ||||||||||||||||||||
| # * max_grad_value=None (default) -> cheap MLX elementwise clip | ||||||||||||||||||||
| # at 1.0, unless the user also passed max_grad_norm > 0 -- in | ||||||||||||||||||||
| # that case the user opted into global-norm clipping (HF/TRL | ||||||||||||||||||||
| # parity) and elementwise is disabled to avoid double-clip. | ||||||||||||||||||||
| # * max_grad_value explicit (float or 0.0) -> honor exactly; | ||||||||||||||||||||
| # if both modes are positive, elementwise wins (warn). | ||||||||||||||||||||
| # max_grad_norm uses MLX's clip_grad_norm helper which materially | ||||||||||||||||||||
| # increases peak memory on bf16 VLM runs, hence the elementwise | ||||||||||||||||||||
| # default. | ||||||||||||||||||||
| max_grad_norm = float(args.max_grad_norm or 0.0) | ||||||||||||||||||||
| # Elementwise clip (clip_grad_value_): leaf-local, free memory. | ||||||||||||||||||||
| # Prefer value clipping when both clipping modes are requested; global | ||||||||||||||||||||
| # norm clipping is exact but materially increases memory on MLX. | ||||||||||||||||||||
| _raw_mgv = getattr(args, "max_grad_value", 1.0) # TODO: expose MLX grad-clip in Studio UI for power users | ||||||||||||||||||||
| max_grad_value = 1.0 if _raw_mgv is None else float(_raw_mgv or 0.0) | ||||||||||||||||||||
| if max_grad_norm > 0 and max_grad_value > 0: | ||||||||||||||||||||
| print( | ||||||||||||||||||||
| "Unsloth: max_grad_norm and max_grad_value are both enabled; " | ||||||||||||||||||||
| "ignoring max_grad_norm in favor of max_grad_value." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| max_grad_norm = 0.0 | ||||||||||||||||||||
| _raw_mgv = getattr(args, "max_grad_value", None) # TODO: expose MLX grad-clip in Studio UI for power users | ||||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [1/2 reviewers] Nit: the dataclass guarantees
Suggested change
|
||||||||||||||||||||
| _user_set_mgv = _raw_mgv is not None | ||||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [1/2 reviewers] Nit:
Suggested change
(removing the standalone |
||||||||||||||||||||
| if _user_set_mgv: | ||||||||||||||||||||
| max_grad_value = float(_raw_mgv or 0.0) | ||||||||||||||||||||
| if max_grad_norm > 0 and max_grad_value > 0: | ||||||||||||||||||||
| print( | ||||||||||||||||||||
| "Unsloth: max_grad_norm and max_grad_value are both enabled; " | ||||||||||||||||||||
| "ignoring max_grad_norm in favor of max_grad_value." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| max_grad_norm = 0.0 | ||||||||||||||||||||
| elif max_grad_norm > 0: | ||||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [1/2 reviewers] Med: with
Suggested change
|
||||||||||||||||||||
| # User opted into global-norm clipping; suppress the default elementwise. | ||||||||||||||||||||
| max_grad_value = 0.0 | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| # Neither requested -> cheap MLX default. | ||||||||||||||||||||
| max_grad_value = 1.0 | ||||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [1/2 reviewers] High: the test file's
Suggested change
|
||||||||||||||||||||
| _clip_grad_value = max_grad_value > 0 | ||||||||||||||||||||
| state = [model.state, optimizer.state, mx.random.state] | ||||||||||||||||||||
| # The direct grad_accum==1 fast path delegates clipping to | ||||||||||||||||||||
|
|
||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[1/2 reviewers] High:
_resolvehere is a hand-written copy ofMLXTrainer._train_inner's resolution. Every "semantics" test (rows 1-5) validates this helper against itself — not the real trainer. The only thing connecting them istest_trainer_source_pins_resolution_rule, which is a substring grep. A semantic-breaking refactor that preserves the substrings passes; a benign refactor that breaks the substrings fails. Better: lift the resolver intounsloth_zoo.mlx.trainer._resolve_grad_clip(raw_mgv, max_grad_norm)and import it here.(then the rest of the test file uses the real function rather than the local mirror.)