MLX Training updates#5656
Conversation
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request introduces support for elementwise gradient value clipping (max_grad_value) in MLX training and refines random state handling by separating model and LoRA initialization seeds. The changes span the backend training logic, API models, and frontend types. Feedback was provided regarding a discrepancy between the max_grad_value field description and its actual default implementation in the worker, which should be aligned to avoid confusion.
| max_grad_value: Optional[float] = Field( | ||
| None, | ||
| ge = 0, | ||
| description = ( | ||
| "MLX-only elementwise gradient value clipping threshold. " | ||
| "If unset, MLX uses its runtime default." | ||
| ), | ||
| ) |
There was a problem hiding this comment.
The description for max_grad_value states that MLX uses its runtime default if unset. However, the implementation in worker.py (line 1396) explicitly defaults it to 1.0 if it is None. To avoid confusion and ensure the API documentation matches the implementation, the description should be updated to reflect that it defaults to 1.0 in this environment.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 8b79ba4235
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| model_random_state = config.get("model_random_state", 3407) | ||
| lora_random_state = config.get("lora_random_state", 3407) |
There was a problem hiding this comment.
Fall back MLX init seeds to random_seed
_run_mlx_training now reads model_random_state/lora_random_state directly, defaulting both to 3407, but Studio never sets those keys (the request/config path still only forwards random_seed in routes/training.py and core/training/training.py). As a result, a user-supplied random_seed no longer controls model load or LoRA initialization RNG, so seed-based reproducibility is broken for any non-3407 seed (including the API default path where random_seed is 42).
Useful? React with 👍 / 👎.
| max_grad_value = config.get("max_grad_value") | ||
| max_grad_value = 1.0 if max_grad_value is None else float(max_grad_value) |
There was a problem hiding this comment.
Honor null max_grad_value as MLX runtime default
The new request model documents that leaving max_grad_value unset should defer to MLX defaults, but this code forces None to 1.0 before building MLXTrainingConfig. Because the frontend now sends max_grad_value: null by default, callers cannot actually use the trainer/runtime default and the API behavior contradicts its contract.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e82926845d
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| model_random_state = config.get("model_random_state", 3407) | ||
| lora_random_state = config.get("lora_random_state", 3407) |
There was a problem hiding this comment.
Use request random_seed for MLX model/LoRA init
_run_mlx_training now reads model_random_state/lora_random_state, but the request path still only forwards random_seed (routes/training.py and core/training/training.py). In practice, user-provided seeds no longer control from_pretrained(... random_state=...) or get_peft_model(... random_state=...), so runs with non-3407 seeds are not reproducible on MLX even though seed is still set for the trainer loop.
Useful? React with 👍 / 👎.
| max_grad_value = config.get("max_grad_value") | ||
| max_grad_value = 1.0 if max_grad_value is None else float(max_grad_value) |
There was a problem hiding this comment.
Preserve null max_grad_value instead of forcing 1.0
The new API contract says max_grad_value can be unset to let MLX use its runtime default (TrainingStartRequest.max_grad_value description), but this code rewrites None to 1.0 before building MLXTrainingConfig. That makes null behaviorally different from the documented contract and prevents callers from actually opting into the trainer/runtime default.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: bfb4203400
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
danielhanchen
left a comment
There was a problem hiding this comment.
Looks good, Matthew. Static review (Studio backend MLX worker is Apple Silicon only, so this is review only on the CUDA host I have).
End-to-end trace of max_grad_value:
studio/backend/models/training.py:267-281acceptsOptional[float](None preserved by Pydantic).studio/backend/routes/training.py:218forwards into the worker dict unchanged.studio/backend/core/training/training.py:218-225forwards into the config dict unchanged.studio/backend/core/training/worker.py:1389-1397reads it, leaves itNone, only coerces tofloatwhen a numeric value is present.studio/frontend/src/features/training/api/mappers.ts:84sendsmax_grad_value: nullby default.
No x or 1.0 fallback downstream, so the API contract null in -> null out is preserved end-to-end. The weight_decay = 0.001 if weight_decay is None else float(weight_decay) normalization at worker.py:1395-1396 is also cleaner than the previous float(config.get(..., 0.001) or 0.001) (which had the well-known "user explicitly passes 0.0 -> coerced to 0.001" trap).
model_random_state / lora_random_state defaulting to random_seed when absent (worker.py:1156-1170) reads correctly. test_training_backend_forwards_random_seed_without_internal_mlx_seed_keys asserts the absent-key path; one thing the test suite does not assert is the present-but-None case (config["model_random_state"] = None would override the seed with None, because config.get("model_random_state", random_seed) only falls back when the key is missing, not when it's present-and-None). Probably not reachable from Studio today since the request schema doesn't expose those keys, but if anything downstream ever does, the semantics may surprise. Easy fix: config.get("model_random_state") or random_seed if you want explicit-null to mean "inherit", or document the present-vs-absent distinction.
tests/studio/run_real_mlx_smoke.py 30-step refresh matches the gate from #5537. Dropping the eos_id append in _compute_loss_and_grad_norm so the smoke loss probe matches Studio's text dataset path is the right move; the prior 7-step assertion was stale per #5622. I can't actually run the smoke from here (CUDA-only host), so this rides on macOS CI evidence.
This PR pairs with unsloth-zoo#684 - cast_norm_output_to_input_dtype and max_grad_value=None semantics only do useful work once MLXTrainingArguments accepts them. Worth a merge-order note in case #684 lands later.
Approving subject to MLX CI green on the smoke test refresh.
test_training_raw_support.py transitively imports the full studio backend (core.training.training -> matplotlib, etc.). Adding every transitive dep to the Windows install smoke is whack-a-mole and defeats the smoke's purpose. test_mlx_training_worker_config.py already covers PR unslothai#5656's wiring (model_random_state / lora_random_state fallback, max_grad_value None preservation, dataset_order=torch_randperm) via source-text assertions on worker.py. The test stubs out structlog/loggers/utils itself, so it works with just stdlib. Drop the broader test from the Windows job.
studio/backend/core/training/worker.py
`config.get("model_random_state", random_seed)` only fills the
default when the key is absent. When a caller passes
`config["model_random_state"] = None` explicitly (which happens
any time a JSON payload sends an explicit `null`), the old code
forwarded `None` to FastMLXModel and disabled deterministic init
silently. Same for `lora_random_state`. Treat absent and explicit
None the same way: fall back to random_seed.
studio/backend/tests/test_training_raw_support.py
Update the source-string assertions to match the new lines.
|
Pushed one small follow-up on top of
The Pydantic schema does not expose |
The PR unslothai#684 and PR unslothai#5656 heads were just updated with maintainer fixes (restored compiler.py UNSLOTH_RETURN_LOGITS elif, GPT-2 ln_* matching, Qwen3-VL flag wiring, default-branch reseed; plus seed present-but-None fix). Bump the three workflow files (comment-only) so the paths filters re-fire and we get a fresh signal on all three runners against the updated PR heads.
Round 2 of reviewer-driven fixes landed on the PR heads: zoo PR unslothai#684: 0753b115 - merged origin/main (restores unslothai#690 / unslothai#691 gpt-oss eager attn) - cleaned up norm cast monkey patch in train() finally - raise on streaming+dataset_order text combo - VLM baseline CE full-sequence forward parity with CCE - scheduler test now matches HF linear-no-warmup behavior unsloth PR unslothai#5656: bff5b44 (unchanged since last run) Re-fire all three workflows so we get a fresh signal.
… PR unslothai#5656 The MLX worker now passes `cast_norm_output_to_input_dtype` and `dataset_order` only when the linked unsloth-zoo dataclass actually declares them. Released zoo trees that predate the paired PR can still construct `MLXTrainingConfig` without raising `TypeError: unexpected keyword argument`. Once the dependency floor is bumped to a release that contains both fields, the feature-detect guards become no-ops. `random_seed = config.get("random_seed", 3407)` was unguarded against explicit `None` from raw / backend callers. The same value seeded the trainer and was the fallback target for `model_random_state` / `lora_random_state`. Normalize once at the top of the function and use the normalized value everywhere so an explicit `None` cannot reach FastMLXModel / get_peft_model / MLXTrainingConfig. Existing seed source-pattern test updated to match the new normalize helper. New test asserts the feature-detection guards exist and that the unconditional kwargs do not include the gated fields.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
for more information, see https://pre-commit.ci
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
Pushed
Tests: Added one new assertion to Not addressing in this PR:
Yell if anything looks off. |
…othai#5656 Round-3 review consensus: the per-field guards that landed in the MLX worker only protect the MLX path. The same `TrainingBackend.start_training` config still reaches the CUDA/text trainer at `worker.py:2267`, the embedding LoRA init at `worker.py:2450`, and embedding TrainingArguments at `worker.py:2624` with raw `None` values, so an explicit `random_seed=None` from a raw / backend caller still breaks non-MLX training even after the previous fix. Move the normalization into `TrainingBackend.start_training` itself, where it runs once for every training mode: - `_coerce_seed(value)`: explicit `None`, non-int, or absent all become 3407. Every downstream worker now sees an int. - `_coerce_optional_bool(value, default)`: explicit `None` falls back to `default` instead of `bool(None) == False`. Also normalizes the common raw-config / YAML string aliases ("true" / "false" / "0" / "1"). Used for `cast_norm_output_to_input_dtype`. - `_coerce_optional_nonneg_float(name, value)`: rejects negative numerics from raw / backend callers, matching the Pydantic `ge=0` constraint the HTTP route already enforces. Used for `max_grad_value`. worker.py MLX path: the existing `bool(config.get(key, True))` for `cast_norm_output_to_input_dtype` was changed to also fall back on explicit `None`, so direct worker callers (bypassing `TrainingBackend.start_training`) are equally safe. `max_grad_value` also raises on negative values inside the worker for the same reason. TrainingStartRequest.random_seed default bumped from 42 to 3407 so direct REST callers that omit the field receive the same default as the Studio frontend and the MLX worker. New regression test exercises the three new helpers across explicit None, valid values, string aliases, and negative-value rejection.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
Pushed
Background context (already in earlier rounds):
Not changed in this PR:
Yell if anything looks off. |
The block-extraction used , which stops at the first inner closing paren (e.g. ) and would silently miss a future unconditional / added later in the same dict literal. Switched to proper paren-depth tracking so the unconditional block is checked end-to-end.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
Added commit Rationale / parity check:
The patch is intentionally conditional: raw_text_mode = training_type == "Continued Pretraining" or format_type == "raw"
mlx_config_kwargs["append_eos"] = bool(raw_text_mode)So raw/CPT text still lets MLX append EOS, matching the CUDA raw-text path, while formatted SFT text does not get an extra EOS behind Studio's back. Validation run locally: pytest -q studio/backend/tests/test_training_raw_support.py
# 11 passed |
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
Rationale / guardrails for the local Studio/vision push: When callers provide explicit VLM LoRA target_modules together with layer filters, FastVisionModel still needs to route the explicit targets through get_peft_regex. Otherwise the layer filters are ignored and adapters can be attached outside the requested language/vision scope. Do not revert this to plain list(target_modules) for explicit module lists. The CUDA/Studio-facing contract is that explicit targets and layer filters compose: target_modules selects module names, while finetune_language_layers / finetune_vision_layers / finetune_attention_modules / finetune_mlp_modules constrain where those targets are allowed. The regression test covers the language-only explicit q_proj case and source-checks that explicit targets are wrapped through get_peft_regex when filters are active.
|
Reviewer / maintainer guardrail for the next Studio/vision push: The local VLM LoRA targeting fix is intentional and should not be reverted to plain When callers provide explicit VLM
Without the regex wrapping, explicit target lists can ignore the language/vision layer filters and attach adapters outside the requested scope. The added test covers the language-only explicit |
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
for more information, see https://pre-commit.ci
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
Verified the Studio-side wiring for Plumbing path verified
Feature-detect guard keeps backwards compat with older Tests passing (15)Including the new Why the new default matters in StudioStudio users default to LoRA training on Apple Silicon where memory headroom is tight. The new MLX default (
Detailed parity data in unslothai/unsloth-zoo#684 review comment. Tested across precedence
Each resolution path covered by LGTM from my side on the Studio plumbing. |
Trim the 11-line comment block to 5 lines and correct the stale claim that MLXTrainingConfig defaults to max_grad_value=1.0. The new default is max_grad_leaf_norm=1.0 (same memory profile as elementwise but direction-preserving). The smoke still pins max_grad_value=1.0 explicitly to keep the 13-seed pass-rate fixture stable.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
Merges 116 main commits (gemini provider, oxc validator package-lock,
uninstall script relocation, lockfile audit, etc). Two content conflicts
resolved:
- studio/backend/tests/test_mlx_training_worker_config.py: both branches
appended a new test (HEAD's tokenizer dual-purpose check, main's VLM
resize math). Both kept side-by-side; both pass.
- tests/studio/run_real_mlx_smoke.py: HEAD's stronger len + train_steps
assertion kept; main's auto-following comment kept.
16 Studio backend tests pass post-merge.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
for more information, see https://pre-commit.ci
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
| tuple, | ||
| str, | ||
| ) | ||
| if type(target_modules) in (list, tuple) and ( |
There was a problem hiding this comment.
NIT: Should we at least warn instead that both are mentioned and choosing one over the other or smth?
| "weight_decay": request.weight_decay, | ||
| "max_grad_norm": request.max_grad_norm, | ||
| "max_grad_value": request.max_grad_value, | ||
| "cast_norm_output_to_input_dtype": request.cast_norm_output_to_input_dtype, |
There was a problem hiding this comment.
NIT: There should be max_grad_leaf_norm entry here?
Expose max grad values and set default random seeds in studio.