Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions studio/backend/core/inference/mlx_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ def load_model(
)

try:
from unsloth_zoo.mlx_loader import FastMLXModel
from unsloth_zoo.mlx.loader import FastMLXModel
except ImportError as e:
raise ImportError(
"Unsloth: MLX inference requires unsloth-zoo with the MLX modules "
"(unsloth_zoo.mlx_loader). Reinstall via install.sh on Apple Silicon."
"(unsloth_zoo.mlx.loader). Reinstall via install.sh on Apple Silicon."
) from e

model, tokenizer_or_processor = FastMLXModel.from_pretrained(
Expand Down
2 changes: 2 additions & 0 deletions studio/backend/core/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def start_training(self, job_id: str, **kwargs) -> bool:
"max_steps": kwargs.get("max_steps", 0),
"save_steps": kwargs.get("save_steps", 0),
"weight_decay": kwargs.get("weight_decay", 0.001),
"max_grad_norm": kwargs.get("max_grad_norm", 0.0),
"max_grad_value": kwargs.get("max_grad_value"),
"random_seed": kwargs.get("random_seed", 3407),
"packing": kwargs.get("packing", False),
"optim": kwargs.get("optim", "adamw_8bit"),
Expand Down
87 changes: 79 additions & 8 deletions studio/backend/core/training/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,55 @@ def _normalize_mlx_studio_scheduler(value):
return raw


def _resolve_mlx_local_dataset_files(file_paths: list) -> list[str]:
"""Resolve Studio local dataset uploads without importing the GPU trainer."""
from utils.paths import resolve_dataset_path

all_files: list[str] = []
for dataset_file in file_paths or []:
file_path = (
dataset_file
if os.path.isabs(dataset_file)
else str(resolve_dataset_path(dataset_file))
)
file_path_obj = Path(file_path)

if file_path_obj.is_dir():
parquet_dir = (
file_path_obj / "parquet-files"
if (file_path_obj / "parquet-files").exists()
else file_path_obj
)
Comment on lines +434 to +438

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check for parquet-files should use .is_dir() instead of .exists(). If a file named parquet-files exists in the directory, the subsequent .glob() call will fail or return no results. Additionally, consider moving this entire dataset resolution logic to a shared utility file (e.g., utils/paths.py) to avoid duplication with UnslothTrainer while still keeping it accessible to the MLX path without importing Torch.

            parquet_dir = file_path_obj / "parquet-files"
            if not parquet_dir.is_dir():
                parquet_dir = file_path_obj

parquet_files = sorted(parquet_dir.glob("*.parquet"))
if parquet_files:
all_files.extend(str(p) for p in parquet_files)
continue

candidates: list[Path] = []
for ext in (".json", ".jsonl", ".csv", ".parquet"):
candidates.extend(sorted(file_path_obj.glob(f"*{ext}")))
if candidates:
all_files.extend(str(c) for c in candidates)
continue

raise ValueError(f"No supported data files in directory: {file_path_obj}")
Comment on lines +444 to +451

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic collects files of all supported extensions into a single list. If a directory contains mixed file types (e.g., both .json and .csv), load_dataset will fail because it is called with a single loader type determined by the first file's extension. The loop should break after finding the first extension that matches any files to ensure consistency within the directory.

Suggested change
candidates: list[Path] = []
for ext in (".json", ".jsonl", ".csv", ".parquet"):
candidates.extend(sorted(file_path_obj.glob(f"*{ext}")))
if candidates:
all_files.extend(str(c) for c in candidates)
continue
raise ValueError(f"No supported data files in directory: {file_path_obj}")
for ext in (".json", ".jsonl", ".csv", ".parquet"):
matches = sorted(file_path_obj.glob(f"*{ext}"))
if matches:
all_files.extend(str(m) for m in matches)
break
else:
raise ValueError(f"No supported data files in directory: {file_path_obj}")
continue


all_files.append(str(file_path_obj))

return all_files


def _mlx_local_dataset_loader_for_files(files: list[str]) -> str:
first_ext = Path(files[0]).suffix.lower()
if first_ext in (".json", ".jsonl"):
return "json"
if first_ext == ".csv":
return "csv"
if first_ext == ".parquet":
return "parquet"
raise ValueError(f"Unsupported dataset format: {files[0]}")


def _run_mlx_training(event_queue, stop_queue, config):
"""Self-contained MLX training path for Apple Silicon.

Expand All @@ -442,16 +491,16 @@ def _send(event_type, **kwargs):
import mlx.core as mx

try:
from unsloth_zoo.mlx_loader import FastMLXModel
from unsloth_zoo.mlx_trainer import (
from unsloth_zoo.mlx.loader import FastMLXModel
from unsloth_zoo.mlx.trainer import (
MLXTrainer,
MLXTrainingConfig,
train_on_responses_only,
)
except ImportError as e:
raise ImportError(
"Unsloth: MLX training requires unsloth-zoo with the MLX modules "
"(unsloth_zoo.mlx_loader / unsloth_zoo.mlx_trainer). Reinstall via "
"(unsloth_zoo.mlx.loader / unsloth_zoo.mlx.trainer). Reinstall via "
"install.sh on Apple Silicon."
) from e
from datasets import load_dataset
Expand Down Expand Up @@ -572,7 +621,6 @@ def _slice(ds):
return ds

def _load_local(file_paths):
from core.training.trainer import UnslothTrainer
from datasets import load_from_disk

if len(file_paths) == 1:
Expand All @@ -581,10 +629,10 @@ def _load_local(file_paths):
(p / "dataset_info.json").exists() or (p / "state.json").exists()
):
return load_from_disk(str(p))
all_files = UnslothTrainer._resolve_local_files(file_paths)
all_files = _resolve_mlx_local_dataset_files(file_paths)
if not all_files:
raise ValueError("No local dataset files found")
loader = UnslothTrainer._loader_for_files(all_files)
loader = _mlx_local_dataset_loader_for_files(all_files)
return load_dataset(loader, data_files = all_files, split = "train")

if hf_dataset:
Expand Down Expand Up @@ -718,6 +766,10 @@ def _fmt_progress(status_message = "", **_kw):
else:
eval_steps_val = int(eval_steps_val)

# MLX: value-clip grads to [-5, 5]; norm clipping disabled for compile-friendliness.
max_grad_norm = 0.0
max_grad_value = 5.0 # TODO: expose MLX grad-clip in Studio UI for power users
Comment on lines +769 to +771

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Honor configured MLX gradient clipping controls

Even though this commit adds max_grad_norm/max_grad_value to the API payload and subprocess config, the MLX worker overwrites them here with fixed values before constructing MLXTrainingConfig. In API-driven runs that set max_grad_norm or disable value clipping with max_grad_value=0, those requested settings are silently ignored and every run uses norm clipping disabled plus value clipping at 5.0. This is fresh evidence beyond the earlier config-forwarding comment because the fields are now forwarded, but these hard-coded assignments still discard them.

Useful? React with 👍 / 👎.


trainer = MLXTrainer(
model = model,
tokenizer = tokenizer,
Expand All @@ -732,6 +784,8 @@ def _fmt_progress(status_message = "", **_kw):
lr_scheduler_type = lr_scheduler_type,
optim = optim_name,
weight_decay = float(config.get("weight_decay", 0.001) or 0.001),
max_grad_norm = max_grad_norm,
max_grad_value = max_grad_value,
logging_steps = 1,
max_seq_length = max_seq_length,
seed = config.get("random_seed", 3407),
Expand Down Expand Up @@ -820,7 +874,17 @@ def _fmt_progress(status_message = "", **_kw):
# ── 9. Real-time progress callback ──
_send("status", status_message = f"Training {model_name}...")

def _on_step(step, total, loss, lr, tok_s, peak_gb, elapsed, num_tokens):
def _on_step(
step,
total,
loss,
lr,
tok_s,
peak_gb,
elapsed,
num_tokens,
grad_norm = None,
):
eta = (elapsed / step * (total - step)) if step > 0 else 0
_send(
"progress",
Expand All @@ -831,7 +895,7 @@ def _on_step(step, total, loss, lr, tok_s, peak_gb, elapsed, num_tokens):
total_steps = total,
elapsed_seconds = elapsed,
eta_seconds = max(0, eta),
grad_norm = None,
grad_norm = grad_norm,
num_tokens = num_tokens,
eval_loss = None,
status_message = None,
Expand All @@ -846,6 +910,11 @@ def _on_step(step, total, loss, lr, tok_s, peak_gb, elapsed, num_tokens):
"train/tokens_per_sec": tok_s,
"train/peak_gb": peak_gb,
"train/num_tokens": num_tokens,
**(
{"train/grad_norm": grad_norm}
if grad_norm is not None
else {}
),
},
step = step,
)
Expand All @@ -857,6 +926,8 @@ def _on_step(step, total, loss, lr, tok_s, peak_gb, elapsed, num_tokens):
tb_writer.add_scalar("train/learning_rate", lr, step)
tb_writer.add_scalar("train/tokens_per_sec", tok_s, step)
tb_writer.add_scalar("train/peak_gb", peak_gb, step)
if grad_norm is not None:
tb_writer.add_scalar("train/grad_norm", grad_norm, step)
except Exception:
pass

Expand Down
13 changes: 13 additions & 0 deletions studio/backend/models/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,19 @@ def _check_lora_dropout(cls, v: float) -> float:
max_steps: Optional[int] = Field(None, description = "Maximum training steps")
save_steps: int = Field(100, description = "Steps between checkpoints")
weight_decay: float = Field(0.001, description = "Weight decay")
max_grad_norm: float = Field(
0.0,
ge = 0,
description = "Global gradient norm clipping threshold. Set 0 to disable.",
)
max_grad_value: Optional[float] = Field(
None,
ge = 0,
description = (
"Elementwise gradient value clipping threshold. Set 0 to disable. "
"If omitted, MLX defaults to 1 unless max_grad_norm is set."
),
)
Comment on lines +265 to +277

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default values for max_grad_norm (0.0) and max_grad_value (3.0) appear to be optimized for MLX on Apple Silicon but are now applied as global defaults for the entire Training API. This changes the default behavior for non-MLX (Torch) training runs, which typically default to max_grad_norm=1.0 and no elementwise clipping. Consider keeping these defaults platform-neutral and applying MLX-specific defaults within the MLX worker logic (which already has fallback logic for these keys).

random_seed: int = Field(42, description = "Random seed")
packing: bool = Field(False, description = "Enable sequence packing")
optim: str = Field("adamw_8bit", description = "Optimizer")
Expand Down
2 changes: 2 additions & 0 deletions studio/backend/routes/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ async def start_training(
"max_steps": request.max_steps,
"save_steps": request.save_steps,
"weight_decay": request.weight_decay,
"max_grad_norm": request.max_grad_norm,
"max_grad_value": request.max_grad_value,
"random_seed": request.random_seed,
"packing": request.packing,
"optim": request.optim,
Expand Down
9 changes: 6 additions & 3 deletions studio/backend/tests/test_mlx_inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,14 @@ def from_pretrained(*args, **kwargs):
return _DummyModel(), _DummyTokenizer()

unsloth_zoo_pkg = types.ModuleType("unsloth_zoo")
mlx_loader = types.ModuleType("unsloth_zoo.mlx_loader")
mlx_pkg = types.ModuleType("unsloth_zoo.mlx")
mlx_loader = types.ModuleType("unsloth_zoo.mlx.loader")
mlx_loader.FastMLXModel = _FastMLXModel
unsloth_zoo_pkg.mlx_loader = mlx_loader
unsloth_zoo_pkg.mlx = mlx_pkg
mlx_pkg.loader = mlx_loader
monkeypatch.setitem(sys.modules, "unsloth_zoo", unsloth_zoo_pkg)
monkeypatch.setitem(sys.modules, "unsloth_zoo.mlx_loader", mlx_loader)
monkeypatch.setitem(sys.modules, "unsloth_zoo.mlx", mlx_pkg)
monkeypatch.setitem(sys.modules, "unsloth_zoo.mlx.loader", mlx_loader)


def test_mlx_inference_text_load_forwards_studio_settings(monkeypatch):
Expand Down
44 changes: 44 additions & 0 deletions studio/backend/tests/test_training_raw_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,50 @@ def start(self):
self.assertTrue(config["load_in_4bit"])
self.assertEqual(config["embedding_learning_rate"], 1e-5)

def test_training_backend_forwards_grad_clipping_controls(self):
backend = TrainingBackend()

class DummyProcess:
pid = 12345

def start(self):
return None

class DummyThread:
def start(self):
return None

dummy_queue = object()

with (
patch(
"core.training.training.prepare_gpu_selection",
return_value = ([0], {"selection_mode": "auto"}),
),
patch(
"core.training.training._CTX.Queue",
side_effect = [dummy_queue, dummy_queue],
),
patch(
"core.training.training._CTX.Process", return_value = DummyProcess()
) as mock_process,
patch(
"core.training.training.threading.Thread",
return_value = DummyThread(),
),
):
backend.start_training(
job_id = "test-grad-clip",
model_name = "unsloth/test",
training_type = "LoRA/QLoRA",
max_grad_norm = 0.7,
max_grad_value = 0.0,
)

config = mock_process.call_args.kwargs["kwargs"]["config"]
self.assertEqual(config["max_grad_norm"], 0.7)
self.assertEqual(config["max_grad_value"], 0.0)

def test_training_route_forwards_embedding_learning_rate(self):
training_route = _load_route_module(
"training_route_module_raw_support",
Expand Down
25 changes: 24 additions & 1 deletion studio/backend/utils/datasets/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,23 @@
{}"""


def _is_mlx_runtime() -> bool:
try:
from unsloth_zoo.mlx import is_mlx_available
except ImportError:
return False
return is_mlx_available()


def _chat_template_kwargs() -> dict:
if not _is_mlx_runtime():
return {}
return {
"patch_saving": False,
"use_zoo_tokenizer_patch": True,
}


def get_tokenizer_chat_template(tokenizer, model_name):
"""
Gets appropriate chat template for tokenizer based on model.
Expand Down Expand Up @@ -60,6 +77,7 @@ def get_tokenizer_chat_template(tokenizer, model_name):
tokenizer = get_chat_template(
tokenizer,
chat_template = matched_template,
**_chat_template_kwargs(),
)
except Exception as e:
logger.info(f"⚠️ Failed to apply Unsloth template '{matched_template}': {e}")
Expand All @@ -79,6 +97,7 @@ def get_tokenizer_chat_template(tokenizer, model_name):
tokenizer = get_chat_template(
tokenizer,
chat_template = "chatml",
**_chat_template_kwargs(),
)
except Exception as e:
logger.info(f"⚠️ Failed to apply default ChatML template: {e}")
Expand Down Expand Up @@ -255,7 +274,11 @@ def _apply_custom_mapping(examples):
if not (hasattr(tokenizer, 'chat_template') and tokenizer.chat_template):
try:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(tokenizer, chat_template = "alpaca")
tokenizer = get_chat_template(
tokenizer,
chat_template = "alpaca",
**_chat_template_kwargs(),
)
logger.info(f"📝 Set alpaca chat template on tokenizer for model saving")
except Exception as e:
logger.info(f"⚠️ Could not set alpaca template on tokenizer: {e}")
Expand Down
1 change: 1 addition & 0 deletions studio/frontend/src/features/training/api/mappers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ export function buildTrainingStartPayload(
save_steps: config.saveSteps,
eval_steps: config.evalSteps,
weight_decay: config.weightDecay,
max_grad_norm: 0.0,
random_seed: config.randomSeed,
packing: isEmbedding ? false : config.packing,
optim: config.optimizerType,
Expand Down
1 change: 1 addition & 0 deletions studio/frontend/src/features/training/types/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export interface TrainingStartRequest {
save_steps: number;
eval_steps: number;
weight_decay: number;
max_grad_norm: number;
random_seed: number;
packing: boolean;
optim: string;
Expand Down
6 changes: 3 additions & 3 deletions tests/studio/run_real_mlx_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def cmd_train(args) -> int:
workdir.mkdir(parents = True, exist_ok = True)

import mlx.core as mx
from unsloth_zoo.mlx_loader import FastMLXModel
from unsloth_zoo.mlx_trainer import MLXTrainer, MLXTrainingConfig
from unsloth_zoo.mlx.loader import FastMLXModel
from unsloth_zoo.mlx.trainer import MLXTrainer, MLXTrainingConfig

hf_token = os.environ.get("HF_TOKEN") or None

Expand Down Expand Up @@ -440,7 +440,7 @@ def cmd_reload(args) -> int:
return _reload_gguf(save_dir, metrics)

import mlx.core as mx
from unsloth_zoo.mlx_loader import FastMLXModel
from unsloth_zoo.mlx.loader import FastMLXModel
from mlx_lm import generate

hf_token = os.environ.get("HF_TOKEN") or None
Expand Down
Loading
Loading