-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
MLX training support for Studio on Apple Silicon #5340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
797ae4c
aaa2fab
051b202
da2b371
e5bf387
09fcc55
f9e8416
dcbb50c
9dded6a
085b4f4
83e577c
b0447f6
d51181f
3d9389d
55a3a00
07f4150
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -417,6 +417,55 @@ def _normalize_mlx_studio_scheduler(value): | |||||||||||||||||||||||||||||||||
| return raw | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _resolve_mlx_local_dataset_files(file_paths: list) -> list[str]: | ||||||||||||||||||||||||||||||||||
| """Resolve Studio local dataset uploads without importing the GPU trainer.""" | ||||||||||||||||||||||||||||||||||
| from utils.paths import resolve_dataset_path | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| all_files: list[str] = [] | ||||||||||||||||||||||||||||||||||
| for dataset_file in file_paths or []: | ||||||||||||||||||||||||||||||||||
| file_path = ( | ||||||||||||||||||||||||||||||||||
| dataset_file | ||||||||||||||||||||||||||||||||||
| if os.path.isabs(dataset_file) | ||||||||||||||||||||||||||||||||||
| else str(resolve_dataset_path(dataset_file)) | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| file_path_obj = Path(file_path) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if file_path_obj.is_dir(): | ||||||||||||||||||||||||||||||||||
| parquet_dir = ( | ||||||||||||||||||||||||||||||||||
| file_path_obj / "parquet-files" | ||||||||||||||||||||||||||||||||||
| if (file_path_obj / "parquet-files").exists() | ||||||||||||||||||||||||||||||||||
| else file_path_obj | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| parquet_files = sorted(parquet_dir.glob("*.parquet")) | ||||||||||||||||||||||||||||||||||
| if parquet_files: | ||||||||||||||||||||||||||||||||||
| all_files.extend(str(p) for p in parquet_files) | ||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| candidates: list[Path] = [] | ||||||||||||||||||||||||||||||||||
| for ext in (".json", ".jsonl", ".csv", ".parquet"): | ||||||||||||||||||||||||||||||||||
| candidates.extend(sorted(file_path_obj.glob(f"*{ext}"))) | ||||||||||||||||||||||||||||||||||
| if candidates: | ||||||||||||||||||||||||||||||||||
| all_files.extend(str(c) for c in candidates) | ||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| raise ValueError(f"No supported data files in directory: {file_path_obj}") | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+444
to
+451
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current logic collects files of all supported extensions into a single list. If a directory contains mixed file types (e.g., both
Suggested change
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| all_files.append(str(file_path_obj)) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| return all_files | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _mlx_local_dataset_loader_for_files(files: list[str]) -> str: | ||||||||||||||||||||||||||||||||||
| first_ext = Path(files[0]).suffix.lower() | ||||||||||||||||||||||||||||||||||
| if first_ext in (".json", ".jsonl"): | ||||||||||||||||||||||||||||||||||
| return "json" | ||||||||||||||||||||||||||||||||||
| if first_ext == ".csv": | ||||||||||||||||||||||||||||||||||
| return "csv" | ||||||||||||||||||||||||||||||||||
| if first_ext == ".parquet": | ||||||||||||||||||||||||||||||||||
| return "parquet" | ||||||||||||||||||||||||||||||||||
| raise ValueError(f"Unsupported dataset format: {files[0]}") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _run_mlx_training(event_queue, stop_queue, config): | ||||||||||||||||||||||||||||||||||
| """Self-contained MLX training path for Apple Silicon. | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
@@ -442,16 +491,16 @@ def _send(event_type, **kwargs): | |||||||||||||||||||||||||||||||||
| import mlx.core as mx | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||
| from unsloth_zoo.mlx_loader import FastMLXModel | ||||||||||||||||||||||||||||||||||
| from unsloth_zoo.mlx_trainer import ( | ||||||||||||||||||||||||||||||||||
| from unsloth_zoo.mlx.loader import FastMLXModel | ||||||||||||||||||||||||||||||||||
| from unsloth_zoo.mlx.trainer import ( | ||||||||||||||||||||||||||||||||||
| MLXTrainer, | ||||||||||||||||||||||||||||||||||
| MLXTrainingConfig, | ||||||||||||||||||||||||||||||||||
| train_on_responses_only, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| except ImportError as e: | ||||||||||||||||||||||||||||||||||
| raise ImportError( | ||||||||||||||||||||||||||||||||||
| "Unsloth: MLX training requires unsloth-zoo with the MLX modules " | ||||||||||||||||||||||||||||||||||
| "(unsloth_zoo.mlx_loader / unsloth_zoo.mlx_trainer). Reinstall via " | ||||||||||||||||||||||||||||||||||
| "(unsloth_zoo.mlx.loader / unsloth_zoo.mlx.trainer). Reinstall via " | ||||||||||||||||||||||||||||||||||
| "install.sh on Apple Silicon." | ||||||||||||||||||||||||||||||||||
| ) from e | ||||||||||||||||||||||||||||||||||
| from datasets import load_dataset | ||||||||||||||||||||||||||||||||||
|
|
@@ -572,7 +621,6 @@ def _slice(ds): | |||||||||||||||||||||||||||||||||
| return ds | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _load_local(file_paths): | ||||||||||||||||||||||||||||||||||
| from core.training.trainer import UnslothTrainer | ||||||||||||||||||||||||||||||||||
| from datasets import load_from_disk | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if len(file_paths) == 1: | ||||||||||||||||||||||||||||||||||
|
|
@@ -581,10 +629,10 @@ def _load_local(file_paths): | |||||||||||||||||||||||||||||||||
| (p / "dataset_info.json").exists() or (p / "state.json").exists() | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| return load_from_disk(str(p)) | ||||||||||||||||||||||||||||||||||
| all_files = UnslothTrainer._resolve_local_files(file_paths) | ||||||||||||||||||||||||||||||||||
| all_files = _resolve_mlx_local_dataset_files(file_paths) | ||||||||||||||||||||||||||||||||||
| if not all_files: | ||||||||||||||||||||||||||||||||||
| raise ValueError("No local dataset files found") | ||||||||||||||||||||||||||||||||||
| loader = UnslothTrainer._loader_for_files(all_files) | ||||||||||||||||||||||||||||||||||
| loader = _mlx_local_dataset_loader_for_files(all_files) | ||||||||||||||||||||||||||||||||||
| return load_dataset(loader, data_files = all_files, split = "train") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if hf_dataset: | ||||||||||||||||||||||||||||||||||
|
|
@@ -718,6 +766,10 @@ def _fmt_progress(status_message = "", **_kw): | |||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| eval_steps_val = int(eval_steps_val) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # MLX: value-clip grads to [-5, 5]; norm clipping disabled for compile-friendliness. | ||||||||||||||||||||||||||||||||||
| max_grad_norm = 0.0 | ||||||||||||||||||||||||||||||||||
| max_grad_value = 5.0 # TODO: expose MLX grad-clip in Studio UI for power users | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+769
to
+771
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Even though this commit adds Useful? React with 👍 / 👎. |
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| trainer = MLXTrainer( | ||||||||||||||||||||||||||||||||||
| model = model, | ||||||||||||||||||||||||||||||||||
| tokenizer = tokenizer, | ||||||||||||||||||||||||||||||||||
|
|
@@ -732,6 +784,8 @@ def _fmt_progress(status_message = "", **_kw): | |||||||||||||||||||||||||||||||||
| lr_scheduler_type = lr_scheduler_type, | ||||||||||||||||||||||||||||||||||
| optim = optim_name, | ||||||||||||||||||||||||||||||||||
| weight_decay = float(config.get("weight_decay", 0.001) or 0.001), | ||||||||||||||||||||||||||||||||||
| max_grad_norm = max_grad_norm, | ||||||||||||||||||||||||||||||||||
| max_grad_value = max_grad_value, | ||||||||||||||||||||||||||||||||||
| logging_steps = 1, | ||||||||||||||||||||||||||||||||||
| max_seq_length = max_seq_length, | ||||||||||||||||||||||||||||||||||
| seed = config.get("random_seed", 3407), | ||||||||||||||||||||||||||||||||||
|
|
@@ -820,7 +874,17 @@ def _fmt_progress(status_message = "", **_kw): | |||||||||||||||||||||||||||||||||
| # ── 9. Real-time progress callback ── | ||||||||||||||||||||||||||||||||||
| _send("status", status_message = f"Training {model_name}...") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _on_step(step, total, loss, lr, tok_s, peak_gb, elapsed, num_tokens): | ||||||||||||||||||||||||||||||||||
| def _on_step( | ||||||||||||||||||||||||||||||||||
| step, | ||||||||||||||||||||||||||||||||||
| total, | ||||||||||||||||||||||||||||||||||
| loss, | ||||||||||||||||||||||||||||||||||
| lr, | ||||||||||||||||||||||||||||||||||
| tok_s, | ||||||||||||||||||||||||||||||||||
| peak_gb, | ||||||||||||||||||||||||||||||||||
| elapsed, | ||||||||||||||||||||||||||||||||||
| num_tokens, | ||||||||||||||||||||||||||||||||||
| grad_norm = None, | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| eta = (elapsed / step * (total - step)) if step > 0 else 0 | ||||||||||||||||||||||||||||||||||
| _send( | ||||||||||||||||||||||||||||||||||
| "progress", | ||||||||||||||||||||||||||||||||||
|
|
@@ -831,7 +895,7 @@ def _on_step(step, total, loss, lr, tok_s, peak_gb, elapsed, num_tokens): | |||||||||||||||||||||||||||||||||
| total_steps = total, | ||||||||||||||||||||||||||||||||||
| elapsed_seconds = elapsed, | ||||||||||||||||||||||||||||||||||
| eta_seconds = max(0, eta), | ||||||||||||||||||||||||||||||||||
| grad_norm = None, | ||||||||||||||||||||||||||||||||||
| grad_norm = grad_norm, | ||||||||||||||||||||||||||||||||||
| num_tokens = num_tokens, | ||||||||||||||||||||||||||||||||||
| eval_loss = None, | ||||||||||||||||||||||||||||||||||
| status_message = None, | ||||||||||||||||||||||||||||||||||
|
|
@@ -846,6 +910,11 @@ def _on_step(step, total, loss, lr, tok_s, peak_gb, elapsed, num_tokens): | |||||||||||||||||||||||||||||||||
| "train/tokens_per_sec": tok_s, | ||||||||||||||||||||||||||||||||||
| "train/peak_gb": peak_gb, | ||||||||||||||||||||||||||||||||||
| "train/num_tokens": num_tokens, | ||||||||||||||||||||||||||||||||||
| **( | ||||||||||||||||||||||||||||||||||
| {"train/grad_norm": grad_norm} | ||||||||||||||||||||||||||||||||||
| if grad_norm is not None | ||||||||||||||||||||||||||||||||||
| else {} | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||
| step = step, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
@@ -857,6 +926,8 @@ def _on_step(step, total, loss, lr, tok_s, peak_gb, elapsed, num_tokens): | |||||||||||||||||||||||||||||||||
| tb_writer.add_scalar("train/learning_rate", lr, step) | ||||||||||||||||||||||||||||||||||
| tb_writer.add_scalar("train/tokens_per_sec", tok_s, step) | ||||||||||||||||||||||||||||||||||
| tb_writer.add_scalar("train/peak_gb", peak_gb, step) | ||||||||||||||||||||||||||||||||||
| if grad_norm is not None: | ||||||||||||||||||||||||||||||||||
| tb_writer.add_scalar("train/grad_norm", grad_norm, step) | ||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -262,6 +262,19 @@ def _check_lora_dropout(cls, v: float) -> float: | |
| max_steps: Optional[int] = Field(None, description = "Maximum training steps") | ||
| save_steps: int = Field(100, description = "Steps between checkpoints") | ||
| weight_decay: float = Field(0.001, description = "Weight decay") | ||
| max_grad_norm: float = Field( | ||
| 0.0, | ||
| ge = 0, | ||
| description = "Global gradient norm clipping threshold. Set 0 to disable.", | ||
| ) | ||
| max_grad_value: Optional[float] = Field( | ||
| None, | ||
| ge = 0, | ||
| description = ( | ||
| "Elementwise gradient value clipping threshold. Set 0 to disable. " | ||
| "If omitted, MLX defaults to 1 unless max_grad_norm is set." | ||
| ), | ||
| ) | ||
|
Comment on lines
+265
to
+277
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default values for |
||
| random_seed: int = Field(42, description = "Random seed") | ||
| packing: bool = Field(False, description = "Enable sequence packing") | ||
| optim: str = Field("adamw_8bit", description = "Optimizer") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check for
parquet-filesshould use.is_dir()instead of.exists(). If a file namedparquet-filesexists in the directory, the subsequent.glob()call will fail or return no results. Additionally, consider moving this entire dataset resolution logic to a shared utility file (e.g.,utils/paths.py) to avoid duplication withUnslothTrainerwhile still keeping it accessible to the MLX path without importing Torch.