Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
c588fe6
mlx with studio
Manan17 Apr 14, 2026
48632f8
updating temporary install.sh
Manan17 Apr 14, 2026
1bf5853
adding t_v5 path
Manan17 Apr 14, 2026
728d08c
fixing vision training
Manan17 Apr 14, 2026
297adec
adding chat
Manan17 Apr 14, 2026
e654ac1
minor
mmathew23 Apr 18, 2026
77f53af
Adding export and fixing training issues, inference with lora adaptors
Manan17 Apr 19, 2026
f1673fc
Merge remote-tracking branch 'origin/fix/mlxvlmcompile' into mlx-appl…
Manan17 Apr 19, 2026
f08953d
fix: MLX worker pass load_in_4bit, override is_vlm based on dataset, …
Manan17 Apr 21, 2026
e1f096f
Merge mlx-apple-silicon into main
Manan17 Apr 21, 2026
de85036
update install.sh to point to main branch
Manan17 Apr 21, 2026
be99b87
fix: export returns 3 values (success, message, output_path) matching…
Manan17 Apr 21, 2026
a253b0f
fix(mlx): show training-process peak memory in Studio UI, not system-…
Manan17 Apr 28, 2026
9b49c3c
fix(mlx): make is_bfloat16_supported detect M1/M2 (no native bf16)
Manan17 Apr 28, 2026
dca214d
feat(mlx): wire training_type="Full Finetuning" through MLX worker
Manan17 Apr 28, 2026
0ba5366
fix(mlx): pass save_method='merged_16bit' from Studio's export page
Manan17 Apr 29, 2026
e0ab0b1
fix(studio): pass private to MLX push, return 3-tuples consistently
Manan17 Apr 29, 2026
dfdcf5d
studio wirings
mmathew23 Apr 30, 2026
b42426e
Merge pull request #5 from Manan17/feat/quant_config
Manan17 Apr 30, 2026
477bd7e
fix(mlx): wire train_on_completions for VLM via per-template lookup
Manan17 Apr 30, 2026
2f4e038
wire in lora rslora, init lora weights, random_state
mmathew23 Apr 30, 2026
a9edbfa
loftq studio error message fix
mmathew23 Apr 30, 2026
4d58b95
handle unknown optim and lr scheduler
mmathew23 Apr 30, 2026
f08b021
Merge pull request #6 from Manan17/update/peftkwargs
Manan17 May 1, 2026
b511646
Merge remote-tracking branch 'upstream/main'
Manan17 May 1, 2026
ae8b5d6
feat(mlx): pass finetune_language/attention/mlp/vision flags to FastM…
Manan17 May 1, 2026
72d17a5
feat(mlx,ux): auto-imply finetune_language_layers when user picks att…
Manan17 May 2, 2026
47d992b
fix(mlx): wire top_k, repetition_penalty, and VLM top_p through to ml…
Manan17 May 2, 2026
b76e0eb
feat(mlx): map format_type to MLX save_method, reuse local save dir f…
Manan17 May 3, 2026
d799e2e
Merge branch 'unslothai:main' into main
Manan17 May 4, 2026
8c51668
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 4, 2026
760714c
restore install
Manan17 May 4, 2026
f40554d
fix(mlx): restore FastVisionModel as a distinct class
Manan17 May 5, 2026
b1d5215
Merge branch 'main' into main
Imagineer99 May 5, 2026
0b1baa1
Merge remote-tracking branch 'origin/main' into
danielhanchen May 5, 2026
be874c7
Update pyproject.toml
danielhanchen May 5, 2026
2fba3b6
Update _utils.py
danielhanchen May 5, 2026
d741cc9
fix: developer to api (#5281)
Imagineer99 May 5, 2026
dc0ca40
Studio: harden MLX training and export, restore GPU init guards
danielhanchen May 5, 2026
d8a0beb
Studio: help svg replacement and Unsloth sidebar text (#5282)
Imagineer99 May 5, 2026
832f48c
Chore/help svg (#5283)
Imagineer99 May 5, 2026
6c9345a
Add Apple Silicon MLX routing
Manan17 Apr 9, 2026
7b3b20d
Studio: regression tests for MLX training/export and GPU init ldconfi…
danielhanchen May 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,4 @@ setup_leo.sh
server.pid
*.log
package-lock.json
llama.cpp/
6 changes: 6 additions & 0 deletions install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1721,6 +1721,12 @@ else
fi
fi

# ── Install mlx-vlm on Apple Silicon (optional, for VLM training) ──
if [ "$OS" = "macos" ] && [ "$_ARCH" = "arm64" ]; then
substep "installing mlx-vlm (VLM training support)..."
run_install_cmd "install mlx-vlm" uv pip install --python "$_VENV_PY" mlx-vlm
fi

# ── Run studio setup ──
tauri_log "STEP" "Running Studio setup"
# When --local, use the repo's own setup.sh directly.
Expand Down
262 changes: 176 additions & 86 deletions studio/backend/core/export/export.py

Large diffs are not rendered by default.

354 changes: 354 additions & 0 deletions studio/backend/core/inference/mlx_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,354 @@
# SPDX-License-Identifier: AGPL-3.0-only
"""MLX inference backend for Apple Silicon.

Drop-in replacement for InferenceBackend — same interface, uses mlx-lm/mlx-vlm
instead of torch/transformers for model loading and generation.
"""

import threading
from typing import Optional, Generator
from loggers import get_logger

logger = get_logger(__name__)


class MLXInferenceBackend:
def __init__(self):
self.models = {}
self.active_model_name = None
self.loading_models = set()
self.loaded_local_models = []
self.device = "mlx"
self._generation_lock = threading.Lock()

# MLX state
self._model = None
self._tokenizer = None
self._processor = None
self._is_vlm = False
self._config = {}

def load_model(
self,
config,
max_seq_length = 2048,
load_in_4bit = True,
hf_token = None,
trust_remote_code = False,
gpu_ids = None,
dtype = None,
) -> bool:
import mlx.core as mx

model_name = config.identifier if hasattr(config, "identifier") else str(config)
is_vision = getattr(config, "is_vision", False)

if hf_token:
import os

os.environ["HF_TOKEN"] = hf_token

if mx.metal.is_available():
mx.set_wired_limit(mx.device_info()["max_recommended_working_set_size"])

is_lora = getattr(config, "is_lora", False)

logger.info(
"Loading %s via %s (is_lora=%s)",
model_name,
"mlx-vlm" if is_vision else "mlx-lm",
is_lora,
)

try:
from unsloth_zoo.mlx_loader import FastMLXModel
except ImportError as e:
raise ImportError(
"Unsloth: MLX inference requires unsloth-zoo with the MLX modules "
"(unsloth_zoo.mlx_loader). Reinstall via install.sh on Apple Silicon."
) from e

model, tokenizer_or_processor = FastMLXModel.from_pretrained(
model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = hf_token,
trust_remote_code = trust_remote_code,
text_only = False if is_vision else True,
)

if is_vision:
processor = tokenizer_or_processor
self._model = model
self._processor = processor
self._tokenizer = getattr(processor, "tokenizer", processor)
self._is_vlm = True
else:
tokenizer = tokenizer_or_processor
self._model = model
self._tokenizer = tokenizer
self._processor = None
self._is_vlm = False

self.active_model_name = model_name
self.models[model_name] = {
"model": self._model,
"tokenizer": self._tokenizer,
"processor": self._processor,
"is_vision": is_vision,
"is_lora": getattr(config, "is_lora", False),
"is_audio": False,
"audio_type": None,
"has_audio_input": False,
}

logger.info("Model %s loaded successfully", model_name)
return True

def unload_model(self, model_name: str) -> bool:
import mlx.core as mx
import gc

if model_name in self.models:
del self.models[model_name]
self._model = None
self._tokenizer = None
self._processor = None
if self.active_model_name == model_name:
self.active_model_name = None
gc.collect()
mx.clear_cache()
logger.info("Model %s unloaded", model_name)
return True

def generate_chat_response(
self,
messages,
system_prompt = "",
image = None,
temperature = 0.7,
top_p = 0.9,
top_k = 40,
min_p = 0.0,
max_new_tokens = 256,
repetition_penalty = 1.0,
cancel_event = None,
) -> Generator[str, None, None]:
if self._model is None:
raise RuntimeError("No model loaded")

# Build messages with system prompt
full_messages = []
if system_prompt:
full_messages.append({"role": "system", "content": system_prompt})
full_messages.extend(messages)

# Inject image into the last user message for VLM
if self._is_vlm and image is not None:
for msg in reversed(full_messages):
if msg.get("role") == "user":
content = msg.get("content", "")
if isinstance(content, str):
msg["content"] = [
{"type": "image"},
{"type": "text", "text": content},
]
elif isinstance(content, list):
# Prepend image if not already there
has_image = any(
p.get("type") == "image"
for p in content
if isinstance(p, dict)
)
if not has_image:
content.insert(0, {"type": "image"})
break

if self._is_vlm:
yield from self._generate_vlm(
full_messages,
image,
temperature,
top_p,
top_k,
min_p,
max_new_tokens,
repetition_penalty,
cancel_event,
)
else:
yield from self._generate_text(
full_messages,
temperature,
top_p,
top_k,
min_p,
max_new_tokens,
repetition_penalty,
cancel_event,
)

def _generate_text(
self,
messages,
temperature,
top_p,
top_k,
min_p,
max_new_tokens,
repetition_penalty,
cancel_event,
):
from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_sampler, make_logits_processors

prompt = self._tokenizer.apply_chat_template(
messages,
tokenize = False,
add_generation_prompt = True,
)
if prompt is None:
raise RuntimeError(
"apply_chat_template returned None — tokenizer may be incompatible"
)

sampler = make_sampler(
temp = temperature,
top_p = top_p,
top_k = int(top_k or 0),
min_p = float(min_p or 0.0),
min_tokens_to_keep = 1,
)
# Only build a logits processor when we actually have a non-trivial
# repetition penalty (1.0 is the no-op value).
logits_processors = None
if repetition_penalty is not None and float(repetition_penalty) not in (
0.0,
1.0,
):
logits_processors = make_logits_processors(
repetition_penalty = float(repetition_penalty),
)

token_ids = []
logger.info(
"Generating: prompt_len=%d, max_tokens=%d, model=%s, tokenizer=%s",
len(prompt),
max_new_tokens,
type(self._model).__name__,
type(self._tokenizer).__name__,
)
with self._generation_lock:
try:
gen_kwargs = dict(
prompt = prompt,
max_tokens = max_new_tokens,
sampler = sampler,
)
if logits_processors is not None:
gen_kwargs["logits_processors"] = logits_processors
for response in stream_generate(
self._model,
self._tokenizer,
**gen_kwargs,
):
token_ids.append(response.token)
# Decode full sequence with skip_special_tokens — same as GPU
cumulative = self._tokenizer.decode(
token_ids,
skip_special_tokens = True,
)
yield cumulative

if cancel_event and cancel_event.is_set():
break
except Exception as e:
import traceback

logger.error("stream_generate failed:\n%s", traceback.format_exc())
raise

def _generate_vlm(
self,
messages,
image,
temperature,
top_p,
top_k,
min_p,
max_new_tokens,
repetition_penalty,
cancel_event,
):
from mlx_vlm import stream_generate as vlm_stream

# Apply chat template
chat_fn = getattr(self._processor, "apply_chat_template", None)
if (
chat_fn is None
or not hasattr(self._processor, "chat_template")
or self._processor.chat_template is None
):
tok = getattr(self._processor, "tokenizer", self._processor)
chat_fn = tok.apply_chat_template

prompt = chat_fn(messages, tokenize = False, add_generation_prompt = True)

# For VLM: always use mlx_vlm's stream_generate which handles
# pixel_values properly (passes None for text-only, image for VLM)
images = [image] if image is not None else None

cumulative = ""
logger.info(
"VLM generating: prompt_len=%d, has_image=%s",
len(prompt),
image is not None,
)
# mlx_vlm.stream_generate forwards **kwargs into generate_step, which
# accepts temp/top_p/top_k/repetition_penalty (and builds the sampler
# + logits_processors internally). Pass them through.
# NOTE: mlx_vlm.generate_step expects ``temperature=`` (long form) —
# passing ``temp=`` silently falls into **kwargs and is ignored,
# leaving generation stuck at the default 0.0 (greedy).
vlm_kwargs = dict(
max_tokens = max_new_tokens,
temperature = temperature,
top_p = top_p,
top_k = int(top_k or 0),
min_p = float(min_p or 0.0),
)
if repetition_penalty is not None and float(repetition_penalty) not in (
0.0,
1.0,
):
vlm_kwargs["repetition_penalty"] = float(repetition_penalty)

with self._generation_lock:
for response in vlm_stream(
self._model,
self._processor,
prompt,
images,
**vlm_kwargs,
):
token_text = (
response.text if hasattr(response, "text") else str(response)
)
cumulative += token_text
yield cumulative
if cancel_event and cancel_event.is_set():
break

def generate_with_adapter_control(
self, use_adapter = None, cancel_event = None, **gen_kwargs
) -> Generator[str, None, None]:
# MLX LoRA adapter toggling not yet supported — generate normally
yield from self.generate_chat_response(cancel_event = cancel_event, **gen_kwargs)

def reset_generation_state(self):
import mlx.core as mx
import gc

gc.collect()
mx.clear_cache()
Loading