diff --git a/.gitignore b/.gitignore index b6786ee655..2f23f18d65 100644 --- a/.gitignore +++ b/.gitignore @@ -228,3 +228,4 @@ setup_leo.sh server.pid *.log package-lock.json +llama.cpp/ diff --git a/install.sh b/install.sh index 1d21117d16..c928d2ae54 100755 --- a/install.sh +++ b/install.sh @@ -1721,6 +1721,12 @@ else fi fi +# ── Install mlx-vlm on Apple Silicon (optional, for VLM training) ── +if [ "$OS" = "macos" ] && [ "$_ARCH" = "arm64" ]; then + substep "installing mlx-vlm (VLM training support)..." + run_install_cmd "install mlx-vlm" uv pip install --python "$_VENV_PY" mlx-vlm +fi + # ── Run studio setup ── tauri_log "STEP" "Running Studio setup" # When --local, use the repo's own setup.sh directly. diff --git a/pyproject.toml b/pyproject.toml index c2f884e192..5687ea12f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ huggingfacenotorch = [ ] huggingface = [ "unsloth[huggingfacenotorch]", - "unsloth_zoo>=2026.5.1", + "unsloth_zoo>=2026.4.8", "torchvision", "unsloth[triton]", ] @@ -579,7 +579,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3 ; ('linux' in sys_platform)", ] colab-new = [ - "unsloth_zoo>=2026.5.1", + "unsloth_zoo>=2026.4.8", "packaging", "tyro", "transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.5.0", diff --git a/studio/backend/core/export/export.py b/studio/backend/core/export/export.py index 6fee5a38f7..4ab95d896f 100644 --- a/studio/backend/core/export/export.py +++ b/studio/backend/core/export/export.py @@ -9,16 +9,14 @@ import glob import json import structlog +import tempfile from loggers import get_logger import os import shutil from pathlib import Path from typing import Optional, Tuple, List -from peft import PeftModel, PeftModelForCausalLM -from unsloth import FastLanguageModel, FastVisionModel +from unsloth import FastLanguageModel, FastVisionModel, _IS_MLX from huggingface_hub import HfApi, ModelCard -from transformers.modeling_utils import PushToHubMixin -import torch from utils.hardware import clear_gpu_cache from utils.models import is_vision_model, get_base_model_from_lora @@ -26,6 +24,12 @@ from utils.paths import ensure_dir, outputs_root, resolve_export_dir, resolve_output_dir from core.inference import get_inference_backend +# GPU-only imports — guarded for Apple Silicon where these aren't needed +if not _IS_MLX: + from peft import PeftModel, PeftModelForCausalLM + from transformers.modeling_utils import PushToHubMixin + import torch + logger = get_logger(__name__) _LLAMA_CPP_SCRIPTS_WARNING_EMITTED = False @@ -225,7 +229,7 @@ def load_checkpoint( model, tokenizer = FastModel.from_pretrained( model_name = checkpoint_path, max_seq_length = max_seq_length, - dtype = torch.float32, + dtype = None if _IS_MLX else torch.float32, load_in_4bit = False, trust_remote_code = trust_remote_code, ) @@ -262,8 +266,12 @@ def load_checkpoint( trust_remote_code = trust_remote_code, ) - # Check if PEFT model - self.is_peft = isinstance(model, (PeftModel, PeftModelForCausalLM)) + # Check if PEFT / LoRA model + if _IS_MLX: + # MLX doesn't use PeftModel — detect LoRA via adapter_config.json + self.is_peft = adapter_config.exists() + else: + self.is_peft = isinstance(model, (PeftModel, PeftModelForCausalLM)) # Store loaded model self.current_model = model @@ -325,9 +333,7 @@ def export_merged_model( private: Whether to make the repo private Returns: - Tuple of (success, message, output_path). output_path is the - resolved absolute on-disk directory of the saved model when - ``save_directory`` was set, else None. + Tuple of (success: bool, message: str, output_path: Optional[str]) """ if not self.current_model or not self.current_tokenizer: return False, "No model loaded. Please select a checkpoint first.", None @@ -341,14 +347,17 @@ def export_merged_model( output_path: Optional[str] = None try: - # Determine save method - if format_type == "4-bit (FP4)": - save_method = "merged_4bit_forced" - elif self._audio_type == "whisper": - # Whisper uses save_method=None for local 16-bit merged save - save_method = None - else: # 16-bit (FP16) - save_method = "merged_16bit" + if _IS_MLX: + mlx_save_method = ( + "merged_4bit" if format_type == "4-bit (FP4)" else "merged_16bit" + ) + else: + if format_type == "4-bit (FP4)": + save_method = "merged_4bit_forced" + elif self._audio_type == "whisper": + save_method = None + else: + save_method = "merged_16bit" # Save locally if requested if save_directory: @@ -356,11 +365,17 @@ def export_merged_model( logger.info(f"Saving merged model locally to: {save_directory}") ensure_dir(Path(save_directory)) - self.current_model.save_pretrained_merged( - save_directory, self.current_tokenizer, save_method = save_method - ) + if _IS_MLX: + self.current_model.save_pretrained_merged( + save_directory, + self.current_tokenizer, + save_method = mlx_save_method, + ) + else: + self.current_model.save_pretrained_merged( + save_directory, self.current_tokenizer, save_method = save_method + ) - # Write export metadata so the Chat page can identify the base model self._write_export_metadata(save_directory) logger.info(f"Model saved successfully to {save_directory}") output_path = str(Path(save_directory).resolve()) @@ -376,17 +391,40 @@ def export_merged_model( logger.info(f"Pushing merged model to Hub: {repo_id}") - # Whisper uses save_method=None for local but "merged_16bit" for hub push - hub_save_method = ( - save_method if save_method is not None else "merged_16bit" - ) - self.current_model.push_to_hub_merged( - repo_id, - self.current_tokenizer, - save_method = hub_save_method, - token = hf_token, - private = private, - ) + if _IS_MLX: + if save_directory: + self.current_model.push_to_hub_merged( + repo_id, + self.current_tokenizer, + save_directory = save_directory, + token = hf_token, + private = private, + ) + else: + with tempfile.TemporaryDirectory() as tmp_dir: + self.current_model.save_pretrained_merged( + tmp_dir, + self.current_tokenizer, + save_method = mlx_save_method, + ) + self.current_model.push_to_hub_merged( + repo_id, + self.current_tokenizer, + save_directory = tmp_dir, + token = hf_token, + private = private, + ) + else: + hub_save_method = ( + save_method if save_method is not None else "merged_16bit" + ) + self.current_model.push_to_hub_merged( + repo_id, + self.current_tokenizer, + save_method = hub_save_method, + token = hf_token, + private = private, + ) logger.info(f"Model pushed successfully to {repo_id}") return True, "Model exported successfully", output_path @@ -411,9 +449,7 @@ def export_base_model( Export base model (for non-PEFT models). Returns: - Tuple of (success, message, output_path). output_path is the - resolved absolute on-disk directory of the saved model when - ``save_directory`` was set, else None. + Tuple of (success: bool, message: str, output_path: Optional[str]) """ if not self.current_model or not self.current_tokenizer: return False, "No model loaded. Please select a checkpoint first.", None @@ -433,8 +469,16 @@ def export_base_model( logger.info(f"Saving base model locally to: {save_directory}") ensure_dir(Path(save_directory)) - self.current_model.save_pretrained(save_directory) - self.current_tokenizer.save_pretrained(save_directory) + if _IS_MLX: + # MLX: save_pretrained_merged handles non-LoRA models too + # (fuse() is a no-op when there are no LoRA layers) + self.current_model.save_pretrained_merged( + save_directory, + self.current_tokenizer, + ) + else: + self.current_model.save_pretrained(save_directory) + self.current_tokenizer.save_pretrained(save_directory) # Write export metadata so the Chat page can identify the base model self._write_export_metadata(save_directory) @@ -452,44 +496,73 @@ def export_base_model( logger.info(f"Pushing base model to Hub: {repo_id}") - # Get base model name from request or model config - base_model = ( - base_model_id - or self.current_model.config._name_or_path - or "unknown" - ) - - # Create repo - hf_api = HfApi(token = hf_token) - repo_id = PushToHubMixin._create_repo( - PushToHubMixin, - repo_id = repo_id, - private = private, - token = hf_token, - ) - username = repo_id.split("/")[0] - - # Create and push model card - content = MODEL_CARD.format( - username = username, - base_model = base_model, - model_type = self.current_model.config.model_type, - method = "", - extra = "unsloth", - ) - card = ModelCard(content) - card.push_to_hub( - repo_id, token = hf_token, commit_message = "Unsloth Model Card" - ) + if _IS_MLX: + if save_directory: + self.current_model.push_to_hub_merged( + repo_id, + self.current_tokenizer, + save_directory = save_directory, + token = hf_token, + private = private, + ) + else: + with tempfile.TemporaryDirectory() as tmp_dir: + self.current_model.save_pretrained_merged( + tmp_dir, + self.current_tokenizer, + ) + self.current_model.push_to_hub_merged( + repo_id, + self.current_tokenizer, + save_directory = tmp_dir, + token = hf_token, + private = private, + ) + else: + # Get base model name from request or model config + base_model = ( + base_model_id + or self.current_model.config._name_or_path + or "unknown" + ) - # Upload model files - if save_directory: - hf_api.upload_folder( - folder_path = save_directory, repo_id = repo_id, repo_type = "model" + # Create repo + hf_api = HfApi(token = hf_token) + repo_id = PushToHubMixin._create_repo( + PushToHubMixin, + repo_id = repo_id, + private = private, + token = hf_token, ) - logger.info(f"Model pushed successfully to {repo_id}") - else: - return False, "Local save directory required for Hub upload", None + username = repo_id.split("/")[0] + + # Create and push model card + content = MODEL_CARD.format( + username = username, + base_model = base_model, + model_type = self.current_model.config.model_type, + method = "", + extra = "unsloth", + ) + card = ModelCard(content) + card.push_to_hub( + repo_id, token = hf_token, commit_message = "Unsloth Model Card" + ) + + # Upload model files + if save_directory: + hf_api.upload_folder( + folder_path = save_directory, + repo_id = repo_id, + repo_type = "model", + ) + logger.info(f"Model pushed successfully to {repo_id}") + else: + return ( + False, + "Local save directory required for Hub upload", + None, + ) return True, "Model exported successfully", output_path @@ -519,9 +592,7 @@ def export_gguf( hf_token: Hugging Face token Returns: - Tuple of (success, message, output_path). output_path is the - resolved absolute on-disk directory containing the .gguf - files when ``save_directory`` was set, else None. + Tuple of (success: bool, message: str, output_path: Optional[str]) """ if not self.current_model or not self.current_tokenizer: return False, "No model loaded. Please select a checkpoint first.", None @@ -692,9 +763,7 @@ def export_lora_adapter( Export LoRA adapter only (not merged). Returns: - Tuple of (success, message, output_path). output_path is the - resolved absolute on-disk directory of the saved adapter - when ``save_directory`` was set, else None. + Tuple of (success: bool, message: str, output_path: Optional[str]) """ if not self.current_model or not self.current_tokenizer: return False, "No model loaded. Please select a checkpoint first.", None @@ -710,8 +779,13 @@ def export_lora_adapter( logger.info(f"Saving LoRA adapter locally to: {save_directory}") ensure_dir(Path(save_directory)) - self.current_model.save_pretrained(save_directory) - self.current_tokenizer.save_pretrained(save_directory) + if _IS_MLX: + # MLX: save adapters.safetensors + tokenizer files + self.current_model.save_lora_adapters(save_directory) + self.current_tokenizer.save_pretrained(save_directory) + else: + self.current_model.save_pretrained(save_directory) + self.current_tokenizer.save_pretrained(save_directory) logger.info(f"Adapter saved successfully to {save_directory}") output_path = str(Path(save_directory).resolve()) @@ -726,10 +800,24 @@ def export_lora_adapter( logger.info(f"Pushing LoRA adapter to Hub: {repo_id}") - self.current_model.push_to_hub(repo_id, token = hf_token, private = private) - self.current_tokenizer.push_to_hub( - repo_id, token = hf_token, private = private - ) + if _IS_MLX: + with tempfile.TemporaryDirectory() as tmp_dir: + self.current_model.save_lora_adapters(tmp_dir) + self.current_tokenizer.save_pretrained(tmp_dir) + hf_api = HfApi(token = hf_token) + hf_api.create_repo(repo_id, private = private, exist_ok = True) + hf_api.upload_folder( + folder_path = tmp_dir, + repo_id = repo_id, + repo_type = "model", + ) + else: + self.current_model.push_to_hub( + repo_id, token = hf_token, private = private + ) + self.current_tokenizer.push_to_hub( + repo_id, token = hf_token, private = private + ) logger.info(f"Adapter pushed successfully to {repo_id}") return True, "LoRA adapter exported successfully", output_path diff --git a/studio/backend/core/inference/mlx_inference.py b/studio/backend/core/inference/mlx_inference.py new file mode 100644 index 0000000000..1d2b03ecb9 --- /dev/null +++ b/studio/backend/core/inference/mlx_inference.py @@ -0,0 +1,395 @@ +# SPDX-License-Identifier: AGPL-3.0-only +"""MLX inference backend for Apple Silicon. + +Drop-in replacement for InferenceBackend — same interface, uses mlx-lm/mlx-vlm +instead of torch/transformers for model loading and generation. +""" + +import threading +from typing import Optional, Generator +from loggers import get_logger + +logger = get_logger(__name__) + + +class MLXInferenceBackend: + def __init__(self): + self.models = {} + self.active_model_name = None + self.loading_models = set() + self.loaded_local_models = [] + self.device = "mlx" + self._generation_lock = threading.Lock() + + # MLX state + self._model = None + self._tokenizer = None + self._processor = None + self._is_vlm = False + self._config = {} + + # Recorded for unload to release pinned memory back to the OS. + self._memory_limits_applied = {} + + def _configure_memory_limits(self): + """Apply Metal memory caps before loading a model. + + Mirrors MLXTrainer._configure_memory_limits's defaults: + memory_limit = 85% of recommended working-set, + wired_limit = min(recommended, memory_limit). Recorded so unload + can lower wired_limit back to release pinned RAM. + """ + import mlx.core as mx + + if not mx.metal.is_available(): + return + info = mx.device_info() + rec_bytes = info.get("max_recommended_working_set_size") + if not rec_bytes or rec_bytes <= 0: + return + rec_gb = rec_bytes / 1e9 + memory_limit_gb = rec_gb * 0.85 + wired_limit_gb = min(rec_gb, memory_limit_gb) + mx.set_memory_limit(int(memory_limit_gb * 1e9)) + mx.set_wired_limit(int(wired_limit_gb * 1e9)) + self._memory_limits_applied = { + "memory_limit_gb": memory_limit_gb, + "wired_limit_gb": wired_limit_gb, + "recommended_gb": rec_gb, + } + logger.info( + "MLX memory caps: memory_limit=%.2f GB, wired_limit=%.2f GB", + memory_limit_gb, + wired_limit_gb, + ) + + def load_model( + self, + config, + max_seq_length = 2048, + load_in_4bit = True, + hf_token = None, + trust_remote_code = False, + gpu_ids = None, + dtype = None, + ) -> bool: + import mlx.core as mx + + model_name = config.identifier if hasattr(config, "identifier") else str(config) + is_vision = getattr(config, "is_vision", False) + + if hf_token: + import os + + os.environ["HF_TOKEN"] = hf_token + self._configure_memory_limits() + + is_lora = getattr(config, "is_lora", False) + + logger.info( + "Loading %s via %s (is_lora=%s)", + model_name, + "mlx-vlm" if is_vision else "mlx-lm", + is_lora, + ) + + try: + from unsloth_zoo.mlx_loader import FastMLXModel + except ImportError as e: + raise ImportError( + "Unsloth: MLX inference requires unsloth-zoo with the MLX modules " + "(unsloth_zoo.mlx_loader). Reinstall via install.sh on Apple Silicon." + ) from e + + model, tokenizer_or_processor = FastMLXModel.from_pretrained( + model_name, + max_seq_length = max_seq_length, + dtype = dtype, + load_in_4bit = load_in_4bit, + token = hf_token, + trust_remote_code = trust_remote_code, + text_only = False if is_vision else True, + ) + + if is_vision: + processor = tokenizer_or_processor + self._model = model + self._processor = processor + self._tokenizer = getattr(processor, "tokenizer", processor) + self._is_vlm = True + else: + tokenizer = tokenizer_or_processor + self._model = model + self._tokenizer = tokenizer + self._processor = None + self._is_vlm = False + + self.active_model_name = model_name + self.models[model_name] = { + "model": self._model, + "tokenizer": self._tokenizer, + "processor": self._processor, + "is_vision": is_vision, + "is_lora": getattr(config, "is_lora", False), + "is_audio": False, + "audio_type": None, + "has_audio_input": False, + } + + logger.info("Model %s loaded successfully", model_name) + return True + + def unload_model(self, model_name: str) -> bool: + import mlx.core as mx + import gc + + if model_name in self.models: + del self.models[model_name] + self._model = None + self._tokenizer = None + self._processor = None + if self.active_model_name == model_name: + self.active_model_name = None + gc.collect() + mx.clear_cache() + + if mx.metal.is_available() and self._memory_limits_applied and not self.models: + try: + mx.set_wired_limit(0) + logger.info("MLX wired_limit released back to OS on unload") + except Exception as e: + logger.warning("Failed to release wired_limit: %s", e) + self._memory_limits_applied = {} + logger.info("Model %s unloaded", model_name) + return True + + def generate_chat_response( + self, + messages, + system_prompt = "", + image = None, + temperature = 0.7, + top_p = 0.9, + top_k = 40, + min_p = 0.0, + max_new_tokens = 256, + repetition_penalty = 1.0, + cancel_event = None, + ) -> Generator[str, None, None]: + if self._model is None: + raise RuntimeError("No model loaded") + + # Build messages with system prompt + full_messages = [] + if system_prompt: + full_messages.append({"role": "system", "content": system_prompt}) + full_messages.extend(messages) + + # Inject image into the last user message for VLM + if self._is_vlm and image is not None: + for msg in reversed(full_messages): + if msg.get("role") == "user": + content = msg.get("content", "") + if isinstance(content, str): + msg["content"] = [ + {"type": "image"}, + {"type": "text", "text": content}, + ] + elif isinstance(content, list): + # Prepend image if not already there + has_image = any( + p.get("type") == "image" + for p in content + if isinstance(p, dict) + ) + if not has_image: + content.insert(0, {"type": "image"}) + break + + if self._is_vlm: + yield from self._generate_vlm( + full_messages, + image, + temperature, + top_p, + top_k, + min_p, + max_new_tokens, + repetition_penalty, + cancel_event, + ) + else: + yield from self._generate_text( + full_messages, + temperature, + top_p, + top_k, + min_p, + max_new_tokens, + repetition_penalty, + cancel_event, + ) + + def _generate_text( + self, + messages, + temperature, + top_p, + top_k, + min_p, + max_new_tokens, + repetition_penalty, + cancel_event, + ): + from mlx_lm import stream_generate + from mlx_lm.sample_utils import make_sampler, make_logits_processors + + prompt = self._tokenizer.apply_chat_template( + messages, + tokenize = False, + add_generation_prompt = True, + ) + if prompt is None: + raise RuntimeError( + "apply_chat_template returned None — tokenizer may be incompatible" + ) + + sampler = make_sampler( + temp = temperature, + top_p = top_p, + top_k = int(top_k or 0), + min_p = float(min_p or 0.0), + min_tokens_to_keep = 1, + ) + # Only build a logits processor when we actually have a non-trivial + # repetition penalty (1.0 is the no-op value). + logits_processors = None + if repetition_penalty is not None and float(repetition_penalty) not in ( + 0.0, + 1.0, + ): + logits_processors = make_logits_processors( + repetition_penalty = float(repetition_penalty), + ) + + token_ids = [] + logger.info( + "Generating: prompt_len=%d, max_tokens=%d, model=%s, tokenizer=%s", + len(prompt), + max_new_tokens, + type(self._model).__name__, + type(self._tokenizer).__name__, + ) + with self._generation_lock: + try: + gen_kwargs = dict( + prompt = prompt, + max_tokens = max_new_tokens, + sampler = sampler, + ) + if logits_processors is not None: + gen_kwargs["logits_processors"] = logits_processors + for response in stream_generate( + self._model, + self._tokenizer, + **gen_kwargs, + ): + token_ids.append(response.token) + # Decode full sequence with skip_special_tokens — same as GPU + cumulative = self._tokenizer.decode( + token_ids, + skip_special_tokens = True, + ) + yield cumulative + + if cancel_event and cancel_event.is_set(): + break + except Exception as e: + import traceback + + logger.error("stream_generate failed:\n%s", traceback.format_exc()) + raise + + def _generate_vlm( + self, + messages, + image, + temperature, + top_p, + top_k, + min_p, + max_new_tokens, + repetition_penalty, + cancel_event, + ): + from mlx_vlm import stream_generate as vlm_stream + + # Apply chat template + chat_fn = getattr(self._processor, "apply_chat_template", None) + if ( + chat_fn is None + or not hasattr(self._processor, "chat_template") + or self._processor.chat_template is None + ): + tok = getattr(self._processor, "tokenizer", self._processor) + chat_fn = tok.apply_chat_template + + prompt = chat_fn(messages, tokenize = False, add_generation_prompt = True) + + # For VLM: always use mlx_vlm's stream_generate which handles + # pixel_values properly (passes None for text-only, image for VLM) + images = [image] if image is not None else None + + cumulative = "" + logger.info( + "VLM generating: prompt_len=%d, has_image=%s", + len(prompt), + image is not None, + ) + # mlx_vlm.stream_generate forwards **kwargs into generate_step, which + # accepts temp/top_p/top_k/repetition_penalty (and builds the sampler + # + logits_processors internally). Pass them through. + # NOTE: mlx_vlm.generate_step expects ``temperature=`` (long form) — + # passing ``temp=`` silently falls into **kwargs and is ignored, + # leaving generation stuck at the default 0.0 (greedy). + vlm_kwargs = dict( + max_tokens = max_new_tokens, + temperature = temperature, + top_p = top_p, + top_k = int(top_k or 0), + min_p = float(min_p or 0.0), + ) + if repetition_penalty is not None and float(repetition_penalty) not in ( + 0.0, + 1.0, + ): + vlm_kwargs["repetition_penalty"] = float(repetition_penalty) + + with self._generation_lock: + for response in vlm_stream( + self._model, + self._processor, + prompt, + images, + **vlm_kwargs, + ): + token_text = ( + response.text if hasattr(response, "text") else str(response) + ) + cumulative += token_text + yield cumulative + if cancel_event and cancel_event.is_set(): + break + + def generate_with_adapter_control( + self, use_adapter = None, cancel_event = None, **gen_kwargs + ) -> Generator[str, None, None]: + # MLX LoRA adapter toggling not yet supported — generate normally + yield from self.generate_chat_response(cancel_event = cancel_event, **gen_kwargs) + + def reset_generation_state(self): + import mlx.core as mx + import gc + + gc.collect() + mx.clear_cache() diff --git a/studio/backend/core/inference/worker.py b/studio/backend/core/inference/worker.py index fbcce276ba..085a1ab899 100644 --- a/studio/backend/core/inference/worker.py +++ b/studio/backend/core/inference/worker.py @@ -663,6 +663,98 @@ def run_inference_process( model_name = config["model_name"] + # ── 0. MLX fast-path — skip torch/transformers entirely ── + backend_path = str(Path(__file__).resolve().parent.parent.parent) + if backend_path not in sys.path: + sys.path.insert(0, backend_path) + + from utils.hardware import hardware as _hw + + _hw.detect_hardware() + if _hw.DEVICE == _hw.DeviceType.MLX: + try: + _activate_transformers_version(model_name) + except Exception: + pass + try: + from core.inference.mlx_inference import MLXInferenceBackend + + backend = MLXInferenceBackend() + _send_response( + resp_queue, + {"type": "status", "message": "Loading model...", "ts": time.time()}, + ) + _handle_load(backend, config, resp_queue) + except Exception as exc: + _send_response( + resp_queue, + { + "type": "error", + "error": f"MLX inference init failed: {exc}", + "stack": traceback.format_exc(limit = 20), + "ts": time.time(), + }, + ) + return + + # Enter same command loop as GPU path + logger.info("MLX inference subprocess ready, entering command loop") + while True: + try: + cmd = cmd_queue.get(timeout = 1.0) + except _queue.Empty: + continue + except (EOFError, OSError): + return + if cmd is None: + continue + cmd_type = cmd.get("type", "") + try: + if cmd_type == "generate": + cancel_event.clear() + _handle_generate(backend, cmd, resp_queue, cancel_event) + elif cmd_type == "load": + if backend.active_model_name: + backend.unload_model(backend.active_model_name) + _handle_load(backend, cmd, resp_queue) + elif cmd_type == "unload": + _handle_unload(backend, cmd, resp_queue) + elif cmd_type == "cancel": + cancel_event.set() + elif cmd_type == "reset": + cancel_event.set() + backend.reset_generation_state() + _send_response(resp_queue, {"type": "reset_ack", "ts": time.time()}) + elif cmd_type == "status": + _send_response( + resp_queue, + { + "type": "status_response", + "active_model": backend.active_model_name, + "models": { + k: {kk: vv for kk, vv in v.items() if kk != "model"} + for k, v in backend.models.items() + }, + "loading": list(backend.loading_models), + "ts": time.time(), + }, + ) + elif cmd_type == "shutdown": + return + except Exception as exc: + logger.error("MLX command error (%s): %s", cmd_type, exc) + _send_response( + resp_queue, + { + "type": "gen_error" if cmd_type == "generate" else "error", + "request_id": cmd.get("request_id"), + "error": str(exc), + "stack": traceback.format_exc(limit = 20), + "ts": time.time(), + }, + ) + return + # ── 1. Activate correct transformers version BEFORE any ML imports ── try: _activate_transformers_version(model_name) diff --git a/studio/backend/core/training/training.py b/studio/backend/core/training/training.py index 5642faa189..a04ad5ef49 100644 --- a/studio/backend/core/training/training.py +++ b/studio/backend/core/training/training.py @@ -62,6 +62,7 @@ class TrainingProgress: grad_norm: Optional[float] = None num_tokens: Optional[int] = None eval_loss: Optional[float] = None + peak_memory_gb: Optional[float] = None class TrainingBackend: @@ -199,21 +200,27 @@ def start_training(self, job_id: str, **kwargs) -> bool: config["load_in_4bit"] = False # Spawn subprocess — use locals so state is untouched on failure - resolved_gpu_ids, gpu_selection = prepare_gpu_selection( - kwargs.get("gpu_ids"), - model_name = config["model_name"], - hf_token = config["hf_token"] or None, - training_type = config["training_type"], - load_in_4bit = config["load_in_4bit"], - batch_size = config.get("batch_size", 4), - max_seq_length = config.get("max_seq_length", 2048), - lora_rank = config.get("lora_r", 16), - target_modules = config.get("target_modules"), - gradient_checkpointing = config.get("gradient_checkpointing", "unsloth"), - optimizer = config.get("optim", "adamw_8bit"), - ) - config["resolved_gpu_ids"] = resolved_gpu_ids - config["gpu_selection"] = gpu_selection + from utils.hardware import hardware as _hw + + if _hw.DEVICE == _hw.DeviceType.MLX: + config["resolved_gpu_ids"] = None + config["gpu_selection"] = None + else: + resolved_gpu_ids, gpu_selection = prepare_gpu_selection( + kwargs.get("gpu_ids"), + model_name = config["model_name"], + hf_token = config["hf_token"] or None, + training_type = config["training_type"], + load_in_4bit = config["load_in_4bit"], + batch_size = config.get("batch_size", 4), + max_seq_length = config.get("max_seq_length", 2048), + lora_rank = config.get("lora_r", 16), + target_modules = config.get("target_modules"), + gradient_checkpointing = config.get("gradient_checkpointing", "unsloth"), + optimizer = config.get("optim", "adamw_8bit"), + ) + config["resolved_gpu_ids"] = resolved_gpu_ids + config["gpu_selection"] = gpu_selection from .worker import run_training_process @@ -512,6 +519,12 @@ def _handle_event(self, event: dict) -> None: self._progress.grad_norm = event.get("grad_norm") self._progress.num_tokens = event.get("num_tokens") self._progress.eval_loss = event.get("eval_loss") + _peak = event.get("peak_memory_gb") + if _peak is not None: + try: + self._progress.peak_memory_gb = float(_peak) + except (TypeError, ValueError): + pass self._progress.is_training = True status = event.get("status_message", "") if status: diff --git a/studio/backend/core/training/worker.py b/studio/backend/core/training/worker.py index 60b9e994ab..9c017db6de 100644 --- a/studio/backend/core/training/worker.py +++ b/studio/backend/core/training/worker.py @@ -338,6 +338,594 @@ def _activate_transformers_version(model_name: str) -> None: activate_transformers_for_subprocess(model_name) +def _adapt_for_mlx_vlm(items): + """Adapt GPU-path VLM dataset output for mlx-vlm consumption. + + The GPU path embeds PIL images inside messages content as + {"type": "image", "image": PIL_Image}. mlx-vlm's prepare_inputs + needs images at top-level to produce pixel_values — regardless of + model type. Extract them and leave bare {"type": "image"} placeholders. + """ + adapted = [] + for item in items: + images = [] + messages = [] + for msg in item.get("messages", []): + content = msg.get("content", "") + if isinstance(content, list): + new_content = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "image": + img = part.get("image") + if img is not None: + images.append(img) + new_content.append({"type": "image"}) + else: + new_content.append(part) + messages.append({"role": msg["role"], "content": new_content}) + else: + messages.append(msg) + out = {"messages": messages} + if images: + out["image"] = images[0] if len(images) == 1 else images + elif "image" in item: + out["image"] = item["image"] + elif "images" in item: + out["images"] = item["images"] + adapted.append(out) + return adapted + + +_MLX_STUDIO_OPTIM_MAP = { + "adamw_8bit": "adamw", + "paged_adamw_8bit": "adamw", + "adamw_bnb_8bit": "adamw", + "paged_adamw_32bit": "adamw", + "adamw_torch": "adamw", + "adamw_torch_fused": "adamw", + "adamw": "adamw", + "adafactor": "adafactor", + "sgd": "sgd", + "adam": "adam", + "muon": "muon", + "lion": "lion", +} +_MLX_STUDIO_LR_SCHEDULERS = {"linear", "cosine", "constant"} + + +def _normalize_mlx_studio_optimizer(value): + raw = str(value or "adamw_8bit").strip().lower() + try: + return _MLX_STUDIO_OPTIM_MAP[raw] + except KeyError: + supported = ", ".join(sorted(_MLX_STUDIO_OPTIM_MAP)) + raise ValueError( + f"Unsupported optimizer for MLX training: {value!r}. " + f"Supported values: {supported}." + ) + + +def _normalize_mlx_studio_scheduler(value): + raw = str(value or "linear").strip().lower() + if raw not in _MLX_STUDIO_LR_SCHEDULERS: + supported = ", ".join(sorted(_MLX_STUDIO_LR_SCHEDULERS)) + raise ValueError( + f"Unsupported LR scheduler for MLX training: {value!r}. " + f"Supported values: {supported}." + ) + return raw + + +def _run_mlx_training(event_queue, stop_queue, config): + """Self-contained MLX training path for Apple Silicon. + + Uses MLXTrainer from unsloth_zoo directly -- no torch/SFTTrainer needed. + Mirrors the event_queue protocol so the parent process pump works unchanged. + """ + import time + import gc + import math + import threading + import queue as _queue + from pathlib import Path + + def _send(event_type, **kwargs): + if event_type == "status" and "message" not in kwargs: + sm = kwargs.get("status_message") + if sm is not None: + kwargs["message"] = sm + event_queue.put({"type": event_type, "ts": time.time(), **kwargs}) + + _send("status", status_message = "Loading MLX libraries...") + + import mlx.core as mx + + try: + from unsloth_zoo.mlx_loader import FastMLXModel + from unsloth_zoo.mlx_trainer import ( + MLXTrainer, + MLXTrainingConfig, + train_on_responses_only, + ) + except ImportError as e: + raise ImportError( + "Unsloth: MLX training requires unsloth-zoo with the MLX modules " + "(unsloth_zoo.mlx_loader / unsloth_zoo.mlx_trainer). Reinstall via " + "install.sh on Apple Silicon." + ) from e + from datasets import load_dataset + + if mx.metal.is_available(): + info = mx.device_info() + rec_bytes = info.get("max_recommended_working_set_size", 0) or 0 + if rec_bytes > 0: + memory_cap = int(rec_bytes * 0.85) + wired_cap = min(int(rec_bytes), memory_cap) + mx.set_memory_limit(memory_cap) + mx.set_wired_limit(wired_cap) + + model_name = config["model_name"] + hf_token = config.get("hf_token") or None + if hf_token: + os.environ["HF_TOKEN"] = hf_token + + if config.get("use_loftq"): + message = "LoftQ is not supported for MLX training yet." + _send("error", error = message) + raise NotImplementedError(message) + + optim_name = _normalize_mlx_studio_optimizer(config.get("optim", "adamw_8bit")) + lr_scheduler_type = _normalize_mlx_studio_scheduler( + config.get("lr_scheduler_type", "linear") + ) + + # ── 1. Load model ── + # Force text-only if the dataset is not an image dataset, even if the model + # has vision capabilities (e.g. Qwen3.5-VL trained on plain alpaca text). + _send("status", status_message = f"Loading {model_name}...") + is_dataset_image = bool(config.get("is_dataset_image", False)) + training_type = config.get("training_type", "LoRA/QLoRA") + use_lora = training_type == "LoRA/QLoRA" + model, tokenizer = FastMLXModel.from_pretrained( + model_name, + load_in_4bit = config.get("load_in_4bit", True), + full_finetuning = not use_lora, + text_only = None if is_dataset_image else True, + token = hf_token, + trust_remote_code = bool(config.get("trust_remote_code", False)), + random_state = config.get("random_seed", 3407), + ) + + is_vlm = bool(is_dataset_image and getattr(model, "_is_vlm_model", False)) + model._is_vlm_model = is_vlm + + # ── 2. Apply LoRA / full FT ── + # Pass gradient_checkpointing as string ("mlx"/"unsloth"/"none"/etc.) + # get_peft_model and MLXTrainer both accept strings and handle them. + gc_setting = config.get("gradient_checkpointing", "mlx") + if isinstance(gc_setting, str): + use_grad_checkpoint = ( + gc_setting if gc_setting.lower() not in ("false", "") else False + ) + else: + use_grad_checkpoint = gc_setting + + if use_lora: + _send("status", status_message = "Configuring LoRA adapters...") + peft_kwargs = dict( + r = config.get("lora_r", 16), + lora_alpha = config.get("lora_alpha", 16), + lora_dropout = config.get("lora_dropout", 0.0), + use_rslora = config.get("use_rslora", False), + init_lora_weights = config.get("init_lora_weights", True), + random_state = config.get("random_seed", 3407), + target_modules = config.get("target_modules") + or [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + use_gradient_checkpointing = use_grad_checkpoint, + ) + finetune_language = config.get("finetune_language_layers", True) + finetune_attention = config.get("finetune_attention_modules", True) + finetune_mlp = config.get("finetune_mlp_modules", True) + finetune_vision = ( + config.get("finetune_vision_layers", False) if is_vlm else False + ) + + if ( + (finetune_attention or finetune_mlp) + and not finetune_language + and not finetune_vision + ): + finetune_language = True + + peft_kwargs["finetune_language_layers"] = finetune_language + peft_kwargs["finetune_attention_modules"] = finetune_attention + peft_kwargs["finetune_mlp_modules"] = finetune_mlp + if is_vlm: + peft_kwargs["finetune_vision_layers"] = finetune_vision + model = FastMLXModel.get_peft_model(model, **peft_kwargs) + + # ── 3. Load dataset ── + _send("status", status_message = "Loading dataset...") + hf_dataset = config.get("hf_dataset", "") + subset = config.get("subset") + train_split = config.get("train_split", "train") or "train" + eval_split = config.get("eval_split") + slice_start = config.get("dataset_slice_start") + slice_end = config.get("dataset_slice_end") + + def _slice(ds): + if slice_start is not None or slice_end is not None: + start = slice_start if slice_start is not None else 0 + end = slice_end if slice_end is not None else len(ds) - 1 + if end < start: + return ds.select([]) + ds = ds.select(range(start, min(end + 1, len(ds)))) + return ds + + def _load_local(file_paths): + from core.training.trainer import UnslothTrainer + from datasets import load_from_disk + + if len(file_paths) == 1: + p = Path(file_paths[0]) + if p.is_dir() and ( + (p / "dataset_info.json").exists() or (p / "state.json").exists() + ): + return load_from_disk(str(p)) + all_files = UnslothTrainer._resolve_local_files(file_paths) + if not all_files: + raise ValueError("No local dataset files found") + loader = UnslothTrainer._loader_for_files(all_files) + return load_dataset(loader, data_files = all_files, split = "train") + + if hf_dataset: + load_kwargs = {"split": train_split, "token": hf_token} + if subset: + load_kwargs["name"] = subset + dataset = load_dataset(hf_dataset, **load_kwargs) + dataset = _slice(dataset) + elif config.get("local_datasets"): + dataset = _load_local(config["local_datasets"]) + dataset = _slice(dataset) + else: + raise ValueError("No dataset specified") + + # Eval dataset (separate split or local file) + eval_dataset = None + if eval_split and hf_dataset: + eval_kwargs = {"split": eval_split, "token": hf_token} + if subset: + eval_kwargs["name"] = subset + try: + eval_dataset = load_dataset(hf_dataset, **eval_kwargs) + except Exception as e: + _send("status", status_message = f"Eval split load failed: {e}") + eval_dataset = None + elif config.get("local_eval_datasets"): + eval_dataset = _load_local(config["local_eval_datasets"]) + + # ── 3b. Format dataset (VLM or text) ── + # Reuse the GPU path's format pipeline for both VLM (auto-detects OCR/caption/ + # llava/sharegpt+images) and text (alpaca/sharegpt/chatml → "text" column). + format_type = config.get("format_type", "") + try: + from utils.datasets import format_and_template_dataset + + def _fmt_progress(status_message = "", **_kw): + _send("status", status_message = status_message) + + if is_vlm: + _send("status", status_message = "Formatting VLM dataset...") + vlm_info = format_and_template_dataset( + dataset, + model_name = model_name, + tokenizer = tokenizer, + is_vlm = True, + dataset_name = hf_dataset or "local", + progress_callback = _fmt_progress, + ) + if vlm_info.get("success"): + dataset = _adapt_for_mlx_vlm(vlm_info["dataset"]) + else: + errors = vlm_info.get("errors", []) + raise ValueError( + f"VLM dataset format conversion failed: {'; '.join(errors)}" + ) + if eval_dataset is not None: + ev_info = format_and_template_dataset( + eval_dataset, + model_name = model_name, + tokenizer = tokenizer, + is_vlm = True, + dataset_name = hf_dataset or "local", + ) + if ev_info.get("success"): + eval_dataset = _adapt_for_mlx_vlm(ev_info["dataset"]) + + elif format_type: + _send("status", status_message = f"Formatting dataset ({format_type})...") + info = format_and_template_dataset( + dataset, + model_name = model_name, + tokenizer = tokenizer, + is_vlm = False, + format_type = format_type, + dataset_name = hf_dataset or "local", + ) + if info.get("success", True): + dataset = info.get("dataset", dataset) + if eval_dataset is not None: + ev = format_and_template_dataset( + eval_dataset, + model_name = model_name, + tokenizer = tokenizer, + is_vlm = False, + format_type = format_type, + dataset_name = hf_dataset or "local", + ) + if ev.get("success", True): + eval_dataset = ev.get("dataset", eval_dataset) + except ImportError: + _send("status", status_message = "Format helper unavailable, using raw dataset") + + # ── 4. Resolve training steps ── + max_steps = config.get("max_steps", 0) or 0 + num_epochs = config.get("num_epochs", 3) + max_seq_length = config.get("max_seq_length", 2048) + batch_size = config.get("batch_size", 4) + grad_accum = config.get("gradient_accumulation_steps", 4) + + if max_steps <= 0: + max_steps = max( + 1, + math.ceil(len(dataset) / batch_size / grad_accum) * num_epochs, + ) + + lr_value = float(config.get("learning_rate", "2e-4")) + + # Warmup: prefer warmup_steps; fall back to warmup_ratio + warmup_steps = config.get("warmup_steps") + warmup_ratio = config.get("warmup_ratio") + if warmup_steps is None and warmup_ratio is not None: + warmup_steps = int(round(warmup_ratio * max_steps)) + if warmup_steps is None: + warmup_steps = 5 + + # ── 5. Build output dir ── + output_dir = config.get("output_dir", "") + if not output_dir: + output_dir = f"{model_name.replace('/', '_')}_{int(time.time())}" + # Resolve to ~/.unsloth/studio/outputs/ so the export page can find it + from utils.paths import resolve_output_dir, ensure_dir + + output_dir = str(resolve_output_dir(output_dir)) + ensure_dir(Path(output_dir)) + + # ── 6. Create trainer ── + eval_steps_val = config.get("eval_steps", 0) or 0 + if isinstance(eval_steps_val, float) and 0 < eval_steps_val < 1: + # Studio sometimes sends fraction-of-total-steps + eval_steps_val = max(1, int(eval_steps_val * max_steps)) + else: + eval_steps_val = int(eval_steps_val) + + trainer = MLXTrainer( + model = model, + tokenizer = tokenizer, + train_dataset = dataset, + eval_dataset = eval_dataset, + args = MLXTrainingConfig( + per_device_train_batch_size = batch_size, + gradient_accumulation_steps = grad_accum, + max_steps = max_steps, + learning_rate = lr_value, + warmup_steps = warmup_steps, + lr_scheduler_type = lr_scheduler_type, + optim = optim_name, + weight_decay = float(config.get("weight_decay", 0.001) or 0.001), + logging_steps = 1, + max_seq_length = max_seq_length, + seed = config.get("random_seed", 3407), + use_cce = True, + compile = True, + gradient_checkpointing = use_grad_checkpoint, + streaming = is_vlm, + packing = bool(config.get("packing", False)), + output_dir = output_dir, + save_steps = int(config.get("save_steps", 0) or 0), + eval_steps = eval_steps_val, + ), + ) + + # Tell the parent that eval is configured so the frontend shows the eval chart + if eval_dataset is not None and eval_steps_val > 0: + _send("eval_configured") + + # ── 7. Apply train_on_responses_only if requested ── + if config.get("train_on_completions", False): + _send("status", status_message = "Configuring response-only training...") + try: + from utils.datasets import ( + MODEL_TO_TEMPLATE_MAPPER, + TEMPLATE_TO_RESPONSES_MAPPER, + ) + + template_name = MODEL_TO_TEMPLATE_MAPPER.get(model_name.lower()) + markers = ( + TEMPLATE_TO_RESPONSES_MAPPER.get(template_name) + if template_name + else None + ) + if markers: + trainer = train_on_responses_only( + trainer, + instruction_part = markers["instruction"], + response_part = markers["response"], + ) + else: + _send( + "status", + status_message = f"train_on_completions skipped (no template for {model_name})", + ) + except Exception as e: + _send("status", status_message = f"train_on_completions failed: {e}") + + # ── 8. Setup wandb / tensorboard ── + wandb_run = None + tb_writer = None + if config.get("enable_wandb", False): + try: + import wandb as _wandb + + wandb_token = config.get("wandb_token") + if wandb_token: + os.environ["WANDB_API_KEY"] = wandb_token + _wandb_sensitive = {"hf_token", "wandb_token"} + wandb_run = _wandb.init( + project = config.get("wandb_project") or "unsloth-mlx", + config = {k: v for k, v in config.items() if k not in _wandb_sensitive}, + reinit = True, + ) + except Exception as e: + _send("status", status_message = f"wandb init failed: {e}") + if config.get("enable_tensorboard", False): + try: + from tensorboardX import SummaryWriter + except ImportError: + try: + from torch.utils.tensorboard import SummaryWriter + except ImportError: + SummaryWriter = None + if SummaryWriter is not None: + try: + tb_dir = config.get("tensorboard_dir") or f"{output_dir}/runs" + tb_writer = SummaryWriter(log_dir = tb_dir) + except Exception as e: + _send("status", status_message = f"tensorboard init failed: {e}") + else: + _send( + "status", + status_message = "tensorboard unavailable (install tensorboardX)", + ) + + # ── 9. Real-time progress callback ── + _send("status", status_message = f"Training {model_name}...") + + def _on_step(step, total, loss, lr, tok_s, peak_gb, elapsed, num_tokens): + eta = (elapsed / step * (total - step)) if step > 0 else 0 + _send( + "progress", + step = step, + epoch = round(step / total * num_epochs, 2) if total > 0 else 0, + loss = loss, + learning_rate = lr, + total_steps = total, + elapsed_seconds = elapsed, + eta_seconds = max(0, eta), + grad_norm = None, + num_tokens = num_tokens, + eval_loss = None, + status_message = None, + peak_memory_gb = peak_gb, + ) + if wandb_run is not None: + try: + wandb_run.log( + { + "train/loss": loss, + "train/learning_rate": lr, + "train/tokens_per_sec": tok_s, + "train/peak_gb": peak_gb, + "train/num_tokens": num_tokens, + }, + step = step, + ) + except Exception: + pass + if tb_writer is not None: + try: + tb_writer.add_scalar("train/loss", loss, step) + tb_writer.add_scalar("train/learning_rate", lr, step) + tb_writer.add_scalar("train/tokens_per_sec", tok_s, step) + tb_writer.add_scalar("train/peak_gb", peak_gb, step) + except Exception: + pass + + trainer.add_step_callback(_on_step) + + def _on_eval(step, eval_loss, perplexity): + _send("progress", step = step, eval_loss = eval_loss) + if wandb_run is not None: + try: + wandb_run.log( + {"eval/loss": eval_loss, "eval/perplexity": perplexity}, step = step + ) + except Exception: + pass + if tb_writer is not None: + try: + tb_writer.add_scalar("eval/loss", eval_loss, step) + tb_writer.add_scalar("eval/perplexity", perplexity, step) + except Exception: + pass + + trainer.add_eval_callback(_on_eval) + + # ── 10. Stop signal polling ── + _stop_save = [True] # mutable so thread can update; [save_flag] + + def _poll_stop(): + while True: + try: + msg = stop_queue.get(timeout = 1.0) + if msg and msg.get("type") == "stop": + _stop_save[0] = msg.get("save", True) + trainer.stop_requested = True + return + except _queue.Empty: + continue + except (EOFError, OSError): + # why safe: pipe permanently broken, no further messages can arrive + return + + stop_thread = threading.Thread(target = _poll_stop, daemon = True) + stop_thread.start() + + # ── 11. Run training ── + gc.collect() + mx.synchronize() + trainer.train() + + # ── 12. Save and finalize ── + if trainer.stop_requested and not _stop_save[0]: + # User clicked "Cancel" (save=False) — skip saving + _send("complete", output_dir = None, status_message = "Training cancelled") + else: + _send("status", status_message = "Saving model...") + mx.synchronize() + trainer.save_model(output_dir) + _send("complete", output_dir = output_dir, status_message = "Training completed") + + if tb_writer is not None: + try: + tb_writer.close() + except Exception: + pass + if wandb_run is not None: + try: + wandb_run.finish() + except Exception: + pass + + def run_training_process( *, event_queue: Any, @@ -371,6 +959,46 @@ def run_training_process( model_name = config["model_name"] + # ── 0. MLX FAST-PATH (must run before any torch/transformers imports) ── + # Apple Silicon uses MLXTrainer directly -- skip transformers version + # activation, causal-conv1d install, and torch imports entirely. + backend_path = str(Path(__file__).resolve().parent.parent.parent) + if backend_path not in sys.path: + sys.path.insert(0, backend_path) + + from utils.hardware import hardware as _hw + + _hw.detect_hardware() + if _hw.DEVICE == _hw.DeviceType.MLX: + if config.get("is_dataset_audio"): + event_queue.put( + { + "type": "error", + "error": "Audio dataset training is not yet supported on Apple Silicon.", + "stack": "", + "ts": time.time(), + } + ) + return + # Activate correct transformers version (Gemma-4 needs 5.5.0, etc.) + # Must happen before any transformers/mlx-lm imports in _run_mlx_training. + try: + _activate_transformers_version(model_name) + except Exception: + pass # Non-fatal: fall through with whatever version is installed + try: + _run_mlx_training(event_queue, stop_queue, config) + except Exception as exc: + event_queue.put( + { + "type": "error", + "error": str(exc), + "stack": traceback.format_exc(limit = 20), + "ts": time.time(), + } + ) + return + # ── 1. Activate correct transformers version BEFORE any ML imports ── try: _activate_transformers_version(model_name) diff --git a/studio/backend/tests/test_mlx_inference_backend.py b/studio/backend/tests/test_mlx_inference_backend.py new file mode 100644 index 0000000000..868e537372 --- /dev/null +++ b/studio/backend/tests/test_mlx_inference_backend.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: AGPL-3.0-only + +import sys +import types +from types import SimpleNamespace + + +class _DummyMetal: + @staticmethod + def is_available(): + return False + + +class _DummyMX: + metal = _DummyMetal() + + @staticmethod + def set_wired_limit(_limit): + return None + + @staticmethod + def device_info(): + return {"max_recommended_working_set_size": 1024} + + +class _DummyTokenizer: + pass + + +class _DummyProcessor: + tokenizer = _DummyTokenizer() + + +class _DummyModel: + pass + + +def _install_fake_mlx(monkeypatch): + mlx_pkg = types.ModuleType("mlx") + mlx_core = types.ModuleType("mlx.core") + mlx_core.metal = _DummyMetal() + mlx_core.set_wired_limit = _DummyMX.set_wired_limit + mlx_core.device_info = _DummyMX.device_info + mlx_pkg.core = mlx_core + monkeypatch.setitem(sys.modules, "mlx", mlx_pkg) + monkeypatch.setitem(sys.modules, "mlx.core", mlx_core) + + +def _install_fake_fast_mlx(monkeypatch, calls): + class _FastMLXModel: + @staticmethod + def from_pretrained(*args, **kwargs): + calls.append((args, kwargs)) + if kwargs["text_only"] is False: + return _DummyModel(), _DummyProcessor() + return _DummyModel(), _DummyTokenizer() + + unsloth_zoo_pkg = types.ModuleType("unsloth_zoo") + mlx_loader = types.ModuleType("unsloth_zoo.mlx_loader") + mlx_loader.FastMLXModel = _FastMLXModel + unsloth_zoo_pkg.mlx_loader = mlx_loader + monkeypatch.setitem(sys.modules, "unsloth_zoo", unsloth_zoo_pkg) + monkeypatch.setitem(sys.modules, "unsloth_zoo.mlx_loader", mlx_loader) + + +def test_mlx_inference_text_load_forwards_studio_settings(monkeypatch): + _install_fake_mlx(monkeypatch) + calls = [] + _install_fake_fast_mlx(monkeypatch, calls) + + from core.inference.mlx_inference import MLXInferenceBackend + + backend = MLXInferenceBackend() + config = SimpleNamespace(identifier = "fake/text", is_vision = False, is_lora = False) + + assert backend.load_model( + config, + max_seq_length = 4096, + load_in_4bit = False, + hf_token = "hf-token", + trust_remote_code = True, + dtype = "float16", + ) + + assert calls == [ + ( + ("fake/text",), + { + "max_seq_length": 4096, + "dtype": "float16", + "load_in_4bit": False, + "token": "hf-token", + "trust_remote_code": True, + "text_only": True, + }, + ) + ] + assert backend._is_vlm is False + assert isinstance(backend._tokenizer, _DummyTokenizer) + + +def test_mlx_inference_vlm_lora_uses_unsloth_loader_without_native_adapter_rewrite( + monkeypatch, + tmp_path, +): + _install_fake_mlx(monkeypatch) + calls = [] + _install_fake_fast_mlx(monkeypatch, calls) + + def _native_vlm_load(*_args, **_kwargs): + raise AssertionError("Studio MLX VLM inference must use FastMLXModel") + + mlx_vlm = types.ModuleType("mlx_vlm") + mlx_vlm.load = _native_vlm_load + monkeypatch.setitem(sys.modules, "mlx_vlm", mlx_vlm) + + adapter_dir = tmp_path / "adapter" + adapter_dir.mkdir() + cfg_path = adapter_dir / "adapter_config.json" + original_cfg = '{"base_model_name_or_path": "fake/base", "rank": 8}\n' + cfg_path.write_text(original_cfg) + + from core.inference.mlx_inference import MLXInferenceBackend + + backend = MLXInferenceBackend() + config = SimpleNamespace( + identifier = str(adapter_dir), + is_vision = True, + is_lora = True, + base_model = "fake/base", + ) + + assert backend.load_model( + config, + max_seq_length = 8192, + load_in_4bit = True, + hf_token = "hf-token", + trust_remote_code = True, + ) + + assert calls == [ + ( + (str(adapter_dir),), + { + "max_seq_length": 8192, + "dtype": None, + "load_in_4bit": True, + "token": "hf-token", + "trust_remote_code": True, + "text_only": False, + }, + ) + ] + assert cfg_path.read_text() == original_cfg + assert backend._is_vlm is True + assert isinstance(backend._processor, _DummyProcessor) + assert isinstance(backend._tokenizer, _DummyTokenizer) diff --git a/studio/backend/tests/test_mlx_training_worker_config.py b/studio/backend/tests/test_mlx_training_worker_config.py new file mode 100644 index 0000000000..5900af4e3d --- /dev/null +++ b/studio/backend/tests/test_mlx_training_worker_config.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: AGPL-3.0-only + +import importlib.util +import sys +import types +from pathlib import Path + +import pytest + + +def _load_worker_module(): + stub_names = ( + "structlog", + "loggers", + "utils", + "utils.hardware", + "utils.wheel_utils", + ) + previous_modules = {name: sys.modules.get(name) for name in stub_names} + + try: + sys.modules["structlog"] = types.ModuleType("structlog") + + loggers = types.ModuleType("loggers") + loggers.get_logger = lambda *_args, **_kwargs: None + sys.modules["loggers"] = loggers + + utils = types.ModuleType("utils") + utils.__path__ = [] + sys.modules["utils"] = utils + + hardware = types.ModuleType("utils.hardware") + hardware.apply_gpu_ids = lambda *_args, **_kwargs: None + sys.modules["utils.hardware"] = hardware + + wheel_utils = types.ModuleType("utils.wheel_utils") + for name in ( + "direct_wheel_url", + "flash_attn_wheel_url", + "install_wheel", + "probe_torch_wheel_env", + "url_exists", + ): + setattr(wheel_utils, name, lambda *_args, **_kwargs: None) + sys.modules["utils.wheel_utils"] = wheel_utils + + worker_path = ( + Path(__file__).resolve().parents[1] / "core" / "training" / "worker.py" + ) + spec = importlib.util.spec_from_file_location( + "mlx_training_worker_under_test", worker_path + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + finally: + for name, module in previous_modules.items(): + if module is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = module + + +_worker = _load_worker_module() +_normalize_mlx_studio_optimizer = _worker._normalize_mlx_studio_optimizer +_normalize_mlx_studio_scheduler = _worker._normalize_mlx_studio_scheduler + + +def test_mlx_studio_optimizer_aliases_are_explicit(): + assert _normalize_mlx_studio_optimizer("adamw_8bit") == "adamw" + assert _normalize_mlx_studio_optimizer("paged_adamw_8bit") == "adamw" + assert _normalize_mlx_studio_optimizer("adafactor") == "adafactor" + + +def test_mlx_studio_rejects_unknown_optimizer(): + with pytest.raises(ValueError, match = "Unsupported optimizer for MLX training"): + _normalize_mlx_studio_optimizer("adamw_typo") + + +def test_mlx_studio_rejects_unknown_scheduler(): + with pytest.raises(ValueError, match = "Unsupported LR scheduler for MLX training"): + _normalize_mlx_studio_scheduler("linear_typo") diff --git a/studio/backend/utils/hardware/hardware.py b/studio/backend/utils/hardware/hardware.py index c218b7b4b9..3764e38272 100644 --- a/studio/backend/utils/hardware/hardware.py +++ b/studio/backend/utils/hardware/hardware.py @@ -143,6 +143,7 @@ def detect_hardware() -> DeviceType: # --- MLX: Apple Silicon --- if is_apple_silicon() and _has_mlx(): DEVICE = DeviceType.MLX + CHAT_ONLY = False chip = platform.processor() or platform.machine() print(f"Hardware detected: MLX — Apple Silicon ({chip})") return DEVICE @@ -270,19 +271,30 @@ def get_gpu_memory_info() -> Dict[str, Any]: import mlx.core as mx import psutil - # MLX uses unified memory — report system memory as the pool + # MLX uses unified memory. Total = system RAM. GPU memory used + # comes from IORegistry's AGXAccelerator (system-wide, no sudo). total = psutil.virtual_memory().total - # MLX doesn't expose per-process GPU allocation; report 0 as allocated - allocated = 0 + agx = _read_apple_gpu_stats() + allocated = agx.get("vram_used_bytes", 0) if agx else 0 + + try: + info = mx.device_info() + gpu_name = ( + info.get("device_name") + or platform.processor() + or platform.machine() + ) + except Exception: + gpu_name = platform.processor() or platform.machine() return { "available": True, "backend": _backend_label(device), "device": 0, - "device_name": f"Apple Silicon ({platform.processor() or platform.machine()})", + "device_name": f"Apple Silicon ({gpu_name})", "total_gb": total / (1024**3), "allocated_gb": allocated / (1024**3), - "reserved_gb": 0, + "reserved_gb": allocated / (1024**3), "free_gb": (total - allocated) / (1024**3), "utilization_pct": (allocated / total) * 100 if total else 0, } @@ -460,6 +472,39 @@ def _smi_query(func_name: str, *args, **kwargs) -> Optional[Dict[str, Any]]: return None +def _read_apple_gpu_stats() -> Dict[str, Any]: + """Query macOS IORegistry for AGX (Apple GPU) live stats. No sudo needed. + + Returns dict with utilization_pct, vram_used_bytes (system-wide GPU memory). + Returns empty dict on failure. + """ + import subprocess + import re + + try: + result = subprocess.run( + ["ioreg", "-r", "-c", "AGXAccelerator"], + capture_output = True, + timeout = 2, + ) + text = result.stdout.decode("utf-8", errors = "replace") + except Exception: + return {} + + # PerformanceStatistics block has GPU utilization and in-use memory + m = re.search(r'"PerformanceStatistics" = \{([^}]+)\}', text) + if not m: + return {} + stats_str = m.group(1) + pairs = re.findall(r'"([^"]+)"=(\d+)', stats_str) + stats = {k: int(v) for k, v in pairs} + + return { + "utilization_pct": stats.get("Device Utilization %", 0), + "vram_used_bytes": stats.get("In use system memory", 0), + } + + def get_gpu_utilization() -> Dict[str, Any]: """Return a live snapshot of device utilization information.""" device = get_device() @@ -470,6 +515,50 @@ def get_gpu_utilization() -> Dict[str, Any]: result["backend"] = _backend_label(device) return result + # MLX path: single _read_apple_gpu_stats() call carries both VRAM-used + # bytes and GPU utilization %. psutil for unified-memory total is cheap. + if device == DeviceType.MLX: + try: + import psutil + + agx = _read_apple_gpu_stats() + total_bytes = psutil.virtual_memory().total + except Exception as e: + logger.error(f"Error getting MLX GPU utilization: {e}") + return {"available": False, "backend": device.value, "error": str(e)} + if not agx: + return {"available": False, "backend": device.value} + allocated_bytes = agx.get("vram_used_bytes", 0) or 0 + vram_used_gb = allocated_bytes / (1024**3) + total_gb = total_bytes / (1024**3) + + try: + from core.training import get_training_backend + + tb = get_training_backend() + tb_progress = getattr(tb, "_progress", None) + if tb_progress is not None and getattr(tb_progress, "is_training", False): + tb_peak = getattr(tb_progress, "peak_memory_gb", None) + if tb_peak is not None and tb_peak > 0: + vram_used_gb = float(tb_peak) + except Exception: + pass + + return { + "available": True, + "backend": device.value, + "gpu_utilization_pct": agx.get("utilization_pct") if agx else None, + "temperature_c": None, + "vram_used_gb": round(vram_used_gb, 2), + "vram_total_gb": round(total_gb, 2), + "vram_utilization_pct": ( + round((vram_used_gb / total_gb) * 100, 1) if total_gb > 0 else None + ), + "power_draw_w": None, + "power_limit_w": None, + "power_utilization_pct": None, + } + mem = get_gpu_memory_info() if device != DeviceType.CPU and mem.get("available"): return { diff --git a/studio/frontend/src/components/app-sidebar.tsx b/studio/frontend/src/components/app-sidebar.tsx index edcd5120eb..171b4eb92d 100644 --- a/studio/frontend/src/components/app-sidebar.tsx +++ b/studio/frontend/src/components/app-sidebar.tsx @@ -37,13 +37,12 @@ import { Delete02Icon, Download03Icon, GemIcon, - Globe02Icon, Search01Icon, PowerIcon, PencilEdit02Icon, LayoutAlignLeftIcon, - HelpCircleIcon, Settings02Icon, + SourceCodeSquareIcon, ZapIcon, } from "@hugeicons/core-free-icons"; import { @@ -528,7 +527,7 @@ export function AppSidebar() {
{displayTitle} - Unsloth + Studio
@@ -549,8 +548,8 @@ export function AppSidebar() { useSettingsDialogStore.getState().openDialog("api-keys")} > - - API + + Developer New @@ -579,12 +578,6 @@ export function AppSidebar() { - useSettingsDialogStore.getState().openDialog("about")} - > - - Help - setShutdownOpen(true)}> Shutdown diff --git a/studio/frontend/src/config/env.ts b/studio/frontend/src/config/env.ts index 72bb3fa815..3839706d25 100644 --- a/studio/frontend/src/config/env.ts +++ b/studio/frontend/src/config/env.ts @@ -50,7 +50,7 @@ export async function fetchDeviceType(): Promise { if (res.ok) { const data = (await res.json()) as { device_type?: string; chat_only?: boolean }; const deviceType = data.device_type ?? detectLocalPlatform(); - const chatOnly = data.chat_only ?? deviceType === "mac"; + const chatOnly = data.chat_only ?? false; usePlatformStore.setState({ deviceType, chatOnly, fetched: true }); return deviceType; } diff --git a/studio/frontend/src/features/settings/settings-dialog.tsx b/studio/frontend/src/features/settings/settings-dialog.tsx index 63c0a9d388..376af06e9d 100644 --- a/studio/frontend/src/features/settings/settings-dialog.tsx +++ b/studio/frontend/src/features/settings/settings-dialog.tsx @@ -10,11 +10,11 @@ import { import { cn } from "@/lib/utils"; import { Cancel01Icon, - Globe02Icon, - HelpCircleIcon, Message01Icon, PaintBrush02Icon, Settings02Icon, + SourceCodeSquareIcon, + SparklesIcon, UserIcon, } from "@hugeicons/core-free-icons"; import { HugeiconsIcon } from "@hugeicons/react"; @@ -40,8 +40,8 @@ const TABS: TabDef[] = [ { id: "profile", label: "Profile", icon: UserIcon }, { id: "appearance", label: "Appearance", icon: PaintBrush02Icon }, { id: "chat", label: "Chat", icon: Message01Icon }, - { id: "api-keys", label: "API", icon: Globe02Icon, badge: "New" }, - { id: "about", label: "Help", icon: HelpCircleIcon }, + { id: "api-keys", label: "Developer", icon: SourceCodeSquareIcon, badge: "New" }, + { id: "about", label: "Help", icon: SparklesIcon }, ]; function renderTab(tab: SettingsTab) { diff --git a/studio/frontend/src/features/settings/tabs/api-keys-tab.tsx b/studio/frontend/src/features/settings/tabs/api-keys-tab.tsx index ac9ec40543..64c6f520c0 100644 --- a/studio/frontend/src/features/settings/tabs/api-keys-tab.tsx +++ b/studio/frontend/src/features/settings/tabs/api-keys-tab.tsx @@ -63,7 +63,7 @@ export function ApiKeysTab() { return (
-

API

+

Developer

Access Unsloth programmatically via the OpenAI-compatible API.{" "} s.deviceType); const isLora = store.trainingMethod !== "full"; const showVisionLora = store.isVisionModel && store.isDatasetImage === true; const [loraOpen, setLoraOpen] = useState(false); @@ -883,7 +885,11 @@ export function ParamsSection(): ReactElement { None Standard - Unsloth + {platformDeviceType === "mac" ? ( + MLX + ) : ( + Unsloth + )} diff --git a/studio/frontend/src/features/training/lib/model-defaults.ts b/studio/frontend/src/features/training/lib/model-defaults.ts index c40a1e2282..8bc9c4e064 100644 --- a/studio/frontend/src/features/training/lib/model-defaults.ts +++ b/studio/frontend/src/features/training/lib/model-defaults.ts @@ -3,6 +3,7 @@ import type { BackendModelConfig } from "../api/models-api"; import type { TrainingConfigState } from "../types/config"; +import { usePlatformStore } from "@/config/env"; type ModelDefaultsPatch = Partial< Pick< @@ -69,7 +70,13 @@ function toStringArray(value: unknown): string[] | undefined { function toGradientCheckpointing( value: unknown, ): TrainingConfigState["gradientCheckpointing"] | undefined { - if (value === "none" || value === "true" || value === "unsloth") return value; + if (value === "none" || value === "true" || value === "unsloth" || value === "mlx") { + // On Mac, map "unsloth" → "mlx" since Unsloth GC is GPU-only + if (usePlatformStore.getState().deviceType === "mac" && value === "unsloth") { + return "mlx"; + } + return value; + } return undefined; } diff --git a/studio/frontend/src/hooks/use-hf-model-search.ts b/studio/frontend/src/hooks/use-hf-model-search.ts index 77214f38d9..efe4d726be 100644 --- a/studio/frontend/src/hooks/use-hf-model-search.ts +++ b/studio/frontend/src/hooks/use-hf-model-search.ts @@ -6,6 +6,7 @@ import { listModels } from "@huggingface/hub"; import { type CachedResult, cachedModelInfo, primeCacheFromListing } from "@/lib/hf-cache"; import { useCallback, useMemo } from "react"; import { useHfPaginatedSearch } from "./use-hf-paginated-search"; +import { usePlatformStore } from "@/config/env"; export interface HfModelResult { id: string; @@ -16,7 +17,8 @@ export interface HfModelResult { isGguf: boolean; } -const EXCLUDED_TAGS = new Set([ +/** Tags to exclude on GPU (CUDA/ROCm) — MLX models won't load on GPU. */ +const EXCLUDED_TAGS_GPU = new Set([ "gptq", "awq", "exl2", @@ -28,6 +30,18 @@ const EXCLUDED_TAGS = new Set([ "ctranslate2", ]); +/** Tags to exclude on MLX (Mac) — GPU-only quant formats won't load on MLX. */ +const EXCLUDED_TAGS_MLX = new Set([ + "gptq", + "awq", + "exl2", + "onnx", + "openvino", + "coreml", + "tflite", + "ctranslate2", +]); + // Embedding / sentence-transformer models ship with onnx/openvino as additional // export formats — they should not be excluded by the tag check above. const EMBEDDING_TAGS = new Set([ @@ -77,7 +91,7 @@ function estimateSizeFromDtypes( return total > 0 ? total : undefined; } -function makeMapModel(excludeGguf: boolean) { +function makeMapModel(excludeGguf: boolean, excludedTags: Set) { return (raw: unknown): HfModelResult | null => { const m = raw as { name: string; @@ -87,7 +101,7 @@ function makeMapModel(excludeGguf: boolean) { tags?: string[]; }; const isEmbedding = m.tags?.some((t) => EMBEDDING_TAGS.has(t)); - if (!isEmbedding && m.tags?.some((t) => EXCLUDED_TAGS.has(t))) { + if (!isEmbedding && m.tags?.some((t) => excludedTags.has(t))) { return null; } const isGguf = @@ -314,7 +328,9 @@ export function useHfModelSearch( [trimmed, searchQuery, pinnedId, task, accessToken, priorityIds], ); - const mapModel = useMemo(() => makeMapModel(excludeGguf), [excludeGguf]); + const deviceType = usePlatformStore((s) => s.deviceType); + const excludedTags = deviceType === "mac" ? EXCLUDED_TAGS_MLX : EXCLUDED_TAGS_GPU; + const mapModel = useMemo(() => makeMapModel(excludeGguf, excludedTags), [excludeGguf, excludedTags]); const search = useHfPaginatedSearch(createIter, mapModel); // Secondary sort guarantee: unsloth models always float to the top. diff --git a/studio/frontend/src/types/training.ts b/studio/frontend/src/types/training.ts index d65d14fb83..187f54a13b 100644 --- a/studio/frontend/src/types/training.ts +++ b/studio/frontend/src/types/training.ts @@ -10,7 +10,7 @@ export function isAdapterMethod(method: TrainingMethod): boolean { export type StepNumber = 1 | 2 | 3 | 4 | 5; export type DatasetSource = "huggingface" | "upload"; export type DatasetFormat = "auto" | "alpaca" | "chatml" | "sharegpt"; -export type GradientCheckpointing = "none" | "true" | "unsloth"; +export type GradientCheckpointing = "none" | "true" | "unsloth" | "mlx"; export interface WizardState { currentStep: StepNumber; diff --git a/tests/python/test_gpu_init_ldconfig_guard.py b/tests/python/test_gpu_init_ldconfig_guard.py new file mode 100644 index 0000000000..081a6132b4 --- /dev/null +++ b/tests/python/test_gpu_init_ldconfig_guard.py @@ -0,0 +1,46 @@ +import ast +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] +GPU_INIT = REPO_ROOT / "unsloth" / "_gpu_init.py" + + +def _find_geteuid_guard(tree: ast.AST): + for node in ast.walk(tree): + if not isinstance(node, ast.If): + continue + for sub in ast.walk(node.test): + if isinstance(sub, ast.Call) and isinstance(sub.func, ast.Attribute): + if sub.func.attr == "geteuid": + return node + return None + + +def test_gpu_init_has_geteuid_guard(): + tree = ast.parse(GPU_INIT.read_text()) + guard = _find_geteuid_guard(tree) + assert ( + guard is not None + ), "_gpu_init.py must guard ldconfig recovery on os.geteuid()" + + +def test_ldconfig_calls_only_inside_geteuid_guard(): + src = GPU_INIT.read_text() + tree = ast.parse(src) + guard = _find_geteuid_guard(tree) + assert guard is not None + guard_src = ast.get_source_segment(src, guard) or "" + ldconfig_lines = [ + line for line in src.splitlines() if "ldconfig" in line and "os.system" in line + ] + for line in ldconfig_lines: + assert line.strip() in guard_src, ( + "os.system('ldconfig ...') must live inside the geteuid guard, " + f"but found unguarded: {line!r}" + ) + + +def test_non_root_branch_warns_when_bnb_present(): + src = GPU_INIT.read_text() + assert "elif bnb is not None" in src + assert "sudo ldconfig" in src diff --git a/tests/studio/test_export_output_path_contract.py b/tests/studio/test_export_output_path_contract.py new file mode 100644 index 0000000000..e99fc42091 --- /dev/null +++ b/tests/studio/test_export_output_path_contract.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: AGPL-3.0-only + +import ast +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] +EXPORT = REPO_ROOT / "studio" / "backend" / "core" / "export" / "export.py" + +EXPORT_FNS = ( + "export_merged_model", + "export_base_model", + "export_gguf", + "export_lora_adapter", +) + + +def _find_method(tree, cls_name, method_name): + for cls in ast.walk(tree): + if isinstance(cls, ast.ClassDef) and cls.name == cls_name: + for item in cls.body: + if isinstance(item, ast.FunctionDef) and item.name == method_name: + return item + return None + + +def _return_tuple_arity(fn): + arities = [] + for node in ast.walk(fn): + if isinstance(node, ast.Return) and isinstance(node.value, ast.Tuple): + arities.append(len(node.value.elts)) + return arities + + +def test_export_methods_return_three_tuple_annotation(): + tree = ast.parse(EXPORT.read_text()) + for fn_name in EXPORT_FNS: + fn = _find_method(tree, "ExportBackend", fn_name) + assert fn is not None, f"missing ExportBackend.{fn_name}" + ret = fn.returns + assert isinstance(ret, ast.Subscript), f"{fn_name} return must be Tuple[...]" + slc = ret.slice + elts = slc.elts if isinstance(slc, ast.Tuple) else None + assert ( + elts is not None and len(elts) == 3 + ), f"{fn_name} return annotation must be a 3-tuple, got {ast.dump(ret)}" + + +def test_export_methods_return_three_element_tuples(): + tree = ast.parse(EXPORT.read_text()) + for fn_name in EXPORT_FNS: + fn = _find_method(tree, "ExportBackend", fn_name) + assert fn is not None + arities = _return_tuple_arity(fn) + assert arities, f"{fn_name} has no tuple-return statements" + for arity in arities: + assert arity == 3, f"{fn_name} return tuple arity {arity}, expected 3" + + +def test_local_save_assigns_output_path(): + tree = ast.parse(EXPORT.read_text()) + for fn_name in EXPORT_FNS: + fn = _find_method(tree, "ExportBackend", fn_name) + assert fn is not None + assigns = [] + for node in ast.walk(fn): + if isinstance(node, ast.Assign): + for tgt in node.targets: + if isinstance(tgt, ast.Name) and tgt.id == "output_path": + assigns.append(node) + non_none = [ + a + for a in assigns + if not (isinstance(a.value, ast.Constant) and a.value.value is None) + ] + assert non_none, f"{fn_name} never assigns a non-None output_path" + + +def test_gpu_save_method_bound_for_hub_only(): + tree = ast.parse(EXPORT.read_text()) + fn = _find_method(tree, "ExportBackend", "export_merged_model") + assert fn is not None + found_pre_save_method = False + for node in ast.walk(fn): + if isinstance(node, ast.Try): + for stmt in node.body: + if isinstance(stmt, ast.If): + test = stmt.test + if isinstance(test, ast.Name) and test.id == "_IS_MLX": + for sub in ast.walk( + ast.Module(body = stmt.orelse, type_ignores = []) + ): + if isinstance(sub, ast.Assign) and any( + isinstance(t, ast.Name) and t.id == "save_method" + for t in sub.targets + ): + found_pre_save_method = True + break + if found_pre_save_method: + break + if found_pre_save_method: + break + assert found_pre_save_method, ( + "GPU save_method must be assigned at the top of the try block, " + "before the `if save_directory:` guard, so Hub-only export does not " + "raise UnboundLocalError." + ) + + +def test_mlx_hub_only_uses_temp_directory(): + src = EXPORT.read_text() + assert ( + src.count("tempfile.TemporaryDirectory") >= 3 + ), "expected TemporaryDirectory in merged, base, and lora hub-push paths" + assert "import tempfile" in src.split("class ExportBackend")[0] + + +def test_is_mlx_imported_from_unsloth(): + src = EXPORT.read_text() + assert "from unsloth import" in src + head = src.split("class ExportBackend")[0] + assert "_IS_MLX" in head + assert "_IS_MLX = platform.system()" not in src diff --git a/tests/studio/test_is_mlx_dispatch_gate.py b/tests/studio/test_is_mlx_dispatch_gate.py new file mode 100644 index 0000000000..fc07a497e7 --- /dev/null +++ b/tests/studio/test_is_mlx_dispatch_gate.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: AGPL-3.0-only + +""" +Regression tests for the CUDA-vs-MLX dispatch gates Studio relies on. + +Two gates drive every dispatch decision in Studio's MLX path: + + 1. ``unsloth._IS_MLX`` at the top of ``unsloth/__init__.py`` -- evaluated + once at import time and read by Studio worker code to choose between + the GPU and MLX trainer / inference / export paths. Defined as + ``Darwin AND arm64 AND find_spec("mlx") is not None``. + + 2. ``utils.hardware.detect_hardware()`` -- runtime probe in the Studio + backend. Priority order: CUDA -> XPU -> MLX -> CPU. The MLX branch is + reached only when both CUDA and XPU are unavailable AND the host is + Apple Silicon AND mlx is importable. + +These gates are the canaries for "MLX support accidentally hijacks +CUDA/AMD/Intel users". The tests here: + + * verify the source-level structure of the ``_IS_MLX`` expression so an + accidental rewrite (e.g. dropping the ``arm64`` check) is caught, + * exercise the runtime gate logic under a spoofed Darwin+arm64 platform + with a fake ``mlx`` module in ``sys.modules`` to confirm both gates + flip True together, + * confirm that on the actual Linux+CUDA test host both gates remain in + their CUDA-side state. + +No real MLX install is required; uses the same ``monkeypatch.setitem`` +fake-mlx pattern as ``test_mlx_inference_backend.py``. +""" + +import ast +import importlib +import sys +import types +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] +UNSLOTH_INIT = REPO_ROOT / "unsloth" / "__init__.py" + + +# --------------------------------------------------------------------------- +# 1. Source-level structure check on _IS_MLX (no platform dependencies). +# --------------------------------------------------------------------------- + + +def test_is_mlx_gate_uses_three_required_predicates(): + """The _IS_MLX assignment must AND together exactly the three checks + that Studio depends on: Darwin OS, arm64 machine, and an importable + mlx package. Dropping any one of them silently breaks dispatch. + """ + tree = ast.parse(UNSLOTH_INIT.read_text()) + + target = None + for node in ast.walk(tree): + if ( + isinstance(node, ast.Assign) + and len(node.targets) == 1 + and isinstance(node.targets[0], ast.Name) + and node.targets[0].id == "_IS_MLX" + ): + target = node.value + break + assert target is not None, "_IS_MLX assignment not found in unsloth/__init__.py" + assert isinstance(target, ast.BoolOp) and isinstance( + target.op, ast.And + ), "_IS_MLX must be a BoolOp(And) of platform + mlx checks" + + expr_src = ast.unparse(target) + assert ( + "platform.system()" in expr_src and "Darwin" in expr_src + ), "_IS_MLX must check platform.system() == 'Darwin'" + assert ( + "platform.machine()" in expr_src and "arm64" in expr_src + ), "_IS_MLX must check platform.machine() == 'arm64'" + assert ( + "find_spec" in expr_src and "'mlx'" in expr_src + ), "_IS_MLX must check importlib.util.find_spec('mlx')" + + +# --------------------------------------------------------------------------- +# 2. Runtime gate behavior with the platform spoofed to Apple Silicon and a +# fake mlx module in sys.modules. Re-evaluates the same expression +# rather than reloading unsloth (which would cascade-reload torch). +# --------------------------------------------------------------------------- + + +def _evaluate_is_mlx_gate(platform_module, importlib_util): + """Re-evaluate the _IS_MLX expression using injected dependencies. + + Mirrors the assignment in unsloth/__init__.py exactly. + """ + return ( + platform_module.system() == "Darwin" + and platform_module.machine() == "arm64" + and importlib_util.find_spec("mlx") is not None + ) + + +def test_is_mlx_gate_true_on_apple_silicon_with_mlx_present(monkeypatch): + import platform + import importlib.util + + # Inject a fake mlx package so find_spec returns a non-None ModuleSpec. + fake_mlx = types.ModuleType("mlx") + fake_mlx.__spec__ = importlib.machinery.ModuleSpec("mlx", loader = None) + fake_mlx.__path__ = [] + monkeypatch.setitem(sys.modules, "mlx", fake_mlx) + + monkeypatch.setattr(platform, "system", lambda: "Darwin") + monkeypatch.setattr(platform, "machine", lambda: "arm64") + + assert _evaluate_is_mlx_gate(platform, importlib.util) is True + + +def test_is_mlx_gate_false_when_mlx_missing(monkeypatch): + import platform + import importlib.util + + # Apple Silicon platform but no mlx package -> gate must be False. + monkeypatch.delitem(sys.modules, "mlx", raising = False) + monkeypatch.setattr(platform, "system", lambda: "Darwin") + monkeypatch.setattr(platform, "machine", lambda: "arm64") + + real_find_spec = importlib.util.find_spec + + def _no_mlx(name, *args, **kwargs): + if name == "mlx": + return None + return real_find_spec(name, *args, **kwargs) + + monkeypatch.setattr(importlib.util, "find_spec", _no_mlx) + + assert _evaluate_is_mlx_gate(platform, importlib.util) is False + + +def test_is_mlx_gate_false_on_non_apple_silicon(): + """On the real Linux+CUDA / AMD / Intel test host, the gate stays False.""" + import platform + import importlib.util + + if platform.system() == "Darwin" and platform.machine() == "arm64": + # On a Mac CI runner this assertion would not apply; skip there. + import pytest + + pytest.skip("Test host is Apple Silicon; CUDA-side canary doesn't apply.") + + assert _evaluate_is_mlx_gate(platform, importlib.util) is False + + +# --------------------------------------------------------------------------- +# 3. Studio's runtime detect_hardware() picks MLX only when CUDA + XPU are +# both unavailable AND the host is Apple Silicon AND mlx is importable. +# --------------------------------------------------------------------------- + + +def _import_studio_hardware(): + """Lazy import for the Studio hardware module, with the bare-imports + convention that Studio uses (studio/backend on sys.path). + """ + studio_backend = REPO_ROOT / "studio" / "backend" + if str(studio_backend) not in sys.path: + sys.path.insert(0, str(studio_backend)) + from utils.hardware import hardware as hw # type: ignore + + return hw + + +def test_detect_hardware_picks_mlx_when_only_apple_silicon_available(monkeypatch): + hw = _import_studio_hardware() + + # Force CUDA + XPU paths off so detect_hardware falls through to MLX. + import torch + + monkeypatch.setattr(torch.cuda, "is_available", lambda: False) + if hasattr(torch, "xpu"): + monkeypatch.setattr(torch.xpu, "is_available", lambda: False) + + # Spoof Apple Silicon and provide an importable mlx.core for _has_mlx(). + import platform + + monkeypatch.setattr(platform, "system", lambda: "Darwin") + monkeypatch.setattr(platform, "machine", lambda: "arm64") + + fake_mlx = types.ModuleType("mlx") + fake_mlx_core = types.ModuleType("mlx.core") + fake_mlx.core = fake_mlx_core + monkeypatch.setitem(sys.modules, "mlx", fake_mlx) + monkeypatch.setitem(sys.modules, "mlx.core", fake_mlx_core) + + detected = hw.detect_hardware() + assert detected == hw.DeviceType.MLX, f"expected MLX, got {detected!r}" + + +def test_detect_hardware_picks_cuda_on_real_host(): + """Canary: on a real CUDA host the MLX branch must NOT be taken even + if mlx happens to be importable. Protects CUDA/AMD/Intel users from + accidental MLX dispatch when MLX support is added. + """ + import torch + + if not torch.cuda.is_available(): + import pytest + + pytest.skip("No CUDA available on this host; canary not applicable.") + + hw = _import_studio_hardware() + detected = hw.detect_hardware() + assert ( + detected == hw.DeviceType.CUDA + ), f"CUDA host must dispatch to CUDA, got {detected!r}" diff --git a/tests/studio/test_mlx_training_worker_behaviors.py b/tests/studio/test_mlx_training_worker_behaviors.py new file mode 100644 index 0000000000..6c067ea00b --- /dev/null +++ b/tests/studio/test_mlx_training_worker_behaviors.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: AGPL-3.0-only + +import ast +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] +WORKER = REPO_ROOT / "studio" / "backend" / "core" / "training" / "worker.py" + + +def _find_func(tree, name): + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == name: + return node + return None + + +def test_run_mlx_training_passes_token_to_from_pretrained(): + tree = ast.parse(WORKER.read_text()) + fn = _find_func(tree, "_run_mlx_training") + assert fn is not None + found = False + for node in ast.walk(fn): + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and node.func.attr == "from_pretrained" + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "FastMLXModel" + ): + kwarg_names = {kw.arg for kw in node.keywords if kw.arg} + assert ( + "token" in kwarg_names + ), f"FastMLXModel.from_pretrained must forward token=hf_token; got {kwarg_names!r}" + found = True + assert found, "FastMLXModel.from_pretrained call not found in _run_mlx_training" + + +def test_wandb_init_strips_secret_keys(): + src = WORKER.read_text() + assert "_wandb_sensitive" in src, "expected a sensitive-key set near wandb.init" + assert '"hf_token"' in src and '"wandb_token"' in src + assert ( + "config = dict(config)" not in src + ), "wandb.init received raw config dict; secrets would leak" + + +def test_local_dataset_loader_uses_load_dataset_path(): + src = WORKER.read_text() + assert "_resolve_local_files" in src + assert "_loader_for_files" in src + assert "data_files = all_files" in src or "data_files=all_files" in src + + +def test_send_aliases_status_message_to_message(): + src = WORKER.read_text() + assert 'kwargs["message"] = sm' in src or 'kwargs["message"]=sm' in src + + +def test_slice_uses_inclusive_end_and_handles_zero(): + src = WORKER.read_text() + assert "min(end + 1, len(ds))" in src or "min(end+1, len(ds))" in src + assert "slice_start if slice_start is not None else 0" in src + assert "slice_end if slice_end is not None else len(ds) - 1" in src + + +def test_poll_stop_returns_on_broken_pipe(): + src = WORKER.read_text() + assert "except (EOFError, OSError)" in src + lines = src.splitlines() + for i, line in enumerate(lines): + if "except (EOFError, OSError)" in line: + for j in range(i + 1, min(i + 6, len(lines))): + stripped = lines[j].strip() + if not stripped or stripped.startswith("#"): + continue + assert stripped.startswith( + "return" + ), f"expected return after EOFError/OSError, got {stripped!r}" + break + break + else: + raise AssertionError("EOFError/OSError handler not found in worker.py") + + +def test_unsloth_zoo_mlx_imports_have_friendly_error(): + src = WORKER.read_text() + assert "from unsloth_zoo.mlx_loader import FastMLXModel" in src + assert "from unsloth_zoo.mlx_trainer import" in src + assert "raise ImportError" in src + assert "install.sh" in src diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 9db9ae0a32..9b620a5c76 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -12,348 +12,117 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings, importlib, sys -from packaging.version import Version -import os, re, subprocess, inspect, functools -import numpy as np +import os, platform, importlib.util -# Log Unsloth is being used os.environ["UNSLOTH_IS_PRESENT"] = "1" -# Check if modules that need patching are already imported -critical_modules = ["trl", "transformers", "peft"] -already_imported = [mod for mod in critical_modules if mod in sys.modules] - -# Fix some issues before importing other packages -from .import_fixes import ( - fix_message_factory_issue, - check_fbgemm_gpu_version, - disable_broken_causal_conv1d, - disable_broken_vllm, - configure_amdgpu_asic_id_table_path, - torchvision_compatibility_check, - fix_diffusers_warnings, - fix_huggingface_hub, +# Detect Apple Silicon + MLX before any torch/numpy imports +_IS_MLX = ( + platform.system() == "Darwin" + and platform.machine() == "arm64" + and importlib.util.find_spec("mlx") is not None ) -# Configure libdrm ids table path early so ROCm can resolve AMD GPU names. -configure_amdgpu_asic_id_table_path() -disable_broken_causal_conv1d() -disable_broken_vllm() -fix_message_factory_issue() -check_fbgemm_gpu_version() -torchvision_compatibility_check() -fix_diffusers_warnings() -fix_huggingface_hub() -del configure_amdgpu_asic_id_table_path -del disable_broken_causal_conv1d -del disable_broken_vllm -del fix_message_factory_issue -del check_fbgemm_gpu_version -del torchvision_compatibility_check -del fix_diffusers_warnings -del fix_huggingface_hub - -# This check is critical because Unsloth optimizes these libraries by modifying -# their code at import time. If they're imported first, the original (slower, -# more memory-intensive) implementations will be used instead of Unsloth's -# optimized versions, potentially causing OOM errors or slower training. -if already_imported: - # stacklevel=2 makes warning point to user's import line rather than this library code, - # showing them exactly where to fix the import order in their script - warnings.warn( - f"WARNING: Unsloth should be imported before [{', '.join(already_imported)}] " - f"to ensure all optimizations are applied. Your code may run slower or encounter " - f"memory issues without these optimizations.\n\n" - f"Please restructure your imports with 'import unsloth' at the top of your file.", - stacklevel = 2, - ) -del already_imported, critical_modules - -# Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so -# enabling it will require much more work, so we have to prioritize. Please understand! -# We do have a beta version, which you can contact us about! -# Thank you for your understanding and we appreciate it immensely! - -# Fixes https://github.com/unslothai/unsloth/issues/1266 -os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" - -# [TODO] Check why some GPUs don't work -# "pinned_use_cuda_host_register:True,"\ -# "pinned_num_register_threads:8" - - -from importlib.metadata import version as importlib_version -from importlib.metadata import PackageNotFoundError - -# Check for unsloth_zoo -try: - unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2026.3.4"): - print( - "Unsloth: Please update Unsloth and Unsloth-Zoo to the latest version!\n" - "Do this via `pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo`" - ) - # if os.environ.get("UNSLOTH_DISABLE_AUTO_UPDATES", "0") == "0": - # try: - # os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") - # except: - # try: - # os.system("pip install --upgrade --no-cache-dir --no-deps --user unsloth_zoo") - # except: - # raise ImportError("Unsloth: Please update unsloth_zoo via `pip install --upgrade --no-cache-dir --no-deps unsloth_zoo`") - import unsloth_zoo -except PackageNotFoundError: - raise ImportError( - f"Unsloth: Please install unsloth_zoo via `pip install unsloth_zoo` then retry!" - ) -except: - raise -del PackageNotFoundError, importlib_version - -# Try importing PyTorch and check version -try: - import torch -except ModuleNotFoundError: - raise ImportError( - "Unsloth: Pytorch is not installed. Go to https://pytorch.org/.\n" - "We have some installation instructions on our Github page." +if _IS_MLX: + try: + import unsloth_zoo + except ImportError as _e: + raise ImportError( + "Unsloth: MLX support requires `unsloth-zoo` with MLX modules. " + "Reinstall with `pip install unsloth-zoo` or rerun install.sh." + ) from _e + # The mlx_trainer / mlx_loader submodules ship with unsloth-zoo's MLX + # support. An older installed unsloth-zoo (e.g. from PyPI before the + # MLX release lands) will satisfy `import unsloth_zoo` but be missing + # these submodules. Surface the same friendly install hint instead of + # a raw ImportError on the submodule path. + try: + from unsloth_zoo.mlx_trainer import MLXTrainer, MLXTrainingConfig + from unsloth_zoo.mlx_loader import FastMLXModel + except ImportError as _e: + raise ImportError( + "Unsloth: MLX support requires an unsloth-zoo build that includes " + "`unsloth_zoo.mlx_trainer` and `unsloth_zoo.mlx_loader`. Upgrade with " + "`pip install -U unsloth-zoo` or rerun install.sh." + ) from _e + + # Load raw_text helpers without executing dataprep/__init__.py, which + # imports synthetic.py -> torch and would defeat the torch-free MLX path. + from pathlib import Path as _Path + + _raw_text_path = _Path(__file__).resolve().parent / "dataprep" / "raw_text.py" + _raw_text_spec = importlib.util.spec_from_file_location( + "unsloth._mlx_raw_text", _raw_text_path ) -except: - raise - -from unsloth_zoo.device_type import ( - is_hip, - get_device_type, - DEVICE_TYPE, - DEVICE_TYPE_TORCH, - DEVICE_COUNT, - ALLOW_PREQUANTIZED_MODELS, -) - -# Fix other issues -from .import_fixes import ( - fix_xformers_performance_issue, - fix_vllm_aimv2_issue, - check_vllm_torch_sm100_compatibility, - fix_vllm_guided_decoding_params, - fix_trl_vllm_ascend, - fix_vllm_pdl_blackwell, - fix_triton_compiled_kernel_missing_attrs, - patch_trunc_normal_precision_issue, - ignore_logger_messages, - patch_ipykernel_hf_xet, - patch_trackio, - patch_datasets, - patch_enable_input_require_grads, - fix_openenv_no_vllm, - patch_openspiel_env_async, - fix_executorch, - patch_vllm_for_notebooks, - patch_torchcodec_audio_decoder, - disable_torchcodec_if_broken, - disable_broken_wandb, - patch_peft_weight_converter_compatibility, -) - -fix_xformers_performance_issue() -fix_vllm_aimv2_issue() -# Check vLLM + torch < 2.9.0 + SM100 compatibility BEFORE importing vLLM -check_vllm_torch_sm100_compatibility() -fix_vllm_guided_decoding_params() -fix_trl_vllm_ascend() -fix_vllm_pdl_blackwell() -fix_triton_compiled_kernel_missing_attrs() -patch_trunc_normal_precision_issue() -ignore_logger_messages() -patch_ipykernel_hf_xet() -patch_trackio() -patch_datasets() -patch_enable_input_require_grads() -fix_openenv_no_vllm() -patch_openspiel_env_async() -fix_executorch() -patch_vllm_for_notebooks() -patch_torchcodec_audio_decoder() -disable_torchcodec_if_broken() -disable_broken_wandb() -patch_peft_weight_converter_compatibility() - -del fix_xformers_performance_issue -del fix_vllm_aimv2_issue -del check_vllm_torch_sm100_compatibility -del fix_vllm_guided_decoding_params -del fix_trl_vllm_ascend -del fix_vllm_pdl_blackwell -del fix_triton_compiled_kernel_missing_attrs -del patch_trunc_normal_precision_issue -del ignore_logger_messages -del patch_ipykernel_hf_xet -del patch_trackio -del patch_datasets -del patch_enable_input_require_grads -del fix_openenv_no_vllm -del patch_openspiel_env_async -del fix_executorch -del patch_vllm_for_notebooks -del patch_torchcodec_audio_decoder -del disable_torchcodec_if_broken -del disable_broken_wandb -del patch_peft_weight_converter_compatibility - -# Torch 2.4 has including_emulation -if DEVICE_TYPE == "cuda": - major_version, minor_version = torch.cuda.get_device_capability() - SUPPORTS_BFLOAT16 = major_version >= 8 - - old_is_bf16_supported = torch.cuda.is_bf16_supported - if "including_emulation" in str(inspect.signature(old_is_bf16_supported)): - - def is_bf16_supported(including_emulation = False): - return old_is_bf16_supported(including_emulation) - - torch.cuda.is_bf16_supported = is_bf16_supported - else: - - def is_bf16_supported(): - return SUPPORTS_BFLOAT16 - - torch.cuda.is_bf16_supported = is_bf16_supported - del major_version, minor_version -elif DEVICE_TYPE == "hip": - SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported() -elif DEVICE_TYPE == "xpu": - # torch.xpu.is_bf16_supported() does not have including_emulation - # set SUPPORTS_BFLOAT16 as torch.xpu.is_bf16_supported() - SUPPORTS_BFLOAT16 = torch.xpu.is_bf16_supported() + if _raw_text_spec is None or _raw_text_spec.loader is None: + raise ImportError("Unsloth: could not load MLX raw_text dataprep helpers.") + _raw_text = importlib.util.module_from_spec(_raw_text_spec) + _raw_text_spec.loader.exec_module(_raw_text) + RawTextDataLoader = _raw_text.RawTextDataLoader + TextPreprocessor = _raw_text.TextPreprocessor + del _raw_text, _raw_text_spec, _raw_text_path, _Path + + __version__ = unsloth_zoo.__version__ + DEVICE_TYPE = "mlx" + + class FastLanguageModel: + @staticmethod + def from_pretrained(*args, **kwargs): + return FastMLXModel.from_pretrained(*args, **kwargs) + + @staticmethod + def get_peft_model(*args, **kwargs): + return FastMLXModel.get_peft_model(*args, **kwargs) + + @staticmethod + def for_inference(*args, **kwargs): + return args[0] if args else None + + class FastVisionModel(FastLanguageModel): + @staticmethod + def from_pretrained(*args, **kwargs): + kwargs.setdefault("text_only", False) + return FastMLXModel.from_pretrained(*args, **kwargs) + + @staticmethod + def for_training(*args, **kwargs): + return args[0] if args else None + + FastTextModel = FastLanguageModel + FastModel = FastLanguageModel + + class FastSentenceTransformer: + @staticmethod + def from_pretrained(*args, **kwargs): + raise NotImplementedError( + "Unsloth: FastSentenceTransformer is not yet supported on MLX." + ) -# For Gradio HF Spaces? -# if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ: -import triton + @staticmethod + def get_peft_model(*args, **kwargs): + raise NotImplementedError( + "Unsloth: FastSentenceTransformer is not yet supported on MLX." + ) -if DEVICE_TYPE == "cuda": - libcuda_dirs = lambda: None - if Version(triton.__version__) >= Version("3.0.0"): + def is_bfloat16_supported(): try: - from triton.backends.nvidia.driver import libcuda_dirs - except: - pass - else: - from triton.common.build import libcuda_dirs - - # Try loading bitsandbytes and triton - try: - import bitsandbytes as bnb - except: - print( - "Unsloth: `bitsandbytes` is not installed - 4bit QLoRA unallowed, but 16bit and full finetuning works!" - ) - bnb = None - try: - cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 - libcuda_dirs() - except: - # Only run the ldconfig recovery when we can actually run - # ldconfig (root). On non-root environments (shared HPC, - # locked-down containers, CI runners, etc.) the recovery would - # shell out to `ldconfig` and fail with "Permission denied", - # which is especially noisy for users who don't even have - # bitsandbytes installed and are just doing 16bit/full - # finetuning. libcuda_dirs() is used by both triton and bnb, - # so we still run the recovery whenever we're root, regardless - # of whether bnb is installed. - if hasattr(os, "geteuid") and os.geteuid() == 0: - warnings.warn("Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA.") + import mlx.core as mx - if os.path.exists("/usr/lib64-nvidia"): - os.system("ldconfig /usr/lib64-nvidia") - elif os.path.exists("/usr/local"): - # Sometimes bitsandbytes cannot be linked properly in Runpod for example - possible_cudas = ( - subprocess.check_output(["ls", "-al", "/usr/local"]) - .decode("utf-8") - .split("\n") - ) - find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$") - possible_cudas = [find_cuda.search(x) for x in possible_cudas] - possible_cudas = [x.group(1) for x in possible_cudas if x is not None] + name = mx.device_info().get("device_name", "") or "" + return not name.startswith(("Apple M1", "Apple M2")) + except Exception: + return True - # Try linking cuda folder, or everything in local - if len(possible_cudas) == 0: - os.system("ldconfig /usr/local/") - else: - find_number = re.compile(r"([\d\.]{2,})") - latest_cuda = np.argsort( - [float(find_number.search(x).group(1)) for x in possible_cudas] - )[::-1][0] - latest_cuda = possible_cudas[latest_cuda] - os.system(f"ldconfig /usr/local/{latest_cuda}") - del find_number, latest_cuda - del possible_cudas, find_cuda + is_bf16_supported = is_bfloat16_supported - if bnb is not None: - importlib.reload(bnb) - importlib.reload(triton) - try: - libcuda_dirs = lambda: None - if Version(triton.__version__) >= Version("3.0.0"): - try: - from triton.backends.nvidia.driver import libcuda_dirs - except: - pass - else: - from triton.common.build import libcuda_dirs - cdequantize_blockwise_fp32 = ( - bnb.functional.lib.cdequantize_blockwise_fp32 - ) - libcuda_dirs() - except: - warnings.warn( - "Unsloth: CUDA is not linked properly.\n" - "Try running `python -m bitsandbytes` then `python -m xformers.info`\n" - "We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n" - "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n" - "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n" - "Unsloth will still run for now, but maybe it might crash - let's hope it works!" - ) - elif bnb is not None: - # Non-root + bnb installed: we can't run ldconfig ourselves, - # but bnb is going to crash later when the user actually uses - # 4bit quantization - tell them how to fix it manually so - # they're not surprised by an opaque error down the road. - warnings.warn( - "Unsloth: CUDA is not linked properly.\n" - "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n" - "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n" - "Unsloth will still run for now, but maybe it might crash - let's hope it works!" + class UnslothVisionDataCollator: + def __init__(self, *args, **kwargs): + raise NotImplementedError( + "Unsloth: UnslothVisionDataCollator is not used on MLX. " + "Use the MLX trainer/data path instead." ) - del libcuda_dirs -elif DEVICE_TYPE == "hip": - # NO-OP for rocm device - pass -elif DEVICE_TYPE == "xpu": - import bitsandbytes as bnb - - # TODO: check triton for intel installed properly. - pass - -from .models import * -from .models import __version__ -from .save import * -from .chat_templates import * -from .tokenizer_utils import * -from .trainer import * - -# Export dataprep utilities for CLI and downstream users -from .dataprep.raw_text import RawTextDataLoader, TextPreprocessor -from unsloth_zoo.rl_environments import ( - check_python_modules, - create_locked_down_function, - execute_with_time_limit, - Benchmarker, - is_port_open, - launch_openenv, -) -# Patch TRL trainers for backwards compatibility -_patch_trl_trainer() +else: + # GPU path: load everything from _gpu_init + from ._gpu_init import * + from ._gpu_init import __version__ diff --git a/unsloth/_gpu_init.py b/unsloth/_gpu_init.py new file mode 100644 index 0000000000..2fc4bfde3c --- /dev/null +++ b/unsloth/_gpu_init.py @@ -0,0 +1,346 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings, importlib, sys +from packaging.version import Version +import os, re, subprocess, inspect, functools +import numpy as np + +# Log Unsloth is being used +os.environ["UNSLOTH_IS_PRESENT"] = "1" + +# Check if modules that need patching are already imported +critical_modules = ["trl", "transformers", "peft"] +already_imported = [mod for mod in critical_modules if mod in sys.modules] + +# Fix some issues before importing other packages +from .import_fixes import ( + fix_message_factory_issue, + check_fbgemm_gpu_version, + disable_broken_causal_conv1d, + disable_broken_vllm, + configure_amdgpu_asic_id_table_path, + torchvision_compatibility_check, + fix_diffusers_warnings, + fix_huggingface_hub, +) + +# Configure libdrm ids table path early so ROCm can resolve AMD GPU names. +configure_amdgpu_asic_id_table_path() +disable_broken_causal_conv1d() +disable_broken_vllm() +fix_message_factory_issue() +check_fbgemm_gpu_version() +torchvision_compatibility_check() +fix_diffusers_warnings() +fix_huggingface_hub() +del configure_amdgpu_asic_id_table_path +del disable_broken_causal_conv1d +del disable_broken_vllm +del fix_message_factory_issue +del check_fbgemm_gpu_version +del torchvision_compatibility_check +del fix_diffusers_warnings +del fix_huggingface_hub + +# This check is critical because Unsloth optimizes these libraries by modifying +# their code at import time. If they're imported first, the original (slower, +# more memory-intensive) implementations will be used instead of Unsloth's +# optimized versions, potentially causing OOM errors or slower training. +if already_imported: + # stacklevel=2 makes warning point to user's import line rather than this library code, + # showing them exactly where to fix the import order in their script + warnings.warn( + f"WARNING: Unsloth should be imported before [{', '.join(already_imported)}] " + f"to ensure all optimizations are applied. Your code may run slower or encounter " + f"memory issues without these optimizations.\n\n" + f"Please restructure your imports with 'import unsloth' at the top of your file.", + stacklevel = 2, + ) +del already_imported, critical_modules + +# Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so +# enabling it will require much more work, so we have to prioritize. Please understand! +# We do have a beta version, which you can contact us about! +# Thank you for your understanding and we appreciate it immensely! + +# Fixes https://github.com/unslothai/unsloth/issues/1266 +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + +# [TODO] Check why some GPUs don't work +# "pinned_use_cuda_host_register:True,"\ +# "pinned_num_register_threads:8" + + +from importlib.metadata import version as importlib_version +from importlib.metadata import PackageNotFoundError + +# Check for unsloth_zoo +try: + unsloth_zoo_version = importlib_version("unsloth_zoo") + if Version(unsloth_zoo_version) < Version("2026.3.4"): + print( + "Unsloth: Please update Unsloth and Unsloth-Zoo to the latest version!\n" + "Do this via `pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo`" + ) + # if os.environ.get("UNSLOTH_DISABLE_AUTO_UPDATES", "0") == "0": + # try: + # os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") + # except: + # try: + # os.system("pip install --upgrade --no-cache-dir --no-deps --user unsloth_zoo") + # except: + # raise ImportError("Unsloth: Please update unsloth_zoo via `pip install --upgrade --no-cache-dir --no-deps unsloth_zoo`") + import unsloth_zoo +except PackageNotFoundError: + raise ImportError( + f"Unsloth: Please install unsloth_zoo via `pip install unsloth_zoo` then retry!" + ) +except: + raise +del PackageNotFoundError, importlib_version + +# Try importing PyTorch and check version +try: + import torch +except ModuleNotFoundError: + raise ImportError( + "Unsloth: Pytorch is not installed. Go to https://pytorch.org/.\n" + "We have some installation instructions on our Github page." + ) +except: + raise + +from unsloth_zoo.device_type import ( + is_hip, + get_device_type, + DEVICE_TYPE, + DEVICE_TYPE_TORCH, + DEVICE_COUNT, + ALLOW_PREQUANTIZED_MODELS, +) + +# Fix other issues +from .import_fixes import ( + fix_xformers_performance_issue, + fix_vllm_aimv2_issue, + check_vllm_torch_sm100_compatibility, + fix_vllm_guided_decoding_params, + fix_vllm_pdl_blackwell, + fix_triton_compiled_kernel_missing_attrs, + patch_trunc_normal_precision_issue, + ignore_logger_messages, + patch_ipykernel_hf_xet, + patch_trackio, + patch_datasets, + patch_enable_input_require_grads, + fix_openenv_no_vllm, + patch_openspiel_env_async, + fix_executorch, + patch_vllm_for_notebooks, + patch_torchcodec_audio_decoder, + disable_torchcodec_if_broken, + disable_broken_wandb, + fix_trl_vllm_ascend, + patch_peft_weight_converter_compatibility, +) + +fix_xformers_performance_issue() +fix_vllm_aimv2_issue() +# Check vLLM + torch < 2.9.0 + SM100 compatibility BEFORE importing vLLM +check_vllm_torch_sm100_compatibility() +fix_vllm_guided_decoding_params() +fix_trl_vllm_ascend() +fix_vllm_pdl_blackwell() +fix_triton_compiled_kernel_missing_attrs() +patch_trunc_normal_precision_issue() +ignore_logger_messages() +patch_ipykernel_hf_xet() +patch_trackio() +patch_datasets() +patch_enable_input_require_grads() +fix_openenv_no_vllm() +patch_openspiel_env_async() +fix_executorch() +patch_vllm_for_notebooks() +patch_torchcodec_audio_decoder() +disable_torchcodec_if_broken() +disable_broken_wandb() +patch_peft_weight_converter_compatibility() + +del fix_xformers_performance_issue +del fix_vllm_aimv2_issue +del check_vllm_torch_sm100_compatibility +del fix_vllm_guided_decoding_params +del fix_trl_vllm_ascend +del fix_vllm_pdl_blackwell +del fix_triton_compiled_kernel_missing_attrs +del patch_trunc_normal_precision_issue +del ignore_logger_messages +del patch_ipykernel_hf_xet +del patch_trackio +del patch_datasets +del patch_enable_input_require_grads +del fix_openenv_no_vllm +del patch_openspiel_env_async +del fix_executorch +del patch_vllm_for_notebooks +del patch_torchcodec_audio_decoder +del disable_torchcodec_if_broken +del disable_broken_wandb +del patch_peft_weight_converter_compatibility + +# Torch 2.4 has including_emulation +if DEVICE_TYPE == "cuda": + major_version, minor_version = torch.cuda.get_device_capability() + SUPPORTS_BFLOAT16 = major_version >= 8 + + old_is_bf16_supported = torch.cuda.is_bf16_supported + if "including_emulation" in str(inspect.signature(old_is_bf16_supported)): + + def is_bf16_supported(including_emulation = False): + return old_is_bf16_supported(including_emulation) + + torch.cuda.is_bf16_supported = is_bf16_supported + else: + + def is_bf16_supported(): + return SUPPORTS_BFLOAT16 + + torch.cuda.is_bf16_supported = is_bf16_supported + del major_version, minor_version +elif DEVICE_TYPE == "hip": + SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported() +elif DEVICE_TYPE == "xpu": + # torch.xpu.is_bf16_supported() does not have including_emulation + # set SUPPORTS_BFLOAT16 as torch.xpu.is_bf16_supported() + SUPPORTS_BFLOAT16 = torch.xpu.is_bf16_supported() + +# For Gradio HF Spaces? +# if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ: +import triton + +if DEVICE_TYPE == "cuda": + libcuda_dirs = lambda: None + if Version(triton.__version__) >= Version("3.0.0"): + try: + from triton.backends.nvidia.driver import libcuda_dirs + except: + pass + else: + from triton.common.build import libcuda_dirs + + # Try loading bitsandbytes and triton + try: + import bitsandbytes as bnb + except: + print( + "Unsloth: `bitsandbytes` is not installed - 4bit QLoRA unallowed, but 16bit and full finetuning works!" + ) + bnb = None + try: + cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 + libcuda_dirs() + except: + if hasattr(os, "geteuid") and os.geteuid() == 0: + warnings.warn("Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA.") + + if os.path.exists("/usr/lib64-nvidia"): + os.system("ldconfig /usr/lib64-nvidia") + elif os.path.exists("/usr/local"): + # Sometimes bitsandbytes cannot be linked properly in Runpod for example + possible_cudas = ( + subprocess.check_output(["ls", "-al", "/usr/local"]) + .decode("utf-8") + .split("\n") + ) + find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$") + possible_cudas = [find_cuda.search(x) for x in possible_cudas] + possible_cudas = [x.group(1) for x in possible_cudas if x is not None] + + # Try linking cuda folder, or everything in local + if len(possible_cudas) == 0: + os.system("ldconfig /usr/local/") + else: + find_number = re.compile(r"([\d\.]{2,})") + latest_cuda = np.argsort( + [float(find_number.search(x).group(1)) for x in possible_cudas] + )[::-1][0] + latest_cuda = possible_cudas[latest_cuda] + os.system(f"ldconfig /usr/local/{latest_cuda}") + del find_number, latest_cuda + del possible_cudas, find_cuda + + if bnb is not None: + importlib.reload(bnb) + importlib.reload(triton) + try: + libcuda_dirs = lambda: None + if Version(triton.__version__) >= Version("3.0.0"): + try: + from triton.backends.nvidia.driver import libcuda_dirs + except: + pass + else: + from triton.common.build import libcuda_dirs + cdequantize_blockwise_fp32 = ( + bnb.functional.lib.cdequantize_blockwise_fp32 + ) + libcuda_dirs() + except: + warnings.warn( + "Unsloth: CUDA is not linked properly.\n" + "Try running `python -m bitsandbytes` then `python -m xformers.info`\n" + "We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n" + "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n" + "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n" + "Unsloth will still run for now, but maybe it might crash - let's hope it works!" + ) + elif bnb is not None: + warnings.warn( + "Unsloth: CUDA is not linked properly.\n" + "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n" + "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n" + "Unsloth will still run for now, but maybe it might crash - let's hope it works!" + ) + del libcuda_dirs +elif DEVICE_TYPE == "hip": + # NO-OP for rocm device + pass +elif DEVICE_TYPE == "xpu": + import bitsandbytes as bnb + + # TODO: check triton for intel installed properly. + pass + +from .models import * +from .models import __version__ +from .save import * +from .chat_templates import * +from .tokenizer_utils import * +from .trainer import * + +# Export dataprep utilities for CLI and downstream users +from .dataprep.raw_text import RawTextDataLoader, TextPreprocessor +from unsloth_zoo.rl_environments import ( + check_python_modules, + create_locked_down_function, + execute_with_time_limit, + Benchmarker, + is_port_open, + launch_openenv, +) + +# Patch TRL trainers for backwards compatibility +_patch_trl_trainer()