diff --git a/.gitignore b/.gitignore index b6786ee655..2f23f18d65 100644 --- a/.gitignore +++ b/.gitignore @@ -228,3 +228,4 @@ setup_leo.sh server.pid *.log package-lock.json +llama.cpp/ diff --git a/install.sh b/install.sh index 1d21117d16..c928d2ae54 100755 --- a/install.sh +++ b/install.sh @@ -1721,6 +1721,12 @@ else fi fi +# ── Install mlx-vlm on Apple Silicon (optional, for VLM training) ── +if [ "$OS" = "macos" ] && [ "$_ARCH" = "arm64" ]; then + substep "installing mlx-vlm (VLM training support)..." + run_install_cmd "install mlx-vlm" uv pip install --python "$_VENV_PY" mlx-vlm +fi + # ── Run studio setup ── tauri_log "STEP" "Running Studio setup" # When --local, use the repo's own setup.sh directly. diff --git a/pyproject.toml b/pyproject.toml index c2f884e192..5687ea12f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ huggingfacenotorch = [ ] huggingface = [ "unsloth[huggingfacenotorch]", - "unsloth_zoo>=2026.5.1", + "unsloth_zoo>=2026.4.8", "torchvision", "unsloth[triton]", ] @@ -579,7 +579,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3 ; ('linux' in sys_platform)", ] colab-new = [ - "unsloth_zoo>=2026.5.1", + "unsloth_zoo>=2026.4.8", "packaging", "tyro", "transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.5.0", diff --git a/studio/backend/core/export/export.py b/studio/backend/core/export/export.py index 6fee5a38f7..4ab95d896f 100644 --- a/studio/backend/core/export/export.py +++ b/studio/backend/core/export/export.py @@ -9,16 +9,14 @@ import glob import json import structlog +import tempfile from loggers import get_logger import os import shutil from pathlib import Path from typing import Optional, Tuple, List -from peft import PeftModel, PeftModelForCausalLM -from unsloth import FastLanguageModel, FastVisionModel +from unsloth import FastLanguageModel, FastVisionModel, _IS_MLX from huggingface_hub import HfApi, ModelCard -from transformers.modeling_utils import PushToHubMixin -import torch from utils.hardware import clear_gpu_cache from utils.models import is_vision_model, get_base_model_from_lora @@ -26,6 +24,12 @@ from utils.paths import ensure_dir, outputs_root, resolve_export_dir, resolve_output_dir from core.inference import get_inference_backend +# GPU-only imports — guarded for Apple Silicon where these aren't needed +if not _IS_MLX: + from peft import PeftModel, PeftModelForCausalLM + from transformers.modeling_utils import PushToHubMixin + import torch + logger = get_logger(__name__) _LLAMA_CPP_SCRIPTS_WARNING_EMITTED = False @@ -225,7 +229,7 @@ def load_checkpoint( model, tokenizer = FastModel.from_pretrained( model_name = checkpoint_path, max_seq_length = max_seq_length, - dtype = torch.float32, + dtype = None if _IS_MLX else torch.float32, load_in_4bit = False, trust_remote_code = trust_remote_code, ) @@ -262,8 +266,12 @@ def load_checkpoint( trust_remote_code = trust_remote_code, ) - # Check if PEFT model - self.is_peft = isinstance(model, (PeftModel, PeftModelForCausalLM)) + # Check if PEFT / LoRA model + if _IS_MLX: + # MLX doesn't use PeftModel — detect LoRA via adapter_config.json + self.is_peft = adapter_config.exists() + else: + self.is_peft = isinstance(model, (PeftModel, PeftModelForCausalLM)) # Store loaded model self.current_model = model @@ -325,9 +333,7 @@ def export_merged_model( private: Whether to make the repo private Returns: - Tuple of (success, message, output_path). output_path is the - resolved absolute on-disk directory of the saved model when - ``save_directory`` was set, else None. + Tuple of (success: bool, message: str, output_path: Optional[str]) """ if not self.current_model or not self.current_tokenizer: return False, "No model loaded. Please select a checkpoint first.", None @@ -341,14 +347,17 @@ def export_merged_model( output_path: Optional[str] = None try: - # Determine save method - if format_type == "4-bit (FP4)": - save_method = "merged_4bit_forced" - elif self._audio_type == "whisper": - # Whisper uses save_method=None for local 16-bit merged save - save_method = None - else: # 16-bit (FP16) - save_method = "merged_16bit" + if _IS_MLX: + mlx_save_method = ( + "merged_4bit" if format_type == "4-bit (FP4)" else "merged_16bit" + ) + else: + if format_type == "4-bit (FP4)": + save_method = "merged_4bit_forced" + elif self._audio_type == "whisper": + save_method = None + else: + save_method = "merged_16bit" # Save locally if requested if save_directory: @@ -356,11 +365,17 @@ def export_merged_model( logger.info(f"Saving merged model locally to: {save_directory}") ensure_dir(Path(save_directory)) - self.current_model.save_pretrained_merged( - save_directory, self.current_tokenizer, save_method = save_method - ) + if _IS_MLX: + self.current_model.save_pretrained_merged( + save_directory, + self.current_tokenizer, + save_method = mlx_save_method, + ) + else: + self.current_model.save_pretrained_merged( + save_directory, self.current_tokenizer, save_method = save_method + ) - # Write export metadata so the Chat page can identify the base model self._write_export_metadata(save_directory) logger.info(f"Model saved successfully to {save_directory}") output_path = str(Path(save_directory).resolve()) @@ -376,17 +391,40 @@ def export_merged_model( logger.info(f"Pushing merged model to Hub: {repo_id}") - # Whisper uses save_method=None for local but "merged_16bit" for hub push - hub_save_method = ( - save_method if save_method is not None else "merged_16bit" - ) - self.current_model.push_to_hub_merged( - repo_id, - self.current_tokenizer, - save_method = hub_save_method, - token = hf_token, - private = private, - ) + if _IS_MLX: + if save_directory: + self.current_model.push_to_hub_merged( + repo_id, + self.current_tokenizer, + save_directory = save_directory, + token = hf_token, + private = private, + ) + else: + with tempfile.TemporaryDirectory() as tmp_dir: + self.current_model.save_pretrained_merged( + tmp_dir, + self.current_tokenizer, + save_method = mlx_save_method, + ) + self.current_model.push_to_hub_merged( + repo_id, + self.current_tokenizer, + save_directory = tmp_dir, + token = hf_token, + private = private, + ) + else: + hub_save_method = ( + save_method if save_method is not None else "merged_16bit" + ) + self.current_model.push_to_hub_merged( + repo_id, + self.current_tokenizer, + save_method = hub_save_method, + token = hf_token, + private = private, + ) logger.info(f"Model pushed successfully to {repo_id}") return True, "Model exported successfully", output_path @@ -411,9 +449,7 @@ def export_base_model( Export base model (for non-PEFT models). Returns: - Tuple of (success, message, output_path). output_path is the - resolved absolute on-disk directory of the saved model when - ``save_directory`` was set, else None. + Tuple of (success: bool, message: str, output_path: Optional[str]) """ if not self.current_model or not self.current_tokenizer: return False, "No model loaded. Please select a checkpoint first.", None @@ -433,8 +469,16 @@ def export_base_model( logger.info(f"Saving base model locally to: {save_directory}") ensure_dir(Path(save_directory)) - self.current_model.save_pretrained(save_directory) - self.current_tokenizer.save_pretrained(save_directory) + if _IS_MLX: + # MLX: save_pretrained_merged handles non-LoRA models too + # (fuse() is a no-op when there are no LoRA layers) + self.current_model.save_pretrained_merged( + save_directory, + self.current_tokenizer, + ) + else: + self.current_model.save_pretrained(save_directory) + self.current_tokenizer.save_pretrained(save_directory) # Write export metadata so the Chat page can identify the base model self._write_export_metadata(save_directory) @@ -452,44 +496,73 @@ def export_base_model( logger.info(f"Pushing base model to Hub: {repo_id}") - # Get base model name from request or model config - base_model = ( - base_model_id - or self.current_model.config._name_or_path - or "unknown" - ) - - # Create repo - hf_api = HfApi(token = hf_token) - repo_id = PushToHubMixin._create_repo( - PushToHubMixin, - repo_id = repo_id, - private = private, - token = hf_token, - ) - username = repo_id.split("/")[0] - - # Create and push model card - content = MODEL_CARD.format( - username = username, - base_model = base_model, - model_type = self.current_model.config.model_type, - method = "", - extra = "unsloth", - ) - card = ModelCard(content) - card.push_to_hub( - repo_id, token = hf_token, commit_message = "Unsloth Model Card" - ) + if _IS_MLX: + if save_directory: + self.current_model.push_to_hub_merged( + repo_id, + self.current_tokenizer, + save_directory = save_directory, + token = hf_token, + private = private, + ) + else: + with tempfile.TemporaryDirectory() as tmp_dir: + self.current_model.save_pretrained_merged( + tmp_dir, + self.current_tokenizer, + ) + self.current_model.push_to_hub_merged( + repo_id, + self.current_tokenizer, + save_directory = tmp_dir, + token = hf_token, + private = private, + ) + else: + # Get base model name from request or model config + base_model = ( + base_model_id + or self.current_model.config._name_or_path + or "unknown" + ) - # Upload model files - if save_directory: - hf_api.upload_folder( - folder_path = save_directory, repo_id = repo_id, repo_type = "model" + # Create repo + hf_api = HfApi(token = hf_token) + repo_id = PushToHubMixin._create_repo( + PushToHubMixin, + repo_id = repo_id, + private = private, + token = hf_token, ) - logger.info(f"Model pushed successfully to {repo_id}") - else: - return False, "Local save directory required for Hub upload", None + username = repo_id.split("/")[0] + + # Create and push model card + content = MODEL_CARD.format( + username = username, + base_model = base_model, + model_type = self.current_model.config.model_type, + method = "", + extra = "unsloth", + ) + card = ModelCard(content) + card.push_to_hub( + repo_id, token = hf_token, commit_message = "Unsloth Model Card" + ) + + # Upload model files + if save_directory: + hf_api.upload_folder( + folder_path = save_directory, + repo_id = repo_id, + repo_type = "model", + ) + logger.info(f"Model pushed successfully to {repo_id}") + else: + return ( + False, + "Local save directory required for Hub upload", + None, + ) return True, "Model exported successfully", output_path @@ -519,9 +592,7 @@ def export_gguf( hf_token: Hugging Face token Returns: - Tuple of (success, message, output_path). output_path is the - resolved absolute on-disk directory containing the .gguf - files when ``save_directory`` was set, else None. + Tuple of (success: bool, message: str, output_path: Optional[str]) """ if not self.current_model or not self.current_tokenizer: return False, "No model loaded. Please select a checkpoint first.", None @@ -692,9 +763,7 @@ def export_lora_adapter( Export LoRA adapter only (not merged). Returns: - Tuple of (success, message, output_path). output_path is the - resolved absolute on-disk directory of the saved adapter - when ``save_directory`` was set, else None. + Tuple of (success: bool, message: str, output_path: Optional[str]) """ if not self.current_model or not self.current_tokenizer: return False, "No model loaded. Please select a checkpoint first.", None @@ -710,8 +779,13 @@ def export_lora_adapter( logger.info(f"Saving LoRA adapter locally to: {save_directory}") ensure_dir(Path(save_directory)) - self.current_model.save_pretrained(save_directory) - self.current_tokenizer.save_pretrained(save_directory) + if _IS_MLX: + # MLX: save adapters.safetensors + tokenizer files + self.current_model.save_lora_adapters(save_directory) + self.current_tokenizer.save_pretrained(save_directory) + else: + self.current_model.save_pretrained(save_directory) + self.current_tokenizer.save_pretrained(save_directory) logger.info(f"Adapter saved successfully to {save_directory}") output_path = str(Path(save_directory).resolve()) @@ -726,10 +800,24 @@ def export_lora_adapter( logger.info(f"Pushing LoRA adapter to Hub: {repo_id}") - self.current_model.push_to_hub(repo_id, token = hf_token, private = private) - self.current_tokenizer.push_to_hub( - repo_id, token = hf_token, private = private - ) + if _IS_MLX: + with tempfile.TemporaryDirectory() as tmp_dir: + self.current_model.save_lora_adapters(tmp_dir) + self.current_tokenizer.save_pretrained(tmp_dir) + hf_api = HfApi(token = hf_token) + hf_api.create_repo(repo_id, private = private, exist_ok = True) + hf_api.upload_folder( + folder_path = tmp_dir, + repo_id = repo_id, + repo_type = "model", + ) + else: + self.current_model.push_to_hub( + repo_id, token = hf_token, private = private + ) + self.current_tokenizer.push_to_hub( + repo_id, token = hf_token, private = private + ) logger.info(f"Adapter pushed successfully to {repo_id}") return True, "LoRA adapter exported successfully", output_path diff --git a/studio/backend/core/inference/mlx_inference.py b/studio/backend/core/inference/mlx_inference.py new file mode 100644 index 0000000000..1d2b03ecb9 --- /dev/null +++ b/studio/backend/core/inference/mlx_inference.py @@ -0,0 +1,395 @@ +# SPDX-License-Identifier: AGPL-3.0-only +"""MLX inference backend for Apple Silicon. + +Drop-in replacement for InferenceBackend — same interface, uses mlx-lm/mlx-vlm +instead of torch/transformers for model loading and generation. +""" + +import threading +from typing import Optional, Generator +from loggers import get_logger + +logger = get_logger(__name__) + + +class MLXInferenceBackend: + def __init__(self): + self.models = {} + self.active_model_name = None + self.loading_models = set() + self.loaded_local_models = [] + self.device = "mlx" + self._generation_lock = threading.Lock() + + # MLX state + self._model = None + self._tokenizer = None + self._processor = None + self._is_vlm = False + self._config = {} + + # Recorded for unload to release pinned memory back to the OS. + self._memory_limits_applied = {} + + def _configure_memory_limits(self): + """Apply Metal memory caps before loading a model. + + Mirrors MLXTrainer._configure_memory_limits's defaults: + memory_limit = 85% of recommended working-set, + wired_limit = min(recommended, memory_limit). Recorded so unload + can lower wired_limit back to release pinned RAM. + """ + import mlx.core as mx + + if not mx.metal.is_available(): + return + info = mx.device_info() + rec_bytes = info.get("max_recommended_working_set_size") + if not rec_bytes or rec_bytes <= 0: + return + rec_gb = rec_bytes / 1e9 + memory_limit_gb = rec_gb * 0.85 + wired_limit_gb = min(rec_gb, memory_limit_gb) + mx.set_memory_limit(int(memory_limit_gb * 1e9)) + mx.set_wired_limit(int(wired_limit_gb * 1e9)) + self._memory_limits_applied = { + "memory_limit_gb": memory_limit_gb, + "wired_limit_gb": wired_limit_gb, + "recommended_gb": rec_gb, + } + logger.info( + "MLX memory caps: memory_limit=%.2f GB, wired_limit=%.2f GB", + memory_limit_gb, + wired_limit_gb, + ) + + def load_model( + self, + config, + max_seq_length = 2048, + load_in_4bit = True, + hf_token = None, + trust_remote_code = False, + gpu_ids = None, + dtype = None, + ) -> bool: + import mlx.core as mx + + model_name = config.identifier if hasattr(config, "identifier") else str(config) + is_vision = getattr(config, "is_vision", False) + + if hf_token: + import os + + os.environ["HF_TOKEN"] = hf_token + self._configure_memory_limits() + + is_lora = getattr(config, "is_lora", False) + + logger.info( + "Loading %s via %s (is_lora=%s)", + model_name, + "mlx-vlm" if is_vision else "mlx-lm", + is_lora, + ) + + try: + from unsloth_zoo.mlx_loader import FastMLXModel + except ImportError as e: + raise ImportError( + "Unsloth: MLX inference requires unsloth-zoo with the MLX modules " + "(unsloth_zoo.mlx_loader). Reinstall via install.sh on Apple Silicon." + ) from e + + model, tokenizer_or_processor = FastMLXModel.from_pretrained( + model_name, + max_seq_length = max_seq_length, + dtype = dtype, + load_in_4bit = load_in_4bit, + token = hf_token, + trust_remote_code = trust_remote_code, + text_only = False if is_vision else True, + ) + + if is_vision: + processor = tokenizer_or_processor + self._model = model + self._processor = processor + self._tokenizer = getattr(processor, "tokenizer", processor) + self._is_vlm = True + else: + tokenizer = tokenizer_or_processor + self._model = model + self._tokenizer = tokenizer + self._processor = None + self._is_vlm = False + + self.active_model_name = model_name + self.models[model_name] = { + "model": self._model, + "tokenizer": self._tokenizer, + "processor": self._processor, + "is_vision": is_vision, + "is_lora": getattr(config, "is_lora", False), + "is_audio": False, + "audio_type": None, + "has_audio_input": False, + } + + logger.info("Model %s loaded successfully", model_name) + return True + + def unload_model(self, model_name: str) -> bool: + import mlx.core as mx + import gc + + if model_name in self.models: + del self.models[model_name] + self._model = None + self._tokenizer = None + self._processor = None + if self.active_model_name == model_name: + self.active_model_name = None + gc.collect() + mx.clear_cache() + + if mx.metal.is_available() and self._memory_limits_applied and not self.models: + try: + mx.set_wired_limit(0) + logger.info("MLX wired_limit released back to OS on unload") + except Exception as e: + logger.warning("Failed to release wired_limit: %s", e) + self._memory_limits_applied = {} + logger.info("Model %s unloaded", model_name) + return True + + def generate_chat_response( + self, + messages, + system_prompt = "", + image = None, + temperature = 0.7, + top_p = 0.9, + top_k = 40, + min_p = 0.0, + max_new_tokens = 256, + repetition_penalty = 1.0, + cancel_event = None, + ) -> Generator[str, None, None]: + if self._model is None: + raise RuntimeError("No model loaded") + + # Build messages with system prompt + full_messages = [] + if system_prompt: + full_messages.append({"role": "system", "content": system_prompt}) + full_messages.extend(messages) + + # Inject image into the last user message for VLM + if self._is_vlm and image is not None: + for msg in reversed(full_messages): + if msg.get("role") == "user": + content = msg.get("content", "") + if isinstance(content, str): + msg["content"] = [ + {"type": "image"}, + {"type": "text", "text": content}, + ] + elif isinstance(content, list): + # Prepend image if not already there + has_image = any( + p.get("type") == "image" + for p in content + if isinstance(p, dict) + ) + if not has_image: + content.insert(0, {"type": "image"}) + break + + if self._is_vlm: + yield from self._generate_vlm( + full_messages, + image, + temperature, + top_p, + top_k, + min_p, + max_new_tokens, + repetition_penalty, + cancel_event, + ) + else: + yield from self._generate_text( + full_messages, + temperature, + top_p, + top_k, + min_p, + max_new_tokens, + repetition_penalty, + cancel_event, + ) + + def _generate_text( + self, + messages, + temperature, + top_p, + top_k, + min_p, + max_new_tokens, + repetition_penalty, + cancel_event, + ): + from mlx_lm import stream_generate + from mlx_lm.sample_utils import make_sampler, make_logits_processors + + prompt = self._tokenizer.apply_chat_template( + messages, + tokenize = False, + add_generation_prompt = True, + ) + if prompt is None: + raise RuntimeError( + "apply_chat_template returned None — tokenizer may be incompatible" + ) + + sampler = make_sampler( + temp = temperature, + top_p = top_p, + top_k = int(top_k or 0), + min_p = float(min_p or 0.0), + min_tokens_to_keep = 1, + ) + # Only build a logits processor when we actually have a non-trivial + # repetition penalty (1.0 is the no-op value). + logits_processors = None + if repetition_penalty is not None and float(repetition_penalty) not in ( + 0.0, + 1.0, + ): + logits_processors = make_logits_processors( + repetition_penalty = float(repetition_penalty), + ) + + token_ids = [] + logger.info( + "Generating: prompt_len=%d, max_tokens=%d, model=%s, tokenizer=%s", + len(prompt), + max_new_tokens, + type(self._model).__name__, + type(self._tokenizer).__name__, + ) + with self._generation_lock: + try: + gen_kwargs = dict( + prompt = prompt, + max_tokens = max_new_tokens, + sampler = sampler, + ) + if logits_processors is not None: + gen_kwargs["logits_processors"] = logits_processors + for response in stream_generate( + self._model, + self._tokenizer, + **gen_kwargs, + ): + token_ids.append(response.token) + # Decode full sequence with skip_special_tokens — same as GPU + cumulative = self._tokenizer.decode( + token_ids, + skip_special_tokens = True, + ) + yield cumulative + + if cancel_event and cancel_event.is_set(): + break + except Exception as e: + import traceback + + logger.error("stream_generate failed:\n%s", traceback.format_exc()) + raise + + def _generate_vlm( + self, + messages, + image, + temperature, + top_p, + top_k, + min_p, + max_new_tokens, + repetition_penalty, + cancel_event, + ): + from mlx_vlm import stream_generate as vlm_stream + + # Apply chat template + chat_fn = getattr(self._processor, "apply_chat_template", None) + if ( + chat_fn is None + or not hasattr(self._processor, "chat_template") + or self._processor.chat_template is None + ): + tok = getattr(self._processor, "tokenizer", self._processor) + chat_fn = tok.apply_chat_template + + prompt = chat_fn(messages, tokenize = False, add_generation_prompt = True) + + # For VLM: always use mlx_vlm's stream_generate which handles + # pixel_values properly (passes None for text-only, image for VLM) + images = [image] if image is not None else None + + cumulative = "" + logger.info( + "VLM generating: prompt_len=%d, has_image=%s", + len(prompt), + image is not None, + ) + # mlx_vlm.stream_generate forwards **kwargs into generate_step, which + # accepts temp/top_p/top_k/repetition_penalty (and builds the sampler + # + logits_processors internally). Pass them through. + # NOTE: mlx_vlm.generate_step expects ``temperature=`` (long form) — + # passing ``temp=`` silently falls into **kwargs and is ignored, + # leaving generation stuck at the default 0.0 (greedy). + vlm_kwargs = dict( + max_tokens = max_new_tokens, + temperature = temperature, + top_p = top_p, + top_k = int(top_k or 0), + min_p = float(min_p or 0.0), + ) + if repetition_penalty is not None and float(repetition_penalty) not in ( + 0.0, + 1.0, + ): + vlm_kwargs["repetition_penalty"] = float(repetition_penalty) + + with self._generation_lock: + for response in vlm_stream( + self._model, + self._processor, + prompt, + images, + **vlm_kwargs, + ): + token_text = ( + response.text if hasattr(response, "text") else str(response) + ) + cumulative += token_text + yield cumulative + if cancel_event and cancel_event.is_set(): + break + + def generate_with_adapter_control( + self, use_adapter = None, cancel_event = None, **gen_kwargs + ) -> Generator[str, None, None]: + # MLX LoRA adapter toggling not yet supported — generate normally + yield from self.generate_chat_response(cancel_event = cancel_event, **gen_kwargs) + + def reset_generation_state(self): + import mlx.core as mx + import gc + + gc.collect() + mx.clear_cache() diff --git a/studio/backend/core/inference/worker.py b/studio/backend/core/inference/worker.py index fbcce276ba..085a1ab899 100644 --- a/studio/backend/core/inference/worker.py +++ b/studio/backend/core/inference/worker.py @@ -663,6 +663,98 @@ def run_inference_process( model_name = config["model_name"] + # ── 0. MLX fast-path — skip torch/transformers entirely ── + backend_path = str(Path(__file__).resolve().parent.parent.parent) + if backend_path not in sys.path: + sys.path.insert(0, backend_path) + + from utils.hardware import hardware as _hw + + _hw.detect_hardware() + if _hw.DEVICE == _hw.DeviceType.MLX: + try: + _activate_transformers_version(model_name) + except Exception: + pass + try: + from core.inference.mlx_inference import MLXInferenceBackend + + backend = MLXInferenceBackend() + _send_response( + resp_queue, + {"type": "status", "message": "Loading model...", "ts": time.time()}, + ) + _handle_load(backend, config, resp_queue) + except Exception as exc: + _send_response( + resp_queue, + { + "type": "error", + "error": f"MLX inference init failed: {exc}", + "stack": traceback.format_exc(limit = 20), + "ts": time.time(), + }, + ) + return + + # Enter same command loop as GPU path + logger.info("MLX inference subprocess ready, entering command loop") + while True: + try: + cmd = cmd_queue.get(timeout = 1.0) + except _queue.Empty: + continue + except (EOFError, OSError): + return + if cmd is None: + continue + cmd_type = cmd.get("type", "") + try: + if cmd_type == "generate": + cancel_event.clear() + _handle_generate(backend, cmd, resp_queue, cancel_event) + elif cmd_type == "load": + if backend.active_model_name: + backend.unload_model(backend.active_model_name) + _handle_load(backend, cmd, resp_queue) + elif cmd_type == "unload": + _handle_unload(backend, cmd, resp_queue) + elif cmd_type == "cancel": + cancel_event.set() + elif cmd_type == "reset": + cancel_event.set() + backend.reset_generation_state() + _send_response(resp_queue, {"type": "reset_ack", "ts": time.time()}) + elif cmd_type == "status": + _send_response( + resp_queue, + { + "type": "status_response", + "active_model": backend.active_model_name, + "models": { + k: {kk: vv for kk, vv in v.items() if kk != "model"} + for k, v in backend.models.items() + }, + "loading": list(backend.loading_models), + "ts": time.time(), + }, + ) + elif cmd_type == "shutdown": + return + except Exception as exc: + logger.error("MLX command error (%s): %s", cmd_type, exc) + _send_response( + resp_queue, + { + "type": "gen_error" if cmd_type == "generate" else "error", + "request_id": cmd.get("request_id"), + "error": str(exc), + "stack": traceback.format_exc(limit = 20), + "ts": time.time(), + }, + ) + return + # ── 1. Activate correct transformers version BEFORE any ML imports ── try: _activate_transformers_version(model_name) diff --git a/studio/backend/core/training/training.py b/studio/backend/core/training/training.py index 5642faa189..a04ad5ef49 100644 --- a/studio/backend/core/training/training.py +++ b/studio/backend/core/training/training.py @@ -62,6 +62,7 @@ class TrainingProgress: grad_norm: Optional[float] = None num_tokens: Optional[int] = None eval_loss: Optional[float] = None + peak_memory_gb: Optional[float] = None class TrainingBackend: @@ -199,21 +200,27 @@ def start_training(self, job_id: str, **kwargs) -> bool: config["load_in_4bit"] = False # Spawn subprocess — use locals so state is untouched on failure - resolved_gpu_ids, gpu_selection = prepare_gpu_selection( - kwargs.get("gpu_ids"), - model_name = config["model_name"], - hf_token = config["hf_token"] or None, - training_type = config["training_type"], - load_in_4bit = config["load_in_4bit"], - batch_size = config.get("batch_size", 4), - max_seq_length = config.get("max_seq_length", 2048), - lora_rank = config.get("lora_r", 16), - target_modules = config.get("target_modules"), - gradient_checkpointing = config.get("gradient_checkpointing", "unsloth"), - optimizer = config.get("optim", "adamw_8bit"), - ) - config["resolved_gpu_ids"] = resolved_gpu_ids - config["gpu_selection"] = gpu_selection + from utils.hardware import hardware as _hw + + if _hw.DEVICE == _hw.DeviceType.MLX: + config["resolved_gpu_ids"] = None + config["gpu_selection"] = None + else: + resolved_gpu_ids, gpu_selection = prepare_gpu_selection( + kwargs.get("gpu_ids"), + model_name = config["model_name"], + hf_token = config["hf_token"] or None, + training_type = config["training_type"], + load_in_4bit = config["load_in_4bit"], + batch_size = config.get("batch_size", 4), + max_seq_length = config.get("max_seq_length", 2048), + lora_rank = config.get("lora_r", 16), + target_modules = config.get("target_modules"), + gradient_checkpointing = config.get("gradient_checkpointing", "unsloth"), + optimizer = config.get("optim", "adamw_8bit"), + ) + config["resolved_gpu_ids"] = resolved_gpu_ids + config["gpu_selection"] = gpu_selection from .worker import run_training_process @@ -512,6 +519,12 @@ def _handle_event(self, event: dict) -> None: self._progress.grad_norm = event.get("grad_norm") self._progress.num_tokens = event.get("num_tokens") self._progress.eval_loss = event.get("eval_loss") + _peak = event.get("peak_memory_gb") + if _peak is not None: + try: + self._progress.peak_memory_gb = float(_peak) + except (TypeError, ValueError): + pass self._progress.is_training = True status = event.get("status_message", "") if status: diff --git a/studio/backend/core/training/worker.py b/studio/backend/core/training/worker.py index 60b9e994ab..9c017db6de 100644 --- a/studio/backend/core/training/worker.py +++ b/studio/backend/core/training/worker.py @@ -338,6 +338,594 @@ def _activate_transformers_version(model_name: str) -> None: activate_transformers_for_subprocess(model_name) +def _adapt_for_mlx_vlm(items): + """Adapt GPU-path VLM dataset output for mlx-vlm consumption. + + The GPU path embeds PIL images inside messages content as + {"type": "image", "image": PIL_Image}. mlx-vlm's prepare_inputs + needs images at top-level to produce pixel_values — regardless of + model type. Extract them and leave bare {"type": "image"} placeholders. + """ + adapted = [] + for item in items: + images = [] + messages = [] + for msg in item.get("messages", []): + content = msg.get("content", "") + if isinstance(content, list): + new_content = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "image": + img = part.get("image") + if img is not None: + images.append(img) + new_content.append({"type": "image"}) + else: + new_content.append(part) + messages.append({"role": msg["role"], "content": new_content}) + else: + messages.append(msg) + out = {"messages": messages} + if images: + out["image"] = images[0] if len(images) == 1 else images + elif "image" in item: + out["image"] = item["image"] + elif "images" in item: + out["images"] = item["images"] + adapted.append(out) + return adapted + + +_MLX_STUDIO_OPTIM_MAP = { + "adamw_8bit": "adamw", + "paged_adamw_8bit": "adamw", + "adamw_bnb_8bit": "adamw", + "paged_adamw_32bit": "adamw", + "adamw_torch": "adamw", + "adamw_torch_fused": "adamw", + "adamw": "adamw", + "adafactor": "adafactor", + "sgd": "sgd", + "adam": "adam", + "muon": "muon", + "lion": "lion", +} +_MLX_STUDIO_LR_SCHEDULERS = {"linear", "cosine", "constant"} + + +def _normalize_mlx_studio_optimizer(value): + raw = str(value or "adamw_8bit").strip().lower() + try: + return _MLX_STUDIO_OPTIM_MAP[raw] + except KeyError: + supported = ", ".join(sorted(_MLX_STUDIO_OPTIM_MAP)) + raise ValueError( + f"Unsupported optimizer for MLX training: {value!r}. " + f"Supported values: {supported}." + ) + + +def _normalize_mlx_studio_scheduler(value): + raw = str(value or "linear").strip().lower() + if raw not in _MLX_STUDIO_LR_SCHEDULERS: + supported = ", ".join(sorted(_MLX_STUDIO_LR_SCHEDULERS)) + raise ValueError( + f"Unsupported LR scheduler for MLX training: {value!r}. " + f"Supported values: {supported}." + ) + return raw + + +def _run_mlx_training(event_queue, stop_queue, config): + """Self-contained MLX training path for Apple Silicon. + + Uses MLXTrainer from unsloth_zoo directly -- no torch/SFTTrainer needed. + Mirrors the event_queue protocol so the parent process pump works unchanged. + """ + import time + import gc + import math + import threading + import queue as _queue + from pathlib import Path + + def _send(event_type, **kwargs): + if event_type == "status" and "message" not in kwargs: + sm = kwargs.get("status_message") + if sm is not None: + kwargs["message"] = sm + event_queue.put({"type": event_type, "ts": time.time(), **kwargs}) + + _send("status", status_message = "Loading MLX libraries...") + + import mlx.core as mx + + try: + from unsloth_zoo.mlx_loader import FastMLXModel + from unsloth_zoo.mlx_trainer import ( + MLXTrainer, + MLXTrainingConfig, + train_on_responses_only, + ) + except ImportError as e: + raise ImportError( + "Unsloth: MLX training requires unsloth-zoo with the MLX modules " + "(unsloth_zoo.mlx_loader / unsloth_zoo.mlx_trainer). Reinstall via " + "install.sh on Apple Silicon." + ) from e + from datasets import load_dataset + + if mx.metal.is_available(): + info = mx.device_info() + rec_bytes = info.get("max_recommended_working_set_size", 0) or 0 + if rec_bytes > 0: + memory_cap = int(rec_bytes * 0.85) + wired_cap = min(int(rec_bytes), memory_cap) + mx.set_memory_limit(memory_cap) + mx.set_wired_limit(wired_cap) + + model_name = config["model_name"] + hf_token = config.get("hf_token") or None + if hf_token: + os.environ["HF_TOKEN"] = hf_token + + if config.get("use_loftq"): + message = "LoftQ is not supported for MLX training yet." + _send("error", error = message) + raise NotImplementedError(message) + + optim_name = _normalize_mlx_studio_optimizer(config.get("optim", "adamw_8bit")) + lr_scheduler_type = _normalize_mlx_studio_scheduler( + config.get("lr_scheduler_type", "linear") + ) + + # ── 1. Load model ── + # Force text-only if the dataset is not an image dataset, even if the model + # has vision capabilities (e.g. Qwen3.5-VL trained on plain alpaca text). + _send("status", status_message = f"Loading {model_name}...") + is_dataset_image = bool(config.get("is_dataset_image", False)) + training_type = config.get("training_type", "LoRA/QLoRA") + use_lora = training_type == "LoRA/QLoRA" + model, tokenizer = FastMLXModel.from_pretrained( + model_name, + load_in_4bit = config.get("load_in_4bit", True), + full_finetuning = not use_lora, + text_only = None if is_dataset_image else True, + token = hf_token, + trust_remote_code = bool(config.get("trust_remote_code", False)), + random_state = config.get("random_seed", 3407), + ) + + is_vlm = bool(is_dataset_image and getattr(model, "_is_vlm_model", False)) + model._is_vlm_model = is_vlm + + # ── 2. Apply LoRA / full FT ── + # Pass gradient_checkpointing as string ("mlx"/"unsloth"/"none"/etc.) + # get_peft_model and MLXTrainer both accept strings and handle them. + gc_setting = config.get("gradient_checkpointing", "mlx") + if isinstance(gc_setting, str): + use_grad_checkpoint = ( + gc_setting if gc_setting.lower() not in ("false", "") else False + ) + else: + use_grad_checkpoint = gc_setting + + if use_lora: + _send("status", status_message = "Configuring LoRA adapters...") + peft_kwargs = dict( + r = config.get("lora_r", 16), + lora_alpha = config.get("lora_alpha", 16), + lora_dropout = config.get("lora_dropout", 0.0), + use_rslora = config.get("use_rslora", False), + init_lora_weights = config.get("init_lora_weights", True), + random_state = config.get("random_seed", 3407), + target_modules = config.get("target_modules") + or [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + use_gradient_checkpointing = use_grad_checkpoint, + ) + finetune_language = config.get("finetune_language_layers", True) + finetune_attention = config.get("finetune_attention_modules", True) + finetune_mlp = config.get("finetune_mlp_modules", True) + finetune_vision = ( + config.get("finetune_vision_layers", False) if is_vlm else False + ) + + if ( + (finetune_attention or finetune_mlp) + and not finetune_language + and not finetune_vision + ): + finetune_language = True + + peft_kwargs["finetune_language_layers"] = finetune_language + peft_kwargs["finetune_attention_modules"] = finetune_attention + peft_kwargs["finetune_mlp_modules"] = finetune_mlp + if is_vlm: + peft_kwargs["finetune_vision_layers"] = finetune_vision + model = FastMLXModel.get_peft_model(model, **peft_kwargs) + + # ── 3. Load dataset ── + _send("status", status_message = "Loading dataset...") + hf_dataset = config.get("hf_dataset", "") + subset = config.get("subset") + train_split = config.get("train_split", "train") or "train" + eval_split = config.get("eval_split") + slice_start = config.get("dataset_slice_start") + slice_end = config.get("dataset_slice_end") + + def _slice(ds): + if slice_start is not None or slice_end is not None: + start = slice_start if slice_start is not None else 0 + end = slice_end if slice_end is not None else len(ds) - 1 + if end < start: + return ds.select([]) + ds = ds.select(range(start, min(end + 1, len(ds)))) + return ds + + def _load_local(file_paths): + from core.training.trainer import UnslothTrainer + from datasets import load_from_disk + + if len(file_paths) == 1: + p = Path(file_paths[0]) + if p.is_dir() and ( + (p / "dataset_info.json").exists() or (p / "state.json").exists() + ): + return load_from_disk(str(p)) + all_files = UnslothTrainer._resolve_local_files(file_paths) + if not all_files: + raise ValueError("No local dataset files found") + loader = UnslothTrainer._loader_for_files(all_files) + return load_dataset(loader, data_files = all_files, split = "train") + + if hf_dataset: + load_kwargs = {"split": train_split, "token": hf_token} + if subset: + load_kwargs["name"] = subset + dataset = load_dataset(hf_dataset, **load_kwargs) + dataset = _slice(dataset) + elif config.get("local_datasets"): + dataset = _load_local(config["local_datasets"]) + dataset = _slice(dataset) + else: + raise ValueError("No dataset specified") + + # Eval dataset (separate split or local file) + eval_dataset = None + if eval_split and hf_dataset: + eval_kwargs = {"split": eval_split, "token": hf_token} + if subset: + eval_kwargs["name"] = subset + try: + eval_dataset = load_dataset(hf_dataset, **eval_kwargs) + except Exception as e: + _send("status", status_message = f"Eval split load failed: {e}") + eval_dataset = None + elif config.get("local_eval_datasets"): + eval_dataset = _load_local(config["local_eval_datasets"]) + + # ── 3b. Format dataset (VLM or text) ── + # Reuse the GPU path's format pipeline for both VLM (auto-detects OCR/caption/ + # llava/sharegpt+images) and text (alpaca/sharegpt/chatml → "text" column). + format_type = config.get("format_type", "") + try: + from utils.datasets import format_and_template_dataset + + def _fmt_progress(status_message = "", **_kw): + _send("status", status_message = status_message) + + if is_vlm: + _send("status", status_message = "Formatting VLM dataset...") + vlm_info = format_and_template_dataset( + dataset, + model_name = model_name, + tokenizer = tokenizer, + is_vlm = True, + dataset_name = hf_dataset or "local", + progress_callback = _fmt_progress, + ) + if vlm_info.get("success"): + dataset = _adapt_for_mlx_vlm(vlm_info["dataset"]) + else: + errors = vlm_info.get("errors", []) + raise ValueError( + f"VLM dataset format conversion failed: {'; '.join(errors)}" + ) + if eval_dataset is not None: + ev_info = format_and_template_dataset( + eval_dataset, + model_name = model_name, + tokenizer = tokenizer, + is_vlm = True, + dataset_name = hf_dataset or "local", + ) + if ev_info.get("success"): + eval_dataset = _adapt_for_mlx_vlm(ev_info["dataset"]) + + elif format_type: + _send("status", status_message = f"Formatting dataset ({format_type})...") + info = format_and_template_dataset( + dataset, + model_name = model_name, + tokenizer = tokenizer, + is_vlm = False, + format_type = format_type, + dataset_name = hf_dataset or "local", + ) + if info.get("success", True): + dataset = info.get("dataset", dataset) + if eval_dataset is not None: + ev = format_and_template_dataset( + eval_dataset, + model_name = model_name, + tokenizer = tokenizer, + is_vlm = False, + format_type = format_type, + dataset_name = hf_dataset or "local", + ) + if ev.get("success", True): + eval_dataset = ev.get("dataset", eval_dataset) + except ImportError: + _send("status", status_message = "Format helper unavailable, using raw dataset") + + # ── 4. Resolve training steps ── + max_steps = config.get("max_steps", 0) or 0 + num_epochs = config.get("num_epochs", 3) + max_seq_length = config.get("max_seq_length", 2048) + batch_size = config.get("batch_size", 4) + grad_accum = config.get("gradient_accumulation_steps", 4) + + if max_steps <= 0: + max_steps = max( + 1, + math.ceil(len(dataset) / batch_size / grad_accum) * num_epochs, + ) + + lr_value = float(config.get("learning_rate", "2e-4")) + + # Warmup: prefer warmup_steps; fall back to warmup_ratio + warmup_steps = config.get("warmup_steps") + warmup_ratio = config.get("warmup_ratio") + if warmup_steps is None and warmup_ratio is not None: + warmup_steps = int(round(warmup_ratio * max_steps)) + if warmup_steps is None: + warmup_steps = 5 + + # ── 5. Build output dir ── + output_dir = config.get("output_dir", "") + if not output_dir: + output_dir = f"{model_name.replace('/', '_')}_{int(time.time())}" + # Resolve to ~/.unsloth/studio/outputs/ so the export page can find it + from utils.paths import resolve_output_dir, ensure_dir + + output_dir = str(resolve_output_dir(output_dir)) + ensure_dir(Path(output_dir)) + + # ── 6. Create trainer ── + eval_steps_val = config.get("eval_steps", 0) or 0 + if isinstance(eval_steps_val, float) and 0 < eval_steps_val < 1: + # Studio sometimes sends fraction-of-total-steps + eval_steps_val = max(1, int(eval_steps_val * max_steps)) + else: + eval_steps_val = int(eval_steps_val) + + trainer = MLXTrainer( + model = model, + tokenizer = tokenizer, + train_dataset = dataset, + eval_dataset = eval_dataset, + args = MLXTrainingConfig( + per_device_train_batch_size = batch_size, + gradient_accumulation_steps = grad_accum, + max_steps = max_steps, + learning_rate = lr_value, + warmup_steps = warmup_steps, + lr_scheduler_type = lr_scheduler_type, + optim = optim_name, + weight_decay = float(config.get("weight_decay", 0.001) or 0.001), + logging_steps = 1, + max_seq_length = max_seq_length, + seed = config.get("random_seed", 3407), + use_cce = True, + compile = True, + gradient_checkpointing = use_grad_checkpoint, + streaming = is_vlm, + packing = bool(config.get("packing", False)), + output_dir = output_dir, + save_steps = int(config.get("save_steps", 0) or 0), + eval_steps = eval_steps_val, + ), + ) + + # Tell the parent that eval is configured so the frontend shows the eval chart + if eval_dataset is not None and eval_steps_val > 0: + _send("eval_configured") + + # ── 7. Apply train_on_responses_only if requested ── + if config.get("train_on_completions", False): + _send("status", status_message = "Configuring response-only training...") + try: + from utils.datasets import ( + MODEL_TO_TEMPLATE_MAPPER, + TEMPLATE_TO_RESPONSES_MAPPER, + ) + + template_name = MODEL_TO_TEMPLATE_MAPPER.get(model_name.lower()) + markers = ( + TEMPLATE_TO_RESPONSES_MAPPER.get(template_name) + if template_name + else None + ) + if markers: + trainer = train_on_responses_only( + trainer, + instruction_part = markers["instruction"], + response_part = markers["response"], + ) + else: + _send( + "status", + status_message = f"train_on_completions skipped (no template for {model_name})", + ) + except Exception as e: + _send("status", status_message = f"train_on_completions failed: {e}") + + # ── 8. Setup wandb / tensorboard ── + wandb_run = None + tb_writer = None + if config.get("enable_wandb", False): + try: + import wandb as _wandb + + wandb_token = config.get("wandb_token") + if wandb_token: + os.environ["WANDB_API_KEY"] = wandb_token + _wandb_sensitive = {"hf_token", "wandb_token"} + wandb_run = _wandb.init( + project = config.get("wandb_project") or "unsloth-mlx", + config = {k: v for k, v in config.items() if k not in _wandb_sensitive}, + reinit = True, + ) + except Exception as e: + _send("status", status_message = f"wandb init failed: {e}") + if config.get("enable_tensorboard", False): + try: + from tensorboardX import SummaryWriter + except ImportError: + try: + from torch.utils.tensorboard import SummaryWriter + except ImportError: + SummaryWriter = None + if SummaryWriter is not None: + try: + tb_dir = config.get("tensorboard_dir") or f"{output_dir}/runs" + tb_writer = SummaryWriter(log_dir = tb_dir) + except Exception as e: + _send("status", status_message = f"tensorboard init failed: {e}") + else: + _send( + "status", + status_message = "tensorboard unavailable (install tensorboardX)", + ) + + # ── 9. Real-time progress callback ── + _send("status", status_message = f"Training {model_name}...") + + def _on_step(step, total, loss, lr, tok_s, peak_gb, elapsed, num_tokens): + eta = (elapsed / step * (total - step)) if step > 0 else 0 + _send( + "progress", + step = step, + epoch = round(step / total * num_epochs, 2) if total > 0 else 0, + loss = loss, + learning_rate = lr, + total_steps = total, + elapsed_seconds = elapsed, + eta_seconds = max(0, eta), + grad_norm = None, + num_tokens = num_tokens, + eval_loss = None, + status_message = None, + peak_memory_gb = peak_gb, + ) + if wandb_run is not None: + try: + wandb_run.log( + { + "train/loss": loss, + "train/learning_rate": lr, + "train/tokens_per_sec": tok_s, + "train/peak_gb": peak_gb, + "train/num_tokens": num_tokens, + }, + step = step, + ) + except Exception: + pass + if tb_writer is not None: + try: + tb_writer.add_scalar("train/loss", loss, step) + tb_writer.add_scalar("train/learning_rate", lr, step) + tb_writer.add_scalar("train/tokens_per_sec", tok_s, step) + tb_writer.add_scalar("train/peak_gb", peak_gb, step) + except Exception: + pass + + trainer.add_step_callback(_on_step) + + def _on_eval(step, eval_loss, perplexity): + _send("progress", step = step, eval_loss = eval_loss) + if wandb_run is not None: + try: + wandb_run.log( + {"eval/loss": eval_loss, "eval/perplexity": perplexity}, step = step + ) + except Exception: + pass + if tb_writer is not None: + try: + tb_writer.add_scalar("eval/loss", eval_loss, step) + tb_writer.add_scalar("eval/perplexity", perplexity, step) + except Exception: + pass + + trainer.add_eval_callback(_on_eval) + + # ── 10. Stop signal polling ── + _stop_save = [True] # mutable so thread can update; [save_flag] + + def _poll_stop(): + while True: + try: + msg = stop_queue.get(timeout = 1.0) + if msg and msg.get("type") == "stop": + _stop_save[0] = msg.get("save", True) + trainer.stop_requested = True + return + except _queue.Empty: + continue + except (EOFError, OSError): + # why safe: pipe permanently broken, no further messages can arrive + return + + stop_thread = threading.Thread(target = _poll_stop, daemon = True) + stop_thread.start() + + # ── 11. Run training ── + gc.collect() + mx.synchronize() + trainer.train() + + # ── 12. Save and finalize ── + if trainer.stop_requested and not _stop_save[0]: + # User clicked "Cancel" (save=False) — skip saving + _send("complete", output_dir = None, status_message = "Training cancelled") + else: + _send("status", status_message = "Saving model...") + mx.synchronize() + trainer.save_model(output_dir) + _send("complete", output_dir = output_dir, status_message = "Training completed") + + if tb_writer is not None: + try: + tb_writer.close() + except Exception: + pass + if wandb_run is not None: + try: + wandb_run.finish() + except Exception: + pass + + def run_training_process( *, event_queue: Any, @@ -371,6 +959,46 @@ def run_training_process( model_name = config["model_name"] + # ── 0. MLX FAST-PATH (must run before any torch/transformers imports) ── + # Apple Silicon uses MLXTrainer directly -- skip transformers version + # activation, causal-conv1d install, and torch imports entirely. + backend_path = str(Path(__file__).resolve().parent.parent.parent) + if backend_path not in sys.path: + sys.path.insert(0, backend_path) + + from utils.hardware import hardware as _hw + + _hw.detect_hardware() + if _hw.DEVICE == _hw.DeviceType.MLX: + if config.get("is_dataset_audio"): + event_queue.put( + { + "type": "error", + "error": "Audio dataset training is not yet supported on Apple Silicon.", + "stack": "", + "ts": time.time(), + } + ) + return + # Activate correct transformers version (Gemma-4 needs 5.5.0, etc.) + # Must happen before any transformers/mlx-lm imports in _run_mlx_training. + try: + _activate_transformers_version(model_name) + except Exception: + pass # Non-fatal: fall through with whatever version is installed + try: + _run_mlx_training(event_queue, stop_queue, config) + except Exception as exc: + event_queue.put( + { + "type": "error", + "error": str(exc), + "stack": traceback.format_exc(limit = 20), + "ts": time.time(), + } + ) + return + # ── 1. Activate correct transformers version BEFORE any ML imports ── try: _activate_transformers_version(model_name) diff --git a/studio/backend/tests/test_mlx_inference_backend.py b/studio/backend/tests/test_mlx_inference_backend.py new file mode 100644 index 0000000000..868e537372 --- /dev/null +++ b/studio/backend/tests/test_mlx_inference_backend.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: AGPL-3.0-only + +import sys +import types +from types import SimpleNamespace + + +class _DummyMetal: + @staticmethod + def is_available(): + return False + + +class _DummyMX: + metal = _DummyMetal() + + @staticmethod + def set_wired_limit(_limit): + return None + + @staticmethod + def device_info(): + return {"max_recommended_working_set_size": 1024} + + +class _DummyTokenizer: + pass + + +class _DummyProcessor: + tokenizer = _DummyTokenizer() + + +class _DummyModel: + pass + + +def _install_fake_mlx(monkeypatch): + mlx_pkg = types.ModuleType("mlx") + mlx_core = types.ModuleType("mlx.core") + mlx_core.metal = _DummyMetal() + mlx_core.set_wired_limit = _DummyMX.set_wired_limit + mlx_core.device_info = _DummyMX.device_info + mlx_pkg.core = mlx_core + monkeypatch.setitem(sys.modules, "mlx", mlx_pkg) + monkeypatch.setitem(sys.modules, "mlx.core", mlx_core) + + +def _install_fake_fast_mlx(monkeypatch, calls): + class _FastMLXModel: + @staticmethod + def from_pretrained(*args, **kwargs): + calls.append((args, kwargs)) + if kwargs["text_only"] is False: + return _DummyModel(), _DummyProcessor() + return _DummyModel(), _DummyTokenizer() + + unsloth_zoo_pkg = types.ModuleType("unsloth_zoo") + mlx_loader = types.ModuleType("unsloth_zoo.mlx_loader") + mlx_loader.FastMLXModel = _FastMLXModel + unsloth_zoo_pkg.mlx_loader = mlx_loader + monkeypatch.setitem(sys.modules, "unsloth_zoo", unsloth_zoo_pkg) + monkeypatch.setitem(sys.modules, "unsloth_zoo.mlx_loader", mlx_loader) + + +def test_mlx_inference_text_load_forwards_studio_settings(monkeypatch): + _install_fake_mlx(monkeypatch) + calls = [] + _install_fake_fast_mlx(monkeypatch, calls) + + from core.inference.mlx_inference import MLXInferenceBackend + + backend = MLXInferenceBackend() + config = SimpleNamespace(identifier = "fake/text", is_vision = False, is_lora = False) + + assert backend.load_model( + config, + max_seq_length = 4096, + load_in_4bit = False, + hf_token = "hf-token", + trust_remote_code = True, + dtype = "float16", + ) + + assert calls == [ + ( + ("fake/text",), + { + "max_seq_length": 4096, + "dtype": "float16", + "load_in_4bit": False, + "token": "hf-token", + "trust_remote_code": True, + "text_only": True, + }, + ) + ] + assert backend._is_vlm is False + assert isinstance(backend._tokenizer, _DummyTokenizer) + + +def test_mlx_inference_vlm_lora_uses_unsloth_loader_without_native_adapter_rewrite( + monkeypatch, + tmp_path, +): + _install_fake_mlx(monkeypatch) + calls = [] + _install_fake_fast_mlx(monkeypatch, calls) + + def _native_vlm_load(*_args, **_kwargs): + raise AssertionError("Studio MLX VLM inference must use FastMLXModel") + + mlx_vlm = types.ModuleType("mlx_vlm") + mlx_vlm.load = _native_vlm_load + monkeypatch.setitem(sys.modules, "mlx_vlm", mlx_vlm) + + adapter_dir = tmp_path / "adapter" + adapter_dir.mkdir() + cfg_path = adapter_dir / "adapter_config.json" + original_cfg = '{"base_model_name_or_path": "fake/base", "rank": 8}\n' + cfg_path.write_text(original_cfg) + + from core.inference.mlx_inference import MLXInferenceBackend + + backend = MLXInferenceBackend() + config = SimpleNamespace( + identifier = str(adapter_dir), + is_vision = True, + is_lora = True, + base_model = "fake/base", + ) + + assert backend.load_model( + config, + max_seq_length = 8192, + load_in_4bit = True, + hf_token = "hf-token", + trust_remote_code = True, + ) + + assert calls == [ + ( + (str(adapter_dir),), + { + "max_seq_length": 8192, + "dtype": None, + "load_in_4bit": True, + "token": "hf-token", + "trust_remote_code": True, + "text_only": False, + }, + ) + ] + assert cfg_path.read_text() == original_cfg + assert backend._is_vlm is True + assert isinstance(backend._processor, _DummyProcessor) + assert isinstance(backend._tokenizer, _DummyTokenizer) diff --git a/studio/backend/tests/test_mlx_training_worker_config.py b/studio/backend/tests/test_mlx_training_worker_config.py new file mode 100644 index 0000000000..5900af4e3d --- /dev/null +++ b/studio/backend/tests/test_mlx_training_worker_config.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: AGPL-3.0-only + +import importlib.util +import sys +import types +from pathlib import Path + +import pytest + + +def _load_worker_module(): + stub_names = ( + "structlog", + "loggers", + "utils", + "utils.hardware", + "utils.wheel_utils", + ) + previous_modules = {name: sys.modules.get(name) for name in stub_names} + + try: + sys.modules["structlog"] = types.ModuleType("structlog") + + loggers = types.ModuleType("loggers") + loggers.get_logger = lambda *_args, **_kwargs: None + sys.modules["loggers"] = loggers + + utils = types.ModuleType("utils") + utils.__path__ = [] + sys.modules["utils"] = utils + + hardware = types.ModuleType("utils.hardware") + hardware.apply_gpu_ids = lambda *_args, **_kwargs: None + sys.modules["utils.hardware"] = hardware + + wheel_utils = types.ModuleType("utils.wheel_utils") + for name in ( + "direct_wheel_url", + "flash_attn_wheel_url", + "install_wheel", + "probe_torch_wheel_env", + "url_exists", + ): + setattr(wheel_utils, name, lambda *_args, **_kwargs: None) + sys.modules["utils.wheel_utils"] = wheel_utils + + worker_path = ( + Path(__file__).resolve().parents[1] / "core" / "training" / "worker.py" + ) + spec = importlib.util.spec_from_file_location( + "mlx_training_worker_under_test", worker_path + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + finally: + for name, module in previous_modules.items(): + if module is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = module + + +_worker = _load_worker_module() +_normalize_mlx_studio_optimizer = _worker._normalize_mlx_studio_optimizer +_normalize_mlx_studio_scheduler = _worker._normalize_mlx_studio_scheduler + + +def test_mlx_studio_optimizer_aliases_are_explicit(): + assert _normalize_mlx_studio_optimizer("adamw_8bit") == "adamw" + assert _normalize_mlx_studio_optimizer("paged_adamw_8bit") == "adamw" + assert _normalize_mlx_studio_optimizer("adafactor") == "adafactor" + + +def test_mlx_studio_rejects_unknown_optimizer(): + with pytest.raises(ValueError, match = "Unsupported optimizer for MLX training"): + _normalize_mlx_studio_optimizer("adamw_typo") + + +def test_mlx_studio_rejects_unknown_scheduler(): + with pytest.raises(ValueError, match = "Unsupported LR scheduler for MLX training"): + _normalize_mlx_studio_scheduler("linear_typo") diff --git a/studio/backend/utils/hardware/hardware.py b/studio/backend/utils/hardware/hardware.py index c218b7b4b9..3764e38272 100644 --- a/studio/backend/utils/hardware/hardware.py +++ b/studio/backend/utils/hardware/hardware.py @@ -143,6 +143,7 @@ def detect_hardware() -> DeviceType: # --- MLX: Apple Silicon --- if is_apple_silicon() and _has_mlx(): DEVICE = DeviceType.MLX + CHAT_ONLY = False chip = platform.processor() or platform.machine() print(f"Hardware detected: MLX — Apple Silicon ({chip})") return DEVICE @@ -270,19 +271,30 @@ def get_gpu_memory_info() -> Dict[str, Any]: import mlx.core as mx import psutil - # MLX uses unified memory — report system memory as the pool + # MLX uses unified memory. Total = system RAM. GPU memory used + # comes from IORegistry's AGXAccelerator (system-wide, no sudo). total = psutil.virtual_memory().total - # MLX doesn't expose per-process GPU allocation; report 0 as allocated - allocated = 0 + agx = _read_apple_gpu_stats() + allocated = agx.get("vram_used_bytes", 0) if agx else 0 + + try: + info = mx.device_info() + gpu_name = ( + info.get("device_name") + or platform.processor() + or platform.machine() + ) + except Exception: + gpu_name = platform.processor() or platform.machine() return { "available": True, "backend": _backend_label(device), "device": 0, - "device_name": f"Apple Silicon ({platform.processor() or platform.machine()})", + "device_name": f"Apple Silicon ({gpu_name})", "total_gb": total / (1024**3), "allocated_gb": allocated / (1024**3), - "reserved_gb": 0, + "reserved_gb": allocated / (1024**3), "free_gb": (total - allocated) / (1024**3), "utilization_pct": (allocated / total) * 100 if total else 0, } @@ -460,6 +472,39 @@ def _smi_query(func_name: str, *args, **kwargs) -> Optional[Dict[str, Any]]: return None +def _read_apple_gpu_stats() -> Dict[str, Any]: + """Query macOS IORegistry for AGX (Apple GPU) live stats. No sudo needed. + + Returns dict with utilization_pct, vram_used_bytes (system-wide GPU memory). + Returns empty dict on failure. + """ + import subprocess + import re + + try: + result = subprocess.run( + ["ioreg", "-r", "-c", "AGXAccelerator"], + capture_output = True, + timeout = 2, + ) + text = result.stdout.decode("utf-8", errors = "replace") + except Exception: + return {} + + # PerformanceStatistics block has GPU utilization and in-use memory + m = re.search(r'"PerformanceStatistics" = \{([^}]+)\}', text) + if not m: + return {} + stats_str = m.group(1) + pairs = re.findall(r'"([^"]+)"=(\d+)', stats_str) + stats = {k: int(v) for k, v in pairs} + + return { + "utilization_pct": stats.get("Device Utilization %", 0), + "vram_used_bytes": stats.get("In use system memory", 0), + } + + def get_gpu_utilization() -> Dict[str, Any]: """Return a live snapshot of device utilization information.""" device = get_device() @@ -470,6 +515,50 @@ def get_gpu_utilization() -> Dict[str, Any]: result["backend"] = _backend_label(device) return result + # MLX path: single _read_apple_gpu_stats() call carries both VRAM-used + # bytes and GPU utilization %. psutil for unified-memory total is cheap. + if device == DeviceType.MLX: + try: + import psutil + + agx = _read_apple_gpu_stats() + total_bytes = psutil.virtual_memory().total + except Exception as e: + logger.error(f"Error getting MLX GPU utilization: {e}") + return {"available": False, "backend": device.value, "error": str(e)} + if not agx: + return {"available": False, "backend": device.value} + allocated_bytes = agx.get("vram_used_bytes", 0) or 0 + vram_used_gb = allocated_bytes / (1024**3) + total_gb = total_bytes / (1024**3) + + try: + from core.training import get_training_backend + + tb = get_training_backend() + tb_progress = getattr(tb, "_progress", None) + if tb_progress is not None and getattr(tb_progress, "is_training", False): + tb_peak = getattr(tb_progress, "peak_memory_gb", None) + if tb_peak is not None and tb_peak > 0: + vram_used_gb = float(tb_peak) + except Exception: + pass + + return { + "available": True, + "backend": device.value, + "gpu_utilization_pct": agx.get("utilization_pct") if agx else None, + "temperature_c": None, + "vram_used_gb": round(vram_used_gb, 2), + "vram_total_gb": round(total_gb, 2), + "vram_utilization_pct": ( + round((vram_used_gb / total_gb) * 100, 1) if total_gb > 0 else None + ), + "power_draw_w": None, + "power_limit_w": None, + "power_utilization_pct": None, + } + mem = get_gpu_memory_info() if device != DeviceType.CPU and mem.get("available"): return { diff --git a/studio/frontend/src/components/app-sidebar.tsx b/studio/frontend/src/components/app-sidebar.tsx index edcd5120eb..171b4eb92d 100644 --- a/studio/frontend/src/components/app-sidebar.tsx +++ b/studio/frontend/src/components/app-sidebar.tsx @@ -37,13 +37,12 @@ import { Delete02Icon, Download03Icon, GemIcon, - Globe02Icon, Search01Icon, PowerIcon, PencilEdit02Icon, LayoutAlignLeftIcon, - HelpCircleIcon, Settings02Icon, + SourceCodeSquareIcon, ZapIcon, } from "@hugeicons/core-free-icons"; import { @@ -528,7 +527,7 @@ export function AppSidebar() {