-
-
Notifications
You must be signed in to change notification settings - Fork 6.1k
MLX Training updates #5656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MLX Training updates #5656
Changes from 8 commits
73f37c6
e36b55e
8b79ba4
e8c944f
377fc67
e829268
bfb4203
a404dfd
bff5b44
56e32b7
29aa91a
1a02643
e293af1
962ca28
65cd019
d66f4a7
6a406cb
ad8bf14
976520c
32ddc22
ae6c259
54e8408
71c363d
d142420
54d8d15
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -216,6 +216,8 @@ async def start_training( | |
| "save_steps": request.save_steps, | ||
| "weight_decay": request.weight_decay, | ||
| "max_grad_norm": request.max_grad_norm, | ||
| "max_grad_value": request.max_grad_value, | ||
| "cast_norm_output_to_input_dtype": request.cast_norm_output_to_input_dtype, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: There should be
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, this was a real gap. The Pydantic schema accepted max_grad_leaf_norm but the route never copied it into the config dict, so REST callers had the value silently dropped (start_training kwargs callers were unaffected). Added the forwarding line in d142420 plus a source-pin test that asserts all three grad clipping fields are forwarded by the route. |
||
| "random_seed": request.random_seed, | ||
| "packing": request.packing, | ||
| "optim": request.optim, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The description for
max_grad_valuestates that MLX uses its runtime default if unset. However, the implementation inworker.py(line 1396) explicitly defaults it to1.0if it isNone. To avoid confusion and ensure the API documentation matches the implementation, the description should be updated to reflect that it defaults to1.0in this environment.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is resolved in the current head. The worker no longer substitutes 1.0: max_grad_value stays None unless the caller sets it, and None reaches MLXTrainingConfig so the trainer applies its own runtime default (per-leaf L2 norm 1.0 after unslothai/unsloth-zoo#684). The schema description now matches the implementation.