Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
73f37c6
Expose MLX grad value clipping in Studio
mmathew23 May 19, 2026
e36b55e
update test
mmathew23 May 20, 2026
8b79ba4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2026
e8c944f
dataset ordering + wd
mmathew23 May 21, 2026
377fc67
fix mlx smoke step expectations
mmathew23 May 21, 2026
e829268
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2026
bfb4203
cast norm activation output back to original input dtype
mmathew23 May 21, 2026
a404dfd
address mlx studio review feedback
mmathew23 May 21, 2026
bff5b44
Fix present-but-None seed override for PR #5656
May 24, 2026
56e32b7
Guard optional MLXTrainingConfig fields and normalize random_seed for…
May 24, 2026
29aa91a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2026
1a02643
Normalize seed / cast / max_grad_value at TrainingBackend for PR #5656
May 24, 2026
e293af1
Tighten feature-detect test paren tracking for PR #5656
May 24, 2026
962ca28
Shorten verbose comments in MLX Studio backend
May 25, 2026
65cd019
Handle MLX Studio EOS appending by mode
mmathew23 May 26, 2026
d66f4a7
Wire MLX leaf norm clipping through Studio
mmathew23 May 26, 2026
6a406cb
Respect VLM layer filters for explicit LoRA targets
mmathew23 May 26, 2026
ad8bf14
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2026
976520c
Refresh MLX smoke clip-config note for leaf_norm default
May 27, 2026
32ddc22
Merge main into explore/mlx; resolve studio test + smoke conflicts
May 27, 2026
ae6c259
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2026
54e8408
Merge main into explore/mlx and resolve smoke, worker, and vision tar…
danielhanchen Jun 12, 2026
71c363d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2026
d142420
Forward max_grad_leaf_norm through the training route and warn when l…
danielhanchen Jun 12, 2026
54d8d15
Merge remote-tracking branch 'origin/main' into explore/mlx
danielhanchen Jun 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions studio/backend/core/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ def start_training(self, job_id: str, **kwargs) -> bool:
"save_steps": kwargs.get("save_steps", 0),
"weight_decay": kwargs.get("weight_decay", 0.001),
"max_grad_norm": kwargs.get("max_grad_norm", 0.0),
"max_grad_value": kwargs.get("max_grad_value"),
"cast_norm_output_to_input_dtype": kwargs.get(
"cast_norm_output_to_input_dtype", True
),
"random_seed": kwargs.get("random_seed", 3407),
"packing": kwargs.get("packing", False),
"optim": kwargs.get("optim", "adamw_8bit"),
Expand Down
23 changes: 16 additions & 7 deletions studio/backend/core/training/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,14 +1156,17 @@ def _send(event_type, **kwargs):
is_dataset_image = bool(config.get("is_dataset_image", False))
training_type = config.get("training_type", "LoRA/QLoRA")
use_lora = training_type == "LoRA/QLoRA"
random_seed = config.get("random_seed", 3407)
model_random_state = config.get("model_random_state", random_seed)
lora_random_state = config.get("lora_random_state", random_seed)
model, tokenizer = FastMLXModel.from_pretrained(
model_name,
load_in_4bit = config.get("load_in_4bit", True),
full_finetuning = not use_lora,
text_only = None if is_dataset_image else True,
token = hf_token,
trust_remote_code = bool(config.get("trust_remote_code", False)),
random_state = config.get("random_seed", 3407),
random_state = model_random_state,
)

