Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions unsloth_zoo/mlx/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,12 @@ class MLXTrainingConfig:
optim: str = "adamw" # "adafactor", "adamw", "adam", "sgd", "muon", "lion"
weight_decay: float = 0.001
max_grad_norm: float = 0.0 # disabled by default on MLX to avoid clip-memory overhead
# Elementwise clipping (PyTorch's torch.nn.utils.clip_grad_value_).
# Clamps every grad value to [-max_grad_value, max_grad_value] leaf-by-leaf
# with no cross-leaf reduction. Set 0.0 to disable.
max_grad_value: float | None = 5.0
# Elementwise clip ([-max_grad_value, max_grad_value], per-leaf, no
# cross-leaf reduction). Set 0.0 to disable. Default 1.0: |g_i| > 5
# rarely fires on real transformer grads, so the historical 5.0 was
# effectively a no-op; 1.0 matches the universal clip_grad_norm=1.0
# baseline while staying on MLX's fast tree_map(mx.clip) path.
max_grad_value: float | None = 1.0
seed: int = 3407
lora_plus_ratio: float = 0.0 # 0 = disabled, 16.0 = recommended
embedding_learning_rate: float = 0.0 # 0 = disabled, 5e-5 = recommended
Expand Down Expand Up @@ -726,8 +728,8 @@ def _train_inner(self):
# Elementwise clip (clip_grad_value_): leaf-local, free memory.
# Prefer value clipping when both clipping modes are requested; global
# norm clipping is exact but materially increases memory on MLX.
_raw_mgv = getattr(args, "max_grad_value", 5.0) # TODO: expose MLX grad-clip in Studio UI for power users
max_grad_value = 5.0 if _raw_mgv is None else float(_raw_mgv or 0.0)
_raw_mgv = getattr(args, "max_grad_value", 1.0) # TODO: expose MLX grad-clip in Studio UI for power users
max_grad_value = 1.0 if _raw_mgv is None else float(_raw_mgv or 0.0)
Comment on lines +731 to +732

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value 1.0 is hardcoded here, duplicating the default value defined in MLXTrainingConfig. This creates a maintenance risk if the default is updated in the configuration but missed here. Additionally, the or 0.0 in the float() conversion is redundant because None is already handled by the conditional branch and float(0.0) correctly returns 0.0.

Consider referencing the default value from MLXTrainingConfig directly to ensure consistency.

Suggested change
_raw_mgv = getattr(args, "max_grad_value", 1.0) # TODO: expose MLX grad-clip in Studio UI for power users
max_grad_value = 1.0 if _raw_mgv is None else float(_raw_mgv or 0.0)
_raw_mgv = getattr(args, "max_grad_value", MLXTrainingConfig.max_grad_value) # TODO: expose MLX grad-clip in Studio UI for power users
max_grad_value = float(MLXTrainingConfig.max_grad_value if _raw_mgv is None else _raw_mgv)

if max_grad_norm > 0 and max_grad_value > 0:
print(
"Unsloth: max_grad_norm and max_grad_value are both enabled; "
Expand Down
Loading