is_vlm = bool(is_dataset_image and getattr(model, "_is_vlm_model", False))
Expand All @@ -1188,7 +1191,7 @@ def _send(event_type, **kwargs):
lora_dropout = config.get("lora_dropout", 0.0),
use_rslora = config.get("use_rslora", False),
init_lora_weights = config.get("init_lora_weights", True),
random_state = config.get("random_seed", 3407),
random_state = lora_random_state,
target_modules = config.get("target_modules")
or [
"q_proj",
Expand Down Expand Up @@ -1386,11 +1389,13 @@ def _fmt_progress(status_message = "", **_kw):
else:
eval_steps_val = int(eval_steps_val)

# MLX: per-element clip to [-1, 1]; norm clip disabled (it needs a
# global reduction that breaks MLX's eager pipeline). 1.0 (not 5.0):
# |g_i| > 5 rarely fires, so the historical 5.0 was effectively no-op.
# MLX Studio uses per-element clipping by default and keeps norm clipping
# disabled. Preserve None so the MLX trainer owns its runtime default.
max_grad_norm = 0.0
max_grad_value = 1.0 # TODO: expose MLX grad-clip in Studio UI for power users
max_grad_value = config.get("max_grad_value")
max_grad_value = None if max_grad_value is None else float(max_grad_value)
weight_decay = config.get("weight_decay", 0.001)
weight_decay = 0.001 if weight_decay is None else float(weight_decay)

trainer = MLXTrainer(
model = model,
Expand All @@ -1405,16 +1410,20 @@ def _fmt_progress(status_message = "", **_kw):
warmup_steps = warmup_steps,
lr_scheduler_type = lr_scheduler_type,
optim = optim_name,
weight_decay = float(config.get("weight_decay", 0.001) or 0.001),
weight_decay = weight_decay,
max_grad_norm = max_grad_norm,
max_grad_value = max_grad_value,
cast_norm_output_to_input_dtype = bool(
config.get("cast_norm_output_to_input_dtype", True)
),
logging_steps = 1,
max_seq_length = max_seq_length,
seed = config.get("random_seed", 3407),
use_cce = True,
compile = True,
gradient_checkpointing = use_grad_checkpoint,
streaming = is_vlm,
dataset_order = "torch_randperm",
packing = bool(config.get("packing", False)),
output_dir = output_dir,
save_steps = int(config.get("save_steps", 0) or 0),
Expand Down
15 changes: 15 additions & 0 deletions studio/backend/models/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,21 @@ def _check_lora_dropout(cls, v: float) -> float:
ge = 0,
description = "Global gradient norm clipping threshold. Set 0 to disable.",
)
max_grad_value: Optional[float] = Field(
None,
ge = 0,
description = (
"MLX-only elementwise gradient value clipping threshold. "
"If unset, MLX uses its runtime default."
),
)
Comment on lines +328 to +335

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The description for max_grad_value states that MLX uses its runtime default if unset. However, the implementation in worker.py (line 1396) explicitly defaults it to 1.0 if it is None. To avoid confusion and ensure the API documentation matches the implementation, the description should be updated to reflect that it defaults to 1.0 in this environment.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is resolved in the current head. The worker no longer substitutes 1.0: max_grad_value stays None unless the caller sets it, and None reaches MLXTrainingConfig so the trainer applies its own runtime default (per-leaf L2 norm 1.0 after unslothai/unsloth-zoo#684). The schema description now matches the implementation.

cast_norm_output_to_input_dtype: bool = Field(
True,
description = (
"MLX-only: keep norm parameters in fp32 but cast norm outputs "
"back to the incoming activation dtype."
),
)
random_seed: int = Field(42, description = "Random seed")
packing: bool = Field(False, description = "Enable sequence packing")
optim: str = Field("adamw_8bit", description = "Optimizer")
Expand Down
2 changes: 2 additions & 0 deletions studio/backend/routes/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ async def start_training(
"save_steps": request.save_steps,
"weight_decay": request.weight_decay,
"max_grad_norm": request.max_grad_norm,
"max_grad_value": request.max_grad_value,
"cast_norm_output_to_input_dtype": request.cast_norm_output_to_input_dtype,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: There should be max_grad_leaf_norm entry here?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, this was a real gap. The Pydantic schema accepted max_grad_leaf_norm but the route never copied it into the config dict, so REST callers had the value silently dropped (start_training kwargs callers were unaffected). Added the forwarding line in d142420 plus a source-pin test that asserts all three grad clipping fields are forwarded by the route.

"random_seed": request.random_seed,
"packing": request.packing,
"optim": request.optim,
Expand Down
9 changes: 9 additions & 0 deletions studio/backend/tests/test_mlx_training_worker_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,12 @@ def test_mlx_studio_rejects_unknown_optimizer():
def test_mlx_studio_rejects_unknown_scheduler():
with pytest.raises(ValueError, match = "Unsupported LR scheduler for MLX training"):
_normalize_mlx_studio_scheduler("linear_typo")


def test_mlx_studio_keeps_hf_style_tokenizer_dual_purpose():
source = (
Path(__file__).resolve().parents[1] / "core" / "training" / "worker.py"
).read_text()

assert "tokenizer = tokenizer" in source
assert "processor = tokenizer if is_vlm else None" not in source
72 changes: 72 additions & 0 deletions studio/backend/tests/test_training_raw_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,82 @@ def start(self):
model_name = "unsloth/test",
training_type = "LoRA/QLoRA",
max_grad_norm = 0.7,
max_grad_value = 3.0,
)

config = mock_process.call_args.kwargs["kwargs"]["config"]
self.assertEqual(config["max_grad_norm"], 0.7)
self.assertEqual(config["max_grad_value"], 3.0)

def test_training_backend_forwards_random_seed_without_internal_mlx_seed_keys(self):
backend = TrainingBackend()

class DummyProcess:
pid = 12345

def start(self):
return None

class DummyThread:
def start(self):
return None

dummy_queue = object()

with (
patch(
"core.training.training.prepare_gpu_selection",
return_value = ([0], {"selection_mode": "auto"}),
),
patch(
"core.training.training._CTX.Queue",
side_effect = [dummy_queue, dummy_queue],
),
patch(
"core.training.training._CTX.Process", return_value = DummyProcess()
) as mock_process,
patch(
"core.training.training.threading.Thread",
return_value = DummyThread(),
),
):
backend.start_training(
job_id = "test-seed",
model_name = "unsloth/test",
training_type = "LoRA/QLoRA",
random_seed = 1234,
)

config = mock_process.call_args.kwargs["kwargs"]["config"]
self.assertEqual(config["random_seed"], 1234)
self.assertNotIn("model_random_state", config)
self.assertNotIn("lora_random_state", config)

def test_mlx_worker_falls_back_init_seeds_to_random_seed(self):
source = (_BACKEND_ROOT / "core" / "training" / "worker.py").read_text()

self.assertIn('random_seed = config.get("random_seed", 3407)', source)
self.assertIn(
'model_random_state = config.get("model_random_state", random_seed)', source
)
self.assertIn(
'lora_random_state = config.get("lora_random_state", random_seed)', source
)
Comment thread
mmathew23 marked this conversation as resolved.
self.assertIn("random_state = model_random_state", source)
self.assertIn("random_state = lora_random_state", source)
self.assertIn('seed = config.get("random_seed", 3407)', source)

def test_mlx_worker_preserves_null_max_grad_value_for_trainer_default(self):
source = (_BACKEND_ROOT / "core" / "training" / "worker.py").read_text()

self.assertIn(
"max_grad_value = None if max_grad_value is None else float(max_grad_value)",
source,
)
self.assertNotIn(
"max_grad_value = 1.0 if max_grad_value is None else float(max_grad_value)",
source,
)

def test_training_route_forwards_embedding_learning_rate(self):
training_route = _load_route_module(
Expand Down
1 change: 1 addition & 0 deletions studio/frontend/src/features/training/api/mappers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ export function buildTrainingStartPayload(
eval_steps: config.evalSteps,
weight_decay: config.weightDecay,
max_grad_norm: 0.0,
max_grad_value: null,
random_seed: config.randomSeed,
packing: isEmbedding ? false : config.packing,
optim: config.optimizerType,
Expand Down
1 change: 1 addition & 0 deletions studio/frontend/src/features/training/types/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export interface TrainingStartRequest {
eval_steps: number;
weight_decay: number;
max_grad_norm: number;
max_grad_value?: number | null;
random_seed: number;
packing: boolean;
optim: string;
Expand Down
19 changes: 13 additions & 6 deletions tests/studio/run_real_mlx_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
1. Loads `unsloth/gemma-3-270m-it` via FastMLXModel.from_pretrained.
2. Applies LoRA r=8 on q/k/v/o.
3. Computes pre-training loss + grad norm via mx.nn.value_and_grad.
4. Trains 7 deterministic steps on a dataset of the SAME row repeated
4. Trains 30 deterministic steps on a dataset of the SAME row repeated
("<<HELLO!!>> My name is Unsloth!"), with batch_size=2 and
gradient_accumulation_steps=3 so each step processes 6 sequences
and the run sees 42 sequences total.
and the run sees 180 sequences total.
5. Computes post-training loss + grad norm.
6. Generates from "<<HELLO!!>> My name is " and asserts "Unsloth"
appears in the in-memory completion.
Expand Down Expand Up @@ -162,10 +162,9 @@ def _compute_loss_and_grad_norm(model, tokenizer, text: str) -> tuple[float, flo
import mlx.nn as nn
from mlx.utils import tree_flatten

# Match Studio's text dataset path: Studio passes exactly the formatted
# text to the tokenizer and does not append EOS behind the user's back.
ids = list(tokenizer.encode(text))
eos_id = getattr(tokenizer, "eos_token_id", None)
if eos_id is not None:
ids.append(int(eos_id))
if len(ids) < 2:
raise RuntimeError(f"text too short to compute loss: {len(ids)} tokens")

Expand Down Expand Up @@ -390,7 +389,15 @@ def _on_step(
)
if k in train_result
}
assert len(losses_per_step) == 7, f"expected 7 logged steps, got {losses_per_step}"
expected_logged_steps = int(config.max_steps)
assert (
len(losses_per_step) == expected_logged_steps
), f"expected {expected_logged_steps} logged steps, got {losses_per_step}"
if "train_steps" in train_result:
assert int(train_result["train_steps"]) == expected_logged_steps, (
f"expected train_steps={expected_logged_steps}, got "
f"{train_result['train_steps']}"
)
for i, l in enumerate(losses_per_step):
# Allow exact 0.0: fp16 per-step loss underflows to 0.0 after
# the LoRA reaches loss=0 around step ~10 with this fixture +
Expand Down
Loading