",
+ delta_text="div>",
+ )
+ assert r is not None
+ assert "content" in r
+ assert "<" in r["content"]
+ assert "div>" in r["content"]
+
+ def test_streaming_multiple_function_blocks(self, parser):
+ """Test streaming with multiple
{"a": 1}',
+ "\n",
+ "
",
+ "2",
+ "",
+ ]
+ accumulated = ""
+ emitted_calls = []
+ for chunk in chunks:
+ prev = accumulated
+ accumulated += chunk
+ r = parser.extract_tool_calls_streaming(
+ previous_text=prev,
+ current_text=accumulated,
+ delta_text=chunk,
+ )
+ if r is not None and "tool_calls" in r:
+ emitted_calls.extend(r["tool_calls"])
+ assert len(emitted_calls) == 2
+ assert emitted_calls[0]["function"]["name"] == "func1"
+ assert emitted_calls[1]["function"]["name"] == "func2"
diff --git a/vllm_mlx/api/anthropic_adapter.py b/vllm_mlx/api/anthropic_adapter.py
index dbb94200f..62c6757b5 100644
--- a/vllm_mlx/api/anthropic_adapter.py
+++ b/vllm_mlx/api/anthropic_adapter.py
@@ -9,6 +9,7 @@
"""
import json
+import re
import uuid
from .anthropic_models import (
@@ -60,6 +61,10 @@ def anthropic_to_openai(request: AnthropicRequest) -> ChatCompletionRequest:
system_text = "\n".join(parts)
else:
system_text = str(request.system)
+ # Strip per-request billing/tracking headers injected by some
+ # clients (e.g. Claude Code). These contain a per-request hash
+ # that prevents prefix-cache reuse across turn boundaries.
+ system_text = re.sub(r"x-anthropic-billing-header:[^\n]*\n?", "", system_text)
messages.append(Message(role="system", content=system_text))
# Convert each message
diff --git a/vllm_mlx/api/anthropic_models.py b/vllm_mlx/api/anthropic_models.py
index a5bc6f776..e8854a5e6 100644
--- a/vllm_mlx/api/anthropic_models.py
+++ b/vllm_mlx/api/anthropic_models.py
@@ -84,8 +84,10 @@ class AnthropicUsage(BaseModel):
class AnthropicResponseContentBlock(BaseModel):
"""A content block in the Anthropic response."""
- type: str # "text" or "tool_use"
+ type: str # "text", "thinking", or "tool_use"
text: str | None = None
+ # thinking block
+ thinking: str | None = None
# tool_use fields
id: str | None = None
name: str | None = None
diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py
index f7bcaaaa5..8af8c9dca 100644
--- a/vllm_mlx/api/models.py
+++ b/vllm_mlx/api/models.py
@@ -159,6 +159,10 @@ class ChatCompletionRequest(BaseModel):
messages: list[Message]
temperature: float | None = None
top_p: float | None = None
+ top_k: int | None = None
+ min_p: float | None = None
+ presence_penalty: float | None = None
+ repetition_penalty: float | None = None
max_tokens: int | None = None
stream: bool = False
stream_options: StreamOptions | None = (
@@ -175,12 +179,16 @@ class ChatCompletionRequest(BaseModel):
# MLLM-specific parameters
video_fps: float | None = None
video_max_frames: int | None = None
+ # Sampling penalties
+ repetition_penalty: float | None = None # mlx-lm style (>1.0 penalizes)
# Request timeout in seconds (None = use server default)
timeout: float | None = None
# SpecPrefill: per-request enable/disable (None = server decides)
specprefill: bool | None = None
# SpecPrefill: per-request keep percentage (0.0-1.0, None = use server default)
specprefill_keep_pct: float | None = None
+ # Enable/disable thinking mode (None = server default, typically True)
+ enable_thinking: bool | None = None
class AssistantMessage(BaseModel):
@@ -239,11 +247,21 @@ class CompletionRequest(BaseModel):
prompt: str | list[str]
temperature: float | None = None
top_p: float | None = None
+ top_k: int | None = None
+ min_p: float | None = None
+ presence_penalty: float | None = None
+ repetition_penalty: float | None = None
max_tokens: int | None = None
stream: bool = False
stop: list[str] | None = None
+ # Sampling penalties
+ repetition_penalty: float | None = None # mlx-lm style (>1.0 penalizes)
# Request timeout in seconds (None = use server default)
timeout: float | None = None
+ # SpecPrefill: per-request enable/disable (None = server decides)
+ specprefill: bool | None = None
+ # SpecPrefill: per-request keep percentage (0.0-1.0, None = use server default)
+ specprefill_keep_pct: float | None = None
class CompletionChoice(BaseModel):
diff --git a/vllm_mlx/api/tool_calling.py b/vllm_mlx/api/tool_calling.py
index 1443c1674..364b65993 100644
--- a/vllm_mlx/api/tool_calling.py
+++ b/vllm_mlx/api/tool_calling.py
@@ -89,6 +89,7 @@ def parse_tool_calls(
Parse tool calls from model output.
Supports multiple formats:
+ - MiniMax:
v
- Qwen3 bracket: [Calling tool: function_name({"arg": "value"})]
- Qwen:
{"name": "...", "arguments": {...}}
- Llama:
{"arg": "value"}
@@ -106,6 +107,47 @@ def parse_tool_calls(
tool_calls = []
cleaned_text = text
+ # Pattern for MiniMax-style:
v
+ minimax_pattern = r"
\s*(.*?)\s*"
+ minimax_matches = re.findall(minimax_pattern, text, re.DOTALL)
+
+ for invoke_block in minimax_matches:
+ # Parse
blocks within the tool_call
+ invoke_pattern = r'(.*?)'
+ invoke_matches = re.findall(invoke_pattern, invoke_block, re.DOTALL)
+
+ for name, params_block in invoke_matches:
+ # Parse value pairs
+ param_pattern = r'\s*(.*?)\s*'
+ params = re.findall(param_pattern, params_block, re.DOTALL)
+ arguments = {}
+ for p_name, p_value in params:
+ # Try to parse value as JSON (for nested objects/arrays/numbers)
+ try:
+ arguments[p_name] = json.loads(p_value)
+ except (json.JSONDecodeError, ValueError):
+ arguments[p_name] = p_value
+
+ tool_calls.append(
+ ToolCall(
+ id=f"call_{uuid.uuid4().hex[:8]}",
+ type="function",
+ function=FunctionCall(
+ name=name.strip(),
+ arguments=json.dumps(arguments),
+ ),
+ )
+ )
+
+ # Remove MiniMax tool call tags from cleaned text
+ if minimax_matches:
+ cleaned_text = re.sub(
+ r"\s*.*?\s*",
+ "",
+ cleaned_text,
+ flags=re.DOTALL,
+ ).strip()
+
# Pattern for Qwen3 bracket-style: [Calling tool: function_name({...})]
bracket_pattern = r"\[Calling tool:\s*(\w+)\((\{.*?\})\)\]"
bracket_matches = re.findall(bracket_pattern, text, re.DOTALL)
diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py
index 9fdbfef13..6218dce7d 100644
--- a/vllm_mlx/api/utils.py
+++ b/vllm_mlx/api/utils.py
@@ -20,7 +20,9 @@
r"<\|im_end\|>|<\|im_start\|>|<\|endoftext\|>|"
r"<\|end\|>|<\|eot_id\|>|<\|start_header_id\|>|<\|end_header_id\|>|"
r"<\|channel\|>|<\|message\|>|<\|start\|>|<\|return\|>|<\|call\|>|<\|constrain\|>|"
- r"|||\[PAD\]|\[SEP\]|\[CLS\]"
+ r"|||\[PAD\]|\[SEP\]|\[CLS\]|"
+ r"\[e~\[|\]~b\][a-z]*|\]~!b\[|"
+ r"?tool_call>|?tool_call_reasoning>"
)
@@ -121,6 +123,7 @@ def clean_output_text(text: str) -> str:
("", ""),
("", ""),
(""),
+ ("<|tool_call>", ""),
("[TOOL_CALL]", "[/TOOL_CALL]"),
("[Calling tool", "]\n"), # Qwen3 bracket-style: [Calling tool: func({...})]\n
]
@@ -339,6 +342,8 @@ def flush(self) -> list[tuple[str, str]]:
"PaliGemma", # PaliGemma
"gemma-3",
"gemma3", # Gemma 3 (multimodal)
+ "gemma-4",
+ "gemma4", # Gemma 4 (multimodal: vision + audio)
"medgemma",
"MedGemma", # MedGemma (medical multimodal with SigLIP vision encoder)
"pixtral",
@@ -353,6 +358,8 @@ def flush(self) -> list[tuple[str, str]]:
"InternVL", # InternVL
"deepseek-vl",
"DeepSeek-VL", # DeepSeek-VL
+ "Qwen3.5-",
+ "qwen3_5", # Qwen3.5 MoE (natively multimodal, hybrid ArraysCache+KVCache)
]
diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py
index 8a90bc9be..07dd17fe1 100644
--- a/vllm_mlx/cli.py
+++ b/vllm_mlx/cli.py
@@ -37,6 +37,13 @@ def serve_command(args):
print("Example: --enable-auto-tool-choice --tool-call-parser mistral")
sys.exit(1)
+ # Validate gpu-memory-utilization range
+ if not (0.0 < args.gpu_memory_utilization <= 1.0):
+ print(
+ "Error: --gpu-memory-utilization must be between 0.0 (exclusive) and 1.0 (inclusive)"
+ )
+ sys.exit(1)
+
# Configure server security settings
server._api_key = args.api_key
server._default_timeout = args.timeout
@@ -105,6 +112,21 @@ def serve_command(args):
print(" Reasoning: Use --reasoning-parser to enable")
print("=" * 60)
+ # Pre-download model with retry/timeout
+ from .api.utils import is_mllm_model
+ from .utils.download import DownloadConfig, ensure_model_downloaded
+
+ download_config = DownloadConfig(
+ download_timeout=args.download_timeout,
+ max_retries=args.download_retries,
+ offline=getattr(args, "offline", False),
+ )
+ ensure_model_downloaded(
+ args.model,
+ config=download_config,
+ is_mllm=is_mllm_model(args.model),
+ )
+
print(f"Loading model: {args.model}")
print(f"Default max tokens: {args.max_tokens}")
@@ -150,6 +172,9 @@ def serve_command(args):
kv_cache_quantization_bits=args.kv_cache_quantization_bits,
kv_cache_quantization_group_size=args.kv_cache_quantization_group_size,
kv_cache_min_quantize_tokens=args.kv_cache_min_quantize_tokens,
+ mllm_prefill_step_size=(
+ args.mllm_prefill_step_size if args.mllm_prefill_step_size > 0 else None
+ ),
)
print("Mode: Continuous batching (for multiple concurrent users)")
@@ -196,7 +221,8 @@ def serve_command(args):
scheduler_config=scheduler_config,
stream_interval=args.stream_interval if args.continuous_batching else 1,
max_tokens=args.max_tokens,
- force_mllm=args.mllm,
+ force_mllm=getattr(args, "mllm", False),
+ gpu_memory_utilization=args.gpu_memory_utilization,
served_model_name=args.served_model_name,
mtp=args.enable_mtp,
prefill_step_size=args.prefill_step_size,
@@ -211,6 +237,23 @@ def serve_command(args):
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
+def download_command(args):
+ """Download a model to local cache without starting a server."""
+ from .utils.download import DownloadConfig, ensure_model_downloaded
+
+ config = DownloadConfig(
+ download_timeout=args.timeout,
+ max_retries=args.retries,
+ )
+ print(f"Downloading model: {args.model}")
+ path = ensure_model_downloaded(
+ args.model,
+ config=config,
+ is_mllm=args.mllm,
+ )
+ print(f"Model ready at: {path}")
+
+
def bench_command(args):
"""Run benchmark."""
import asyncio
@@ -249,6 +292,7 @@ async def run_benchmark():
kv_cache_quantization_group_size=args.kv_cache_quantization_group_size,
kv_cache_min_quantize_tokens=args.kv_cache_min_quantize_tokens,
)
+
engine_config = EngineConfig(
model_name=args.model,
scheduler_config=scheduler_config,
@@ -593,7 +637,8 @@ def bench_kv_cache_command(args):
)
-def main():
+def create_parser() -> argparse.ArgumentParser:
+ """Build the top-level CLI parser."""
parser = argparse.ArgumentParser(
description="vllm-mlx: Apple Silicon MLX backend for vLLM",
formatter_class=argparse.RawDescriptionHelpFormatter,
@@ -627,6 +672,12 @@ def main():
serve_parser.add_argument(
"--completion-batch-size", type=int, default=32, help="Completion batch size"
)
+ serve_parser.add_argument(
+ "--mllm-prefill-step-size",
+ type=int,
+ default=0,
+ help="Override MLLM prefill-step guard (0=use MLLM default: 1024)",
+ )
serve_parser.add_argument(
"--enable-prefix-cache",
action="store_true",
@@ -704,6 +755,14 @@ def main():
action="store_true",
help="Enable continuous batching for multiple concurrent users (slower for single user)",
)
+ serve_parser.add_argument(
+ "--gpu-memory-utilization",
+ type=float,
+ default=0.90,
+ help="Fraction of device memory for Metal allocation limit and emergency "
+ "cache clear threshold (0.0-1.0, default: 0.90). Increase to 0.95 for "
+ "large models (200GB+) that need more memory headroom.",
+ )
# Paged cache options (experimental)
serve_parser.add_argument(
"--use-paged-cache",
@@ -832,18 +891,23 @@ def main():
"qwen3_coder",
"llama",
"hermes",
+ "harmony",
+ "gpt-oss",
"deepseek",
"kimi",
"granite",
"nemotron",
"xlam",
"functionary",
+ "gemma4",
"glm47",
+ "minimax",
],
help=(
"Select the tool call parser for the model. Options: "
"auto (auto-detect), mistral, qwen, qwen3_coder, llama, hermes, "
- "deepseek, kimi, granite, nemotron, xlam, functionary, glm47. "
+ "harmony, gpt-oss, deepseek, gemma4, kimi, granite, nemotron, "
+ "xlam, functionary, glm47, minimax. "
"Required for --enable-auto-tool-choice."
),
)
@@ -888,6 +952,24 @@ def main():
default=None,
help="Pre-load an embedding model at startup (e.g. mlx-community/embeddinggemma-300m-6bit)",
)
+ # Download options
+ serve_parser.add_argument(
+ "--download-timeout",
+ type=int,
+ default=300,
+ help="Per-file download timeout in seconds (default: 300)",
+ )
+ serve_parser.add_argument(
+ "--download-retries",
+ type=int,
+ default=3,
+ help="Number of download retry attempts (default: 3)",
+ )
+ serve_parser.add_argument(
+ "--offline",
+ action="store_true",
+ help="Offline mode — only use locally cached models",
+ )
# Bench command
bench_parser = subparsers.add_parser("bench", help="Run benchmark")
bench_parser.add_argument("model", type=str, help="Model to benchmark")
@@ -1023,6 +1105,34 @@ def main():
help="Quantization group size (default: 64)",
)
+ # Download command
+ download_parser = subparsers.add_parser(
+ "download", help="Download a model to local cache without starting a server"
+ )
+ download_parser.add_argument("model", type=str, help="Model to download")
+ download_parser.add_argument(
+ "--timeout",
+ type=int,
+ default=300,
+ help="Per-file download timeout in seconds (default: 300)",
+ )
+ download_parser.add_argument(
+ "--retries",
+ type=int,
+ default=3,
+ help="Number of retry attempts (default: 3)",
+ )
+ download_parser.add_argument(
+ "--mllm",
+ action="store_true",
+ help="Download as multimodal model (broader file patterns)",
+ )
+
+ return parser
+
+
+def main():
+ parser = create_parser()
args = parser.parse_args()
if args.command == "serve":
@@ -1033,6 +1143,8 @@ def main():
bench_detok_command(args)
elif args.command == "bench-kv-cache":
bench_kv_cache_command(args)
+ elif args.command == "download":
+ download_command(args)
else:
parser.print_help()
sys.exit(1)
diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py
index 0f0f8f0f1..cb9c8aad8 100644
--- a/vllm_mlx/engine/batched.py
+++ b/vllm_mlx/engine/batched.py
@@ -137,6 +137,7 @@ def __init__(
scheduler_config: Any | None = None,
stream_interval: int = 1,
force_mllm: bool = False,
+ gpu_memory_utilization: float = 0.90,
):
"""
Initialize the batched engine.
@@ -147,11 +148,14 @@ def __init__(
scheduler_config: Optional scheduler configuration
stream_interval: Tokens to batch before streaming (1=every token)
force_mllm: Force loading as MLLM even if not auto-detected
+ gpu_memory_utilization: Fraction of device memory for Metal allocation
+ limit and emergency threshold (0.0-1.0, default 0.90)
"""
self._model_name = model_name
self._trust_remote_code = trust_remote_code
self._scheduler_config = scheduler_config
self._stream_interval = stream_interval
+ self._gpu_memory_utilization = gpu_memory_utilization
self._is_mllm = force_mllm or is_mllm_model(model_name)
self._model = None
@@ -207,6 +211,10 @@ async def _start_mllm(self) -> None:
self._model = self._mllm_instance.model
self._processor = self._mllm_instance.processor
+ # Inject MTP support if enabled
+ if self._scheduler_config and self._scheduler_config.enable_mtp:
+ self._inject_mtp_mllm()
+
# Create MLLM scheduler config with batch generator support
if self._scheduler_config and hasattr(self._scheduler_config, "max_num_seqs"):
max_num_seqs = self._scheduler_config.max_num_seqs
@@ -219,12 +227,38 @@ async def _start_mllm(self) -> None:
self._scheduler_config, "completion_batch_size", 16
)
+ cache_memory_mb = getattr(self._scheduler_config, "cache_memory_mb", None)
+ enable_mtp = (
+ self._scheduler_config.enable_mtp if self._scheduler_config else False
+ )
+ mtp_num_draft = getattr(self._scheduler_config, "mtp_num_draft_tokens", 1)
+ kv_quant = getattr(self._scheduler_config, "kv_cache_quantization", False)
+ kv_bits = getattr(self._scheduler_config, "kv_cache_quantization_bits", 8)
+ kv_group_size = getattr(
+ self._scheduler_config, "kv_cache_quantization_group_size", 64
+ )
+
+ # Forward MLLM prefill-step override only when explicitly configured.
+ # This keeps default behavior unchanged for MLLM (1024) unless set.
+ prefill_step_size = getattr(
+ self._scheduler_config, "mllm_prefill_step_size", None
+ )
+ mllm_extra = {}
+ if prefill_step_size is not None:
+ mllm_extra["prefill_step_size"] = prefill_step_size
mllm_config = MLLMSchedulerConfig(
max_num_seqs=max_num_seqs,
prefill_batch_size=prefill_batch_size,
completion_batch_size=completion_batch_size,
enable_vision_cache=True,
vision_cache_size=100,
+ cache_memory_mb=cache_memory_mb,
+ enable_mtp=enable_mtp,
+ mtp_num_draft_tokens=mtp_num_draft,
+ kv_cache_quantization=kv_quant,
+ kv_cache_quantization_bits=kv_bits,
+ kv_cache_quantization_group_size=kv_group_size,
+ **mllm_extra,
)
# Create and start MLLM scheduler
@@ -238,9 +272,58 @@ async def _start_mllm(self) -> None:
logger.info(
f"MLLM Scheduler started with continuous batching: "
f"max_num_seqs={max_num_seqs}, prefill_batch={prefill_batch_size}, "
- f"completion_batch={completion_batch_size}"
+ f"completion_batch={completion_batch_size}, "
+ f"prefill_step_size={mllm_config.prefill_step_size}"
)
+ def _inject_mtp_mllm(self) -> None:
+ """Inject MTP weights into the MLLM model's language_model."""
+ import json
+ from pathlib import Path
+
+ from mlx_lm.utils import _download
+
+ model = self._model
+ model_path = Path(_download(self._model_name))
+ config_path = model_path / "config.json"
+ if not config_path.exists():
+ logger.warning("[MTP-MLLM] No config.json found, skipping MTP")
+ return
+
+ with open(config_path) as f:
+ config = json.load(f)
+
+ text_config = config.get("text_config", config)
+ num_mtp = text_config.get("mtp_num_hidden_layers", 0)
+ if num_mtp == 0:
+ num_mtp = text_config.get(
+ "num_nextn_predict_layers",
+ config.get("num_nextn_predict_layers", 0),
+ )
+ if num_mtp == 0:
+ logger.info("[MTP-MLLM] No MTP layers in config, skipping")
+ return
+
+ # Navigate to text model
+ text_model = model
+ if hasattr(model, "language_model"):
+ text_model = model.language_model
+ if getattr(text_model, "mtp", None) is not None:
+ logger.info("[MTP-MLLM] Model already has MTP, skipping injection")
+ return
+
+ model_type = text_config.get("model_type", config.get("model_type", ""))
+ if "qwen3_5" in model_type:
+ from ..patches.qwen3_5_mtp import inject_mtp_support
+
+ ok = inject_mtp_support(model, model_path, config)
+ if ok:
+ logger.info("[MTP-MLLM] Qwen3.5 MTP injected successfully")
+ else:
+ logger.warning("[MTP-MLLM] Qwen3.5 MTP injection failed")
+ else:
+ logger.info(f"[MTP-MLLM] MTP not supported for model_type={model_type}")
+
async def _start_llm(self) -> None:
"""Start the LLM engine with AsyncEngineCore."""
from ..engine_core import AsyncEngineCore, EngineConfig
@@ -261,9 +344,10 @@ async def _start_llm(self) -> None:
# Validate MTP support if enabled
if self._scheduler_config and self._scheduler_config.enable_mtp:
+ from ..patches.qwen3_5_mtp import validate_mtp_support as validate_35
from ..patches.qwen3_next_mtp import validate_mtp_support
- if validate_mtp_support(self._model):
+ if validate_mtp_support(self._model) or validate_35(self._model):
logger.info("[MTP] Model validated for MTP speculative decoding")
else:
logger.warning(
@@ -283,13 +367,14 @@ async def _start_llm(self) -> None:
device_info.get("memory_size", 0),
)
if max_recommended > 0:
- soft_limit = int(max_recommended * 0.90)
+ soft_limit = int(max_recommended * self._gpu_memory_utilization)
mx.set_memory_limit(soft_limit)
mx.set_cache_limit(32 * 1024 * 1024 * 1024) # 32GB
+ pct = self._gpu_memory_utilization * 100
logger.info(
f"Metal memory limits set: "
f"allocation_limit={soft_limit / 1e9:.1f}GB "
- f"(90% of {max_recommended / 1e9:.1f}GB), "
+ f"({pct:.0f}% of {max_recommended / 1e9:.1f}GB), "
f"cache_limit=32GB"
)
except Exception as e:
@@ -301,6 +386,7 @@ async def _start_llm(self) -> None:
model_name=self._model_name,
scheduler_config=scheduler_config,
stream_interval=self._stream_interval,
+ gpu_memory_utilization=self._gpu_memory_utilization,
)
# Create async engine
@@ -336,6 +422,7 @@ def _apply_chat_template(
tools: list[dict] | None = None,
num_images: int = 0,
chat_template_kwargs: dict[str, Any] | None = None,
+ enable_thinking: bool | None = None,
) -> str:
"""Apply chat template to messages.
@@ -364,9 +451,13 @@ def _apply_chat_template(
if self._is_mllm and num_images > 0:
messages = self._prepare_mllm_messages(messages)
+ # Per-request enable_thinking override; default: True unless coder model.
+ if enable_thinking is None:
+ enable_thinking = "coder" not in self._model_name.lower()
template_kwargs = {
"tokenize": False,
"add_generation_prompt": True,
+ "enable_thinking": enable_thinking,
}
if chat_template_kwargs:
template_kwargs.update(chat_template_kwargs)
@@ -380,7 +471,7 @@ def _apply_chat_template(
except TypeError as e:
# Some templates don't accept extra kwargs; retry without them.
logger.debug(f"Chat template TypeError, retrying without extras: {e}")
- for key in ["tools", *(chat_template_kwargs or {}).keys()]:
+ for key in ["tools", "enable_thinking", *(chat_template_kwargs or {}).keys()]:
template_kwargs.pop(key, None)
return template_applicator.apply_chat_template(
messages, **template_kwargs
@@ -466,10 +557,15 @@ async def generate(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
+ top_k=kwargs.pop("top_k", 0),
+ min_p=kwargs.pop("min_p", 0.0),
+ presence_penalty=kwargs.pop("presence_penalty", 0.0),
+ repetition_penalty=kwargs.pop("repetition_penalty", 1.0),
)
return GenerationOutput(
text=clean_output_text(output.output_text),
+ tokens=output.output_token_ids,
prompt_tokens=output.prompt_tokens,
completion_tokens=output.completion_tokens,
finish_reason=output.finish_reason,
@@ -482,6 +578,10 @@ async def generate(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
+ top_k=kwargs.pop("top_k", 0),
+ min_p=kwargs.pop("min_p", 0.0),
+ presence_penalty=kwargs.pop("presence_penalty", 0.0),
+ repetition_penalty=kwargs.pop("repetition_penalty", 1.0),
stop=stop or [],
)
@@ -494,6 +594,7 @@ async def generate(
return GenerationOutput(
text=text,
+ tokens=output.output_token_ids,
prompt_tokens=output.prompt_tokens,
completion_tokens=output.completion_tokens,
finish_reason=output.finish_reason,
@@ -538,6 +639,10 @@ async def stream_generate(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
+ top_k=kwargs.pop("top_k", 0),
+ min_p=kwargs.pop("min_p", 0.0),
+ presence_penalty=kwargs.pop("presence_penalty", 0.0),
+ repetition_penalty=kwargs.pop("repetition_penalty", 1.0),
)
async for output in self._mllm_scheduler.stream_outputs(request_id):
@@ -558,6 +663,10 @@ async def stream_generate(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
+ top_k=kwargs.pop("top_k", 0),
+ min_p=kwargs.pop("min_p", 0.0),
+ presence_penalty=kwargs.pop("presence_penalty", 0.0),
+ repetition_penalty=kwargs.pop("repetition_penalty", 1.0),
stop=stop or [],
)
@@ -624,12 +733,16 @@ async def chat(
template_tools = convert_tools_for_template(tools) if tools else None
chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {})
+ # Per-request enable_thinking override
+ enable_thinking = kwargs.pop("enable_thinking", None)
+
# Apply chat template
prompt = self._apply_chat_template(
messages,
template_tools,
num_images=len(all_images),
chat_template_kwargs=chat_template_kwargs,
+ enable_thinking=enable_thinking,
)
return await self.generate(
@@ -748,12 +861,16 @@ async def stream_chat(
template_tools = convert_tools_for_template(tools) if tools else None
chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {})
+ # Per-request enable_thinking override
+ enable_thinking = kwargs.pop("enable_thinking", None)
+
# Apply chat template
prompt = self._apply_chat_template(
messages,
template_tools,
num_images=len(all_images),
chat_template_kwargs=chat_template_kwargs,
+ enable_thinking=enable_thinking,
)
# Compute prefix boundary for cache
@@ -789,14 +906,27 @@ def get_stats(self) -> dict[str, Any]:
if self._mllm_scheduler:
mllm_stats = self._mllm_scheduler.get_stats()
stats["mllm_scheduler"] = mllm_stats
- # Promote Metal memory stats to top-level for /v1/status
+ # Promote stats to top-level for /v1/status and monitoring
for key in (
+ "running",
+ "num_running",
+ "num_waiting",
+ "num_requests_processed",
+ "total_prompt_tokens",
+ "total_completion_tokens",
"metal_active_memory_gb",
"metal_peak_memory_gb",
"metal_cache_memory_gb",
+ "memory_aware_cache",
+ "paged_cache",
+ "prefix_cache",
+ "requests",
):
if key in mllm_stats:
stats[key] = mllm_stats[key]
+ # MLLM engine is always "running" once loaded
+ if "running" not in stats:
+ stats["running"] = self._loaded
elif self._engine:
stats.update(self._engine.get_stats())
@@ -804,20 +934,28 @@ def get_stats(self) -> dict[str, Any]:
def get_cache_stats(self) -> dict[str, Any] | None:
"""Get cache statistics."""
- if self._mllm_scheduler and self._mllm_scheduler.vision_cache:
- return self._mllm_scheduler.vision_cache.get_stats()
+ if self._mllm_scheduler and self._mllm_scheduler.batch_generator:
+ return self._mllm_scheduler.batch_generator.get_vision_cache_stats()
elif self._engine:
return self._engine.get_cache_stats()
return None
def save_cache_to_disk(self, cache_dir: str) -> bool:
"""Save prefix cache to disk for persistence across restarts."""
+ if self._mllm_scheduler and self._mllm_scheduler.batch_generator:
+ pc = self._mllm_scheduler.batch_generator.prefix_cache
+ if pc is not None:
+ return pc.save_to_disk(cache_dir)
if self._engine:
return self._engine.save_cache_to_disk(cache_dir)
return False
def load_cache_from_disk(self, cache_dir: str) -> int:
"""Load prefix cache from disk. Returns number of entries loaded."""
+ if self._mllm_scheduler and self._mllm_scheduler.batch_generator:
+ pc = self._mllm_scheduler.batch_generator.prefix_cache
+ if pc is not None:
+ return pc.load_from_disk(cache_dir)
if self._engine:
return self._engine.load_cache_from_disk(cache_dir)
return 0
diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py
index 768a0f4ad..681118998 100644
--- a/vllm_mlx/engine/simple.py
+++ b/vllm_mlx/engine/simple.py
@@ -226,6 +226,24 @@ async def stop(self) -> None:
self._system_kv_token_count = 0
logger.info("SimpleEngine stopped")
+ async def _run_blocking_serialized(self, func, /, *args, **kwargs):
+ """Run a blocking MLX operation under the generation lock.
+
+ Cancellation must not release the async lock before the worker thread
+ finishes, or a follow-up request can enter MLX/Metal concurrently and
+ corrupt the command-buffer state.
+ """
+ async with self._generation_lock:
+ task = asyncio.create_task(asyncio.to_thread(func, *args, **kwargs))
+ try:
+ return await asyncio.shield(task)
+ except asyncio.CancelledError:
+ try:
+ await task
+ except BaseException:
+ pass
+ raise
+
async def generate(
self,
prompt: str,
@@ -238,13 +256,27 @@ async def generate(
"""
Generate a complete response (non-streaming).
+ Thin accumulator over stream_generate(). stream_generate() is the
+ only code path that consumes per-request SpecPrefill overrides
+ (`specprefill`, `specprefill_keep_pct`) and routes through
+ _stream_generate_specprefill() when engaged. The prior direct
+ self._model.generate() path silently dropped those overrides for
+ non-streaming /v1/completions callers, so extra_body.specprefill
+ was advertised by the server but had no effect on this route.
+
+ By iterating stream_generate() and returning the last
+ GenerationOutput, non-streaming clients get the same SpecPrefill
+ engagement, accurate prompt_tokens reporting, and per-request
+ override support as streaming clients.
+
Args:
prompt: Input text
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling
stop: Stop sequences
- **kwargs: Additional model-specific parameters
+ **kwargs: Additional parameters forwarded to stream_generate,
+ including per-request `specprefill` / `specprefill_keep_pct`
Returns:
GenerationOutput with complete text
@@ -252,30 +284,29 @@ async def generate(
if not self._loaded:
await self.start()
- async with self._generation_lock:
- # Run in thread pool to allow asyncio timeout to work
- output = await asyncio.to_thread(
- self._model.generate,
- prompt=prompt,
- max_tokens=max_tokens,
- temperature=temperature,
- top_p=top_p,
- stop=stop,
- **kwargs,
- )
-
- # Clean output text
- text = clean_output_text(output.text)
-
- return GenerationOutput(
- text=text,
- tokens=getattr(output, "tokens", []),
- prompt_tokens=getattr(output, "prompt_tokens", 0),
- completion_tokens=getattr(
- output, "completion_tokens", len(getattr(output, "tokens", []))
- ),
- finish_reason=output.finish_reason,
- )
+ last_output: GenerationOutput | None = None
+ async for output in self.stream_generate(
+ prompt=prompt,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ stop=stop,
+ **kwargs,
+ ):
+ last_output = output
+
+ if last_output is None:
+ return GenerationOutput(text="", finish_reason="stop")
+
+ text = clean_output_text(last_output.text)
+ return GenerationOutput(
+ text=text,
+ tokens=list(last_output.tokens),
+ prompt_tokens=last_output.prompt_tokens,
+ completion_tokens=last_output.completion_tokens,
+ finish_reason=last_output.finish_reason,
+ finished=True,
+ )
async def stream_generate(
self,
@@ -439,61 +470,84 @@ async def chat(
chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {})
+ # mlx-lm non-streaming chat with tools can stall indefinitely on some
+ # local models, while the streaming path completes normally. Reuse the
+ # streaming implementation and aggregate its final state so both chat
+ # APIs share the same tool-capable execution path.
+ if tools and not self._is_mllm:
+ final_output = GenerationOutput(text="")
+ async for output in self.stream_chat(
+ messages=messages,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ tools=tools,
+ images=images,
+ videos=videos,
+ chat_template_kwargs=chat_template_kwargs,
+ **kwargs,
+ ):
+ final_output = output
+ text = clean_output_text(final_output.text)
+ return GenerationOutput(
+ text=text,
+ tokens=list(final_output.tokens),
+ prompt_tokens=final_output.prompt_tokens,
+ completion_tokens=final_output.completion_tokens,
+ finish_reason=final_output.finish_reason,
+ )
+
# Convert tools for template if provided
template_tools = convert_tools_for_template(tools) if tools else None
- async with self._generation_lock:
- if self._is_mllm:
- # For MLLM, use the chat method which handles images/videos
- # Run in thread pool to allow asyncio timeout to work
- if chat_template_kwargs:
- kwargs["chat_template_kwargs"] = chat_template_kwargs
- output = await asyncio.to_thread(
- self._model.chat,
- messages=messages,
- max_tokens=max_tokens,
- temperature=temperature,
- tools=template_tools,
- **kwargs,
- )
- text = clean_output_text(output.text)
- return GenerationOutput(
- text=text,
- prompt_tokens=output.prompt_tokens,
- completion_tokens=output.completion_tokens,
- finish_reason=output.finish_reason,
- )
- else:
- # For LLM, use the chat method
- # Run in thread pool to allow asyncio timeout to work
- output = await asyncio.to_thread(
- self._model.chat,
- messages=messages,
- max_tokens=max_tokens,
- temperature=temperature,
- top_p=top_p,
- tools=template_tools,
- chat_template_kwargs=chat_template_kwargs,
- **kwargs,
- )
- text = clean_output_text(output.text)
- # Count prompt tokens from the full templated prompt
- tokenizer = self._model.tokenizer
- template_kwargs = {
- "tokenize": True,
- "add_generation_prompt": True,
- }
- if template_tools:
- template_kwargs["tools"] = template_tools
- prompt_ids = tokenizer.apply_chat_template(messages, **template_kwargs)
- prompt_token_count = len(prompt_ids)
- return GenerationOutput(
- text=text,
- tokens=output.tokens,
- prompt_tokens=prompt_token_count,
- completion_tokens=len(output.tokens),
- finish_reason=output.finish_reason,
- )
+ if self._is_mllm:
+ if chat_template_kwargs:
+ kwargs["chat_template_kwargs"] = chat_template_kwargs
+ output = await self._run_blocking_serialized(
+ self._model.chat,
+ messages=messages,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ tools=template_tools,
+ **kwargs,
+ )
+ text = clean_output_text(output.text)
+ return GenerationOutput(
+ text=text,
+ prompt_tokens=output.prompt_tokens,
+ completion_tokens=output.completion_tokens,
+ finish_reason=output.finish_reason,
+ )
+ else:
+ output = await self._run_blocking_serialized(
+ self._model.chat,
+ messages=messages,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ tools=template_tools,
+ chat_template_kwargs=chat_template_kwargs,
+ **kwargs,
+ )
+ text = clean_output_text(output.text)
+ # Preserve upstream prompt accounting while routing the blocking
+ # chat call through the cancellation-safe serialized runner.
+ tokenizer = self._model.tokenizer
+ template_kwargs = {
+ "tokenize": True,
+ "add_generation_prompt": True,
+ }
+ if template_tools:
+ template_kwargs["tools"] = template_tools
+ prompt_ids = tokenizer.apply_chat_template(messages, **template_kwargs)
+ prompt_token_count = len(prompt_ids)
+ return GenerationOutput(
+ text=text,
+ tokens=output.tokens,
+ prompt_tokens=prompt_token_count,
+ completion_tokens=len(output.tokens),
+ finish_reason=output.finish_reason,
+ )
async def stream_chat(
self,
@@ -557,53 +611,53 @@ async def stream_chat(
# For MLLM, use stream_chat which yields tokens incrementally.
# Must hold _generation_lock to prevent concurrent Metal access
# (e.g. OpenCode sends title + main request simultaneously).
- async with self._generation_lock:
- accumulated_text = ""
- token_count = 0
-
- # Run stream_chat in thread pool since it's synchronous
- def run_stream():
- local_kwargs = dict(kwargs)
- if chat_template_kwargs:
- local_kwargs["chat_template_kwargs"] = chat_template_kwargs
- return list(
- self._model.stream_chat(
- messages=messages,
- max_tokens=max_tokens,
- temperature=temperature,
- tools=template_tools,
- **local_kwargs,
- )
+ accumulated_text = ""
+ token_count = 0
+
+ # Run stream_chat in thread pool since it's synchronous
+ def run_stream():
+ local_kwargs = dict(kwargs)
+ if chat_template_kwargs:
+ local_kwargs["chat_template_kwargs"] = chat_template_kwargs
+ return list(
+ self._model.stream_chat(
+ messages=messages,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ tools=template_tools,
+ **local_kwargs,
)
+ )
- chunks = await asyncio.to_thread(run_stream)
+ chunks = await self._run_blocking_serialized(run_stream)
- for chunk in chunks:
- token_count += 1
- new_text = chunk.text if hasattr(chunk, "text") else str(chunk)
- accumulated_text += new_text
+ for chunk in chunks:
+ token_count += 1
+ new_text = chunk.text if hasattr(chunk, "text") else str(chunk)
+ accumulated_text += new_text
- finished = chunk.finish_reason is not None
+ finished = chunk.finish_reason is not None
- yield GenerationOutput(
- text=accumulated_text,
- new_text=new_text,
- prompt_tokens=getattr(chunk, "prompt_tokens", 0),
- completion_tokens=token_count,
- finished=finished,
- finish_reason=chunk.finish_reason if finished else None,
- )
+ yield GenerationOutput(
+ text=accumulated_text,
+ new_text=new_text,
+ prompt_tokens=getattr(chunk, "prompt_tokens", 0),
+ completion_tokens=token_count,
+ finished=finished,
+ finish_reason=chunk.finish_reason if finished else None,
+ )
- if finished:
- break
+ if finished:
+ break
return
# For LLM, apply chat template and stream
tokenizer = self._model.tokenizer
if hasattr(tokenizer, "apply_chat_template"):
- # Disable thinking mode for coder models since it interferes
- # with tool call parsing (tags leak as raw text).
- enable_thinking = "coder" not in self._model_name.lower()
+ # Per-request enable_thinking override; default: True unless coder model.
+ enable_thinking = kwargs.pop("enable_thinking", None)
+ if enable_thinking is None:
+ enable_thinking = "coder" not in self._model_name.lower()
template_kwargs = {
"tokenize": False,
"add_generation_prompt": True,
@@ -661,129 +715,125 @@ async def _stream_generate_specprefill(
tokenizer = self._model.tokenizer
n_tokens = len(tokens)
- async with self._generation_lock:
-
- def _run_all():
- try:
- return _run_specprefill()
- except Exception as e:
- logger.error(
- "SpecPrefill failed, falling back to normal path: %s", e
- )
- return _run_normal()
+ def _run_all():
+ try:
+ return _run_specprefill()
+ except Exception as e:
+ logger.error("SpecPrefill failed, falling back to normal path: %s", e)
+ return _run_normal()
+
+ def _run_specprefill():
+ """Score tokens, sparse prefill, generate autoregressively."""
+ import time
+ from types import SimpleNamespace
+
+ from ..specprefill import (
+ cleanup_rope,
+ score_tokens,
+ select_chunks,
+ sparse_prefill,
+ )
- def _run_specprefill():
- """Score tokens, sparse prefill, generate autoregressively."""
- import time
- from types import SimpleNamespace
+ cache = make_prompt_cache(model)
- from ..specprefill import (
- cleanup_rope,
- score_tokens,
- select_chunks,
- sparse_prefill,
+ try:
+ # Phase 1: Score with draft model
+ t0 = time.monotonic()
+ importance = score_tokens(
+ self._draft_model,
+ tokens,
+ prefill_step_size=self._prefill_step_size,
)
+ t_score = time.monotonic() - t0
- cache = make_prompt_cache(model)
+ # Phase 2: Select important chunks
+ effective_keep = specprefill_keep_pct or self._specprefill_keep_pct
+ selected = select_chunks(importance, keep_pct=effective_keep)
+ n_selected = selected.shape[0]
- try:
- # Phase 1: Score with draft model
- t0 = time.monotonic()
- importance = score_tokens(
- self._draft_model,
- tokens,
- prefill_step_size=self._prefill_step_size,
- )
- t_score = time.monotonic() - t0
-
- # Phase 2: Select important chunks
- effective_keep = specprefill_keep_pct or self._specprefill_keep_pct
- selected = select_chunks(importance, keep_pct=effective_keep)
- n_selected = selected.shape[0]
-
- # Phase 3: Sparse prefill on target model
- t0 = time.monotonic()
- logits = sparse_prefill(
- model,
- tokens,
- selected,
- cache,
- step_size=self._prefill_step_size,
- )
- t_prefill = time.monotonic() - t0
+ # Phase 3: Sparse prefill on target model
+ t0 = time.monotonic()
+ logits = sparse_prefill(
+ model,
+ tokens,
+ selected,
+ cache,
+ step_size=self._prefill_step_size,
+ )
+ t_prefill = time.monotonic() - t0
- logger.info(
- "SpecPrefill: scored %d tokens in %.1fs, "
- "sparse prefill %d/%d (keep=%.0f%%) in %.1fs",
- n_tokens,
- t_score,
- n_selected,
- n_tokens,
- n_selected / n_tokens * 100,
- t_prefill,
- )
+ logger.info(
+ "SpecPrefill: scored %d tokens in %.1fs, "
+ "sparse prefill %d/%d (keep=%.0f%%) in %.1fs",
+ n_tokens,
+ t_score,
+ n_selected,
+ n_tokens,
+ n_selected / n_tokens * 100,
+ t_prefill,
+ )
- # Phase 4: Generate (simple autoregressive, no MTP)
- sampler = make_sampler(temp=temperature, top_p=top_p)
- eos_id = tokenizer.eos_token_id
- y = sampler(logits[:, -1, :])
- mx.eval(y)
+ # Phase 4: Generate (simple autoregressive, no MTP)
+ sampler = make_sampler(temp=temperature, top_p=top_p)
+ eos_id = tokenizer.eos_token_id
+ y = sampler(logits[:, -1, :])
+ mx.eval(y)
- results = []
- generated_ids = []
- prev_decoded = ""
+ results = []
+ generated_ids = []
+ prev_decoded = ""
- for _ in range(max_tokens):
- tok_id = y.item()
- generated_ids.append(tok_id)
+ for _ in range(max_tokens):
+ tok_id = y.item()
+ generated_ids.append(tok_id)
- decoded = tokenizer.decode(generated_ids)
- new_text = decoded[len(prev_decoded) :]
- prev_decoded = decoded
+ decoded = tokenizer.decode(generated_ids)
+ new_text = decoded[len(prev_decoded) :]
+ prev_decoded = decoded
- is_eos = tok_id == eos_id
- results.append(
- SimpleNamespace(
- text=new_text,
- finish_reason="stop" if is_eos else None,
- )
+ is_eos = tok_id == eos_id
+ results.append(
+ SimpleNamespace(
+ text=new_text,
+ finish_reason="stop" if is_eos else None,
)
+ )
- if is_eos:
- break
+ if is_eos:
+ break
- logits = model(y.reshape(1, -1), cache=cache)
- y = sampler(logits[:, -1, :])
- mx.eval(y)
+ logits = model(y.reshape(1, -1), cache=cache)
+ y = sampler(logits[:, -1, :])
+ mx.eval(y)
- return results
+ return results
- finally:
- cleanup_rope(model)
+ finally:
+ cleanup_rope(model)
- def _run_normal():
- """Fallback: normal generation without specprefill."""
- from types import SimpleNamespace
+ def _run_normal():
+ """Fallback: normal generation without specprefill."""
+ from types import SimpleNamespace
- results = []
- for chunk in self._model.stream_generate(
- prompt=prompt,
- max_tokens=max_tokens,
- temperature=temperature,
- top_p=top_p,
- stop=stop,
- **kwargs,
- ):
- new_text = chunk.text if hasattr(chunk, "text") else str(chunk)
- results.append(
- SimpleNamespace(
- text=new_text,
- finish_reason=getattr(chunk, "finish_reason", None),
- )
+ results = []
+ for chunk in self._model.stream_generate(
+ prompt=prompt,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ stop=stop,
+ **kwargs,
+ ):
+ new_text = chunk.text if hasattr(chunk, "text") else str(chunk)
+ results.append(
+ SimpleNamespace(
+ text=new_text,
+ finish_reason=getattr(chunk, "finish_reason", None),
)
- return results
+ )
+ return results
- all_resps = await asyncio.to_thread(_run_all)
+ all_resps = await self._run_blocking_serialized(_run_all)
# Yield results as GenerationOutput
accumulated_text = ""
@@ -850,9 +900,11 @@ async def _stream_generate_text(
specprefill_keep_pct = kwargs.pop("specprefill_keep_pct", None)
chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {})
- # Read enable_thinking from env (set by runtime_patches, consistent with MLLM path)
- enable_thinking_env = os.environ.get("VLLM_MLX_ENABLE_THINKING", "true")
- enable_thinking = enable_thinking_env.lower() in ("true", "1", "yes")
+ # Per-request enable_thinking override; fall back to env var / default True.
+ enable_thinking = kwargs.pop("enable_thinking", None)
+ if enable_thinking is None:
+ enable_thinking_env = os.environ.get("VLLM_MLX_ENABLE_THINKING", "true")
+ enable_thinking = enable_thinking_env.lower() in ("true", "1", "yes")
# Apply chat template for full prompt
template_kwargs = {
@@ -1026,194 +1078,192 @@ async def _stream_generate_text(
)
use_specprefill = False
- # Run under generation lock, all Metal ops in single thread
- async with self._generation_lock:
+ # Run all Metal ops in a single serialized thread.
+ def _run_all():
+ nonlocal backbone_cache, prompt_to_send
- def _run_all():
- nonlocal backbone_cache, prompt_to_send
+ model = self._text_model
- model = self._text_model
+ # Cache MISS with valid prefix: prefill system tokens and snapshot
+ if (
+ not cache_hit
+ and system_token_count > 0
+ and system_tokens is not None
+ and suffix_tokens is not None
+ ):
+ mc = make_prompt_cache(model)
+ sys_arr = mx.array(system_tokens)
+
+ # Prefill system tokens in chunks (matching generate_step)
+ step = self._prefill_step_size
+ while sys_arr.size > step:
+ model(sys_arr[:step][None], cache=mc)
+ mx.eval([c.state for c in mc])
+ sys_arr = sys_arr[step:]
+ mx.clear_cache()
+ if sys_arr.size > 0:
+ model(sys_arr[None], cache=mc)
+ mx.eval([c.state for c in mc])
+
+ # Snapshot backbone cache (immutable mx.arrays, safe to reuse)
+ snapshot = [c.state for c in mc]
+ mx.eval([s for pair in snapshot for s in pair])
+
+ self._system_kv_snapshot = snapshot
+ self._system_kv_hash = system_hash
+ self._system_kv_token_count = system_token_count
+
+ backbone_cache = mc
+ prompt_to_send = mx.array(suffix_tokens)
+ logger.info(
+ "System KV cache: stored %d-token snapshot (%.1f MB), "
+ "prefilling %d remaining",
+ system_token_count,
+ sum(c.nbytes for c in mc) / 1e6,
+ len(suffix_tokens),
+ )
- # Cache MISS with valid prefix: prefill system tokens and snapshot
- if (
- not cache_hit
- and system_token_count > 0
- and system_tokens is not None
- and suffix_tokens is not None
- ):
- mc = make_prompt_cache(model)
- sys_arr = mx.array(system_tokens)
-
- # Prefill system tokens in chunks (matching generate_step)
- step = self._prefill_step_size
- while sys_arr.size > step:
- model(sys_arr[:step][None], cache=mc)
- mx.eval([c.state for c in mc])
- sys_arr = sys_arr[step:]
- mx.clear_cache()
- if sys_arr.size > 0:
- model(sys_arr[None], cache=mc)
- mx.eval([c.state for c in mc])
-
- # Snapshot backbone cache (immutable mx.arrays, safe to reuse)
- snapshot = [c.state for c in mc]
- mx.eval([s for pair in snapshot for s in pair])
-
- self._system_kv_snapshot = snapshot
- self._system_kv_hash = system_hash
- self._system_kv_token_count = system_token_count
-
- backbone_cache = mc
- prompt_to_send = mx.array(suffix_tokens)
- logger.info(
- "System KV cache: stored %d-token snapshot (%.1f MB), "
- "prefilling %d remaining",
- system_token_count,
- sum(c.nbytes for c in mc) / 1e6,
- len(suffix_tokens),
+ # --- SpecPrefill path (with fallback to normal on failure) ---
+ if use_specprefill:
+ try:
+ return _run_specprefill(model, backbone_cache)
+ except Exception as e:
+ logger.error(
+ "SpecPrefill failed, falling back to normal MTP path: %s",
+ e,
)
+ # Discard potentially corrupted cache
+ backbone_cache = None
+ prompt_to_send = full_prompt
+
+ # --- Normal path (MTP via mlx_lm stream_generate) ---
+ prompt_cache = None
+ if backbone_cache is not None:
+ # Add MTP cache on top of backbone
+ if hasattr(model, "make_mtp_cache"):
+ mtp_cache = model.make_mtp_cache()
+ prompt_cache = backbone_cache + mtp_cache
+ else:
+ prompt_cache = backbone_cache
- # --- SpecPrefill path (with fallback to normal on failure) ---
- if use_specprefill:
- try:
- return _run_specprefill(model, backbone_cache)
- except Exception as e:
- logger.error(
- "SpecPrefill failed, falling back to normal MTP path: %s",
- e,
- )
- # Discard potentially corrupted cache
- backbone_cache = None
- prompt_to_send = full_prompt
-
- # --- Normal path (MTP via mlx_lm stream_generate) ---
- prompt_cache = None
- if backbone_cache is not None:
- # Add MTP cache on top of backbone
- if hasattr(model, "make_mtp_cache"):
- mtp_cache = model.make_mtp_cache()
- prompt_cache = backbone_cache + mtp_cache
- else:
- prompt_cache = backbone_cache
+ results = []
+ gen_kwargs = dict(
+ max_tokens=max_tokens,
+ sampler=sampler,
+ mtp=True,
+ prefill_step_size=self._prefill_step_size,
+ )
+ if prompt_cache is not None:
+ gen_kwargs["prompt_cache"] = prompt_cache
+
+ for resp in mlx_stream_generate(
+ model,
+ self._text_tokenizer,
+ prompt=prompt_to_send,
+ **gen_kwargs,
+ ):
+ results.append(resp)
+ return results
+
+ def _run_specprefill(model, bc):
+ """Score tokens, sparse prefill, generate without MTP."""
+ from types import SimpleNamespace
+
+ from ..specprefill import (
+ cleanup_rope,
+ score_tokens,
+ select_chunks,
+ sparse_prefill,
+ )
- results = []
- gen_kwargs = dict(
- max_tokens=max_tokens,
- sampler=sampler,
- mtp=True,
+ # Create backbone cache if not already from system KV
+ if bc is None:
+ bc = make_prompt_cache(model)
+
+ try:
+ # Phase 1: Score with draft model
+ import time
+
+ t0 = time.monotonic()
+ importance = score_tokens(
+ self._draft_model,
+ specprefill_tokens,
prefill_step_size=self._prefill_step_size,
)
- if prompt_cache is not None:
- gen_kwargs["prompt_cache"] = prompt_cache
+ t_score = time.monotonic() - t0
- for resp in mlx_stream_generate(
- model,
- self._text_tokenizer,
- prompt=prompt_to_send,
- **gen_kwargs,
- ):
- results.append(resp)
- return results
+ # Phase 2: Select important chunks
+ effective_keep = specprefill_keep_pct or self._specprefill_keep_pct
+ selected = select_chunks(importance, keep_pct=effective_keep)
+ n_selected = selected.shape[0]
+ n_total = len(specprefill_tokens)
- def _run_specprefill(model, bc):
- """Score tokens, sparse prefill, generate without MTP."""
- from types import SimpleNamespace
+ # Phase 3: Sparse prefill on target model
+ t0 = time.monotonic()
+ logits = sparse_prefill(
+ model,
+ specprefill_tokens,
+ selected,
+ bc,
+ step_size=self._prefill_step_size,
+ position_offset=specprefill_offset,
+ )
+ t_prefill = time.monotonic() - t0
- from ..specprefill import (
- cleanup_rope,
- score_tokens,
- select_chunks,
- sparse_prefill,
+ logger.info(
+ "SpecPrefill: scored %d tokens in %.1fs, "
+ "sparse prefill %d/%d (keep=%.0f%%) in %.1fs "
+ "(offset=%d, effective_keep=%.2f)",
+ n_total,
+ t_score,
+ n_selected,
+ n_total,
+ n_selected / n_total * 100,
+ t_prefill,
+ specprefill_offset,
+ effective_keep,
)
- # Create backbone cache if not already from system KV
- if bc is None:
- bc = make_prompt_cache(model)
+ # Phase 4: Generate (simple autoregressive, no MTP)
+ eos_id = self._text_tokenizer.eos_token_id
+ y = sampler(logits[:, -1, :])
+ mx.eval(y)
- try:
- # Phase 1: Score with draft model
- import time
-
- t0 = time.monotonic()
- importance = score_tokens(
- self._draft_model,
- specprefill_tokens,
- prefill_step_size=self._prefill_step_size,
- )
- t_score = time.monotonic() - t0
-
- # Phase 2: Select important chunks
- effective_keep = specprefill_keep_pct or self._specprefill_keep_pct
- selected = select_chunks(importance, keep_pct=effective_keep)
- n_selected = selected.shape[0]
- n_total = len(specprefill_tokens)
-
- # Phase 3: Sparse prefill on target model
- t0 = time.monotonic()
- logits = sparse_prefill(
- model,
- specprefill_tokens,
- selected,
- bc,
- step_size=self._prefill_step_size,
- position_offset=specprefill_offset,
- )
- t_prefill = time.monotonic() - t0
+ results = []
+ generated_ids = []
+ prev_decoded = ""
- logger.info(
- "SpecPrefill: scored %d tokens in %.1fs, "
- "sparse prefill %d/%d (keep=%.0f%%) in %.1fs "
- "(offset=%d, effective_keep=%.2f)",
- n_total,
- t_score,
- n_selected,
- n_total,
- n_selected / n_total * 100,
- t_prefill,
- specprefill_offset,
- effective_keep,
- )
+ for _ in range(max_tokens):
+ tok_id = y.item()
+ generated_ids.append(tok_id)
- # Phase 4: Generate (simple autoregressive, no MTP)
- eos_id = self._text_tokenizer.eos_token_id
- y = sampler(logits[:, -1, :])
- mx.eval(y)
+ # Incremental text decode
+ decoded = self._text_tokenizer.decode(generated_ids)
+ new_text = decoded[len(prev_decoded) :]
+ prev_decoded = decoded
- results = []
- generated_ids = []
- prev_decoded = ""
-
- for _ in range(max_tokens):
- tok_id = y.item()
- generated_ids.append(tok_id)
-
- # Incremental text decode
- decoded = self._text_tokenizer.decode(generated_ids)
- new_text = decoded[len(prev_decoded) :]
- prev_decoded = decoded
-
- is_eos = tok_id == eos_id
- results.append(
- SimpleNamespace(
- text=new_text,
- finish_reason="stop" if is_eos else None,
- )
+ is_eos = tok_id == eos_id
+ results.append(
+ SimpleNamespace(
+ text=new_text,
+ finish_reason="stop" if is_eos else None,
)
+ )
- if is_eos:
- break
+ if is_eos:
+ break
- # Next token
- logits = model(y.reshape(1, -1), cache=bc)
- y = sampler(logits[:, -1, :])
- mx.eval(y)
+ # Next token
+ logits = model(y.reshape(1, -1), cache=bc)
+ y = sampler(logits[:, -1, :])
+ mx.eval(y)
- return results
+ return results
- finally:
- cleanup_rope(model)
+ finally:
+ cleanup_rope(model)
- all_resps = await asyncio.to_thread(_run_all)
+ all_resps = await self._run_blocking_serialized(_run_all)
# Yield results as GenerationOutput
accumulated_text = ""
diff --git a/vllm_mlx/engine_core.py b/vllm_mlx/engine_core.py
index d21928824..ae75fd39e 100644
--- a/vllm_mlx/engine_core.py
+++ b/vllm_mlx/engine_core.py
@@ -36,6 +36,7 @@ class EngineConfig:
scheduler_config: Optional[SchedulerConfig] = None
step_interval: float = 0.001 # 1ms between steps
stream_interval: int = 1 # Tokens to batch before streaming (1=every token)
+ gpu_memory_utilization: float = 0.90 # Fraction of device memory for allocation
class EngineCore:
@@ -150,18 +151,12 @@ async def _engine_loop(self) -> None:
stream_interval = self.config.stream_interval
use_simple_streaming = stream_interval == 1
- # Emergency memory pressure threshold — use 85% of Metal's
- # max recommended working set so this scales with system RAM.
+ # Emergency memory pressure threshold — dynamic based on gpu_memory_utilization
+ _gpu_mem_util = self.config.gpu_memory_utilization
try:
- _device_info = mx.device_info()
- _max_recommended = _device_info.get(
- "max_recommended_working_set_size",
- _device_info.get("memory_size", 0),
- )
- _memory_pressure_threshold = (
- int(_max_recommended * 0.85)
- if _max_recommended > 0
- else 200 * 1024 * 1024 * 1024
+ _device_mem = mx.device_info().get("memory_size", 200 * 1024 * 1024 * 1024)
+ _memory_pressure_threshold = int(
+ _device_mem * min(_gpu_mem_util + 0.05, 0.99)
)
except Exception:
_memory_pressure_threshold = 200 * 1024 * 1024 * 1024
diff --git a/vllm_mlx/memory_cache.py b/vllm_mlx/memory_cache.py
index f43763541..2668c3cec 100644
--- a/vllm_mlx/memory_cache.py
+++ b/vllm_mlx/memory_cache.py
@@ -255,44 +255,121 @@ def create(cls, tokens: list[int], cache: list[Any]) -> _CacheEntry:
def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]:
- """Create shallow copies of KVCache/QuantizedKVCache layers with offset reduced.
+ """Create copies of cache layers with the last ``trim_by`` positions removed.
This is used when returning a cached KV state to the scheduler so that
the last N positions are "freed" and the model will recompute them on the
next forward pass (preventing duplicate KV entries).
- Supports both KVCache (keys/values are arrays) and QuantizedKVCache
- (keys/values are 3-tuples of arrays).
- """
- from mlx_lm.models.cache import KVCache
+ For plain KVCache: reduces offset (surplus data beyond offset is harmless
+ since merge slices to ``keys[:, :, :offset, :]``).
- try:
- from mlx_lm.models.cache import QuantizedKVCache
- except ImportError:
- QuantizedKVCache = None # noqa: N806
+ For RotatingKVCache: actually trims the circular buffer — reducing offset
+ alone breaks ``size()`` / ``_temporal_order`` invariants.
+
+ Supports KVCache, RotatingKVCache, and _QuantizedCacheWrapper.
+ """
+ import mlx.core as mx
+ from mlx_lm.models.cache import RotatingKVCache
trimmed: list[Any] = []
+ eval_targets: list[Any] = []
for layer_cache in cache:
- if QuantizedKVCache is not None and isinstance(layer_cache, QuantizedKVCache):
- tc = QuantizedKVCache.__new__(QuantizedKVCache)
+ if isinstance(layer_cache, _QuantizedCacheWrapper):
+ # Shallow copy with reduced offset
+ tc = _QuantizedCacheWrapper.__new__(_QuantizedCacheWrapper)
tc.keys = layer_cache.keys
tc.values = layer_cache.values
tc.offset = max(layer_cache.offset - trim_by, 0)
- tc.group_size = layer_cache.group_size
tc.bits = layer_cache.bits
+ tc.group_size = layer_cache.group_size
+ tc.orig_type = layer_cache.orig_type
+ tc.orig_attrs = layer_cache.orig_attrs
+ trimmed.append(tc)
+ elif isinstance(layer_cache, RotatingKVCache):
+ if layer_cache.keys is None or trim_by <= 0:
+ trimmed.append(layer_cache)
+ continue
+ # RotatingKVCache: must trim buffer, not just offset.
+ # The buffer stores the last min(offset, max_size) tokens in a
+ # circular arrangement. Trimming excess positions from the END
+ # means removing the newest entries (chronologically last).
+ old_offset = layer_cache.offset
+ new_offset = max(old_offset - trim_by, 0)
+ old_size = min(old_offset, layer_cache.max_size)
+ entries_to_keep = max(0, old_size - trim_by)
+
+ orig_cls = type(layer_cache)
+ tc = orig_cls.__new__(orig_cls)
+ tc.offset = new_offset
+ tc.max_size = layer_cache.max_size
+ tc.keep = getattr(layer_cache, "keep", 0)
+ tc.step = getattr(layer_cache, "step", layer_cache.max_size)
+
+ if entries_to_keep <= 0:
+ # All buffer content is beyond the trim point — clear
+ tc.keys = None
+ tc.values = None
+ tc._idx = 0
+ elif entries_to_keep < old_size:
+ # Reorder to temporal order, keep the oldest entries
+ ordered_k = layer_cache._temporal_order(layer_cache.keys)
+ ordered_v = layer_cache._temporal_order(layer_cache.values)
+ kept_k = ordered_k[:, :, :entries_to_keep, :]
+ kept_v = ordered_v[:, :, :entries_to_keep, :]
+
+ if new_offset >= tc.max_size:
+ # Invariant: when offset >= max_size, buffer must be
+ # full (keys.shape[2] == max_size). Left-pad with
+ # zeros to restore the full buffer. Zeros represent
+ # positions evicted long ago; _idx = max_size so
+ # _temporal_order returns as-is and _update_in_place
+ # rotates to overwrite zeros first.
+ pad_n = tc.max_size - entries_to_keep
+ pad_k = mx.zeros(
+ (kept_k.shape[0], kept_k.shape[1], pad_n, kept_k.shape[3]),
+ dtype=kept_k.dtype,
+ )
+ pad_v = mx.zeros(
+ (kept_v.shape[0], kept_v.shape[1], pad_n, kept_v.shape[3]),
+ dtype=kept_v.dtype,
+ )
+ tc.keys = mx.concatenate([pad_k, kept_k], axis=2)
+ tc.values = mx.concatenate([pad_v, kept_v], axis=2)
+ tc._idx = tc.max_size
+ else:
+ tc.keys = kept_k
+ tc.values = kept_v
+ tc._idx = entries_to_keep
+ eval_targets.extend([tc.keys, tc.values])
+ else:
+ # No entries removed (trim_by == 0 already handled above,
+ # this covers entries_to_keep == old_size edge case)
+ tc.keys = layer_cache.keys
+ tc.values = layer_cache.values
+ tc._idx = layer_cache._idx
trimmed.append(tc)
elif (
hasattr(layer_cache, "offset")
and hasattr(layer_cache, "keys")
and not isinstance(layer_cache.keys, (list, tuple))
):
- tc = KVCache.__new__(KVCache)
+ orig_cls = type(layer_cache)
+ tc = orig_cls.__new__(orig_cls)
tc.keys = layer_cache.keys
tc.values = layer_cache.values
tc.offset = max(layer_cache.offset - trim_by, 0)
+ # Preserve type-specific attrs (max_size, keep, step, _idx)
+ for attr in ("max_size", "keep", "step", "_idx"):
+ if hasattr(layer_cache, attr):
+ setattr(tc, attr, getattr(layer_cache, attr))
trimmed.append(tc)
else:
trimmed.append(layer_cache)
+
+ if eval_targets:
+ mx.eval(*eval_targets)
+
return trimmed
@@ -353,28 +430,72 @@ def _trim_to_offset(cache: list[Any]) -> list[Any]:
return trimmed
+class _QuantizedCacheWrapper:
+ """Lightweight wrapper storing quantized KV arrays + original cache metadata.
+
+ Unlike ``QuantizedKVCache``, this preserves enough info to reconstruct
+ the *original* cache type (KVCache, RotatingKVCache, etc.) on dequantize.
+ """
+
+ __slots__ = (
+ "keys",
+ "values",
+ "offset",
+ "bits",
+ "group_size",
+ "orig_type",
+ "orig_attrs",
+ )
+
+ def __init__(self, layer: Any, bits: int, group_size: int):
+ import mlx.core as mx
+
+ self.keys = mx.quantize(layer.keys, group_size=group_size, bits=bits)
+ self.values = mx.quantize(layer.values, group_size=group_size, bits=bits)
+ self.offset = layer.offset
+ self.bits = bits
+ self.group_size = group_size
+ self.orig_type = type(layer)
+ # Preserve RotatingKVCache-specific attrs
+ self.orig_attrs = {}
+ for attr in ("max_size", "keep", "step", "_idx"):
+ if hasattr(layer, attr):
+ self.orig_attrs[attr] = getattr(layer, attr)
+
+
def _quantize_cache(cache: list[Any], bits: int = 8, group_size: int = 64) -> list[Any]:
- """Quantize KVCache layers to reduce memory. Non-KVCache layers are kept as-is."""
+ """Quantize KV cache layers to reduce memory.
+
+ Only plain KVCache layers are quantized. RotatingKVCache (sliding window)
+ is left as-is because its internal _idx/rotation state is tightly coupled
+ with update_and_fetch logic and cannot survive quantize/dequantize roundtrip.
+ RotatingKVCache is typically small (max_size=1024) so skipping it is fine.
+ """
from mlx_lm.models.cache import KVCache
quantized = []
for layer in cache:
- if isinstance(layer, KVCache) and layer.keys is not None:
- quantized.append(layer.to_quantized(group_size=group_size, bits=bits))
+ if type(layer) is KVCache and getattr(layer, "keys", None) is not None:
+ quantized.append(_QuantizedCacheWrapper(layer, bits, group_size))
else:
quantized.append(layer)
return quantized
def _dequantize_cache(cache: list[Any]) -> list[Any]:
- """Dequantize QuantizedKVCache layers back to regular KVCache."""
+ """Dequantize _QuantizedCacheWrapper layers and copy non-quantized layers.
+
+ All layers are copied (never returned by reference) so that the model's
+ ``update_and_fetch`` mutations don't corrupt the stored cache entry.
+ """
import mlx.core as mx
- from mlx_lm.models.cache import KVCache, QuantizedKVCache
result = []
for layer in cache:
- if isinstance(layer, QuantizedKVCache) and layer.keys is not None:
- kv = KVCache()
+ if isinstance(layer, _QuantizedCacheWrapper):
+ # Reconstruct original cache type from quantized data
+ orig_cls = layer.orig_type
+ kv = orig_cls.__new__(orig_cls)
kv.keys = mx.dequantize(
*layer.keys, group_size=layer.group_size, bits=layer.bits
)
@@ -382,6 +503,21 @@ def _dequantize_cache(cache: list[Any]) -> list[Any]:
*layer.values, group_size=layer.group_size, bits=layer.bits
)
kv.offset = layer.offset
+ # Restore type-specific attrs (max_size, keep, step, _idx)
+ for attr, val in layer.orig_attrs.items():
+ setattr(kv, attr, val)
+ result.append(kv)
+ elif hasattr(layer, "keys") and hasattr(layer, "offset"):
+ # Deep-copy non-quantized cache layers (e.g. RotatingKVCache)
+ # so model's in-place mutations don't corrupt stored entries
+ orig_cls = type(layer)
+ kv = orig_cls.__new__(orig_cls)
+ kv.keys = mx.array(layer.keys) if layer.keys is not None else None
+ kv.values = mx.array(layer.values) if layer.values is not None else None
+ kv.offset = layer.offset
+ for attr in ("max_size", "keep", "step", "_idx"):
+ if hasattr(layer, attr):
+ setattr(kv, attr, getattr(layer, attr))
result.append(kv)
else:
result.append(layer)
@@ -635,7 +771,15 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]:
f"layer_types={[type(lc).__name__ for lc in best_lcp_entry.cache[:3]]}"
)
- if not has_non_trimmable:
+ if has_non_trimmable:
+ # Hybrid model (SSM+Attention): SSM state can't be rewound.
+ # Block LCP for hybrid models — use think-suffix stripping
+ # in the engine layer to get clean PREFIX matches instead.
+ logger.debug(
+ "[cache_fetch] LCP skipped: non-trimmable cache layers "
+ "(hybrid model, SSM state can't be rewound)"
+ )
+ else:
trimmed_cache = _trim_cache_offset(best_lcp_entry.cache, excess)
self._entries.move_to_end(best_lcp_entry.tokens)
self._stats.hits += 1
diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py
index ee8d8da7b..a6a59afba 100644
--- a/vllm_mlx/mllm_batch_generator.py
+++ b/vllm_mlx/mllm_batch_generator.py
@@ -24,12 +24,21 @@
import mlx.core as mx
import mlx.nn as nn
+from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig, _trim_cache_offset
from .multimodal_processor import MultimodalProcessor
from .vision_embedding_cache import VisionEmbeddingCache
logger = logging.getLogger(__name__)
+class PrefillAbortedError(Exception):
+ """Raised when a prefill is aborted due to client disconnect."""
+
+ def __init__(self, request_id: str):
+ self.request_id = request_id
+ super().__init__(f"Prefill aborted for request {request_id}")
+
+
@dataclass
class MLLMBatchRequest:
"""
@@ -47,6 +56,10 @@ class MLLMBatchRequest:
max_tokens: int = 256
temperature: float = 0.7
top_p: float = 0.9
+ top_k: int = 0
+ min_p: float = 0.0
+ presence_penalty: float = 0.0
+ repetition_penalty: float = 1.0
# Processed inputs (set after vision preprocessing)
input_ids: Optional[mx.array] = None
@@ -55,6 +68,9 @@ class MLLMBatchRequest:
image_grid_thw: Optional[mx.array] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
+ # Text-only flag (no images/videos — eligible for prefix cache)
+ is_text_only: bool = False
+
# Generation state
num_tokens: int = 0 # Tokens generated so far
output_tokens: List[int] = field(default_factory=list)
@@ -98,6 +114,8 @@ class MLLMBatch:
num_tokens: List[int] # Tokens generated per request
cache: List[Any] # BatchKVCache for language model
requests: List[MLLMBatchRequest] # Full request data
+ logits_processors: Optional[List[Optional[List[Callable]]]] = None
+ samplers: Optional[List[Optional[Callable]]] = None
def __len__(self) -> int:
return len(self.uids)
@@ -115,6 +133,10 @@ def filter(self, keep_idx: List[int]) -> None:
self.max_tokens = [self.max_tokens[k] for k in keep_idx]
self.num_tokens = [self.num_tokens[k] for k in keep_idx]
self.requests = [self.requests[k] for k in keep_idx]
+ if self.logits_processors is not None:
+ self.logits_processors = [self.logits_processors[k] for k in keep_idx]
+ if self.samplers is not None:
+ self.samplers = [self.samplers[k] for k in keep_idx]
keep_idx_array = mx.array(keep_idx, mx.int32)
self.y = self.y[keep_idx_array]
@@ -139,32 +161,73 @@ def extend(self, other: "MLLMBatch") -> None:
self.max_tokens.extend(other.max_tokens)
self.requests.extend(other.requests)
- # Extend cache - handle None and incompatible caches
+ # Extend logits_processors
+ if self.logits_processors is not None or other.logits_processors is not None:
+ # At this point self.uids already includes other.uids from extend above
+ self_len = len(self.uids) - len(other.uids)
+ self_lp = self.logits_processors or [None] * self_len
+ other_lp = other.logits_processors or [None] * len(other.uids)
+ self.logits_processors = list(self_lp) + list(other_lp)
+
+ # Extend samplers
+ if self.samplers is not None or other.samplers is not None:
+ self_len = len(self.uids) - len(other.uids)
+ self_s = self.samplers or [None] * self_len
+ other_s = other.samplers or [None] * len(other.uids)
+ self.samplers = list(self_s) + list(other_s)
+
+ # Extend cache - handle both BatchKVCache (.keys/.values) and
+ # ArraysCache (.cache list) from hybrid models like Qwen3.5
for c, o in zip(self.cache, other.cache):
if c is not None and o is not None and hasattr(c, "extend"):
try:
- # Only extend if both caches have valid keys
- if (
- hasattr(c, "keys")
- and c.keys is not None
- and hasattr(o, "keys")
- and o.keys is not None
- ):
+ has_kv = hasattr(c, "keys") and c.keys is not None
+ has_arrays = hasattr(c, "cache")
+ if has_kv or has_arrays:
c.extend(o)
except Exception as e:
logger.warning(f"Failed to extend cache: {e}")
def extract_cache(self, idx: int) -> List[Any]:
"""
- Extract cache for a single request (for caching).
-
- Args:
- idx: Index of request in batch
+ Extract cache for a single request (for prefix caching).
- Returns:
- Cache state for that request
+ Handles BatchRotatingKVCache negative left_padding bug:
+ during generation with rotation, left_padding becomes negative,
+ causing extract() to use Python negative indexing and truncate
+ the buffer to only generation tokens instead of the full window.
"""
- return [c.extract(idx) if hasattr(c, "extract") else None for c in self.cache]
+ from mlx_lm.models.cache import (
+ BatchRotatingKVCache,
+ RotatingKVCache,
+ )
+
+ result = []
+ for c in self.cache:
+ if not hasattr(c, "extract"):
+ result.append(None)
+ elif isinstance(c, BatchRotatingKVCache):
+ # Custom extraction: clamp left_padding to >= 0
+ cache = RotatingKVCache(c.max_size)
+ padding = max(0, c.left_padding[idx].item())
+ offset = c.offset[idx].item()
+ cache.keys = c.keys[idx : idx + 1]
+ cache.values = c.values[idx : idx + 1]
+ cache._idx = c._idx
+ if c.rotated:
+ cache.keys = mx.roll(cache.keys, -c._idx, axis=2)
+ cache.values = mx.roll(cache.values, -c._idx, axis=2)
+ cache._idx = c.max_size
+ cache.keys = mx.contiguous(cache.keys[:, :, padding : cache._idx])
+ cache.values = mx.contiguous(cache.values[:, :, padding : cache._idx])
+ cache.offset = offset
+ cache._idx = cache.keys.shape[2]
+ cache.step = getattr(c, "step", c.max_size)
+ cache.keep = getattr(c, "keep", 0)
+ result.append(cache)
+ else:
+ result.append(c.extract(idx))
+ return result
class MLLMBatchStats:
@@ -205,32 +268,6 @@ def to_dict(self) -> Dict[str, Any]:
}
-def _make_batch_cache(model: nn.Module, left_padding: List[int]) -> List[Any]:
- """
- Create batch-aware KV cache for the language model.
-
- Args:
- model: The language model (model.language_model from VLM)
- left_padding: Padding amounts for left-padded prompts
-
- Returns:
- List of BatchKVCache objects for each layer
- """
- from mlx_lm.models.cache import BatchKVCache, KVCache
-
- def to_batch_cache(c):
- if isinstance(c, KVCache):
- return BatchKVCache(left_padding)
- else:
- raise ValueError(f"{type(c)} does not yet support batching")
-
- if hasattr(model, "make_cache"):
- cache = model.make_cache()
- return [to_batch_cache(c) for c in cache]
- else:
- return [BatchKVCache(left_padding) for _ in model.layers]
-
-
def _left_pad_prompts(
prompts: List[List[int]], max_length: Optional[int] = None
) -> mx.array:
@@ -289,6 +326,7 @@ def __init__(
prefill_step_size: int = 1024,
enable_vision_cache: bool = True,
vision_cache_size: int = 100,
+ prefix_cache_config: Optional[MemoryCacheConfig] = None,
):
"""
Initialize MLLM batch generator.
@@ -305,6 +343,7 @@ def __init__(
prefill_step_size: Tokens to process per prefill step
enable_vision_cache: Enable vision embedding caching
vision_cache_size: Max entries in vision cache
+ prefix_cache_config: Config for KV prefix cache (text-only requests)
"""
self.model = model
self.processor = processor
@@ -324,6 +363,13 @@ def __init__(
"MLLMBatchGenerator: Model does not have language_model, using model directly"
)
+ # Patch attention for BatchKVCache compatibility
+ from .patches.qwen3_5_mllm import patch_qwen35_attention_for_batching
+ from .patches.gemma4_mllm import patch_gemma4_attention_for_batching
+
+ patch_qwen35_attention_for_batching()
+ patch_gemma4_attention_for_batching()
+
self.max_tokens = max_tokens
self.stop_tokens = stop_tokens or set()
self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
@@ -340,6 +386,18 @@ def __init__(
# Statistics
self._stats = MLLMBatchStats()
+ # Error responses for requests that failed during preprocessing
+ self._pending_error_responses: List[MLLMBatchResponse] = []
+
+ # Per-request prefill progress: request_id → (processed_tokens, total_tokens)
+ self._prefill_progress: Dict[str, Tuple[int, int]] = {}
+
+ # Aborted request IDs — checked between prefill chunks to allow
+ # early termination when a client disconnects during long prefill.
+ # Set operations are GIL-protected, safe across event-loop and
+ # executor threads.
+ self._aborted_request_ids: set = set()
+
# Vision embedding cache for repeated images
self.vision_cache = VisionEmbeddingCache(
max_pixel_entries=vision_cache_size,
@@ -351,6 +409,33 @@ def __init__(
f"MLLMBatchGenerator: Vision cache enabled (size={vision_cache_size})"
)
+ # KV prefix cache for text-only requests
+ self.prefix_cache: Optional[MemoryAwarePrefixCache] = None
+ if prefix_cache_config is not None:
+ self.prefix_cache = MemoryAwarePrefixCache(
+ model=self.language_model,
+ config=prefix_cache_config,
+ )
+ logger.info("MLLMBatchGenerator: KV prefix cache enabled")
+
+ # Normalize chat template for prefix-cache stability.
+ # Qwen3.5 chat template retroactively changes formatting of earlier
+ # assistant messages based on last_query_index (position of last
+ # non-tool user message). When a user text message is appended,
+ # last_query_index jumps forward, removing blocks from
+ # earlier assistant turns — shifting tokens mid-sequence and
+ # breaking prefix match. Fix: always use plain format for
+ # historical assistant turns (thinking is still added by the
+ # generation prompt at the end).
+ self._normalize_chat_template_for_prefix_cache()
+
+ # Compute think-suffix length for prefix cache key stripping.
+ # Models with enable_thinking=True add \n to the generation
+ # prompt. This breaks prefix cache (stored key ends with
+ # but next request has actual response at that position).
+ # Stripping the suffix from cache keys enables clean PREFIX match.
+ self._think_suffix_len = self._compute_think_suffix_len()
+
# Generation stream
if MLLMBatchGenerator._stream is None:
MLLMBatchGenerator._stream = mx.new_stream(mx.default_device())
@@ -362,6 +447,132 @@ def __init__(
mx.device_info()["max_recommended_working_set_size"]
)
+ def _normalize_chat_template_for_prefix_cache(self) -> None:
+ """Patch chat template so historical assistant turns are prefix-stable.
+
+ Qwen3.5's chat template computes ``last_query_index`` — the position
+ of the last non-tool-response user message — and conditionally wraps
+ assistant turns after that index in ``...\\n\\n\\n``.
+ When a new user text message is appended, ``last_query_index`` jumps
+ forward, retroactively removing these ```` wrappers from
+ earlier assistant turns. This shifts tokens mid-sequence and breaks
+ prefix cache.
+
+ Fix: replace the conditional with the plain (ELSE) branch so ALL
+ historical assistant messages use ``<|im_start|>assistant\\ncontent``
+ without any injected ```` block. The generation prompt still
+ adds ``\\n`` at the very end, so the model generates thinking.
+ """
+ if self.prefix_cache is None:
+ return # No prefix cache — no need to normalize
+
+ # Find the chat template. VLM processors (e.g. Qwen3VLProcessor)
+ # keep a SEPARATE copy of chat_template from their tokenizer — both
+ # must be patched. The processor's copy is used by
+ # BatchedEngine._apply_chat_template() (text rendering), while the
+ # tokenizer's copy is used by _compute_think_suffix_len().
+ tokenizer = getattr(self.processor, "tokenizer", self.processor)
+ # Prefer the processor's own template (it's the one used for rendering)
+ template = getattr(self.processor, "chat_template", None)
+ if not template:
+ template = getattr(tokenizer, "chat_template", None)
+ if not template or "last_query_index" not in template:
+ return # Not affected
+
+ import re
+
+ # The pattern in Qwen3.5 template:
+ # {%- if loop.index0 > ns.last_query_index %}
+ # {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }}
+ # {%- else %}
+ # {{- '<|im_start|>' + message.role + '\n' + content }}
+ # {%- endif %}
+ #
+ # Replace with just the ELSE branch (always plain format).
+ pattern = (
+ r"\{%-\s*if\s+loop\.index0\s*>\s*ns\.last_query_index\s*%\}"
+ r".*?"
+ r"\{%-\s*else\s*%\}"
+ r"\s*(\{\{-.*?content.*?\}\})"
+ r"\s*\{%-\s*endif\s*%\}"
+ )
+ new_template = re.sub(pattern, r"\1", template, flags=re.DOTALL)
+ if new_template != template:
+ # Patch ALL copies: processor, tokenizer, and any dict variants.
+ if hasattr(self.processor, "chat_template"):
+ self.processor.chat_template = new_template
+ tokenizer.chat_template = new_template
+ logger.info(
+ "[prefix_cache] Normalized chat template: removed "
+ "last_query_index conditional for prefix-stable assistant turns"
+ )
+ else:
+ logger.debug(
+ "[prefix_cache] Chat template has last_query_index but "
+ "regex did not match — template may use a different pattern"
+ )
+
+ def _compute_think_suffix_len(self) -> int:
+ """Compute how many extra tokens enable_thinking=True adds at the END.
+
+ Compares the generation prompt suffix with and without
+ ``enable_thinking`` to find the think-tag suffix length
+ (typically ``\\n`` = 2 tokens for Qwen3/Qwen3.5).
+
+ Returns 0 if the template doesn't support ``enable_thinking``.
+ """
+ try:
+ # Find something with apply_chat_template
+ applicator = None
+ for candidate in [
+ getattr(self.processor, "tokenizer", None),
+ self.processor,
+ ]:
+ if candidate is not None and hasattr(candidate, "apply_chat_template"):
+ applicator = candidate
+ break
+
+ if applicator is None:
+ return 0
+
+ dummy = [{"role": "user", "content": "x"}]
+
+ try:
+ text_with = applicator.apply_chat_template(
+ dummy,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=True,
+ )
+ text_without = applicator.apply_chat_template(
+ dummy,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=False,
+ )
+ except TypeError:
+ return 0
+
+ # Check if enable_thinking adds a known think tag at the end.
+ # enable_thinking may also change the system prompt, so we can't
+ # simply compare lengths — we look at the ending instead.
+ for tag in ["\n", ""]:
+ if text_with.endswith(tag) and not text_without.endswith(tag):
+ tokenizer = getattr(self.processor, "tokenizer", self.processor)
+ suffix_tokens = tokenizer.encode(tag)
+ base_tokens = tokenizer.encode("")
+ suffix_len = len(suffix_tokens) - len(base_tokens)
+ if suffix_len > 0:
+ logger.info(
+ f"[think_suffix] Detected think tag "
+ f"'{tag.strip()}' = {suffix_len} token(s)"
+ )
+ return max(0, suffix_len)
+
+ return 0
+ except Exception:
+ return 0
+
def close(self) -> None:
"""Release resources and reset wired limit."""
if self._old_wired_limit is not None:
@@ -369,6 +580,16 @@ def close(self) -> None:
mx.set_wired_limit(self._old_wired_limit)
self._old_wired_limit = None
+ def abort_prefill(self, request_id: str) -> None:
+ """Signal that a request's prefill should be aborted.
+
+ Called from the event loop thread when a client disconnects.
+ The prefill loop checks this set between chunks and raises
+ PrefillAbortedError to exit early.
+ """
+ self._aborted_request_ids.add(request_id)
+ logger.info(f"[abort_prefill] Marked {request_id} for prefill abort")
+
def __del__(self):
try:
self.close()
@@ -545,12 +766,81 @@ def _preprocess_request(self, request: MLLMBatchRequest) -> None:
self._stats.num_images_processed += len(all_images)
self._stats.vision_encoding_time += processing_time
+ # Mark text-only requests (eligible for prefix cache)
+ request.is_text_only = not bool(all_images)
+
logger.debug(
f"Preprocessed request {request.request_id}: "
f"{len(all_images)} images, {request.input_ids.size if request.input_ids is not None else 0} tokens "
f"({processing_time:.2f}s)"
)
+ def _run_chunked_text_prefill(
+ self, request: MLLMBatchRequest, cache: List[Any]
+ ) -> mx.array:
+ """
+ Run prefill in chunks for text-only requests, reporting real progress.
+
+ Processes input_ids in prefill_step_size chunks through the language
+ model, updating ``_prefill_progress`` after each chunk so the status
+ endpoint can report accurate prefill percentage.
+
+ Returns:
+ Logits from the last chunk (same contract as _run_vision_encoding).
+ """
+ input_ids = request.input_ids
+ if input_ids.ndim == 1:
+ input_ids = input_ids[None, :]
+
+ total = input_ids.shape[1]
+ step = self.prefill_step_size
+
+ # Short prompt — process in one shot (no chunking overhead)
+ if total <= step:
+ self._prefill_progress[request.request_id] = (total, total)
+ output = self.language_model(input_ids, cache=cache)
+ request.vision_encoded = True
+ if hasattr(output, "logits"):
+ return output.logits
+ return output
+
+ # Process all chunks except the last
+ processed = 0
+ chunk_count = 0
+ while processed + step < total:
+ # Check for abort between chunks (client disconnect)
+ if request.request_id in self._aborted_request_ids:
+ self._aborted_request_ids.discard(request.request_id)
+ logger.info(
+ f"[chunked_prefill] Aborted {request.request_id} at "
+ f"{processed}/{total} tokens"
+ )
+ raise PrefillAbortedError(request.request_id)
+
+ chunk = input_ids[:, processed : processed + step]
+ self.language_model(chunk, cache=cache)
+ mx.eval([c.state for c in cache])
+ processed += step
+ chunk_count += 1
+ self._prefill_progress[request.request_id] = (processed, total)
+
+ # Release Metal buffer pool periodically. Full-attention layers
+ # produce attention score buffers that grow each chunk (1024 ×
+ # growing_context). Old smaller buffers can't be reused, so the
+ # pool accumulates O(N²) memory without clearing.
+ if chunk_count % 4 == 0:
+ mx.clear_cache()
+
+ # Last chunk — return logits for sampling
+ last_chunk = input_ids[:, processed:]
+ output = self.language_model(last_chunk, cache=cache)
+ request.vision_encoded = True
+ self._prefill_progress[request.request_id] = (total, total)
+
+ if hasattr(output, "logits"):
+ return output.logits
+ return output
+
def _run_vision_encoding(
self, request: MLLMBatchRequest, cache: Optional[List[Any]] = None
) -> mx.array:
@@ -613,68 +903,305 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch:
tic = time.perf_counter()
- # Preprocess all requests
+ # Preprocess all requests (per-request error handling)
+ failed_requests = []
for req in requests:
- self._preprocess_request(req)
+ try:
+ self._preprocess_request(req)
+ except Exception as e:
+ logger.error(
+ f"Failed to preprocess request {req.request_id}: "
+ f"{type(e).__name__}: {e}"
+ )
+ failed_requests.append(req)
+
+ # Remove failed requests from batch and create error responses
+ if failed_requests:
+ for req in failed_requests:
+ requests.remove(req)
+ self._pending_error_responses.append(
+ MLLMBatchResponse(
+ uid=req.uid,
+ request_id=req.request_id,
+ token=0,
+ logprobs=mx.zeros(1),
+ finish_reason="error",
+ )
+ )
+
+ if not requests:
+ # All requests failed
+ return None
total_prompt_tokens = sum(
req.input_ids.size if req.input_ids is not None else 1 for req in requests
)
self._stats.prompt_tokens += total_prompt_tokens
- # Guard against excessive memory usage during cache merge.
- # Each token in the batch requires KV entries across all layers.
+ # Log large prompts for monitoring (was previously a hard check that
+ # caused infinite retry loops when requests exceeded the limit).
max_batch_tokens = self.prefill_step_size * len(requests)
if total_prompt_tokens > max_batch_tokens:
- raise ValueError(
- f"Total prompt tokens ({total_prompt_tokens}) exceeds safe limit "
- f"({max_batch_tokens}) for {len(requests)} requests. "
- f"Reduce prompt length or batch size."
+ logger.warning(
+ f"Large batch prefill: {total_prompt_tokens} tokens "
+ f"(step_size={self.prefill_step_size}, requests={len(requests)}). "
+ f"Processing may be slow."
)
# Run vision encoding for each request with its own KVCache.
# Vision encoding cannot be batched because each request may have
# different images/pixel values. We pass a per-request KVCache to
# the VLM so the language model writes its KV state directly into it.
+ #
+ # For text-only requests, we check the prefix cache first. If there's
+ # a hit, we skip the full VLM forward and run only the language model
+ # on the remaining (uncached) tokens.
first_tokens = []
all_logprobs = []
per_request_caches = []
+ aborted_requests = []
for req in requests:
- # Create a fresh KVCache for this request's language model prefill
- request_cache = make_prompt_cache(self.language_model)
-
- with mx.stream(MLLMBatchGenerator._stream):
- # Run VLM forward pass — cache= flows through to language_model
- logits = self._run_vision_encoding(req, cache=request_cache)
-
- # Extract last token logits and sample
- last_logits = logits[:, -1, :]
- logprobs = last_logits - mx.logsumexp(
- last_logits, axis=-1, keepdims=True
- )
- sampled = self.sampler(logprobs)
-
- mx.eval(sampled, logprobs)
+ try:
+ # Check abort before starting prefill
+ if req.request_id in self._aborted_request_ids:
+ self._aborted_request_ids.discard(req.request_id)
+ raise PrefillAbortedError(req.request_id)
+
+ # Try prefix cache for all requests (text-only and multimodal).
+ # VLM forward writes the same KV state as language model forward
+ # for text tokens, so cached KV from a previous VLM run is valid.
+ # However, if the remaining (uncached) tokens contain image
+ # placeholders, we must fall back to VLM forward instead of
+ # running them through the language model alone.
+ cached_kv = None
+ remaining_ids = None
+ if self.prefix_cache is not None and req.input_ids is not None:
+ input_ids_list = req.input_ids.reshape(-1).tolist()
+ # Strip think suffix from lookup key so stored entries
+ # (also stripped) match as clean PREFIX.
+ S = self._think_suffix_len
+ lookup_ids = input_ids_list[:-S] if S > 0 else input_ids_list
+ cached_kv, remaining_ids = self.prefix_cache.fetch(lookup_ids)
+ # Append think suffix back to remaining so the model
+ # sees the full generation prompt (\n).
+ if cached_kv is not None and S > 0:
+ remaining_ids = list(remaining_ids) + input_ids_list[-S:]
+
+ # If remaining tokens contain image placeholders, the
+ # language-model-only path cannot handle them — clear the
+ # cache hit so we fall through to full VLM forward.
+ if cached_kv is not None and remaining_ids:
+ img_tok = getattr(
+ getattr(self.model, "config", None),
+ "image_token_index",
+ None,
+ )
+ if img_tok is not None and img_tok in remaining_ids:
+ cached_kv = None
+ remaining_ids = None
+
+ if cached_kv is not None and remaining_ids:
+ # Prefix/LCP match — run language model on remaining tokens
+ request_cache = cached_kv
+ remaining = mx.array(remaining_ids)[None, :]
+ cached_count = len(input_ids_list) - len(remaining_ids)
+ total_tokens = len(input_ids_list)
+ remaining_count = len(remaining_ids)
+
+ with mx.stream(MLLMBatchGenerator._stream):
+ step = self.prefill_step_size
+ if remaining_count <= step:
+ # Short remaining — process in one shot
+ self._prefill_progress[req.request_id] = (
+ total_tokens,
+ total_tokens,
+ )
+ logits = self.language_model(remaining, cache=request_cache)
+ else:
+ # Chunked prefill on remaining tokens
+ self._prefill_progress[req.request_id] = (
+ cached_count,
+ total_tokens,
+ )
+ processed = 0
+ chunk_count = 0
+ while processed + step < remaining_count:
+ # Check for abort between chunks
+ if req.request_id in self._aborted_request_ids:
+ self._aborted_request_ids.discard(req.request_id)
+ logger.info(
+ f"[chunked_prefill] Aborted {req.request_id} "
+ f"at {cached_count + processed}/{total_tokens} tokens"
+ )
+ raise PrefillAbortedError(req.request_id)
+
+ chunk = remaining[:, processed : processed + step]
+ self.language_model(chunk, cache=request_cache)
+ mx.eval([c.state for c in request_cache])
+ processed += step
+ chunk_count += 1
+ self._prefill_progress[req.request_id] = (
+ cached_count + processed,
+ total_tokens,
+ )
+ if chunk_count % 4 == 0:
+ mx.clear_cache()
+ # Last chunk — return logits
+ remaining = remaining[:, processed:]
+ logits = self.language_model(remaining, cache=request_cache)
+ self._prefill_progress[req.request_id] = (
+ total_tokens,
+ total_tokens,
+ )
+
+ if hasattr(logits, "logits"):
+ logits = logits.logits
+
+ last_logits = logits[:, -1, :]
+ logprobs = last_logits - mx.logsumexp(
+ last_logits, axis=-1, keepdims=True
+ )
+ sampled = self.sampler(logprobs)
+ mx.eval(sampled, logprobs)
+
+ first_tokens.append(sampled.item())
+ all_logprobs.append(logprobs.squeeze(0))
+
+ per_request_caches.append(request_cache)
+ req.vision_encoded = True
+ logger.debug(
+ f"Prefix cache hit for {req.request_id}: "
+ f"cached={cached_count}, "
+ f"remaining={remaining_count}"
+ )
- first_tokens.append(sampled.item())
- all_logprobs.append(logprobs.squeeze(0))
+ elif cached_kv is not None and not remaining_ids:
+ # Exact/supersequence match — cache has all tokens,
+ # but we still need logits for the last token.
+ # fetch() with trim-by-1 store always returns remaining=[last_token].
+ # If we get here (empty remaining), re-run on last token.
+ request_cache = cached_kv
+ last_token = req.input_ids[:, -1:]
+ total_tokens = len(input_ids_list)
+ self._prefill_progress[req.request_id] = (
+ total_tokens,
+ total_tokens,
+ )
- per_request_caches.append(request_cache)
+ with mx.stream(MLLMBatchGenerator._stream):
+ logits = self.language_model(last_token, cache=request_cache)
+ if hasattr(logits, "logits"):
+ logits = logits.logits
+
+ last_logits = logits[:, -1, :]
+ logprobs = last_logits - mx.logsumexp(
+ last_logits, axis=-1, keepdims=True
+ )
+ sampled = self.sampler(logprobs)
+ mx.eval(sampled, logprobs)
+
+ first_tokens.append(sampled.item())
+ all_logprobs.append(logprobs.squeeze(0))
+
+ per_request_caches.append(request_cache)
+ req.vision_encoded = True
+ logger.debug(
+ f"Prefix cache exact hit for {req.request_id}: "
+ f"all {total_tokens} tokens cached"
+ )
- # Merge per-request KVCaches into a single BatchKVCache.
- # KVCache.merge() creates a BatchKVCache with proper left-padding
- # alignment, so all requests share a single batched cache for
- # subsequent generation steps.
- from mlx_lm.models.cache import KVCache
+ else:
+ # Cache miss — full forward pass
+ request_cache = make_prompt_cache(self.language_model)
+
+ with mx.stream(MLLMBatchGenerator._stream):
+ # Text-only: chunked prefill with real progress tracking
+ # Multimodal: atomic VLM forward (vision encoder needs full input)
+ if req.is_text_only:
+ logits = self._run_chunked_text_prefill(
+ req, cache=request_cache
+ )
+ else:
+ logits = self._run_vision_encoding(req, cache=request_cache)
+
+ # Extract last token logits and sample
+ last_logits = logits[:, -1, :]
+ logprobs = last_logits - mx.logsumexp(
+ last_logits, axis=-1, keepdims=True
+ )
+ sampled = self.sampler(logprobs)
+
+ mx.eval(sampled, logprobs)
+
+ first_tokens.append(sampled.item())
+ all_logprobs.append(logprobs.squeeze(0))
+
+ per_request_caches.append(request_cache)
+
+ except PrefillAbortedError:
+ aborted_requests.append(req)
+ self._prefill_progress.pop(req.request_id, None)
+ self._pending_error_responses.append(
+ MLLMBatchResponse(
+ uid=req.uid,
+ request_id=req.request_id,
+ token=0,
+ logprobs=mx.zeros(1),
+ finish_reason="abort",
+ )
+ )
- sample_cache = per_request_caches[0][0]
- if not isinstance(sample_cache, KVCache):
- raise ValueError(
- f"MLLM continuous batching requires standard KVCache but got "
- f"{type(sample_cache).__name__}. Disable --kv-cache-quantization "
- f"when using multimodal models with --continuous-batching."
- )
+ # Remove aborted requests — they have no entries in the parallel
+ # lists (first_tokens, all_logprobs, per_request_caches)
+ if aborted_requests:
+ for req in aborted_requests:
+ requests.remove(req)
+ mx.clear_cache()
+ if not requests:
+ return None
+
+ # Merge per-request caches into batched caches.
+ # Both KVCache.merge() and ArraysCache.merge() produce batch-aware
+ # caches that support filter/extend/extract for continuous batching.
+ #
+ # Fix: RotatingKVCache._update_concat does NOT trim on first call —
+ # if prompt length > max_size, the buffer grows beyond max_size.
+ # BatchRotatingKVCache.merge() then hits a shape mismatch when
+ # copying via _temporal_order (full buffer) into a max_size slice.
+ # Trim buffer to max_size before merging.
+ from mlx_lm.models.cache import RotatingKVCache
+
+ for rc in per_request_caches:
+ for layer_cache in rc:
+ if isinstance(layer_cache, RotatingKVCache):
+ if layer_cache.keys is not None:
+ buf_len = layer_cache.keys.shape[2]
+ if buf_len > layer_cache.max_size:
+ trim_size = buf_len - layer_cache.max_size
+ layer_cache.keys = layer_cache._trim(
+ trim_size, layer_cache.keys
+ )
+ layer_cache.values = layer_cache._trim(
+ trim_size, layer_cache.values
+ )
+ layer_cache._idx = layer_cache.max_size
+ # Normalize wrapped rotating cache for merge:
+ # after rotation _idx wraps around but merge()
+ # expects _idx == actual buffer size.
+ # Use keys.shape[2] (actual entries) NOT size()
+ # which can be inconsistent after prefix cache trim
+ # (size() = min(offset, max_size) but buffer may
+ # have fewer entries when trimmed).
+ actual_buf = layer_cache.keys.shape[2]
+ if layer_cache._idx != actual_buf and actual_buf > 0:
+ layer_cache.keys = layer_cache._temporal_order(
+ layer_cache.keys
+ )
+ layer_cache.values = layer_cache._temporal_order(
+ layer_cache.values
+ )
+ layer_cache._idx = actual_buf
try:
batch_cache = [
@@ -684,14 +1211,61 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch:
for layer_idx in range(len(per_request_caches[0]))
]
except Exception as e:
+ sample_type = type(per_request_caches[0][0]).__name__
logger.error(
- f"Failed to merge per-request KV caches: {type(e).__name__}: {e}"
+ f"Failed to merge per-request caches ({sample_type}): "
+ f"{type(e).__name__}: {e}"
)
raise
# Create initial y (first generated tokens)
y = mx.array(first_tokens)
+ # Build per-request logits processors (repetition_penalty, presence_penalty)
+ from mlx_lm.sample_utils import make_logits_processors, make_sampler
+
+ batch_logits_processors = []
+ has_any_lp = False
+ for req in requests:
+ need_rep = req.repetition_penalty and req.repetition_penalty != 1.0
+ need_pres = req.presence_penalty and req.presence_penalty != 0.0
+ if need_rep or need_pres:
+ lp_kwargs = {}
+ if need_rep:
+ lp_kwargs["repetition_penalty"] = req.repetition_penalty
+ if need_pres:
+ lp_kwargs["presence_penalty"] = req.presence_penalty
+ lp = make_logits_processors(**lp_kwargs)
+ batch_logits_processors.append(lp)
+ has_any_lp = True
+ logger.info(
+ f"[sampling] request={req.request_id[:12]} "
+ f"rep_penalty={req.repetition_penalty} "
+ f"pres_penalty={req.presence_penalty}"
+ )
+ else:
+ batch_logits_processors.append(None)
+
+ # Build per-request samplers for top_k/min_p
+ batch_samplers = []
+ has_any_sampler = False
+ for req in requests:
+ if req.top_k != 0 or req.min_p != 0.0:
+ s = make_sampler(
+ temp=req.temperature,
+ top_p=req.top_p,
+ top_k=req.top_k,
+ min_p=req.min_p,
+ )
+ batch_samplers.append(s)
+ has_any_sampler = True
+ logger.info(
+ f"[sampling] request={req.request_id[:12]} "
+ f"top_k={req.top_k} min_p={req.min_p}"
+ )
+ else:
+ batch_samplers.append(None)
+
self._stats.prompt_time += time.perf_counter() - tic
return MLLMBatch(
@@ -703,10 +1277,17 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch:
num_tokens=[0] * len(requests),
cache=batch_cache,
requests=requests,
+ logits_processors=batch_logits_processors if has_any_lp else None,
+ samplers=batch_samplers if has_any_sampler else None,
)
def _step(
- self, input_tokens: mx.array, cache: List[Any]
+ self,
+ input_tokens: mx.array,
+ cache: List[Any],
+ logits_processors: Optional[List[Optional[List[Callable]]]] = None,
+ output_tokens: Optional[List[List[int]]] = None,
+ samplers: Optional[List[Optional[Callable]]] = None,
) -> Tuple[mx.array, List[mx.array]]:
"""
Run one generation step through the language model.
@@ -714,6 +1295,9 @@ def _step(
Args:
input_tokens: Input tokens [batch_size, 1] or [batch_size]
cache: BatchKVCache for the language model
+ logits_processors: Per-request logits processors (e.g. repetition penalty)
+ output_tokens: Per-request generated tokens so far (needed by processors)
+ samplers: Per-request sampler functions (for top_k/min_p)
Returns:
Tuple of (sampled tokens, logprobs list)
@@ -733,9 +1317,29 @@ def _step(
logits = logits[:, -1, :]
- # Sample
+ # Apply per-request logits processors (repetition penalty etc.)
+ if logits_processors and output_tokens and any(logits_processors):
+ processed_logits = []
+ for e in range(logits.shape[0]):
+ sample_logits = logits[e : e + 1]
+ if logits_processors[e]:
+ for processor in logits_processors[e]:
+ sample_logits = processor(
+ mx.array(output_tokens[e]), sample_logits
+ )
+ processed_logits.append(sample_logits)
+ logits = mx.concatenate(processed_logits, axis=0)
+
+ # Sample — per-request samplers for top_k/min_p support
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
- sampled = self.sampler(logprobs)
+ if samplers and any(samplers):
+ sampled_list = []
+ for e in range(logprobs.shape[0]):
+ s = samplers[e] if samplers[e] else self.sampler
+ sampled_list.append(s(logprobs[e : e + 1]))
+ sampled = mx.concatenate(sampled_list, axis=0)
+ else:
+ sampled = self.sampler(logprobs)
return sampled, list(logprobs)
@@ -757,6 +1361,8 @@ def _next(self) -> List[MLLMBatchResponse]:
# merged into a single BatchKVCache. Merging into an active batch
# mid-generation would cause shape mismatches in attention layers,
# so queued requests wait until the current batch finishes.
+ # Exception: text-only requests can be extended into an active batch
+ # via the elif branch below (they skip vision encoding entirely).
if num_active == 0:
requests = self.unprocessed_requests[: self.completion_batch_size]
@@ -764,18 +1370,100 @@ def _next(self) -> List[MLLMBatchResponse]:
self.active_batch = None
return []
- new_batch = self._process_prompts(requests)
- self.unprocessed_requests = self.unprocessed_requests[len(requests) :]
- self.active_batch = new_batch
- prompt_processing = True
+ try:
+ # Save count before _process_prompts which modifies
+ # `requests` in-place via .remove() for failed items.
+ num_to_consume = len(requests)
+ new_batch = self._process_prompts(requests)
+ self.unprocessed_requests = self.unprocessed_requests[num_to_consume:]
+ self.active_batch = new_batch
+ prompt_processing = True
+ except Exception as e:
+ logger.error(
+ f"Failed to process batch of {len(requests)} prompts: "
+ f"{type(e).__name__}: {e}",
+ exc_info=True,
+ )
+ # Remove failed requests to avoid infinite retry loop
+ self.unprocessed_requests = self.unprocessed_requests[len(requests) :]
+ for req in requests:
+ self._pending_error_responses.append(
+ MLLMBatchResponse(
+ uid=req.uid,
+ request_id=req.request_id,
+ token=0,
+ logprobs=mx.zeros(1),
+ finish_reason="error",
+ )
+ )
+
+ # Mid-batch extend: text-only requests can join an active batch
+ # without vision encoding (no shape mismatch risk).
+ elif self.unprocessed_requests:
+ text_only = [
+ r for r in self.unprocessed_requests if not r.images and not r.videos
+ ][: self.completion_batch_size]
+
+ if text_only:
+ try:
+ # Capture UIDs before _process_prompts modifies
+ # text_only in-place via .remove() for failed items.
+ all_uids = {r.uid for r in text_only}
+ new_batch = self._process_prompts(text_only)
+ # Remove ALL requested (both successful and failed)
+ self.unprocessed_requests = [
+ r for r in self.unprocessed_requests if r.uid not in all_uids
+ ]
+ if new_batch is not None:
+ batch.extend(new_batch)
+ prompt_processing = True
+ except Exception as e:
+ logger.warning(
+ f"Failed to extend batch with text-only requests: "
+ f"{type(e).__name__}: {e}"
+ )
+ # Remove failed requests to avoid infinite retry loop
+ processed_uids = {r.uid for r in text_only}
+ self.unprocessed_requests = [
+ r
+ for r in self.unprocessed_requests
+ if r.uid not in processed_uids
+ ]
+ for req in text_only:
+ self._pending_error_responses.append(
+ MLLMBatchResponse(
+ uid=req.uid,
+ request_id=req.request_id,
+ token=0,
+ logprobs=mx.zeros(1),
+ finish_reason="error",
+ )
+ )
+
+ # Collect any pending error responses (from failed preprocessing)
+ error_responses = []
+ if self._pending_error_responses:
+ error_responses = list(self._pending_error_responses)
+ self._pending_error_responses.clear()
# Generate next token for active batch
batch = self.active_batch
if batch is None:
- return []
+ return error_responses
y, logprobs = batch.y, batch.logprobs
- batch.y, batch.logprobs = self._step(y[:, None], batch.cache)
+ output_tokens = (
+ [req.output_tokens for req in batch.requests]
+ if batch.logits_processors
+ else None
+ )
+ batch.y, batch.logprobs = self._step(
+ y[:, None],
+ batch.cache,
+ batch.logits_processors,
+ output_tokens,
+ batch.samplers,
+ )
mx.async_eval(batch.y, batch.logprobs)
y = y.tolist()
@@ -821,6 +1509,8 @@ def _next(self) -> List[MLLMBatchResponse]:
if finish_reason is not None:
# Extract cache for this request
cache_fn = lambda idx=i: batch.extract_cache(idx)
+ # Cleanup prefill progress tracking
+ self._prefill_progress.pop(request_id, None)
responses.append(
MLLMBatchResponse(
@@ -833,6 +1523,9 @@ def _next(self) -> List[MLLMBatchResponse]:
)
)
+ # Store caches for finished text-only requests BEFORE filtering
+ self._maybe_store_prefix_cache(batch, end_idx)
+
# Remove finished requests from batch
if end_idx:
if keep_idx:
@@ -841,7 +1534,7 @@ def _next(self) -> List[MLLMBatchResponse]:
self.active_batch = None
self._stats.generation_tokens += len(responses)
- return responses
+ return error_responses + responses
def next(self) -> List[MLLMBatchResponse]:
"""
@@ -863,10 +1556,404 @@ def stats(self) -> MLLMBatchStats:
self._stats.peak_memory = mx.get_peak_memory() / 1e9
return self._stats
+ def _maybe_store_prefix_cache(
+ self, batch: MLLMBatch, end_indices: List[int]
+ ) -> None:
+ """Store KV caches for finished text-only requests into prefix cache.
+
+ Must be called BEFORE batch.filter() so that indices are still valid.
+ """
+ if self.prefix_cache is None or not end_indices:
+ return
+ for i in end_indices:
+ req = batch.requests[i]
+ if req.input_ids is not None:
+ try:
+ extracted = batch.extract_cache(i)
+ input_ids_list = req.input_ids.reshape(-1).tolist()
+ # Store prompt-only KV (trim output tokens + 1 so next
+ # fetch returns remaining=[last_prompt_token] at minimum).
+ # Also strip think suffix from key so next request's
+ # (also stripped) key matches as a clean PREFIX.
+ output_count = batch.num_tokens[i]
+ S = self._think_suffix_len
+ total_trim = output_count + 1 + S
+ prompt_cache = _trim_cache_offset(extracted, total_trim)
+ cache_key = input_ids_list[:-S] if S > 0 else input_ids_list
+ self.prefix_cache.store(cache_key, prompt_cache)
+ except Exception as e:
+ logger.warning(
+ f"Failed to store prefix cache for {req.request_id}: {type(e).__name__}: {e}"
+ )
+
+ def get_prefill_progress(self, request_id: str) -> Optional[Tuple[int, int]]:
+ """Return (processed_tokens, total_tokens) or None."""
+ return self._prefill_progress.get(request_id)
+
def get_vision_cache_stats(self) -> Dict[str, Any]:
"""Get vision cache statistics."""
return self.vision_cache.get_stats()
+ def get_prefix_cache_stats(self) -> Dict[str, Any]:
+ """Get KV prefix cache statistics."""
+ if self.prefix_cache is not None:
+ return self.prefix_cache.get_stats()
+ return {
+ "hits": 0,
+ "misses": 0,
+ "hit_rate": 0.0,
+ "evictions": 0,
+ "tokens_saved": 0,
+ "current_memory_mb": 0.0,
+ "max_memory_mb": 0.0,
+ "memory_utilization": 0.0,
+ "entry_count": 0,
+ }
+
def has_pending(self) -> bool:
"""Check if there are pending or active requests."""
return bool(self.unprocessed_requests or self.active_batch)
+
+
+def install_mtp_mllm(
+ batch_gen: "MLLMBatchGenerator",
+ language_model: Any,
+ num_draft_tokens: int = 1,
+) -> None:
+ """Install MTP (Multi-Token Prediction) on an MLLMBatchGenerator.
+
+ Adapts the always-advance MTP strategy from scheduler._install_mtp
+ for the MLLM batched generation path. Handles hybrid model caches
+ (BatchKVCache for attention + ArraysCache for recurrent layers).
+
+ Flow per generation step:
+ 1. Use skip_state logits/hidden OR run model forward -> sample primary
+ 2. MTP head drafts one token
+ 3. Verify [primary, draft] in one model call (always advances cache)
+ 4. Accept: skip_state from pos 1, defer draft for next step emission
+ Reject: trim KV by 2 + restore RNN state + re-advance with primary
+ 5. Draft is emitted in the NEXT generation step after primary
+ """
+ from .scheduler import make_sampler
+
+ _orig_step = batch_gen._step
+ _draft_sampler = make_sampler(temp=0.0)
+
+ # Skip state: stored logits + hidden from verify pass
+ _skip_state: list = [None]
+
+ # Deferred drafts keyed by UID
+ _deferred_drafts: Dict[int, dict] = {}
+
+ # MTP stats
+ _mtp_stats = {"accepted": 0, "rejected": 0, "errors": 0}
+
+ def _mtp_step(
+ input_tokens: mx.array,
+ cache: List[Any],
+ logits_processors: Optional[List[Optional[List[Callable]]]] = None,
+ output_tokens: Optional[List[List[int]]] = None,
+ samplers: Optional[List[Optional[Callable]]] = None,
+ ) -> Tuple[mx.array, List[mx.array]]:
+ """Extended _step with MTP always-advance strategy."""
+ batch_size = input_tokens.shape[0]
+
+ # Prefill guard: skip MTP for multi-token input or when no active batch
+ # Also skip MTP when batch has multiple active requests (MTP overhead
+ # hurts aggregate throughput in concurrent scenarios)
+ if (
+ input_tokens.shape[1] > 1
+ or batch_gen.active_batch is None
+ or len(batch_gen.active_batch) > 1
+ ):
+ _skip_state[0] = None
+ return _orig_step(
+ input_tokens, cache, logits_processors, output_tokens, samplers
+ )
+
+ # Check skip state
+ skip = _skip_state[0]
+ if skip is not None and skip["logits"].shape[0] != batch_size:
+ skip = None
+ _skip_state[0] = None
+
+ if skip is not None:
+ logits = skip["logits"]
+ hidden_states = skip["hidden"]
+ _skip_state[0] = None
+ else:
+ # Normal forward with return_hidden
+ model_output = language_model(input_tokens, cache=cache, return_hidden=True)
+ if isinstance(model_output, tuple):
+ logits, hidden_states = model_output
+ else:
+ return _orig_step(
+ input_tokens, cache, logits_processors, output_tokens, samplers
+ )
+ logits = logits[:, -1, :]
+
+ # Apply logits processors before sampling
+ if logits_processors and output_tokens and any(logits_processors):
+ processed_logits = []
+ for e in range(batch_size):
+ sample_logits = logits[e : e + 1]
+ if logits_processors[e]:
+ for processor in logits_processors[e]:
+ sample_logits = processor(
+ mx.array(output_tokens[e]), sample_logits
+ )
+ processed_logits.append(sample_logits)
+ logits = mx.concatenate(processed_logits, axis=0)
+
+ # Sample primary (use per-request sampler if available)
+ logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
+ if samplers and any(samplers):
+ sampled_list = []
+ for e in range(logprobs.shape[0]):
+ s = samplers[e] if samplers[e] else batch_gen.sampler
+ sampled_list.append(s(logprobs[e : e + 1]))
+ primary_tokens = mx.concatenate(sampled_list, axis=0)
+ else:
+ primary_tokens = batch_gen.sampler(logprobs)
+
+ current_uids = list(batch_gen.active_batch.uids)
+
+ # MTP draft + always-advance verify
+ try:
+ draft_logits = language_model.mtp_forward(
+ hidden_states[:, -1:, :],
+ primary_tokens[:, None],
+ mtp_cache=None,
+ )
+ draft_logits = draft_logits[:, -1, :]
+ draft_logprobs = draft_logits - mx.logsumexp(
+ draft_logits, axis=-1, keepdims=True
+ )
+ draft_tokens = _draft_sampler(draft_logprobs)
+
+ # Snapshot RNN state for hybrid models
+ _rnn_snapshots = {}
+ for _ci, _c in enumerate(cache):
+ if not (hasattr(_c, "is_trimmable") and _c.is_trimmable()):
+ if hasattr(_c, "state"):
+ _rnn_snapshots[_ci] = [
+ mx.array(s) if s is not None else None for s in _c.state
+ ]
+
+ # Verify [primary, draft]
+ verify_input = mx.concatenate(
+ [primary_tokens[:, None], draft_tokens[:, None]], axis=1
+ )
+ verify_output = language_model(
+ verify_input, cache=cache, return_hidden=True
+ )
+ if isinstance(verify_output, tuple):
+ verify_logits, verify_hidden = verify_output
+ else:
+ verify_logits = verify_output
+ verify_hidden = None
+
+ # Verified mode: check if draft matches verify prediction
+ verify_pred = mx.argmax(verify_logits[:, 0, :], axis=-1)
+ mx.eval(verify_pred, draft_tokens)
+ pred_list = verify_pred.tolist()
+ draft_list = draft_tokens.tolist()
+ all_accepted = pred_list == draft_list
+
+ if all_accepted and verify_hidden is not None:
+ # ACCEPT
+ _skip_state[0] = {
+ "logits": verify_logits[:, 1, :],
+ "hidden": verify_hidden[:, -1:, :],
+ }
+ mx.async_eval(_skip_state[0]["logits"], _skip_state[0]["hidden"])
+ verify_lp = verify_logits[:, 0, :] - mx.logsumexp(
+ verify_logits[:, 0, :], axis=-1, keepdims=True
+ )
+ for e in range(batch_size):
+ uid = current_uids[e]
+ _deferred_drafts[uid] = {
+ "token": draft_list[e],
+ "logprobs": verify_lp[e],
+ }
+ _mtp_stats["accepted"] += 1
+
+ else:
+ # REJECT
+ if _rnn_snapshots:
+ # Hybrid model: undo entire verify, re-advance with primary
+ for c in cache:
+ if (
+ hasattr(c, "is_trimmable")
+ and c.is_trimmable()
+ and hasattr(c, "trim")
+ ):
+ c.trim(2)
+ for _ci, _snap in _rnn_snapshots.items():
+ cache[_ci].state = _snap
+ rerun_out = language_model(
+ primary_tokens[:, None],
+ cache=cache,
+ return_hidden=True,
+ )
+ if isinstance(rerun_out, tuple):
+ rerun_logits, rerun_hidden = rerun_out
+ else:
+ rerun_logits = rerun_out
+ rerun_hidden = None
+ if rerun_hidden is not None:
+ _skip_state[0] = {
+ "logits": rerun_logits[:, -1, :],
+ "hidden": rerun_hidden[:, -1:, :],
+ }
+ mx.async_eval(
+ _skip_state[0]["logits"],
+ _skip_state[0]["hidden"],
+ )
+ else:
+ _skip_state[0] = None
+ else:
+ # Pure attention model: simple trim
+ for c in cache:
+ if (
+ hasattr(c, "is_trimmable")
+ and c.is_trimmable()
+ and hasattr(c, "trim")
+ ):
+ c.trim(1)
+ if verify_hidden is not None:
+ _skip_state[0] = {
+ "logits": verify_logits[:, 0, :],
+ "hidden": verify_hidden[:, 0:1, :],
+ }
+ mx.async_eval(
+ _skip_state[0]["logits"],
+ _skip_state[0]["hidden"],
+ )
+ else:
+ _skip_state[0] = None
+ for uid in current_uids:
+ _deferred_drafts.pop(uid, None)
+ _mtp_stats["rejected"] += 1
+
+ except Exception as e:
+ logger.warning(f"[MTP-MLLM] draft/verify failed: {e}")
+ _skip_state[0] = None
+ _mtp_stats["errors"] += 1
+
+ # Log MTP stats every 50 steps
+ total = _mtp_stats["accepted"] + _mtp_stats["rejected"] + _mtp_stats["errors"]
+ if total > 0 and total % 50 == 0:
+ acc = _mtp_stats["accepted"]
+ rej = _mtp_stats["rejected"]
+ err = _mtp_stats["errors"]
+ rate = acc / (acc + rej) * 100 if (acc + rej) > 0 else 0
+ logger.info(
+ f"[MTP-MLLM] stats: accepted={acc} rejected={rej} "
+ f"errors={err} acceptance={rate:.0f}%"
+ )
+
+ return primary_tokens, list(logprobs)
+
+ # Wrap _next to emit deferred MTP drafts
+ batch_gen._inner_next = batch_gen._next
+
+ def _mtp_next() -> List[MLLMBatchResponse]:
+ """Wrapper around _next that emits deferred MTP draft tokens."""
+ if batch_gen.active_batch is None:
+ _skip_state[0] = None
+ _deferred_drafts.clear()
+
+ # Save deferred drafts from previous step
+ prev_deferred: Dict[int, dict] = {}
+ if batch_gen.active_batch is not None:
+ for uid in batch_gen.active_batch.uids:
+ if uid in _deferred_drafts:
+ prev_deferred[uid] = _deferred_drafts.pop(uid)
+
+ responses = batch_gen._inner_next()
+
+ if not prev_deferred or not responses:
+ return responses
+
+ # Augment responses with deferred drafts
+ augmented: List[MLLMBatchResponse] = []
+ draft_end_uids: set = set()
+
+ for r in responses:
+ uid = r.uid
+ augmented.append(r)
+
+ if r.finish_reason is not None:
+ _deferred_drafts.pop(uid, None)
+ prev_deferred.pop(uid, None)
+ continue
+
+ if uid in prev_deferred:
+ draft_info = prev_deferred.pop(uid)
+ draft_t = draft_info["token"]
+ draft_lp = draft_info["logprobs"]
+
+ if draft_t in batch_gen.stop_tokens:
+ augmented.append(
+ MLLMBatchResponse(
+ uid=uid,
+ request_id=r.request_id,
+ token=draft_t,
+ logprobs=draft_lp,
+ finish_reason="stop",
+ )
+ )
+ draft_end_uids.add(uid)
+ else:
+ draft_finish = None
+ batch = batch_gen.active_batch
+ if batch is not None:
+ for e, bu in enumerate(batch.uids):
+ if bu == uid:
+ batch.num_tokens[e] += 1
+ batch.requests[e].output_tokens.append(draft_t)
+ if batch.num_tokens[e] >= batch.max_tokens[e]:
+ draft_finish = "length"
+ draft_end_uids.add(uid)
+ break
+
+ augmented.append(
+ MLLMBatchResponse(
+ uid=uid,
+ request_id=r.request_id,
+ token=draft_t,
+ logprobs=draft_lp,
+ finish_reason=draft_finish,
+ )
+ )
+
+ # Store prefix caches for draft-ended sequences BEFORE filtering
+ if draft_end_uids and batch_gen.active_batch is not None:
+ end_indices = [
+ e
+ for e, u in enumerate(batch_gen.active_batch.uids)
+ if u in draft_end_uids
+ ]
+ batch_gen._maybe_store_prefix_cache(batch_gen.active_batch, end_indices)
+
+ keep = [
+ e
+ for e, u in enumerate(batch_gen.active_batch.uids)
+ if u not in draft_end_uids
+ ]
+ if keep:
+ batch_gen.active_batch.filter(keep)
+ else:
+ batch_gen.active_batch = None
+
+ return augmented
+
+ batch_gen._step = _mtp_step
+ batch_gen._next = _mtp_next
+
+ total = _mtp_stats
+ logger.info(
+ f"[MTP-MLLM] installed with num_draft_tokens={num_draft_tokens}, "
+ f"always-advance verified mode"
+ )
diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py
index 555b230f2..04c7cac2a 100644
--- a/vllm_mlx/mllm_scheduler.py
+++ b/vllm_mlx/mllm_scheduler.py
@@ -19,6 +19,7 @@
"""
import asyncio
+import concurrent.futures
import logging
import time
import uuid
@@ -35,7 +36,6 @@
MLLMBatchRequest,
MLLMBatchResponse,
)
-from .mllm_cache import MLLMCacheManager
from .multimodal_processor import MultimodalProcessor
from .request import RequestOutput, RequestStatus, SamplingParams
@@ -62,8 +62,22 @@ class MLLMSchedulerConfig:
default_max_tokens: int = 256
# Default video FPS for frame extraction
default_video_fps: float = 2.0
+ # KV cache memory limit (from --cache-memory-mb)
+ cache_memory_mb: Optional[int] = None
# Maximum video frames
max_video_frames: int = 128
+ # Enable MTP speculative decoding
+ enable_mtp: bool = False
+ # Number of draft tokens for MTP
+ mtp_num_draft_tokens: int = 1
+ # Enable KV prefix cache for text-only requests
+ enable_prefix_cache: bool = True
+ # Memory limit for prefix cache (None = auto-detect)
+ prefix_cache_memory_mb: Optional[int] = None
+ # KV cache quantization for prefix cache store/fetch
+ kv_cache_quantization: bool = False
+ kv_cache_quantization_bits: int = 8
+ kv_cache_quantization_group_size: int = 64
@dataclass
@@ -94,6 +108,9 @@ class MLLMRequest:
num_prompt_tokens: int = 0
num_output_tokens: int = 0
+ # Timing
+ first_token_time: Optional[float] = None
+
@dataclass
class MLLMSchedulerOutput:
@@ -176,13 +193,6 @@ def __init__(
config=self.model_config,
)
- # Vision cache for repeated images
- self.vision_cache: Optional[MLLMCacheManager] = None
- if self.config.enable_vision_cache:
- self.vision_cache = MLLMCacheManager(
- max_entries=self.config.vision_cache_size
- )
-
# Get stop tokens from tokenizer
self.stop_tokens = self._get_stop_tokens()
@@ -218,8 +228,12 @@ def __init__(
self.total_prompt_tokens = 0
self.total_completion_tokens = 0
+ # Memory management: periodic mx.clear_cache() to free Metal buffers
+ self._step_count = 0
+ self._clear_cache_interval = 32
+
def _get_stop_tokens(self) -> Set[int]:
- """Get stop token IDs from tokenizer."""
+ """Get stop token IDs from tokenizer and generation_config.json."""
stop_tokens = set()
tokenizer = (
self.processor.tokenizer
@@ -239,6 +253,25 @@ def _get_stop_tokens(self) -> Set[int]:
else:
stop_tokens.add(tokenizer.eos_token_ids)
+ # Also read generation_config.json which may have additional EOS tokens
+ # (e.g., Gemma 4 has =106, <|tool_response>=50 as EOS)
+ model_path = getattr(tokenizer, "name_or_path", None)
+ if model_path:
+ import json
+ from pathlib import Path
+
+ gc_path = Path(model_path) / "generation_config.json"
+ if gc_path.exists():
+ try:
+ gc = json.loads(gc_path.read_text())
+ gc_eos = gc.get("eos_token_id")
+ if isinstance(gc_eos, list):
+ stop_tokens.update(gc_eos)
+ elif gc_eos is not None:
+ stop_tokens.add(gc_eos)
+ except Exception:
+ pass
+
return stop_tokens
def _ensure_batch_generator(self) -> None:
@@ -246,9 +279,24 @@ def _ensure_batch_generator(self) -> None:
if self.batch_generator is None:
from mlx_lm.sample_utils import make_sampler
+ from .memory_cache import MemoryCacheConfig
+
# Default sampler (can be overridden per-request in future)
sampler = make_sampler(temp=0.7, top_p=0.9)
+ # Configure KV prefix cache for text-only requests
+ # KV cache quantization reduces prefix cache memory ~4x (BF16→Q8).
+ # Quantization happens on store(), dequantization on fetch() —
+ # the model always receives normal KVCache with plain arrays.
+ prefix_cache_config = None
+ if self.config.enable_prefix_cache:
+ prefix_cache_config = MemoryCacheConfig(
+ max_memory_mb=self.config.prefix_cache_memory_mb,
+ kv_quantize=self.config.kv_cache_quantization,
+ kv_bits=self.config.kv_cache_quantization_bits,
+ kv_group_size=self.config.kv_cache_quantization_group_size,
+ )
+
self.batch_generator = MLLMBatchGenerator(
model=self.model,
processor=self.processor,
@@ -259,8 +307,21 @@ def _ensure_batch_generator(self) -> None:
prefill_batch_size=self.config.prefill_batch_size,
completion_batch_size=self.config.completion_batch_size,
prefill_step_size=self.config.prefill_step_size,
+ prefix_cache_config=prefix_cache_config,
)
+ # Install MTP if enabled and language model supports it
+ if self.config.enable_mtp:
+ lm = self.batch_generator.language_model
+ if hasattr(lm, "mtp") and lm.mtp is not None:
+ from .mllm_batch_generator import install_mtp_mllm
+
+ install_mtp_mllm(
+ self.batch_generator,
+ lm,
+ num_draft_tokens=self.config.mtp_num_draft_tokens,
+ )
+
# ========== Sync API (step-based) ==========
def add_request(
@@ -297,6 +358,10 @@ def add_request(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
+ top_k=kwargs.pop("top_k", 0),
+ min_p=kwargs.pop("min_p", 0.0),
+ presence_penalty=kwargs.pop("presence_penalty", 0.0),
+ repetition_penalty=kwargs.pop("repetition_penalty", 1.0),
)
request = MLLMRequest(
@@ -307,6 +372,19 @@ def add_request(
sampling_params=sampling_params,
)
+ # Estimate prompt token count for monitoring (text tokens only;
+ # vision tokens are added during prefill but this gives a useful
+ # approximation for the status endpoint).
+ tokenizer = (
+ self.processor.tokenizer
+ if hasattr(self.processor, "tokenizer")
+ else self.processor
+ )
+ try:
+ request.num_prompt_tokens = len(tokenizer.encode(prompt))
+ except Exception:
+ pass
+
self.requests[request_id] = request
self.waiting.append(request)
@@ -331,6 +409,12 @@ def abort_request(self, request_id: str) -> bool:
if request is None:
return False
+ # Signal batch generator to abort any in-progress prefill for this
+ # request. The prefill loop checks _aborted_request_ids between
+ # chunks and raises PrefillAbortedError to exit early.
+ if self.batch_generator is not None:
+ self.batch_generator.abort_prefill(request_id)
+
# Remove from waiting queue
if request.status == RequestStatus.WAITING:
try:
@@ -403,6 +487,10 @@ def _schedule_waiting(self) -> List[MLLMRequest]:
max_tokens=request.sampling_params.max_tokens,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
+ top_k=request.sampling_params.top_k,
+ min_p=request.sampling_params.min_p,
+ presence_penalty=request.sampling_params.presence_penalty,
+ repetition_penalty=request.sampling_params.repetition_penalty,
)
batch_requests.append(batch_req)
@@ -453,21 +541,41 @@ def _process_batch_responses(
if request is None:
continue
+ # Handle error responses from failed preprocessing
+ if response.finish_reason == "error":
+ output = RequestOutput(
+ request_id=request_id,
+ new_token_ids=[],
+ new_text="",
+ output_token_ids=[],
+ prompt_tokens=0,
+ completion_tokens=0,
+ finished=True,
+ finish_reason="error",
+ )
+ request.status = RequestStatus.FINISHED_ABORTED
+ request.output_text = ""
+ request.finish_reason = "error"
+ finished_ids.add(request_id)
+ self.num_requests_processed += 1
+ logger.warning(f"Request {request_id} failed during preprocessing")
+ outputs.append(output)
+ continue
+
# Append token to request
request.output_tokens.append(response.token)
request.num_output_tokens = len(request.output_tokens)
+ if request.first_token_time is None and request.num_output_tokens > 0:
+ request.first_token_time = time.time()
+
# Decode the new token using streaming detokenizer (UTF-8 safe).
# Skip stop tokens — they are not content.
if response.finish_reason == "stop":
new_text = ""
else:
if request_id not in self._detokenizer_pool:
- if hasattr(tokenizer, "detokenizer"):
- detok = tokenizer.detokenizer
- else:
- detok = NaiveStreamingDetokenizer(tokenizer)
- detok.reset()
+ detok = NaiveStreamingDetokenizer(tokenizer)
self._detokenizer_pool[request_id] = detok
detok = self._detokenizer_pool[request_id]
detok.add_token(response.token)
@@ -495,7 +603,7 @@ def _process_batch_responses(
finished_ids.add(request_id)
# Finalize streaming detokenizer and get full output
- detok = self._detokenizer_pool.get(request_id)
+ detok = self._detokenizer_pool.pop(request_id, None)
if detok is not None:
detok.finalize()
output.output_text = detok.text
@@ -503,7 +611,6 @@ def _process_batch_responses(
output.output_text = tokenizer.decode(request.output_tokens)
request.output_text = output.output_text
request.finish_reason = response.finish_reason
- self._detokenizer_pool.pop(request_id, None)
self.total_completion_tokens += request.num_output_tokens
self.num_requests_processed += 1
@@ -524,6 +631,9 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
if request_id in self.running:
del self.running[request_id]
+ # Drain from requests dict to prevent linear memory growth
+ self.requests.pop(request_id, None)
+
# Remove UID mappings
if request_id in self.request_id_to_uid:
uid = self.request_id_to_uid[request_id]
@@ -531,10 +641,17 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
del self.uid_to_request_id[uid]
del self.request_id_to_uid[request_id]
+ # Clean up detokenizer pool (handles abort/timeout cases)
+ self._detokenizer_pool.pop(request_id, None)
+
# Track as finished
self.finished_req_ids.add(request_id)
self.requests.pop(request_id, None)
+ # Clear Metal buffer pool after cleanup to release memory
+ if finished_ids:
+ mx.clear_cache()
+
def step(self) -> MLLMSchedulerOutput:
"""
Execute one scheduling step.
@@ -634,14 +751,33 @@ async def stop(self) -> None:
logger.info("MLLM Scheduler stopped")
async def _process_loop(self) -> None:
- """Main async processing loop."""
+ """Main async processing loop.
+
+ Uses a thread pool executor for steps that involve prefill
+ (waiting requests or partial prefill in progress) so that the
+ event loop stays responsive for health checks and other HTTP
+ endpoints. Decode-only steps are fast (<3 ms) and run inline.
+ """
+ _executor = concurrent.futures.ThreadPoolExecutor(
+ max_workers=1, thread_name_prefix="mllm-step"
+ )
+ loop = asyncio.get_running_loop()
+
while self._running:
try:
if self.has_requests():
- # Run one step
- self.step()
- # Yield to other tasks
- await asyncio.sleep(0)
+ has_waiting = self.get_num_waiting() > 0
+ has_partial = (
+ self.batch_generator is not None
+ and getattr(self.batch_generator, "_partial", None) is not None
+ )
+ needs_executor = has_waiting or has_partial
+
+ if needs_executor:
+ await loop.run_in_executor(_executor, self.step)
+ else:
+ self.step()
+ await asyncio.sleep(0)
else:
# No work, wait a bit
await asyncio.sleep(0.01)
@@ -649,7 +785,7 @@ async def _process_loop(self) -> None:
except asyncio.CancelledError:
break
except Exception as e:
- logger.error(f"Error in MLLM process loop: {e}")
+ logger.error(f"Error in MLLM process loop: {e}", exc_info=True)
await asyncio.sleep(0.1)
async def add_request_async(
@@ -778,6 +914,77 @@ async def generate(
# ========== Stats and utilities ==========
+ def get_running_requests_info(self) -> List[Dict[str, Any]]:
+ """Per-request details for status endpoint."""
+ now = time.time()
+ result = []
+
+ # Waiting requests
+ for req in self.waiting:
+ result.append(
+ {
+ "request_id": req.request_id,
+ "status": "waiting",
+ "phase": "queued",
+ "elapsed_s": round(now - req.arrival_time, 2),
+ "prompt_tokens": req.num_prompt_tokens,
+ "completion_tokens": 0,
+ "max_tokens": req.sampling_params.max_tokens,
+ "progress": 0.0,
+ "tokens_per_second": None,
+ "ttft_s": None,
+ "cache_hit_type": None,
+ "cached_tokens": 0,
+ }
+ )
+
+ # Running requests
+ for req in self.running.values():
+ n_out = req.num_output_tokens
+ elapsed = now - req.arrival_time
+
+ if n_out == 0:
+ phase = "prefill"
+ else:
+ phase = "generation"
+
+ tok_s = None
+ ttft = None
+ if req.first_token_time is not None:
+ ttft = round(req.first_token_time - req.arrival_time, 3)
+ gen_elapsed = now - req.first_token_time
+ if gen_elapsed > 0 and n_out > 0:
+ tok_s = round(n_out / gen_elapsed, 1)
+
+ max_tokens = req.sampling_params.max_tokens
+ if phase == "prefill" and self.batch_generator is not None:
+ pp = self.batch_generator.get_prefill_progress(req.request_id)
+ if pp is not None:
+ progress = round(pp[0] / pp[1], 3) if pp[1] > 0 else 0.0
+ else:
+ progress = 0.0
+ else:
+ progress = round(n_out / max_tokens, 3) if max_tokens > 0 else 0.0
+
+ result.append(
+ {
+ "request_id": req.request_id,
+ "status": "running",
+ "phase": phase,
+ "elapsed_s": round(elapsed, 2),
+ "prompt_tokens": req.num_prompt_tokens,
+ "completion_tokens": n_out,
+ "max_tokens": max_tokens,
+ "progress": min(progress, 1.0),
+ "tokens_per_second": tok_s,
+ "ttft_s": ttft,
+ "cache_hit_type": None,
+ "cached_tokens": 0,
+ }
+ )
+
+ return result
+
def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics."""
stats = {
@@ -787,27 +994,45 @@ def get_stats(self) -> Dict[str, Any]:
"num_requests_processed": self.num_requests_processed,
"total_prompt_tokens": self.total_prompt_tokens,
"total_completion_tokens": self.total_completion_tokens,
+ "requests": self.get_running_requests_info(),
}
if self.batch_generator is not None:
batch_stats = self.batch_generator.stats()
stats["batch_generator"] = batch_stats.to_dict()
- # Add vision embedding cache stats from batch generator
- stats["vision_embedding_cache"] = (
- self.batch_generator.get_vision_cache_stats()
- )
-
- if self.vision_cache:
- stats["vision_cache"] = self.vision_cache.get_stats()
+ # Vision embedding cache stats from batch generator
+ vec_stats = self.batch_generator.get_vision_cache_stats()
+ stats["vision_embedding_cache"] = vec_stats
# Include Metal memory stats
try:
if mx.metal.is_available():
- stats["metal_active_memory_gb"] = round(mx.get_active_memory() / 1e9, 2)
- stats["metal_peak_memory_gb"] = round(mx.get_peak_memory() / 1e9, 2)
- stats["metal_cache_memory_gb"] = round(mx.get_cache_memory() / 1e9, 2)
+ active_gb = round(mx.get_active_memory() / 1e9, 2)
+ peak_gb = round(mx.get_peak_memory() / 1e9, 2)
+ cache_gb = round(mx.get_cache_memory() / 1e9, 2)
+ stats["metal_active_memory_gb"] = active_gb
+ stats["metal_peak_memory_gb"] = peak_gb
+ stats["metal_cache_memory_gb"] = cache_gb
except Exception:
- pass
+ active_gb = 0
+ cache_gb = 0
+
+ # KV prefix cache stats for /v1/status and monitoring UI.
+ if self.batch_generator is not None:
+ prefix_stats = self.batch_generator.get_prefix_cache_stats()
+ else:
+ prefix_stats = {
+ "hits": 0,
+ "misses": 0,
+ "hit_rate": 0.0,
+ "evictions": 0,
+ "tokens_saved": 0,
+ "current_memory_mb": 0.0,
+ "max_memory_mb": 0.0,
+ "memory_utilization": 0.0,
+ "entry_count": 0,
+ }
+ stats["memory_aware_cache"] = prefix_stats
return stats
diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py
index 811a6d4da..46e26a744 100644
--- a/vllm_mlx/models/llm.py
+++ b/vllm_mlx/models/llm.py
@@ -111,6 +111,8 @@ def _create_sampler(
self,
temperature: float = 0.7,
top_p: float = 0.9,
+ top_k: int = 0,
+ min_p: float = 0.0,
):
"""Create a sampler for text generation."""
from mlx_lm.sample_utils import make_sampler
@@ -118,16 +120,38 @@ def _create_sampler(
return make_sampler(
temp=temperature,
top_p=top_p,
+ top_k=top_k,
+ min_p=min_p,
)
+ def _create_logits_processors(
+ self,
+ presence_penalty: float = 0.0,
+ repetition_penalty: float = 1.0,
+ ):
+ """Create logits processors for penalty-based sampling."""
+ from mlx_lm.sample_utils import make_logits_processors
+
+ processors = make_logits_processors(
+ repetition_penalty=(
+ repetition_penalty if repetition_penalty != 1.0 else None
+ ),
+ presence_penalty=presence_penalty if presence_penalty != 0.0 else None,
+ )
+ return processors if processors else None
+
def generate(
self,
prompt: str,
max_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 0.9,
+ top_k: int = 0,
+ min_p: float = 0.0,
+ presence_penalty: float = 0.0,
repetition_penalty: float = 1.0,
stop: list[str] | None = None,
+ **kwargs,
) -> GenerationOutput:
"""
Generate text from a prompt.
@@ -137,7 +161,10 @@ def generate(
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (0 = greedy)
top_p: Top-p (nucleus) sampling parameter
- repetition_penalty: Penalty for repeating tokens
+ top_k: Top-k sampling (0 = disabled)
+ min_p: Minimum probability threshold
+ presence_penalty: Additive penalty for token presence
+ repetition_penalty: Multiplicative penalty for repeating tokens
stop: List of stop sequences
Returns:
@@ -148,8 +175,11 @@ def generate(
from mlx_lm import generate
- # Create sampler with parameters
- sampler = self._create_sampler(temperature, top_p)
+ # Create sampler and logits processors with full Unsloth params
+ sampler = self._create_sampler(temperature, top_p, top_k, min_p)
+ logits_processors = self._create_logits_processors(
+ presence_penalty, repetition_penalty
+ )
# Generate text
output_text = generate(
@@ -158,6 +188,7 @@ def generate(
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
+ logits_processors=logits_processors,
verbose=False,
)
@@ -179,8 +210,13 @@ def stream_generate(
max_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 0.9,
+ top_k: int = 0,
+ min_p: float = 0.0,
+ presence_penalty: float = 0.0,
repetition_penalty: float = 1.0,
stop: list[str] | None = None,
+ logits_processors: list | None = None,
+ **kwargs,
) -> Iterator[StreamingOutput]:
"""
Stream text generation token by token.
@@ -190,7 +226,10 @@ def stream_generate(
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (0 = greedy)
top_p: Top-p (nucleus) sampling parameter
- repetition_penalty: Penalty for repeating tokens
+ top_k: Top-k sampling (0 = disabled)
+ min_p: Minimum probability threshold
+ presence_penalty: Additive penalty for token presence
+ repetition_penalty: Multiplicative penalty for repeating tokens
stop: List of stop sequences
Yields:
@@ -201,8 +240,15 @@ def stream_generate(
from mlx_lm import stream_generate
- # Create sampler with parameters
- sampler = self._create_sampler(temperature, top_p)
+ # Create sampler and logits processors with full Unsloth params
+ sampler = self._create_sampler(temperature, top_p, top_k, min_p)
+ penalty_processors = self._create_logits_processors(
+ presence_penalty, repetition_penalty
+ )
+ # Merge any externally-provided logits_processors with penalty processors
+ all_processors = None
+ if penalty_processors or logits_processors:
+ all_processors = (logits_processors or []) + (penalty_processors or [])
# Count prompt tokens once upfront
num_prompt_tokens = len(self.tokenizer.encode(prompt))
@@ -220,6 +266,7 @@ def stream_generate(
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
+ logits_processors=all_processors,
**mtp_kwargs,
):
token_count += 1
diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py
index fcf3537f4..5a3551eb1 100644
--- a/vllm_mlx/models/mllm.py
+++ b/vllm_mlx/models/mllm.py
@@ -465,8 +465,9 @@ def save_base64_image(base64_string: str) -> str:
"""Save base64 image to temp file and return path. Caches identical images."""
import hashlib
- # Hash the base64 string to check cache
- image_hash = hashlib.md5(base64_string.encode()).hexdigest()
+ # Hash the full base64 string to prevent collisions between images
+ # with identical headers (e.g. JPEG images sharing first 1000 chars)
+ image_hash = hashlib.sha256(base64_string.encode()).hexdigest()
# Return cached path if available and file still exists
if image_hash in _base64_image_cache:
@@ -1328,6 +1329,7 @@ def chat(
video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES)
tools = kwargs.pop("tools", None)
use_cache = kwargs.pop("use_cache", True)
+ enable_thinking = kwargs.pop("enable_thinking", True)
# Collect video inputs from messages
_msg_video_inputs = self._collect_video_inputs(messages)
@@ -1453,11 +1455,11 @@ def chat(
template_extra_kwargs["tools"] = tools
try:
- # Use get_chat_template directly since messages are already properly formatted
formatted_prompt = get_chat_template(
self.processor,
chat_messages,
add_generation_prompt=True,
+ enable_thinking=enable_thinking,
**template_extra_kwargs,
)
except Exception as e:
@@ -1724,6 +1726,7 @@ def stream_chat(
video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES)
tools = kwargs.pop("tools", None)
use_cache = kwargs.pop("use_cache", True)
+ enable_thinking = kwargs.pop("enable_thinking", True)
# Collect video inputs from messages
_msg_video_inputs = self._collect_video_inputs(messages)
@@ -1838,6 +1841,7 @@ def stream_chat(
self.processor,
chat_messages,
add_generation_prompt=True,
+ enable_thinking=enable_thinking,
**template_extra_kwargs,
)
except Exception as e:
diff --git a/vllm_mlx/patches/gemma4_mllm.py b/vllm_mlx/patches/gemma4_mllm.py
new file mode 100644
index 000000000..dc041cf31
--- /dev/null
+++ b/vllm_mlx/patches/gemma4_mllm.py
@@ -0,0 +1,121 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Runtime patch for mlx-vlm's Gemma 4 attention to support BatchKVCache.
+
+Gemma 4 Attention reads cache.offset into a local variable before calling
+update_and_fetch, then uses the same variable later for RoPE on queries:
+
+ offset = cache.offset # reference to mx.array([22])
+ keys = self.rope(keys, offset=offset)
+ keys, values = cache.update_and_fetch(keys, values)
+ # ^^^ self.offset += 1 mutates the SAME mx.array in-place!
+ queries = self.rope(queries, offset=offset) # offset is now 23!
+
+For KVCache, cache.offset is a Python int (immutable), so the local copy
+is unaffected. For BatchKVCache, cache.offset is an mx.array and
+mx.array.__iadd__ is *in-place*, so the local reference is silently
+mutated by update_and_fetch, giving queries the wrong RoPE position.
+
+This patch replaces Gemma4 Attention.__call__ with a version that
+snapshots cache.offset as a defensive copy before any mutation can occur.
+The mx.array copy preserves per-sequence offsets needed for correct RoPE
+in continuous batching (unlike int conversion which would lose this info).
+"""
+
+import logging
+from typing import Any, Optional
+
+import mlx.core as mx
+
+logger = logging.getLogger(__name__)
+
+
+def _snapshot_cache_offset(cache):
+ """Snapshot cache offset, making a defensive copy if it's an mx.array.
+
+ BatchKVCache stores offset as mx.array (per-batch-item).
+ mx.array.__iadd__ is in-place, so update_and_fetch mutates the original.
+ We return a copy to preserve the pre-update value for RoPE on queries.
+ """
+ if cache is None:
+ return 0
+ off = cache.offset
+ if isinstance(off, int):
+ return off
+ if isinstance(off, mx.array):
+ return off + 0 # defensive copy — new array, same values
+ return off
+
+
+def patch_gemma4_attention_for_batching() -> bool:
+ """Monkey-patch Gemma4 Attention.__call__ to snapshot offset before update.
+
+ Returns True if patch was applied, False if mlx-vlm is not installed
+ or Gemma 4 module not available.
+ """
+ try:
+ from mlx_vlm.models.gemma4.language import Attention as Gemma4Attention
+ from mlx_vlm.models.base import scaled_dot_product_attention
+ except ImportError:
+ logger.debug("[Gemma4 patch] mlx-vlm Gemma4 module not available")
+ return False
+
+ if getattr(Gemma4Attention, "_batch_patched", False):
+ logger.debug("[Gemma4 patch] Already patched")
+ return True
+
+ _orig_call = Gemma4Attention.__call__
+
+ def _patched_call(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[Any] = None,
+ ) -> mx.array:
+ B, L, _ = x.shape
+
+ queries = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim)
+ queries = self.q_norm(queries)
+
+ # Snapshot offset BEFORE update_and_fetch can mutate it in-place.
+ # Preserves per-sequence mx.array offsets for correct batched RoPE.
+ offset = _snapshot_cache_offset(cache)
+
+ if self.is_kv_shared_layer and cache is not None:
+ state = cache.state
+ keys, values = state[0], state[1]
+ else:
+ keys = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim)
+
+ if self.use_k_eq_v:
+ values = keys
+ else:
+ values = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim)
+
+ keys = self.k_norm(keys)
+ values = self.v_norm(values)
+ values = values.transpose(0, 2, 1, 3)
+
+ keys = keys.transpose(0, 2, 1, 3)
+ keys = self.rope(keys, offset=offset)
+
+ if cache is not None:
+ keys, values = cache.update_and_fetch(keys, values)
+
+ queries = queries.transpose(0, 2, 1, 3)
+ queries = self.rope(queries, offset=offset)
+
+ if mask is not None and isinstance(mask, mx.array):
+ if mask.shape[-1] != keys.shape[-2]:
+ mask = mask[..., -keys.shape[-2] :]
+
+ output = scaled_dot_product_attention(
+ queries, keys, values, cache=cache, scale=self.scale, mask=mask
+ )
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+ return self.o_proj(output)
+
+ Gemma4Attention.__call__ = _patched_call
+ Gemma4Attention._batch_patched = True
+ logger.info("[Gemma4 patch] Attention patched for BatchKVCache support")
+ return True
diff --git a/vllm_mlx/patches/qwen3_5_mllm.py b/vllm_mlx/patches/qwen3_5_mllm.py
new file mode 100644
index 000000000..c592928da
--- /dev/null
+++ b/vllm_mlx/patches/qwen3_5_mllm.py
@@ -0,0 +1,120 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Runtime patch for mlx-vlm's Qwen3.5 attention to support BatchKVCache.
+
+mlx-vlm's Qwen3_5Attention uses cache.offset directly for kv_seq_len
+computation and mask slicing. BatchKVCache stores offset as mx.array
+(per-batch-item), not int, causing:
+
+ mask = mask[..., :kv_seq_len]
+ ValueError: Slice indices must be integers or None.
+
+This patch replaces Qwen3_5Attention.__call__ with a version that
+converts cache.offset to int before using it for arithmetic/slicing,
+while leaving the actual cache.offset untouched so update_and_fetch
+still works correctly with per-batch offsets.
+"""
+
+import logging
+from typing import Optional
+
+import mlx.core as mx
+
+logger = logging.getLogger(__name__)
+
+
+def _cache_offset_to_int(cache) -> int:
+ """Extract cache offset as int, handling BatchKVCache mx.array offset."""
+ if cache is None:
+ return 0
+ off = cache.offset
+ if isinstance(off, int):
+ return off
+ if isinstance(off, mx.array):
+ return int(off.max().item()) if off.ndim > 0 else int(off.item())
+ return int(off)
+
+
+def patch_qwen35_attention_for_batching() -> bool:
+ """Monkey-patch Qwen3_5Attention.__call__ to handle BatchKVCache.
+
+ Returns True if patch was applied, False if mlx-vlm is not installed
+ or Qwen3.5 module not available.
+ """
+ try:
+ from mlx_vlm.models.qwen3_5.language import (
+ Qwen3_5Attention,
+ apply_multimodal_rotary_pos_emb,
+ )
+ from mlx_lm.models.base import scaled_dot_product_attention
+ except ImportError:
+ logger.debug("[Qwen3.5 patch] mlx-vlm Qwen3.5 module not available")
+ return False
+
+ if getattr(Qwen3_5Attention, "_batch_patched", False):
+ logger.debug("[Qwen3.5 patch] Already patched")
+ return True
+
+ def _patched_call(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache=None,
+ position_ids: Optional[mx.array] = None,
+ ) -> mx.array:
+ B, L, D = x.shape
+
+ q_proj_output = self.q_proj(x)
+ queries, gate = mx.split(
+ q_proj_output.reshape(B, L, self.num_attention_heads, -1),
+ 2,
+ axis=-1,
+ )
+ gate = gate.reshape(B, L, -1)
+
+ keys, values = self.k_proj(x), self.v_proj(x)
+
+ queries = self.q_norm(queries).transpose(0, 2, 1, 3)
+ keys = self.k_norm(keys.reshape(B, L, self.num_key_value_heads, -1)).transpose(
+ 0, 2, 1, 3
+ )
+ values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
+ 0, 2, 1, 3
+ )
+
+ kv_seq_len = keys.shape[-2]
+
+ # Convert cache.offset to int for slice compatibility.
+ # BatchKVCache stores offset as mx.array (per-batch-item),
+ # but kv_seq_len must be int for mask[..., :kv_seq_len].
+ _offset = _cache_offset_to_int(cache)
+
+ if position_ids is None:
+ kv_seq_len += _offset + 1
+ position_ids = mx.arange(_offset, _offset + L)
+ position_ids = mx.expand_dims(position_ids, axis=0)
+ position_ids = mx.tile(position_ids, (3, 1, 1))
+ else:
+ kv_seq_len += _offset + 1 if cache is not None else 0
+
+ cos, sin = self.rotary_emb(values, position_ids)
+
+ if mask is not None and isinstance(mask, mx.array):
+ mask = mask[..., :kv_seq_len]
+
+ queries, keys = apply_multimodal_rotary_pos_emb(queries, keys, cos, sin)
+
+ if cache is not None:
+ keys, values = cache.update_and_fetch(keys, values)
+
+ output = scaled_dot_product_attention(
+ queries, keys, values, cache=cache, scale=self.scale, mask=mask
+ )
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+
+ return self.o_proj(output * mx.sigmoid(gate))
+
+ Qwen3_5Attention.__call__ = _patched_call
+ Qwen3_5Attention._batch_patched = True
+ logger.info("[Qwen3.5 patch] Attention patched for BatchKVCache support")
+ return True
diff --git a/vllm_mlx/patches/qwen3_5_mtp.py b/vllm_mlx/patches/qwen3_5_mtp.py
new file mode 100644
index 000000000..3d5f3e632
--- /dev/null
+++ b/vllm_mlx/patches/qwen3_5_mtp.py
@@ -0,0 +1,399 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Runtime MTP (Multi-Token Prediction) support for Qwen3.5 models.
+
+Qwen3.5 models may include a built-in MTP head that predicts token n+2
+from hidden states + token n+1. MTP weights are added to the quantized
+MLX model via scripts/add_mtp_weights_qwen35.py.
+
+Since mlx_lm's qwen3_5.py does NOT define MTP module/methods, this
+module provides:
+ - inject_mtp_support(): dynamically creates MTP module, loads weights,
+ and monkey-patches the model class with return_hidden, mtp_forward,
+ and make_mtp_cache
+ - validate_mtp_support(): checks whether a loaded model has working MTP
+
+Supports both Dense (27B) and MoE (122B-A10B, 35B-A3B) architectures.
+
+The actual MTP scheduling logic lives in:
+ - vllm_mlx/scheduler.py (_install_mtp, _mtp_step, _mtp_next)
+"""
+
+import logging
+from pathlib import Path
+from typing import Any
+
+logger = logging.getLogger(__name__)
+
+
+def _fixup_moe_mtp(mtp, inner_model, loaded_keys: set, mx) -> None:
+ """Fix missing weights in MoE MTP module.
+
+ MoE MTP checkpoints (122B, 35B) only contain: fc, q_proj, o_proj,
+ shared_expert.*, and per-expert weights. Missing:
+ - k_proj, v_proj → zero out (attention becomes no-op)
+ - gate, shared_expert_gate → copy from main model's last full-attn layer
+ - norms → already at identity (weight=1.0), no action needed
+ """
+ import mlx.utils
+
+ mtp_layer = mtp.layers[0]
+
+ # Find last full-attention layer in main model for gate weights
+ last_fa_layer = None
+ for layer in reversed(inner_model.layers):
+ if not layer.is_linear:
+ last_fa_layer = layer
+ break
+
+ if last_fa_layer is None:
+ logger.warning("[MTP fixup] No full-attention layer found in main model")
+ return
+
+ # Copy expert routing gate if not in checkpoint
+ if "layers.0.mlp.gate.weight" not in loaded_keys:
+ src = getattr(last_fa_layer.mlp, "gate", None)
+ dst = getattr(mtp_layer.mlp, "gate", None)
+ if src is not None and dst is not None:
+ src_params = mlx.utils.tree_flatten(src.parameters())
+ dst.load_weights(src_params)
+ mx.eval(dst.parameters())
+ logger.info("[MTP fixup] Copied mlp.gate from main model last layer")
+
+ # Copy shared_expert_gate if not in checkpoint
+ if "layers.0.mlp.shared_expert_gate.weight" not in loaded_keys:
+ src = getattr(last_fa_layer.mlp, "shared_expert_gate", None)
+ dst = getattr(mtp_layer.mlp, "shared_expert_gate", None)
+ if src is not None and dst is not None:
+ src_params = mlx.utils.tree_flatten(src.parameters())
+ dst.load_weights(src_params)
+ mx.eval(dst.parameters())
+ logger.info(
+ "[MTP fixup] Copied shared_expert_gate from main model last layer"
+ )
+
+ # Zero out k_proj and v_proj → attention becomes no-op
+ attn = getattr(mtp_layer, "self_attn", None)
+ if attn is None:
+ return
+
+ for proj_name in ("k_proj", "v_proj"):
+ key = f"layers.0.self_attn.{proj_name}.weight"
+ if key not in loaded_keys:
+ proj = getattr(attn, proj_name, None)
+ if proj is None:
+ continue
+ # For quantized layers: zero scales+biases → dequantized = 0
+ if hasattr(proj, "scales"):
+ proj.scales = mx.zeros_like(proj.scales)
+ proj.biases = mx.zeros_like(proj.biases)
+ else:
+ proj.weight = mx.zeros_like(proj.weight)
+ mx.eval(proj.parameters())
+ logger.info(f"[MTP fixup] Zeroed {proj_name} (not in checkpoint)")
+
+
+def inject_mtp_support(model: Any, model_path, config: dict) -> bool:
+ """Inject MTP module into a loaded Qwen3.5 model.
+
+ mlx_lm's qwen3_5.py does not define MTP layers, so we:
+ 1. Create MTP module matching the weight structure
+ 2. Quantize it to match the base model
+ 3. Load MTP weights from model-mtp.safetensors
+ 4. Monkey-patch Model with return_hidden, mtp_forward, make_mtp_cache
+
+ Args:
+ model: A model loaded via mlx_lm (strict=False, MTP weights ignored)
+ model_path: Path to model directory (contains model-mtp.safetensors)
+ config: Parsed config.json dict
+
+ Returns:
+ True if MTP was successfully injected, False otherwise.
+ """
+ import mlx.core as mx
+ import mlx.nn as nn
+
+ # Navigate nested config: text_config for VLM wrappers
+ text_config = config.get("text_config", config)
+ num_mtp_layers = text_config.get("mtp_num_hidden_layers", 0)
+ if num_mtp_layers == 0:
+ # Fallback: check flat config for num_nextn_predict_layers
+ num_mtp_layers = text_config.get(
+ "num_nextn_predict_layers",
+ config.get("num_nextn_predict_layers", 0),
+ )
+ if num_mtp_layers == 0:
+ logger.info("[MTP inject] No MTP layers configured, skipping")
+ return False
+
+ model_path = Path(model_path)
+ # Look for MTP weights in mtp/ subdirectory first (avoids mlx_vlm glob),
+ # then fall back to model-mtp.safetensors in model dir.
+ mtp_file = model_path / "mtp" / "weights.safetensors"
+ if not mtp_file.exists():
+ mtp_file = model_path / "model-mtp.safetensors"
+ if not mtp_file.exists():
+ logger.warning(f"[MTP inject] MTP weights not found in {model_path}")
+ return False
+
+ # Get model args — navigate VLM wrapper if needed
+ # Model hierarchy: Model → language_model (TextModel) → model (Qwen3_5TextModel)
+ text_model = model
+ if hasattr(model, "language_model"):
+ text_model = model.language_model
+
+ args = text_model.args
+
+ # When loaded via mlx_vlm, args may be a TextConfig object missing fields
+ # that mlx_lm's TextModelArgs defines (rope_theta, partial_rotary_factor,
+ # rope_scaling, etc.). Build a proper TextModelArgs from the config dict.
+ from mlx_lm.models.qwen3_5 import TextModelArgs
+
+ if not isinstance(args, TextModelArgs):
+ logger.info("[MTP inject] Building TextModelArgs from config dict")
+ args = TextModelArgs.from_dict(text_config)
+
+ # Detect MoE vs Dense from args
+ num_experts = getattr(args, "num_experts", 0)
+ is_moe = num_experts > 0
+
+ # Import model components
+ from mlx_lm.models.base import create_attention_mask, create_ssm_mask
+ from mlx_lm.models.cache import KVCache
+ from mlx_lm.models.qwen3_5 import DecoderLayer
+
+ logger.info(
+ f"[MTP inject] Creating MTP module ({num_mtp_layers} layers, "
+ f"{'MoE' if is_moe else 'Dense'})"
+ )
+
+ # MTP decoder uses full attention (not GatedDeltaNet).
+ # layer_idx = full_attention_interval - 1 ensures is_linear=False.
+ fa_idx = args.full_attention_interval - 1
+
+ class _MTPModule(nn.Module):
+ def __init__(self, args, n_layers):
+ super().__init__()
+ self.pre_fc_norm_hidden = nn.RMSNorm(
+ args.hidden_size, eps=args.rms_norm_eps
+ )
+ self.pre_fc_norm_embedding = nn.RMSNorm(
+ args.hidden_size, eps=args.rms_norm_eps
+ )
+ self.fc = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False)
+ self.layers = [
+ DecoderLayer(args, layer_idx=fa_idx) for _ in range(n_layers)
+ ]
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+ mtp = _MTPModule(args, num_mtp_layers)
+
+ # --- Load MTP weights in BF16 (no quantization) ---
+ # MTP head is extremely sensitive to quantization — even 4-bit destroys
+ # prediction quality (0% acceptance). Keep MTP in full precision.
+ # See: https://github.com/vllm-project/vllm/issues/36331
+ quant_config = text_config.get("quantization", config.get("quantization", {}))
+ bits = quant_config.get("bits", 4) if quant_config else 4
+ group_size = quant_config.get("group_size", 64) if quant_config else 64
+
+ logger.info(
+ f"[MTP inject] Loading weights from {mtp_file.name} (BF16, no quantization)"
+ )
+ raw = mx.load(str(mtp_file))
+ raw_mtp = {
+ k.removeprefix("mtp."): v for k, v in raw.items() if k.startswith("mtp.")
+ }
+ del raw
+
+ # Dequantize any quantized weight triplets (weight + scales + biases)
+ mtp_weights: dict[str, mx.array] = {}
+ processed = set()
+ for key in sorted(raw_mtp.keys()):
+ if key in processed:
+ continue
+ if key.endswith(".scales") or key.endswith(".biases"):
+ continue
+
+ scales_key = key.replace(".weight", ".scales")
+ biases_key = key.replace(".weight", ".biases")
+
+ if scales_key in raw_mtp and biases_key in raw_mtp:
+ # Quantized triplet → dequantize to BF16
+ dq = mx.dequantize(
+ raw_mtp[key],
+ raw_mtp[scales_key],
+ raw_mtp[biases_key],
+ group_size=group_size,
+ bits=bits,
+ )
+ mtp_weights[key] = dq
+ processed.update([key, scales_key, biases_key])
+ else:
+ # Already FP (norms, fc, shared_expert_gate)
+ mtp_weights[key] = raw_mtp[key]
+ processed.add(key)
+ del raw_mtp
+
+ mtp.load_weights(list(mtp_weights.items()), strict=False)
+ mx.eval(mtp.parameters())
+
+ dq_count = sum(1 for k in mtp_weights if not k.endswith((".scales", ".biases")))
+ has_quantized = any(k.endswith(".scales") for k in processed)
+ mode = "dequantized from quantized" if has_quantized else "native BF16"
+ logger.info(f"[MTP inject] Loaded {dq_count} MTP weight tensors ({mode})")
+
+ # --- Step 4: Fix missing MoE MTP weights ---
+ # MoE checkpoints lack: k_proj, v_proj, gate, shared_expert_gate, norms.
+ # Norms default to identity (weight=1.0) which is correct.
+ # k_proj/v_proj: zero out → attention becomes no-op, MLP does prediction.
+ # gate/shared_expert_gate: copy from main model's last full-attention layer.
+ if is_moe:
+ loaded_key_set = set(mtp_weights.keys())
+ _fixup_moe_mtp(mtp, text_model.model, loaded_key_set, mx)
+
+ # --- Attach MTP and monkey-patch model class ---
+ text_model.mtp = mtp
+
+ original_class = text_model.__class__
+
+ class _Qwen3_5MTP(original_class):
+ """Qwen3.5 with MTP support (injected at runtime)."""
+
+ def __call__(
+ self,
+ inputs,
+ cache=None,
+ return_hidden: bool = False,
+ input_embeddings=None,
+ **kwargs,
+ ):
+ inner = self.model
+ if input_embeddings is not None:
+ hidden_states = input_embeddings
+ else:
+ hidden_states = inner.embed_tokens(inputs)
+
+ if cache is None:
+ cache = [None] * len(inner.layers)
+
+ fa_mask = create_attention_mask(hidden_states, cache[inner.fa_idx])
+ ssm_mask = create_ssm_mask(hidden_states, cache[inner.ssm_idx])
+
+ for layer, c in zip(inner.layers, cache):
+ mask = ssm_mask if layer.is_linear else fa_mask
+ hidden_states = layer(hidden_states, mask=mask, cache=c)
+
+ normed = inner.norm(hidden_states)
+
+ if self.args.tie_word_embeddings:
+ out = inner.embed_tokens.as_linear(normed)
+ else:
+ out = self.lm_head(normed)
+
+ if return_hidden:
+ return out, normed # post-norm hidden states (MTP expects post-norm)
+ return out
+
+ def mtp_forward(
+ self,
+ hidden_states,
+ next_token_ids,
+ cache=None,
+ mtp_cache=None,
+ ):
+ """Run MTP head: predict token n+2 from hidden states + token n+1."""
+ input_embeds = self.model.embed_tokens(next_token_ids)
+ e = self.mtp.pre_fc_norm_embedding(input_embeds)
+ h = self.mtp.pre_fc_norm_hidden(hidden_states)
+ x = self.mtp.fc(mx.concatenate([e, h], axis=-1))
+
+ layer = self.mtp.layers[0]
+ c = mtp_cache[0] if mtp_cache else None
+ mask = create_attention_mask(x, c)
+ x = layer(x, mask=mask, cache=c)
+
+ x = self.mtp.norm(x)
+
+ if self.args.tie_word_embeddings:
+ return self.model.embed_tokens.as_linear(x)
+ return self.lm_head(x)
+
+ def make_mtp_cache(self):
+ """Create KV cache for MTP layers."""
+ if self.mtp is None:
+ return None
+ return [KVCache() for _ in self.mtp.layers]
+
+ text_model.__class__ = _Qwen3_5MTP
+ logger.info("[MTP inject] Model class patched with MTP support")
+
+ # If we patched the inner language_model, also expose MTP on the outer Model
+ if hasattr(model, "language_model") and model.language_model is text_model:
+ model.mtp = mtp
+
+ return True
+
+
+def validate_mtp_support(model: Any) -> bool:
+ """Validate that a loaded model has working MTP support.
+
+ Checks:
+ 1. model.mtp exists and is not None
+ 2. model.mtp has layers with loaded weights
+ 3. model has return_hidden support in __call__
+ 4. model has mtp_forward method
+ 5. model has make_mtp_cache method
+
+ Args:
+ model: A model loaded via mlx_lm.load()
+
+ Returns:
+ True if MTP is fully functional, False otherwise.
+ """
+ # Navigate to text model if VLM wrapper
+ text_model = model
+ if hasattr(model, "language_model"):
+ text_model = model.language_model
+
+ mtp = getattr(text_model, "mtp", None)
+ if mtp is None:
+ args = getattr(text_model, "args", None)
+ if args is not None:
+ num_mtp = getattr(args, "mtp_num_hidden_layers", 0)
+ if num_mtp == 0:
+ num_mtp = getattr(args, "num_nextn_predict_layers", 0)
+ if num_mtp > 0:
+ logger.warning(
+ "[MTP] Model config has MTP layers=%d but model.mtp is None. "
+ "Run scripts/add_mtp_weights_qwen35.py to add weights.",
+ num_mtp,
+ )
+ return False
+
+ mtp_layers = getattr(mtp, "layers", [])
+ if not mtp_layers:
+ logger.warning("[MTP] model.mtp exists but has no layers.")
+ return False
+
+ import inspect
+
+ call_sig = inspect.signature(type(text_model).__call__)
+ if "return_hidden" not in call_sig.parameters:
+ logger.warning("[MTP] Model.__call__ does not accept return_hidden parameter.")
+ return False
+
+ if not hasattr(text_model, "mtp_forward") or not callable(text_model.mtp_forward):
+ logger.warning("[MTP] Model does not have mtp_forward() method.")
+ return False
+
+ if not hasattr(text_model, "make_mtp_cache") or not callable(
+ text_model.make_mtp_cache
+ ):
+ logger.warning("[MTP] Model does not have make_mtp_cache() method.")
+ return False
+
+ logger.info(
+ "[MTP] Qwen3.5 model has working MTP support: %d MTP layer(s)",
+ len(mtp_layers),
+ )
+ return True
diff --git a/vllm_mlx/prefix_cache.py b/vllm_mlx/prefix_cache.py
index e8f47a324..a419f3973 100644
--- a/vllm_mlx/prefix_cache.py
+++ b/vllm_mlx/prefix_cache.py
@@ -586,7 +586,7 @@ def store_cache(
# Extract and store actual tensor slices for this block
if is_tensor_data and HAS_MLX:
block_kv_data = self._extract_block_tensor_slice(
- cache_data, global_start, global_end
+ cache_data, global_start, global_end, len(tokens)
)
if block_kv_data:
block.cache_data = block_kv_data
@@ -629,56 +629,122 @@ def _extract_block_tensor_slice(
cache_data: List[Dict[str, Any]],
start_idx: int,
end_idx: int,
- ) -> Optional[List[Tuple[Any, Any]]]:
+ total_tokens: int,
+ ) -> Optional[List[Optional[Dict[str, Any]]]]:
"""
- Extract tensor slices for a single block from cache data.
+ Extract per-layer cache data for a single block.
Args:
- cache_data: List of layer states, each containing 'state': (keys, values)
+ cache_data: List of extracted layer states
start_idx: Start token index in the sequence
end_idx: End token index in the sequence
+ total_tokens: Total number of tokens covered by cache_data
Returns:
- List of (keys_slice, values_slice) for each layer, or None on failure
+ Per-layer block cache state, or None on failure
"""
if not HAS_MLX or not cache_data:
return None
try:
- block_slices = []
+ block_slices: List[Optional[Dict[str, Any]]] = []
for layer_state in cache_data:
if "state" not in layer_state:
+ block_slices.append(None)
continue
- keys, values = layer_state["state"]
+ state = layer_state["state"]
+ meta_state = layer_state.get("meta_state")
+ class_ref = layer_state.get("class_ref")
+ class_name = layer_state.get("class_name")
- # KV cache shape: (batch, n_kv_heads, seq_len, head_dim)
- # Slice along seq_len dimension (axis 2)
- seq_len = keys.shape[2] if hasattr(keys, "shape") else 0
+ if self._can_concatenate_cache_state(state):
+ state_slice = self._slice_concat_cache_state(
+ state, start_idx, end_idx
+ )
+ block_slices.append(
+ {
+ "state": state_slice,
+ "meta_state": meta_state,
+ "class_ref": class_ref,
+ "class_name": class_name,
+ "storage": "concat",
+ "seq_axis": 2,
+ }
+ )
+ continue
- if end_idx > seq_len:
- # Requested range extends beyond available data
- logger.debug(
- f"Block slice [{start_idx}:{end_idx}] exceeds seq_len {seq_len}"
+ if end_idx == total_tokens:
+ block_slices.append(
+ {
+ "state": state,
+ "meta_state": meta_state,
+ "class_ref": class_ref,
+ "class_name": class_name,
+ "storage": "latest",
+ }
)
- # Use whatever is available
- actual_end = min(end_idx, seq_len)
- if start_idx >= actual_end:
- continue
- keys_slice = keys[:, :, start_idx:actual_end, :]
- values_slice = values[:, :, start_idx:actual_end, :]
else:
- keys_slice = keys[:, :, start_idx:end_idx, :]
- values_slice = values[:, :, start_idx:end_idx, :]
+ block_slices.append(None)
- block_slices.append((keys_slice, values_slice))
-
- return block_slices if block_slices else None
+ return (
+ block_slices
+ if any(entry is not None for entry in block_slices)
+ else None
+ )
except Exception as e:
logger.warning(f"Failed to extract block tensor slice: {e}")
return None
+ def _can_concatenate_cache_state(self, state: Any) -> bool:
+ """Return True when cache state can be concatenated block-by-block."""
+ if not isinstance(state, (list, tuple)) or not state:
+ return False
+ return all(
+ tensor is not None and hasattr(tensor, "shape") and len(tensor.shape) == 4
+ for tensor in state
+ )
+
+ def _slice_concat_cache_state(
+ self,
+ state: Tuple[Any, ...] | List[Any],
+ start_idx: int,
+ end_idx: int,
+ ) -> Tuple[Any, ...] | List[Any]:
+ """Slice a sequence-backed cache state across the token axis."""
+ seq_len = state[0].shape[2]
+ actual_end = min(end_idx, seq_len)
+ if start_idx >= actual_end:
+ raise ValueError(
+ f"Block slice [{start_idx}:{end_idx}] exceeds seq_len {seq_len}"
+ )
+
+ def _slice_tensor(tensor: Any) -> Any:
+ slices = [slice(None)] * len(tensor.shape)
+ slices[2] = slice(start_idx, actual_end)
+ return tensor[tuple(slices)]
+
+ sliced = [_slice_tensor(tensor) for tensor in state]
+ return tuple(sliced) if isinstance(state, tuple) else sliced
+
+ def _concat_cache_states(
+ self,
+ states: List[Tuple[Any, ...] | List[Any]],
+ seq_axis: int,
+ ) -> Optional[Tuple[Any, ...] | List[Any]]:
+ """Concatenate state fragments for a sequence-backed cache layer."""
+ if not states:
+ return None
+ arity = len(states[0])
+ concatenated = []
+ for idx in range(arity):
+ parts = [state[idx] for state in states]
+ if any(part is None for part in parts):
+ return None
+ concatenated.append(mx.concatenate(parts, axis=seq_axis))
+ return tuple(concatenated) if isinstance(states[0], tuple) else concatenated
+
def get_cache_for_generation(
self,
request_id: str,
@@ -763,10 +829,11 @@ def reconstruct_cache(
block_table: BlockTable,
) -> Optional[List[Any]]:
"""
- Reconstruct KVCache objects from stored block tensor data.
+ Reconstruct cache objects from stored block tensor data.
- This method concatenates tensor slices from all blocks and
- creates new KVCache objects that can be used for inference.
+ Sequence-backed caches are concatenated block-by-block. Recurrent
+ caches such as ArraysCache are restored from the latest sequence
+ boundary snapshot that was actually stored.
Args:
block_table: BlockTable containing block IDs to reconstruct from
@@ -800,67 +867,62 @@ def reconstruct_cache(
if not all_block_data:
return None
- # Get number of layers from first block
- num_layers = len(all_block_data[0])
+ # Get number of layers from the richest block
+ num_layers = max(len(block_data) for block_data in all_block_data)
if num_layers == 0:
return None
- # Concatenate tensors for each layer
reconstructed_caches = []
-
for layer_idx in range(num_layers):
- layer_keys = []
- layer_values = []
+ layer_entries = [
+ block_data[layer_idx]
+ for block_data in all_block_data
+ if layer_idx < len(block_data)
+ ]
+ layer_entries = [entry for entry in layer_entries if entry is not None]
+ if not layer_entries:
+ return None
- for block_data in all_block_data:
- if layer_idx < len(block_data):
- keys_slice, values_slice = block_data[layer_idx]
- layer_keys.append(keys_slice)
- layer_values.append(values_slice)
+ layer_meta = layer_entries[-1]
+ state = layer_meta["state"]
+ if layer_meta["storage"] == "concat":
+ state = self._concat_cache_states(
+ [entry["state"] for entry in layer_entries],
+ layer_meta["seq_axis"],
+ )
+ elif layer_meta["storage"] == "latest":
+ state = layer_entries[-1]["state"]
- if not layer_keys:
- continue
+ if state is None:
+ return None
- # Concatenate along sequence dimension (axis 2)
- # Shape: (batch, n_kv_heads, seq_len, head_dim)
- concat_keys = mx.concatenate(layer_keys, axis=2)
- concat_values = mx.concatenate(layer_values, axis=2)
+ cache_cls = layer_meta.get("class_ref")
+ meta_state = layer_meta.get("meta_state")
- # Create KVCache object
- # Try to use mlx_lm's KVCache.from_state if available
- try:
+ if cache_cls is not None and hasattr(cache_cls, "from_state"):
+ from mlx_lm.models.cache import (
+ BatchKVCache as _BatchKVCache,
+ KVCache as _KVCache,
+ )
+
+ if cache_cls is _BatchKVCache:
+ keys, values = state[0], state[1]
+ cache = _KVCache()
+ cache.keys = keys
+ cache.values = values
+ cache.offset = keys.shape[2]
+ else:
+ cache = cache_cls.from_state(state, meta_state)
+ else:
from mlx_lm.models.cache import KVCache
- # Create new cache and set its state
+ if len(state) != 2:
+ return None
cache = KVCache()
- seq_len = concat_keys.shape[2]
-
- # Set internal state directly
- # KVCache stores keys/values and offset
- cache.keys = concat_keys
- cache.values = concat_values
- cache.offset = seq_len
-
- reconstructed_caches.append(cache)
-
- except ImportError:
- # Fallback: create a simple cache-like object
- class SimpleKVCache:
- def __init__(self, keys, values):
- self.keys = keys
- self.values = values
- self.offset = keys.shape[2]
-
- @property
- def state(self):
- return (self.keys, self.values)
-
- @property
- def meta_state(self):
- return (str(self.offset),)
-
- cache = SimpleKVCache(concat_keys, concat_values)
- reconstructed_caches.append(cache)
+ cache.keys, cache.values = state
+ cache.offset = cache.keys.shape[2]
+
+ reconstructed_caches.append(cache)
if not reconstructed_caches:
return None
diff --git a/vllm_mlx/reasoning/__init__.py b/vllm_mlx/reasoning/__init__.py
index f138796ff..49d13a26b 100644
--- a/vllm_mlx/reasoning/__init__.py
+++ b/vllm_mlx/reasoning/__init__.py
@@ -76,6 +76,7 @@ def list_parsers() -> list[str]:
def _register_builtin_parsers():
"""Register built-in parsers."""
from .deepseek_r1_parser import DeepSeekR1ReasoningParser
+ from .gemma4_parser import Gemma4ReasoningParser
from .gpt_oss_parser import GptOssReasoningParser
from .harmony_parser import HarmonyReasoningParser
from .qwen3_parser import Qwen3ReasoningParser
@@ -84,6 +85,7 @@ def _register_builtin_parsers():
register_parser("deepseek_r1", DeepSeekR1ReasoningParser)
register_parser("gpt_oss", GptOssReasoningParser)
register_parser("harmony", HarmonyReasoningParser)
+ register_parser("gemma4", Gemma4ReasoningParser)
# Register built-in parsers on module load
diff --git a/vllm_mlx/reasoning/gemma4_parser.py b/vllm_mlx/reasoning/gemma4_parser.py
new file mode 100644
index 000000000..8b6dd8149
--- /dev/null
+++ b/vllm_mlx/reasoning/gemma4_parser.py
@@ -0,0 +1,170 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Reasoning parser for Gemma 4 models.
+
+Gemma 4 uses a channel-based protocol for reasoning:
+
+ <|channel>thought
+ ...thinking content...
+
+ ...response content...
+
+Where:
+ <|channel> = token 100 (channel switch marker)
+ = token 101 (end-of-channel marker)
+
+The channel names "thought" and "response" appear as text after the
+special tokens and should be stripped from the output.
+
+Some model variants may use <|channel>response instead of
+to transition from thinking to response mode. This parser handles both.
+
+When thinking is disabled or not triggered, output contains no tags.
+"""
+
+from .base import DeltaMessage
+from .think_parser import BaseThinkingReasoningParser
+
+# Channel names that follow <|channel> — stripped from output
+_THOUGHT_PREFIX = "thought"
+_RESPONSE_MARKER = "<|channel>response"
+
+
+def _strip_channel_name(text: str, prefix: str) -> str:
+ """Strip channel name and leading whitespace/newline from text start."""
+ if text.startswith(prefix):
+ text = text[len(prefix) :]
+ return text.lstrip("\n")
+
+
+class Gemma4ReasoningParser(BaseThinkingReasoningParser):
+ """
+ Reasoning parser for Gemma 4 models.
+
+ Handles two transition formats:
+ 1. <|channel>thought...response (standard: token 100 + 101)
+ 2. <|channel>thought...<|channel>response (alternative: token 100 + 100)
+
+ Channel names ("thought", "response") are stripped from output.
+
+ Example:
+ Input: "<|channel>thought\\nLet me think...The answer is 42."
+ Output: reasoning="Let me think...", content="The answer is 42."
+
+ When no tags are present, the entire output is treated as content.
+ """
+
+ @property
+ def start_token(self) -> str:
+ return "<|channel>"
+
+ @property
+ def end_token(self) -> str:
+ return ""
+
+ def extract_reasoning(
+ self,
+ model_output: str,
+ ) -> tuple[str | None, str | None]:
+ """
+ Extract reasoning from complete output.
+
+ Handles both and <|channel>response as transition markers.
+ Strips channel names ("thought", "response") from output.
+ """
+ text = model_output
+
+ # Try standard format first: <|channel>thought...response
+ if self.start_token in text and self.end_token in text:
+ _, _, after_start = text.partition(self.start_token)
+ reasoning, _, content = after_start.partition(self.end_token)
+ reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX)
+ content = content.strip()
+ return reasoning or None, content or None
+
+ # Try alternative format: <|channel>thought...<|channel>response...
+ if text.count(self.start_token) >= 2 and _RESPONSE_MARKER in text:
+ _, _, after_start = text.partition(self.start_token)
+ reasoning, _, content = after_start.partition(_RESPONSE_MARKER)
+ reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX)
+ content = content.lstrip("\n").strip()
+ return reasoning or None, content or None
+
+ # Only closing tag (think injected in prompt)
+ if self.end_token in text:
+ reasoning, _, content = text.partition(self.end_token)
+ reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX)
+ content = content.strip()
+ return reasoning or None, content or None
+
+ # Only start tag (incomplete reasoning, no end yet)
+ if self.start_token in text:
+ _, _, reasoning = text.partition(self.start_token)
+ reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX)
+ return reasoning or None, None
+
+ # No tags at all — pure content
+ return None, model_output
+
+ def extract_reasoning_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ ) -> DeltaMessage | None:
+ """
+ Extract reasoning from streaming delta.
+
+ Handles:
+ - No tags: treat as content (Gemma 4 doesn't inject tags in prompt)
+ - <|channel>thought: enter reasoning mode, strip channel name
+ - or <|channel>response: transition to content mode
+ """
+ # No channel tokens at all — plain content
+ if self.start_token not in current_text and self.end_token not in current_text:
+ return DeltaMessage(content=delta_text)
+
+ # Check for alternative transition: <|channel>response
+ if _RESPONSE_MARKER in current_text:
+ if _RESPONSE_MARKER not in previous_text:
+ # Transition happening in this delta
+ # Find what (if any) content comes after the marker
+ marker_pos = current_text.find(_RESPONSE_MARKER)
+ after_marker = current_text[marker_pos + len(_RESPONSE_MARKER) :]
+ after_marker = after_marker.lstrip("\n")
+ if after_marker:
+ return DeltaMessage(content=after_marker)
+ return None # Suppress the marker itself
+ else:
+ # Already past transition — pure content
+ # But we need to only emit the NEW text (delta)
+ return DeltaMessage(content=delta_text)
+
+ # Delegate to base class for standard <|channel>/ handling
+ result = super().extract_reasoning_streaming(
+ previous_text, current_text, delta_text
+ )
+
+ # Strip "thought" channel name from initial reasoning
+ if result is not None and result.reasoning is not None:
+ r = result.reasoning
+ # First reasoning delta after <|channel> will be "thought" or "thought\n"
+ if self.start_token in current_text:
+ # Check if this is the very first reasoning content
+ after_channel = current_text.split(self.start_token, 1)[1]
+ if after_channel.startswith(_THOUGHT_PREFIX):
+ # Remove "thought" prefix from the accumulated reasoning so far
+ clean = after_channel[len(_THOUGHT_PREFIX) :].lstrip("\n")
+ # Compute what portion of clean text is in this delta
+ prev_after = ""
+ if self.start_token in previous_text:
+ prev_after = previous_text.split(self.start_token, 1)[1]
+ if prev_after.startswith(_THOUGHT_PREFIX):
+ prev_after = prev_after[len(_THOUGHT_PREFIX) :].lstrip("\n")
+ # The new reasoning text is clean minus what was already emitted
+ new_reasoning = clean[len(prev_after) :]
+ if new_reasoning:
+ return DeltaMessage(reasoning=new_reasoning)
+ return None # Suppress channel name token
+
+ return result
diff --git a/vllm_mlx/reasoning/think_parser.py b/vllm_mlx/reasoning/think_parser.py
index 136348206..a2e9cb727 100644
--- a/vllm_mlx/reasoning/think_parser.py
+++ b/vllm_mlx/reasoning/think_parser.py
@@ -9,6 +9,12 @@
1. Both tags in output: reasoningcontent
2. Only closing tag (think injected in prompt): reasoningcontent
3. No tags: pure content
+
+Performance: The streaming parser uses a simple state machine to track the
+current phase (pre-think / thinking / content). Tag completion is detected
+against the accumulated text for correctness when `` / `` are
+split across delta boundaries, but phase tracking still avoids the old
+whole-output rescanning behavior.
"""
from abc import abstractmethod
@@ -27,8 +33,12 @@ class BaseThinkingReasoningParser(ReasoningParser):
and only appears in the model output. This is common with AI agents
like OpenCode that force models to reason by injecting thinking tags.
- The parser tracks state during streaming to correctly separate reasoning
- from content as tokens arrive incrementally.
+ The streaming parser uses a state machine with three phases:
+
+ pre_think -> thinking -> content
+
+ Transitions are tracked by parser state. Accumulated text is consulted only
+ to detect when a start/end tag has completed across delta boundaries.
"""
@property
@@ -43,6 +53,12 @@ def end_token(self) -> str:
def __init__(self, tokenizer=None):
super().__init__(tokenizer)
+ # Streaming state — reset per request via reset_state()
+ self._phase: str = "pre_think" # "pre_think" | "thinking" | "content"
+
+ def reset_state(self):
+ """Reset state machine for a new streaming request."""
+ self._phase = "pre_think"
def extract_reasoning(
self,
@@ -66,14 +82,11 @@ def extract_reasoning(
# Case 1: Both tags present (normal case)
if self.start_token in text and self.end_token in text:
- # Get everything after start token
_, _, after_start = text.partition(self.start_token)
- # Split on end token
reasoning, _, content = after_start.partition(self.end_token)
return reasoning.strip() or None, content.strip() or None
# Case 2: Only closing tag (think was injected in prompt)
- # Everything before is reasoning
if self.end_token in text:
reasoning, _, content = text.partition(self.end_token)
return reasoning.strip() or None, content.strip() or None
@@ -83,7 +96,7 @@ def extract_reasoning(
_, _, reasoning = text.partition(self.start_token)
return reasoning.strip() or None, None
- # Case 4: No tags at all - pure content
+ # Case 4: No tags at all — pure content
return None, model_output
def extract_reasoning_streaming(
@@ -93,123 +106,99 @@ def extract_reasoning_streaming(
delta_text: str,
) -> DeltaMessage | None:
"""
- Extract reasoning from streaming delta using text-based detection.
+ Extract reasoning from a streaming delta using state-machine tracking.
+
+ Instead of rescanning the full accumulated text on every token, this
+ method tracks the current phase (pre_think / thinking / content) and
+ only consults accumulated text to detect completed start/end tags that
+ were split across delta boundaries.
- Handles implicit reasoning mode where was in the prompt
- and only appears in the output.
+ Handles three scenarios:
+ 1. Explicit ... in model output
+ 2. Implicit mode ( in prompt, only in output)
+ 3. No tags at all (pure content after first token with no reasoning)
Args:
previous_text: Text accumulated before this delta.
current_text: Text including this delta.
- delta_text: Just the new text.
+ delta_text: Just the new text in this chunk.
Returns:
- DeltaMessage with reasoning/content, or None to skip.
+ DeltaMessage with reasoning and/or content, or None to skip.
"""
- # Skip if delta is just the special tokens themselves
- stripped_delta = delta_text.strip()
- if stripped_delta == self.start_token:
- return None
- if stripped_delta == self.end_token:
+ if not delta_text:
return None
- # Check token positions in text (stateless text-based detection)
- start_in_prev = self.start_token in previous_text
- start_in_current = self.start_token in current_text
- end_in_prev = self.end_token in previous_text
- end_in_delta = self.end_token in delta_text
-
- # Case 1: Explicit found in text - standard behavior
- if start_in_current:
- return self._handle_explicit_think(
- previous_text, delta_text, start_in_prev, end_in_prev, end_in_delta
- )
-
- # Case 2: No but found - implicit reasoning mode
- # This handles when was injected in the prompt
- if self.end_token in current_text:
- return self._handle_implicit_think(delta_text, end_in_prev, end_in_delta)
-
- # Case 3: No think tags seen yet
- # We can't know if was in the prompt, so we must make a choice:
- # - Treat as content (safe, but loses reasoning if think was in prompt)
- # - Treat as reasoning (risky, wrong if no thinking at all)
- # We choose to treat as reasoning IF we haven't seen yet,
- # because if think was in prompt, we want to capture the reasoning.
- # This will be corrected once is seen.
- return DeltaMessage(reasoning=delta_text)
-
- def _handle_explicit_think(
- self,
- previous_text: str,
- delta_text: str,
- start_in_prev: bool,
- end_in_prev: bool,
- end_in_delta: bool,
- ) -> DeltaMessage | None:
- """Handle case where tag is explicitly in the output."""
- start_in_delta = self.start_token in delta_text
-
- if start_in_prev:
- # We're after the start token
- if end_in_delta:
- # Transition: end token in this delta
- idx = delta_text.find(self.end_token)
- reasoning_part = delta_text[:idx]
- content_part = delta_text[idx + len(self.end_token) :]
+ start_tok = self.start_token
+ end_tok = self.end_token
+
+ # ── Phase: pre_think ──────────────────────────────────────
+ # Haven't seen a completed tag yet. Could be:
+ # - About to see (explicit reasoning)
+ # - Already inside implicit reasoning (think was in prompt)
+ # - No reasoning at all (pure content model)
+ if self._phase == "pre_think":
+ if start_tok in current_text:
+ self._phase = "thinking"
+ idx = delta_text.find(start_tok)
+ after = delta_text[idx + len(start_tok) :] if idx >= 0 else delta_text
+
+ if end_tok in after:
+ self._phase = "content"
+ eidx = after.find(end_tok)
+ reasoning = after[:eidx]
+ content = after[eidx + len(end_tok) :]
+ if not reasoning and not content:
+ return None
+ return DeltaMessage(
+ reasoning=reasoning or None,
+ content=content or None,
+ )
+ return DeltaMessage(reasoning=after) if after else None
+
+ # Implicit mode: completed without an explicit .
+ if end_tok in current_text:
+ self._phase = "content"
+ idx = delta_text.find(end_tok)
+ if idx >= 0:
+ reasoning = delta_text[:idx]
+ content = delta_text[idx + len(end_tok) :]
+ else:
+ reasoning = None
+ content = delta_text
+ if not reasoning and not content:
+ return None
return DeltaMessage(
- reasoning=reasoning_part if reasoning_part else None,
- content=content_part if content_part else None,
+ reasoning=reasoning or None,
+ content=content or None,
)
- elif end_in_prev:
- # Already past reasoning phase - pure content
- return DeltaMessage(content=delta_text)
- else:
- # Still in reasoning phase
- return DeltaMessage(reasoning=delta_text)
-
- elif start_in_delta:
- # Start token is in this delta
- start_idx = delta_text.find(self.start_token)
-
- if end_in_delta:
- # Both tokens in this delta
- end_idx = delta_text.find(self.end_token)
- reasoning_part = delta_text[start_idx + len(self.start_token) : end_idx]
- content_part = delta_text[end_idx + len(self.end_token) :]
- return DeltaMessage(
- reasoning=reasoning_part if reasoning_part else None,
- content=content_part if content_part else None,
- )
- else:
- # Only start token - beginning of reasoning
- reasoning_part = delta_text[start_idx + len(self.start_token) :]
+
+ # No tags — default to reasoning (implicit mode assumption).
+ # If the model doesn't use thinking at all, the server's
+ # non-parser path handles it. This path only activates when
+ # a reasoning parser is explicitly configured.
+ return DeltaMessage(reasoning=delta_text)
+
+ # ── Phase: thinking ───────────────────────────────────────
+ # Inside a reasoning block, waiting for end tag.
+ if self._phase == "thinking":
+ if end_tok in current_text and end_tok not in previous_text:
+ self._phase = "content"
+ idx = delta_text.find(end_tok)
+ if idx >= 0:
+ reasoning = delta_text[:idx]
+ content = delta_text[idx + len(end_tok) :]
+ else:
+ reasoning = delta_text
+ content = None
+ if not reasoning and not content:
+ return None
return DeltaMessage(
- reasoning=reasoning_part if reasoning_part else None
+ reasoning=reasoning or None,
+ content=content or None,
)
+ return DeltaMessage(reasoning=delta_text)
- # Fallback - treat as content
+ # ── Phase: content ────────────────────────────────────────
+ # Past the reasoning block — everything is content.
return DeltaMessage(content=delta_text)
-
- def _handle_implicit_think(
- self,
- delta_text: str,
- end_in_prev: bool,
- end_in_delta: bool,
- ) -> DeltaMessage | None:
- """Handle case where was in prompt (only in output)."""
- if end_in_delta:
- # Transition: end token in this delta
- idx = delta_text.find(self.end_token)
- reasoning_part = delta_text[:idx]
- content_part = delta_text[idx + len(self.end_token) :]
- return DeltaMessage(
- reasoning=reasoning_part if reasoning_part else None,
- content=content_part if content_part else None,
- )
- elif end_in_prev:
- # Already past reasoning phase - pure content
- return DeltaMessage(content=delta_text)
- else:
- # Still in implicit reasoning phase
- return DeltaMessage(reasoning=delta_text)
diff --git a/vllm_mlx/request.py b/vllm_mlx/request.py
index 41679c0ba..f18b238d8 100644
--- a/vllm_mlx/request.py
+++ b/vllm_mlx/request.py
@@ -57,6 +57,7 @@ class SamplingParams:
top_p: float = 0.9
top_k: int = 0 # 0 means disabled
min_p: float = 0.0
+ presence_penalty: float = 0.0
repetition_penalty: float = 1.0
stop: Optional[List[str]] = None
stop_token_ids: Optional[List[int]] = None
diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py
index ec4684049..c706c85b5 100644
--- a/vllm_mlx/scheduler.py
+++ b/vllm_mlx/scheduler.py
@@ -19,7 +19,7 @@
import mlx.core as mx
from mlx_lm.generate import BatchGenerator
-from mlx_lm.sample_utils import make_sampler
+from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.tokenizer_utils import NaiveStreamingDetokenizer
from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig
@@ -62,6 +62,8 @@ class SchedulerConfig:
prefill_batch_size: int = 8
completion_batch_size: int = 32
prefill_step_size: int = 2048
+ # Optional override for MLLM prefill guard (None = use MLLM default).
+ mllm_prefill_step_size: Optional[int] = None
# Prefix cache settings
enable_prefix_cache: bool = True
@@ -102,6 +104,10 @@ class SchedulerConfig:
mtp_num_draft_tokens: int = 1 # Number of draft tokens from MTP head
mtp_optimistic: bool = False # Skip acceptance check for max speed
+ def __post_init__(self) -> None:
+ if self.mllm_prefill_step_size is not None and self.mllm_prefill_step_size <= 0:
+ raise ValueError("mllm_prefill_step_size must be > 0 when provided")
+
@dataclass
class SchedulerOutput:
@@ -148,13 +154,66 @@ def _install_chunked_prefill(
import time as _time
from mlx_lm.generate import (
- Batch,
_left_pad_prompts,
_make_cache,
_merge_caches,
_right_pad_prompts,
)
+ try:
+ from mlx_lm.generate import _lazy_extract_cache
+ except ImportError:
+
+ def _lazy_extract_cache(cache, idx):
+ return (c.extract(idx) for c in cache)
+
+ try:
+ from mlx_lm.generate import Batch as _batch_cls
+ except ImportError:
+
+ @dataclass
+ class _batch_cls:
+ uids: List[int]
+ y: Any
+ logprobs: List[Any]
+ max_tokens: List[int]
+ num_tokens: List[int]
+ cache: List[Any]
+ samplers: List[Any]
+ logits_processors: List[Any]
+ tokens: List[Any]
+
+ def __len__(self):
+ return len(self.uids)
+
+ def filter(self, keep_idx: List[int]):
+ self.uids = [self.uids[k] for k in keep_idx]
+ self.logprobs = [self.logprobs[k] for k in keep_idx]
+ self.max_tokens = [self.max_tokens[k] for k in keep_idx]
+ self.num_tokens = [self.num_tokens[k] for k in keep_idx]
+ self.samplers = [self.samplers[k] for k in keep_idx]
+ self.logits_processors = [self.logits_processors[k] for k in keep_idx]
+ self.tokens = [self.tokens[k] for k in keep_idx]
+ keep_idx_mx = mx.array(keep_idx, mx.int32)
+ self.y = self.y[keep_idx_mx]
+ for c in self.cache:
+ c.filter(keep_idx_mx)
+
+ def extend(self, other):
+ self.uids.extend(other.uids)
+ self.y = mx.concatenate([self.y, other.y])
+ self.logprobs.extend(other.logprobs)
+ self.num_tokens.extend(other.num_tokens)
+ self.max_tokens.extend(other.max_tokens)
+ self.samplers.extend(other.samplers)
+ self.logits_processors.extend(other.logits_processors)
+ self.tokens.extend(other.tokens)
+ for c, o in zip(self.cache, other.cache):
+ c.extend(o)
+
+ def extract_cache(self, idx):
+ return [c.extract(idx) for c in self.cache]
+
# Keep references to originals
_orig_next = batch_gen._next
_orig_remove = batch_gen.remove
@@ -201,6 +260,10 @@ def _generation_step(self=batch_gen):
batch.tokens,
)
mx.async_eval(batch.y, batch.logprobs)
+ # Evaluate accumulated tokens to prevent Metal buffer buildup
+ # from lazy mx.concatenate() chains holding AGXAllocation handles
+ if batch.tokens:
+ mx.async_eval(*batch.tokens)
y = y.tolist()
self._stats.generation_time += _time.perf_counter() - tic_gen
@@ -268,8 +331,13 @@ def _chunked_next(self=batch_gen): # noqa: C901
inputs = partial["inputs"]
prompt_cache = partial["cache"]
remaining = inputs.shape[1]
+ prompt_checkpoint = max(1, int(partial.get("prompt_checkpoint", 1)))
- n_to_process = min(budget, remaining - 1) if remaining > 1 else 0
+ n_to_process = (
+ min(budget, remaining - prompt_checkpoint)
+ if remaining > prompt_checkpoint
+ else 0
+ )
if n_to_process > 0:
self.model(mx.contiguous(inputs[:, :n_to_process]), cache=prompt_cache)
@@ -294,8 +362,8 @@ def _chunked_next(self=batch_gen): # noqa: C901
if partial.get("is_cached"):
mx.clear_cache()
- # Check if prefill is done (only 1 token left or 0)
- if inputs.shape[1] <= 1:
+ # Check if prefill is done once only the checkpoint tail remains.
+ if inputs.shape[1] <= prompt_checkpoint:
# Finalize
if partial.get("is_cached"):
mx.eval([c.state for c in prompt_cache])
@@ -303,8 +371,31 @@ def _chunked_next(self=batch_gen): # noqa: C901
for c in prompt_cache:
c.finalize()
+
+ if self.prompt_checkpoint_callback is not None:
+ self.prompt_checkpoint_callback(
+ [
+ (
+ uid,
+ prompt_checkpoint,
+ _lazy_extract_cache(prompt_cache, i),
+ )
+ for i, uid in enumerate(partial["uids"])
+ ]
+ )
mx.clear_cache()
+ # Mirror upstream BatchGenerator semantics: after finalize() and
+ # the checkpoint callback, replay the remaining checkpoint tail
+ # except for the final token, which _step() consumes.
+ if prompt_checkpoint > 1:
+ self.model(
+ mx.contiguous(inputs[:, : prompt_checkpoint - 1]),
+ cache=prompt_cache,
+ )
+ mx.eval([c.state for c in prompt_cache])
+ mx.clear_cache()
+
y, logprobs = self._step(
inputs,
prompt_cache,
@@ -314,10 +405,10 @@ def _chunked_next(self=batch_gen): # noqa: C901
)
mx.async_eval(y, logprobs)
- new_batch = Batch(
+ new_batch = _batch_cls(
list(partial["uids"]),
y,
- logprobs,
+ list(logprobs),
list(partial["max_tokens"]),
[0] * len(partial["uids"]),
prompt_cache,
@@ -393,12 +484,20 @@ def _chunked_next(self=batch_gen): # noqa: C901
caches,
samplers,
logits_processors,
- _prompt_checkpoints,
+ prompt_checkpoints,
) = zip(*batch_prompts)
lengths = [len(p) for p in inputs_raw]
max_length = max(lengths)
padding = [max_length - ln for ln in lengths]
tokens = [mx.array(inp) for inp in inputs_raw]
+ # Match mlx-lm's prompt_checkpoint contract: positive values
+ # name the checkpoint token position in the prompt, while
+ # non-positive values already encode an offset from the end.
+ checkpoint_offsets = [
+ (ln - pc if pc > 0 else -pc)
+ for ln, pc in zip(lengths, prompt_checkpoints)
+ ]
+ prompt_checkpoint = max(1, max(checkpoint_offsets))
is_cached = not all(c[0].empty() for c in caches)
self._stats.prompt_tokens += sum(lengths)
@@ -409,12 +508,14 @@ def _chunked_next(self=batch_gen): # noqa: C901
self.model, padding, self.max_kv_size
)
else:
- last_inputs = mx.array([p[-1:] for p in inputs_raw])
+ last_inputs = mx.array(
+ [p[-prompt_checkpoint:] for p in inputs_raw]
+ )
padded = _right_pad_prompts(inputs_raw, max_length=max_length)
prompt_cache = _merge_caches(caches)
for c in prompt_cache:
c.prepare(
- lengths=[ln - 1 for ln in lengths],
+ lengths=[ln - prompt_checkpoint for ln in lengths],
right_padding=padding,
)
@@ -437,9 +538,11 @@ def _chunked_next(self=batch_gen): # noqa: C901
_pb = getattr(_req0, "prefix_boundary", 0) if _req0 else 0
_cached = getattr(_req0, "cached_tokens", 0) if _req0 else 0
_adjusted_pb = _pb - _cached
- if 0 < _adjusted_pb < padded.shape[1]:
+ if 0 < _adjusted_pb < padded.shape[1] - prompt_checkpoint + 1:
_first_chunk = _adjusted_pb
- n_to_process = min(_first_chunk, padded.shape[1] - 1)
+ n_to_process = min(
+ _first_chunk, padded.shape[1] - prompt_checkpoint
+ )
if n_to_process > 0:
self.model(
mx.contiguous(padded[:, :n_to_process]),
@@ -458,6 +561,7 @@ def _chunked_next(self=batch_gen): # noqa: C901
"max_tokens": list(max_tokens_list),
"samplers": list(samplers),
"logits_processors": list(logits_processors),
+ "prompt_checkpoint": prompt_checkpoint,
"processed": n_to_process,
"total": max_length,
"is_cached": is_cached,
@@ -648,6 +752,10 @@ def _mtp_step(
# --- Apply logits processors + sample primary ---
if any(logits_processors):
+ logger.debug(
+ f"[logits_proc] applying {sum(len(lp) for lp in logits_processors)} "
+ f"processors to batch_size={batch_size}"
+ )
processed_logits = []
for e in range(batch_size):
sample_logits = logits[e : e + 1]
@@ -698,12 +806,13 @@ def _mtp_step(
# RNN snapshot, then re-advance with just P so both cache
# types end up consistent at [..., P].
_rnn_snapshots = {}
- for _ci, _c in enumerate(prompt_cache):
- if not (hasattr(_c, "is_trimmable") and _c.is_trimmable()):
- if hasattr(_c, "state"):
- _rnn_snapshots[_ci] = [
- s.copy() if s is not None else None for s in _c.state
- ]
+ if not optimistic:
+ for _ci, _c in enumerate(prompt_cache):
+ if not (hasattr(_c, "is_trimmable") and _c.is_trimmable()):
+ if hasattr(_c, "state"):
+ _rnn_snapshots[_ci] = [
+ s.copy() if s is not None else None for s in _c.state
+ ]
verify_input = mx.concatenate(
[primary_tokens[:, None], draft_tokens[:, None]], axis=1
@@ -1094,11 +1203,7 @@ def _decode_tokens(self, token_ids: List[int]) -> str:
def _get_detokenizer(self, request_id: str) -> Any:
"""Get or create a streaming detokenizer for a request."""
if request_id not in self._detokenizer_pool:
- if hasattr(self.tokenizer, "detokenizer"):
- detok = self.tokenizer.detokenizer
- else:
- detok = NaiveStreamingDetokenizer(self._actual_tokenizer)
- detok.reset()
+ detok = NaiveStreamingDetokenizer(self._actual_tokenizer)
self._detokenizer_pool[request_id] = detok
return self._detokenizer_pool[request_id]
@@ -1158,15 +1263,25 @@ def _prefill_progress(progress_list):
prefill_batch_size=self.config.prefill_batch_size,
completion_batch_size=self.config.completion_batch_size,
prefill_step_size=self.config.prefill_step_size,
- prompt_progress_callback=_prefill_progress,
)
+ # Set callback as attribute — used by _install_chunked_prefill
+ # monkey-patch. Not a BatchGenerator constructor parameter.
+ bg.prompt_progress_callback = _prefill_progress
# Install chunked prefill when explicitly configured OR when
# memory-aware cache is active (needed for prefix_boundary saves
# in agentic multi-turn workloads with hybrid Mamba+Transformer models).
chunked_budget = self.config.chunked_prefill_tokens
need_chunked = chunked_budget > 0 or self.memory_aware_cache is not None
- if need_chunked:
+
+ # The chunked prefill monkey-patch relies on BatchGenerator internals
+ # (_process_prompts, active_batch, _step, etc.) that were refactored
+ # in mlx-lm 0.31.x. Skip gracefully when the required API is absent.
+ chunked_compatible = hasattr(bg, "_process_prompts") and hasattr(
+ bg, "active_batch"
+ )
+
+ if need_chunked and chunked_compatible:
if chunked_budget <= 0:
# No explicit budget — use a very large value so normal
# prompts pass through unchanged. Prefix boundary splits
@@ -1189,6 +1304,12 @@ def _prefill_progress(progress_list):
uid_to_request_id=self.uid_to_request_id,
requests=self.requests,
)
+ elif need_chunked and not chunked_compatible:
+ logger.warning(
+ "Chunked prefill disabled: mlx-lm BatchGenerator lacks required "
+ "internals (_process_prompts, active_batch). Upgrade mlx-lm or "
+ "check compatibility."
+ )
# Install MTP if the model supports it
if self.config.enable_mtp:
@@ -1791,15 +1912,30 @@ def _schedule_waiting(self) -> List[Request]:
request.remaining_tokens = request.prompt_token_ids
tokens_to_process = request.prompt_token_ids
+ # Build per-request logits_processors from repetition_penalty
+ rep_penalty = request.sampling_params.repetition_penalty
+ lp = None
+ if rep_penalty and rep_penalty != 1.0:
+ lp = make_logits_processors(repetition_penalty=rep_penalty)
+ logger.info(
+ f"[rep_penalty] request={request.request_id[:12]} "
+ f"penalty={rep_penalty} processors={len(lp)}"
+ )
+
# Insert into BatchGenerator with optional cache.
# Wrap in try/except: if cache shapes are incompatible
# (e.g. stale entry after BatchGenerator recreation),
# fall back to no-cache insert instead of crashing.
+ insert_kwargs = {
+ "max_tokens": [request.sampling_params.max_tokens],
+ "caches": [cache_to_use] if cache_to_use else None,
+ }
+ if lp:
+ insert_kwargs["logits_processors"] = [lp]
try:
uids = self.batch_generator.insert(
[tokens_to_process],
- max_tokens=[request.sampling_params.max_tokens],
- caches=[cache_to_use] if cache_to_use else None,
+ **insert_kwargs,
)
except Exception as e:
if cache_to_use is not None:
@@ -1812,10 +1948,10 @@ def _schedule_waiting(self) -> List[Request]:
request.cached_tokens = 0
request.remaining_tokens = request.prompt_token_ids
tokens_to_process = request.prompt_token_ids
+ insert_kwargs["caches"] = None
uids = self.batch_generator.insert(
[tokens_to_process],
- max_tokens=[request.sampling_params.max_tokens],
- caches=None,
+ **insert_kwargs,
)
else:
raise
@@ -1836,11 +1972,16 @@ def _schedule_waiting(self) -> List[Request]:
else ""
)
tokens_to_prefill = len(tokens_to_process)
+ rep_info = (
+ f" rep_penalty={rep_penalty}"
+ if rep_penalty and rep_penalty != 1.0
+ else ""
+ )
logger.info(
f"[schedule] request={request.request_id[:12]} uid={uid} "
f"prompt_tokens={request.num_prompt_tokens} "
f"tokens_to_prefill={tokens_to_prefill}{cache_info} "
- f"max_tokens={request.sampling_params.max_tokens} "
+ f"max_tokens={request.sampling_params.max_tokens}{rep_info} "
f"running={len(self.running)} waiting={len(self.waiting)}"
)
@@ -2216,9 +2357,16 @@ def step(self, max_retries: int = 1) -> SchedulerOutput:
# Run generation step if we have running requests
if self.batch_generator is not None and self.running:
- responses = self.batch_generator.next()
+ result = self.batch_generator.next()
output.has_work = True
+ # mlx-lm >=0.31.x returns (prompt_responses, generation_responses);
+ # older versions returned a flat list.
+ if isinstance(result, tuple):
+ responses = result[1] # generation_responses only
+ else:
+ responses = result
+
if responses:
outputs, finished_ids = self._process_batch_responses(responses)
output.outputs = outputs
@@ -2285,6 +2433,7 @@ def step(self, max_retries: int = 1) -> SchedulerOutput:
# Evaluate batch tokens to collapse lazy concatenation chains
if (
self.batch_generator is not None
+ and hasattr(self.batch_generator, "active_batch")
and self.batch_generator.active_batch is not None
and hasattr(self.batch_generator.active_batch, "tokens")
):
diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py
index 18af96438..6a749fffb 100644
--- a/vllm_mlx/server.py
+++ b/vllm_mlx/server.py
@@ -42,6 +42,7 @@
import json
import logging
import os
+import re
import secrets
import tempfile
import threading
@@ -57,8 +58,13 @@
# Import from new modular API
# Re-export for backwards compatibility with tests
-from .api.anthropic_adapter import anthropic_to_openai, openai_to_anthropic
-from .api.anthropic_models import AnthropicRequest
+from .api.anthropic_adapter import anthropic_to_openai
+from .api.anthropic_models import (
+ AnthropicRequest,
+ AnthropicResponse,
+ AnthropicResponseContentBlock,
+ AnthropicUsage,
+)
from .api.models import (
AssistantMessage, # noqa: F401
ChatCompletionChoice, # noqa: F401
@@ -98,8 +104,6 @@
)
from .api.utils import (
SPECIAL_TOKENS_PATTERN,
- StreamingThinkRouter,
- StreamingToolCallFilter,
clean_output_text,
extract_multimodal_content,
is_mllm_model, # noqa: F401
@@ -163,6 +167,11 @@ def _resolve_top_p(request_value: float | None) -> float:
_tool_call_parser: str | None = None # Parser name: auto, mistral, qwen, llama, hermes
_tool_parser_instance = None # Instantiated parser
+# Pattern to strip leaked tool call markup from content output.
+# Safety net: the tool parser should consume these, but if it doesn't
+# (e.g. malformed JSON, stray closing tags), strip them before emitting.
+_TOOL_MARKUP_PATTERN = re.compile(r"?tool_call>|?tool_call_reasoning>")
+
def _load_prefix_cache_from_disk() -> None:
"""Load prefix cache from disk during startup."""
@@ -343,6 +352,53 @@ def get_engine() -> BaseEngine:
return _engine
+def _coerce_tool_arguments(
+ arguments_json: str, tool_name: str, tools: list[dict] | None
+) -> str:
+ """
+ Coerce tool call arguments to match the tool schema.
+
+ If a schema field expects "string" but the model produced an object/array,
+ JSON-stringify the value. This fixes a common LLM failure mode where models
+ output raw JSON objects instead of JSON strings for file content, etc.
+ """
+ if not tools:
+ return arguments_json
+
+ # Find the schema for this tool
+ schema = None
+ for tool in tools:
+ if isinstance(tool, dict) and tool.get("function", {}).get("name") == tool_name:
+ schema = tool["function"].get("parameters", {})
+ break
+
+ if not schema or "properties" not in schema:
+ return arguments_json
+
+ try:
+ arguments = json.loads(arguments_json)
+ except (json.JSONDecodeError, TypeError):
+ return arguments_json
+
+ if not isinstance(arguments, dict):
+ return arguments_json
+
+ properties = schema.get("properties", {})
+ changed = False
+
+ for key, value in arguments.items():
+ if key in properties:
+ expected_type = properties[key].get("type")
+ if expected_type == "string" and isinstance(value, (dict, list)):
+ arguments[key] = json.dumps(value, ensure_ascii=False, indent=2)
+ changed = True
+
+ if changed:
+ return json.dumps(arguments, ensure_ascii=False)
+
+ return arguments_json
+
+
def _validate_model_name(request_model: str) -> None:
"""Validate that the request model name matches the served model."""
if _model_name and request_model != _model_name:
@@ -373,6 +429,14 @@ def _parse_tool_calls_with_parser(
request_dict = request.model_dump() if request else None
+ # tool_choice="none" means never return tool calls — skip all parsing
+ if request is not None:
+ tool_choice = getattr(request, "tool_choice", None)
+ if tool_choice is None and request_dict:
+ tool_choice = request_dict.get("tool_choice")
+ if tool_choice == "none":
+ return output_text, None
+
# If auto tool choice is not enabled, use the generic parser
if not _enable_auto_tool_choice or not _tool_call_parser:
return parse_tool_calls(output_text, request_dict)
@@ -400,13 +464,16 @@ def _parse_tool_calls_with_parser(
_tool_parser_instance.reset()
result = _tool_parser_instance.extract_tool_calls(output_text, request_dict)
if result.tools_called:
+ tools = request_dict.get("tools") if request_dict else None
tool_calls = [
ToolCall(
id=tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
type="function",
function=FunctionCall(
name=tc["name"],
- arguments=tc["arguments"],
+ arguments=_coerce_tool_arguments(
+ tc["arguments"], tc["name"], tools
+ ),
),
)
for tc in result.tool_calls
@@ -485,6 +552,7 @@ def load_model(
stream_interval: int = 1,
max_tokens: int = 32768,
force_mllm: bool = False,
+ gpu_memory_utilization: float = 0.90,
served_model_name: str | None = None,
mtp: bool = False,
prefill_step_size: int = 2048,
@@ -528,6 +596,7 @@ def load_model(
scheduler_config=scheduler_config,
stream_interval=stream_interval,
force_mllm=force_mllm,
+ gpu_memory_utilization=gpu_memory_utilization,
)
# BatchedEngine will be started in lifespan (uvicorn's event loop)
# Just log for now
@@ -591,14 +660,11 @@ async def health():
"tools_available": len(_mcp_manager.get_all_tools()),
}
- engine_stats = _engine.get_stats() if _engine else {}
-
return {
"status": "healthy",
"model_loaded": _engine is not None,
"model_name": _model_name,
"model_type": "mllm" if (_engine and _engine.is_mllm) else "llm",
- "engine_type": engine_stats.get("engine_type", "unknown"),
"mcp": mcp_info,
}
@@ -1027,15 +1093,19 @@ async def _disconnect_guard(
generator: AsyncIterator[str],
raw_request: Request,
poll_interval: float = 0.5,
+ heartbeat_interval: float = 5.0,
) -> AsyncIterator[str]:
"""Wrap streaming generator to abort on client disconnect.
Uses asyncio racing: each __anext__() on the inner generator is
- raced against a disconnect poller. This catches disconnects even
- during prefill when no chunks are being yielded for tens of seconds.
-
- On disconnect, aclose() propagates down the generator chain to
- engine_core.stream_outputs() finally-block → abort_request().
+ raced against a disconnect poller. When neither completes within
+ ``heartbeat_interval`` seconds, an SSE comment is yielded as a
+ heartbeat. This forces an ASGI write which triggers broken-pipe
+ detection — without heartbeats, ``is_disconnected()`` stays False
+ during long prefill because no data is written to the socket.
+
+ On disconnect, the cancellation propagates to stream_outputs()
+ finally-block → abort_request() → abort_prefill().
"""
import time as _time
@@ -1044,7 +1114,9 @@ async def _disconnect_guard(
def _elapsed():
return f"{_time.monotonic() - _t0:.1f}s"
- logger.info(f"[disconnect_guard] START poll_interval={poll_interval}s")
+ logger.info(
+ f"[disconnect_guard] START poll={poll_interval}s heartbeat={heartbeat_interval}s"
+ )
async def _wait_disconnect():
poll_count = 0
@@ -1061,21 +1133,28 @@ async def _wait_disconnect():
return
chunk_count = 0
+ heartbeat_count = 0
disconnect_task: asyncio.Task | None = None
anext_task: asyncio.Task | None = None
try:
aiter = generator.__aiter__()
disconnect_task = asyncio.create_task(_wait_disconnect())
+ anext_task = None
while True:
- anext_task = asyncio.ensure_future(aiter.__anext__())
+ if anext_task is None:
+ anext_task = asyncio.ensure_future(aiter.__anext__())
+
done, _ = await asyncio.wait(
[anext_task, disconnect_task],
return_when=asyncio.FIRST_COMPLETED,
+ timeout=heartbeat_interval,
)
+
if disconnect_task in done:
logger.info(
f"[disconnect_guard] CLIENT DISCONNECTED after "
- f"{chunk_count} chunks, elapsed={_elapsed()}"
+ f"{chunk_count} chunks, {heartbeat_count} heartbeats, "
+ f"elapsed={_elapsed()}"
)
anext_task.cancel()
try:
@@ -1083,20 +1162,32 @@ async def _wait_disconnect():
except (asyncio.CancelledError, StopAsyncIteration):
pass
break
- try:
- chunk = anext_task.result()
- except StopAsyncIteration:
- logger.info(
- f"[disconnect_guard] generator exhausted normally, "
- f"{chunk_count} chunks, elapsed={_elapsed()}"
- )
- break
- chunk_count += 1
- if chunk_count == 1:
- logger.info(
- f"[disconnect_guard] first chunk arrived, elapsed={_elapsed()}"
- )
- yield chunk
+
+ if anext_task in done:
+ try:
+ chunk = anext_task.result()
+ except StopAsyncIteration:
+ logger.info(
+ f"[disconnect_guard] generator exhausted normally, "
+ f"{chunk_count} chunks, elapsed={_elapsed()}"
+ )
+ break
+ chunk_count += 1
+ if chunk_count == 1:
+ logger.info(
+ f"[disconnect_guard] first chunk arrived, elapsed={_elapsed()}"
+ )
+ yield chunk
+ anext_task = None
+ continue
+
+ # Timeout — no chunk and no disconnect detected yet.
+ # Send SSE comment as heartbeat to force an ASGI write.
+ # If the client has disconnected, this write will fail and
+ # the next is_disconnected() poll will return True.
+ heartbeat_count += 1
+ yield ": heartbeat\n\n"
+
except GeneratorExit:
logger.info(
f"[disconnect_guard] GeneratorExit after {chunk_count} chunks, elapsed={_elapsed()}"
@@ -1116,7 +1207,8 @@ async def _wait_disconnect():
# anext_task.cancel() → CancelledError in stream_outputs()
# → finally block → abort_request() → request removed from scheduler
logger.info(
- f"[disconnect_guard] CLEANUP done, {chunk_count} chunks total, elapsed={_elapsed()}"
+ f"[disconnect_guard] CLEANUP done, {chunk_count} chunks, "
+ f"{heartbeat_count} heartbeats, elapsed={_elapsed()}"
)
@@ -1218,13 +1310,24 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
logger.info(
f"[REQUEST] POST /v1/completions stream={request.stream} "
f"max_tokens={request.max_tokens} temp={request.temperature} "
+ f"top_p={request.top_p} top_k={request.top_k} min_p={request.min_p} "
+ f"presence_penalty={request.presence_penalty} "
+ f"repetition_penalty={request.repetition_penalty} "
f"prompt_chars={prompt_len} prompt_preview={prompt_preview!r}"
)
+ # Resolve repetition penalty for completions
+ comp_rep_penalty = request.repetition_penalty
+
if request.stream:
return StreamingResponse(
_disconnect_guard(
- stream_completion(engine, prompts[0], request),
+ stream_completion(
+ engine,
+ prompts[0],
+ request,
+ repetition_penalty=comp_rep_penalty,
+ ),
raw_request,
),
media_type="text/event-stream",
@@ -1238,14 +1341,25 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
total_prompt_tokens = 0
for i, prompt in enumerate(prompts):
+ generate_kwargs = {
+ "prompt": prompt,
+ "max_tokens": request.max_tokens or _default_max_tokens,
+ "temperature": _resolve_temperature(request.temperature),
+ "top_p": _resolve_top_p(request.top_p),
+ "top_k": request.top_k or 0,
+ "min_p": request.min_p or 0.0,
+ "presence_penalty": request.presence_penalty or 0.0,
+ "stop": request.stop,
+ }
+ if comp_rep_penalty is not None:
+ generate_kwargs["repetition_penalty"] = comp_rep_penalty
+ if request.specprefill is not None:
+ generate_kwargs["specprefill"] = request.specprefill
+ if request.specprefill_keep_pct is not None:
+ generate_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct
+
output = await _wait_with_disconnect(
- engine.generate(
- prompt=prompt,
- max_tokens=request.max_tokens or _default_max_tokens,
- temperature=_resolve_temperature(request.temperature),
- top_p=_resolve_top_p(request.top_p),
- stop=request.stop,
- ),
+ engine.generate(**generate_kwargs),
raw_request,
timeout=timeout,
)
@@ -1345,7 +1459,11 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
logger.info(
f"[REQUEST] POST /v1/chat/completions stream={request.stream} "
f"model={request.model!r} max_tokens={request.max_tokens} "
- f"temp={request.temperature} msgs={n_msgs} roles={msg_roles} "
+ f"temp={request.temperature} top_p={request.top_p} "
+ f"top_k={request.top_k} min_p={request.min_p} "
+ f"presence_penalty={request.presence_penalty} "
+ f"repetition_penalty={request.repetition_penalty} "
+ f"msgs={n_msgs} roles={msg_roles} "
f"total_chars={total_chars} tools={n_tools} "
f"response_format={request.response_format}"
)
@@ -1367,12 +1485,14 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
messages.append(msg_dict)
images, videos = [], [] # MLLM extracts these from messages
logger.debug(f"MLLM: Processing {len(messages)} messages")
+ messages = _normalize_messages(messages)
else:
# For LLM, extract text, images, and videos separately
messages, images, videos = extract_multimodal_content(
request.messages,
preserve_native_format=engine.preserve_native_tool_format,
)
+ messages = _normalize_messages(messages)
has_media = bool(images or videos)
if engine.is_mllm and not has_media:
@@ -1401,12 +1521,21 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
# Inject JSON instruction into messages
messages = _inject_json_instruction(messages, json_instruction)
+ # Resolve repetition penalty
+ rep_penalty = request.repetition_penalty
+
# Prepare kwargs
chat_kwargs = {
"max_tokens": request.max_tokens or _default_max_tokens,
"temperature": _resolve_temperature(request.temperature),
"top_p": _resolve_top_p(request.top_p),
+ "top_k": request.top_k or 0,
+ "min_p": request.min_p or 0.0,
+ "presence_penalty": request.presence_penalty or 0.0,
+ "repetition_penalty": request.repetition_penalty or 1.0,
}
+ if rep_penalty is not None:
+ chat_kwargs["repetition_penalty"] = rep_penalty
# Add multimodal content
if has_media:
@@ -1425,8 +1554,12 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
if request.chat_template_kwargs:
chat_kwargs["chat_template_kwargs"] = dict(request.chat_template_kwargs)
+ # Enable/disable thinking mode per request
+ if request.enable_thinking is not None:
+ chat_kwargs["enable_thinking"] = request.enable_thinking
+
# Add tools if provided
- if request.tools:
+ if request.tools and request.tool_choice != "none":
chat_kwargs["tools"] = convert_tools_for_template(request.tools)
if request.stream:
@@ -1460,8 +1593,9 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
cleaned_text, tool_calls = _parse_tool_calls_with_parser(output.text, request)
# Extract reasoning content FIRST (strips channel tokens before JSON extraction)
+ # Skip reasoning parser when enable_thinking=False (no think tags expected)
reasoning_text = None
- if _reasoning_parser and not tool_calls:
+ if _reasoning_parser and not tool_calls and request.enable_thinking is not False:
text_to_parse = cleaned_text or output.text
reasoning_text, cleaned_text = _reasoning_parser.extract_reasoning(
text_to_parse
@@ -1500,6 +1634,64 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
)
+def _normalize_messages(messages: list[dict]) -> list[dict]:
+ """Normalize message roles and merge consecutive same-role messages.
+
+ 1. Maps non-standard roles to standard ones (e.g. ``developer`` -> ``system``).
+ 2. Merges consecutive same-role messages to satisfy chat template constraints
+ (Qwen 3.5, Llama, etc. require alternating roles).
+
+ Only merges when both messages have string content. Messages with list
+ content (multimodal) are left as-is to preserve image/video attachments.
+
+ Args:
+ messages: List of message dicts with 'role' and 'content' keys.
+
+ Returns:
+ New list with normalized roles and consecutive same-role messages merged.
+ """
+ # OpenAI Responses API uses "developer" instead of "system".
+ # Map it so chat templates don't fail and fall back to raw prefill.
+ _ROLE_MAP = {"developer": "system"}
+
+ if not messages:
+ return messages
+
+ merged = [messages[0].copy()]
+ if merged[0]["role"] in _ROLE_MAP:
+ merged[0]["role"] = _ROLE_MAP[merged[0]["role"]]
+ for msg in messages[1:]:
+ prev = merged[-1]
+ role = _ROLE_MAP.get(msg["role"], msg["role"])
+ if (
+ role == prev["role"]
+ and isinstance(prev.get("content"), str)
+ and isinstance(msg.get("content"), str)
+ ):
+ # Merge string content with double newline separator
+ prev["content"] = prev["content"] + "\n\n" + msg["content"]
+ logger.debug(
+ f"Merged consecutive {role} messages "
+ f"({len(prev['content'])} chars total)"
+ )
+ else:
+ copy = msg.copy()
+ copy["role"] = role
+ merged.append(copy)
+
+ mapped_roles = sum(1 for m in messages if m["role"] in _ROLE_MAP)
+ merged_count = len(messages) - len(merged)
+ if mapped_roles or merged_count:
+ parts = []
+ if mapped_roles:
+ parts.append(f"mapped {mapped_roles} role(s)")
+ if merged_count:
+ parts.append(f"merged {len(messages)} -> {len(merged)}")
+ logger.info(f"Normalized messages: {', '.join(parts)}")
+
+ return merged
+
+
def _inject_json_instruction(messages: list, instruction: str) -> list:
"""
Inject JSON instruction into messages.
@@ -1537,6 +1729,17 @@ def _inject_json_instruction(messages: list, instruction: str) -> list:
# =============================================================================
+def _convert_anthropic_stop_reason(openai_reason: str | None) -> str:
+ """Convert OpenAI finish_reason to Anthropic stop_reason."""
+ mapping = {
+ "stop": "end_turn",
+ "tool_calls": "tool_use",
+ "length": "max_tokens",
+ "content_filter": "end_turn",
+ }
+ return mapping.get(openai_reason or "", "end_turn")
+
+
@app.post("/v1/messages")
async def create_anthropic_message(
request: Request,
@@ -1551,8 +1754,19 @@ async def create_anthropic_message(
"""
engine = get_engine()
- # Parse the raw body to handle Anthropic request format
- body = await request.json()
+ # Parse the raw body to handle Anthropic request format.
+ # Some clients (e.g. Claude Code) may send JSON with invalid escape
+ # sequences like \s, \d in regex patterns within tool definitions.
+ # Python's json.loads is strict per RFC 8259 and rejects these.
+ try:
+ body = await request.json()
+ except json.JSONDecodeError as e:
+ if "Invalid \\escape" in str(e):
+ raw = await request.body()
+ # Replace lone backslashes (not valid JSON escapes) with \\
+ body = json.loads(re.sub(rb'\\(?!["\\/bfnrtu])', rb"\\\\", raw))
+ else:
+ raise
anthropic_request = AnthropicRequest(**body)
_validate_model_name(anthropic_request.model)
@@ -1597,14 +1811,19 @@ async def create_anthropic_message(
openai_request.messages,
preserve_native_format=engine.preserve_native_tool_format,
)
+ messages = _normalize_messages(messages)
chat_kwargs = {
"max_tokens": openai_request.max_tokens or _default_max_tokens,
"temperature": openai_request.temperature,
"top_p": openai_request.top_p,
+ "top_k": openai_request.top_k or 0,
+ "min_p": openai_request.min_p or 0.0,
+ "presence_penalty": openai_request.presence_penalty or 0.0,
+ "repetition_penalty": openai_request.repetition_penalty or 1.0,
}
- if openai_request.tools:
+ if openai_request.tools and openai_request.tool_choice != "none":
chat_kwargs["tools"] = convert_tools_for_template(openai_request.tools)
start_time = time.perf_counter()
@@ -1629,35 +1848,63 @@ async def create_anthropic_message(
output.text, openai_request
)
+ # Extract reasoning if parser is configured
+ reasoning_text = None
+ if _reasoning_parser and not tool_calls:
+ text_to_parse = cleaned_text or output.text
+ reasoning_text, cleaned_text = _reasoning_parser.extract_reasoning(
+ text_to_parse
+ )
+
# Clean output text
final_content = None
if cleaned_text:
final_content = clean_output_text(cleaned_text)
- # Determine finish reason
- finish_reason = "tool_calls" if tool_calls else output.finish_reason
+ # Build Anthropic content blocks directly (with thinking support)
+ content_blocks = []
- # Build OpenAI response to convert
- openai_response = ChatCompletionResponse(
- model=_model_name,
- choices=[
- ChatCompletionChoice(
- message=AssistantMessage(
- content=final_content,
- tool_calls=tool_calls,
- ),
- finish_reason=finish_reason,
+ if reasoning_text:
+ content_blocks.append(
+ AnthropicResponseContentBlock(type="thinking", thinking=reasoning_text)
+ )
+
+ if final_content:
+ content_blocks.append(
+ AnthropicResponseContentBlock(type="text", text=final_content)
+ )
+
+ if tool_calls:
+ for tc in tool_calls:
+ try:
+ tool_input = json.loads(tc.function.arguments)
+ except (json.JSONDecodeError, AttributeError):
+ tool_input = {}
+ content_blocks.append(
+ AnthropicResponseContentBlock(
+ type="tool_use",
+ id=tc.id,
+ name=tc.function.name,
+ input=tool_input,
+ )
)
- ],
- usage=Usage(
- prompt_tokens=output.prompt_tokens,
- completion_tokens=output.completion_tokens,
- total_tokens=output.prompt_tokens + output.completion_tokens,
- ),
+
+ if not content_blocks:
+ content_blocks.append(AnthropicResponseContentBlock(type="text", text=""))
+
+ stop_reason = _convert_anthropic_stop_reason(
+ "tool_calls" if tool_calls else output.finish_reason
)
- # Convert to Anthropic response
- anthropic_response = openai_to_anthropic(openai_response, _model_name)
+ anthropic_response = AnthropicResponse(
+ model=_model_name,
+ content=content_blocks,
+ stop_reason=stop_reason,
+ usage=AnthropicUsage(
+ input_tokens=output.prompt_tokens,
+ output_tokens=output.completion_tokens,
+ ),
+ )
return Response(
content=anthropic_response.model_dump_json(exclude_none=True),
media_type="application/json",
@@ -1798,6 +2045,10 @@ async def _stream_anthropic_messages(
Converts OpenAI streaming chunks to Anthropic event format:
message_start -> content_block_start -> content_block_delta* ->
content_block_stop -> message_delta -> message_stop
+
+ When a reasoning parser is active, emits a ``thinking`` content block
+ (index 0) for reasoning tokens and a ``text`` content block (index 1)
+ for the actual response, matching the Anthropic extended thinking format.
"""
msg_id = f"msg_{uuid.uuid4().hex[:24]}"
start_time = time.perf_counter()
@@ -1807,14 +2058,19 @@ async def _stream_anthropic_messages(
openai_request.messages,
preserve_native_format=engine.preserve_native_tool_format,
)
+ messages = _normalize_messages(messages)
chat_kwargs = {
"max_tokens": openai_request.max_tokens or _default_max_tokens,
"temperature": openai_request.temperature,
"top_p": openai_request.top_p,
+ "top_k": openai_request.top_k or 0,
+ "min_p": openai_request.min_p or 0.0,
+ "presence_penalty": openai_request.presence_penalty or 0.0,
+ "repetition_penalty": openai_request.repetition_penalty or 1.0,
}
- if openai_request.tools:
+ if openai_request.tools and openai_request.tool_choice != "none":
chat_kwargs["tools"] = convert_tools_for_template(openai_request.tools)
# Emit message_start
@@ -1836,115 +2092,171 @@ async def _stream_anthropic_messages(
}
yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n"
- # Stream pipeline: raw text → tool call filter → think router → emit
- # - Tool call filter strips tool call markup (emitted as structured blocks later)
- # - Think router separates content into Anthropic thinking blocks
+ use_reasoning = _reasoning_parser is not None
+
+ if use_reasoning:
+ _reasoning_parser.reset_state()
+
+ # Block index tracking: with reasoning parser we use index 0 for
+ # thinking and index 1 for text; without parser, index 0 for text.
+ thinking_block_started = False
+ text_block_started = False
+ thinking_index = 0
+ text_index = 1 if use_reasoning else 0
+
+ if not use_reasoning:
+ # No reasoning parser — start text block immediately
+ yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n"
+ text_block_started = True
+
+ # Stream content deltas
accumulated_text = ""
- tool_filter = StreamingToolCallFilter()
- # Detect if the model's chat template injects into the
- # generation prompt. If so, the model starts in thinking mode and
- # the opening tag never appears in the output stream.
- _tokenizer = engine.tokenizer if hasattr(engine, "tokenizer") else None
- _chat_template = ""
- if _tokenizer and hasattr(_tokenizer, "chat_template"):
- _chat_template = _tokenizer.chat_template or ""
- _starts_thinking = (
- "" in _chat_template and "add_generation_prompt" in _chat_template
- )
- think_router = StreamingThinkRouter(start_in_thinking=_starts_thinking)
- prompt_tokens = 0
completion_tokens = 0
- # Track which content blocks we've started
- current_block_type = None # "thinking" or "text"
- block_index = 0
+ # Tool call streaming suppression — prevents raw tool markup from leaking
+ # as text_delta events. Mirrors the OpenAI streaming path logic.
+ global _tool_parser_instance
+ tool_parser = None
+ tool_accumulated_text = ""
+ tool_markup_possible = False
+ tool_choice = getattr(openai_request, "tool_choice", None)
+ if _enable_auto_tool_choice and _tool_call_parser and tool_choice != "none":
+ if _tool_parser_instance is None:
+ try:
+ parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser)
+ tokenizer = None
+ if _engine is not None and hasattr(_engine, "_tokenizer"):
+ tokenizer = _engine._tokenizer
+ _tool_parser_instance = parser_cls(tokenizer)
+ except Exception:
+ pass
+ if _tool_parser_instance is not None:
+ tool_parser = _tool_parser_instance
+ tool_parser.reset()
async for output in engine.stream_chat(messages=messages, **chat_kwargs):
delta_text = output.new_text
# Track token counts
- if hasattr(output, "prompt_tokens") and output.prompt_tokens:
- prompt_tokens = output.prompt_tokens
if hasattr(output, "completion_tokens") and output.completion_tokens:
completion_tokens = output.completion_tokens
- if delta_text:
- # Accumulate raw text BEFORE special token cleaning for tool parsing
- accumulated_text += delta_text
+ if not delta_text:
+ continue
- # Filter special tokens for display
- content = SPECIAL_TOKENS_PATTERN.sub("", delta_text)
+ # Filter special tokens
+ filtered = SPECIAL_TOKENS_PATTERN.sub("", delta_text)
+ if not filtered:
+ continue
- if content:
- # Stage 1: strip tool call markup
- filtered = tool_filter.process(content)
- if not filtered:
- continue
- # Stage 2: route thinking vs text
- pieces = think_router.process(filtered)
- events, current_block_type, block_index = _emit_content_pieces(
- pieces, current_block_type, block_index
- )
- for event in events:
- yield event
-
- # Flush remaining from both filters
- remaining = tool_filter.flush()
- if remaining:
- events, current_block_type, block_index = _emit_content_pieces(
- think_router.process(remaining), current_block_type, block_index
- )
- for event in events:
- yield event
+ if not use_reasoning:
+ # Simple path — no reasoning parsing
+ accumulated_text += filtered
+ content_to_emit = filtered
- flush_pieces = think_router.flush()
- if flush_pieces:
- events, current_block_type, block_index = _emit_content_pieces(
- flush_pieces, current_block_type, block_index
+ # Filter tool call markup during streaming
+ if tool_parser and content_to_emit:
+ if not tool_markup_possible and "<" not in content_to_emit:
+ tool_accumulated_text += content_to_emit
+ else:
+ if not tool_markup_possible:
+ tool_markup_possible = True
+ tool_previous = tool_accumulated_text
+ tool_accumulated_text += content_to_emit
+ tool_result = tool_parser.extract_tool_calls_streaming(
+ tool_previous, tool_accumulated_text, content_to_emit
+ )
+ if tool_result is None or "tool_calls" in tool_result:
+ # Inside tool markup or tool calls detected — suppress
+ continue
+ content_to_emit = tool_result.get("content", "")
+ if content_to_emit:
+ content_to_emit = _TOOL_MARKUP_PATTERN.sub("", content_to_emit)
+ if not content_to_emit:
+ continue
+
+ yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': content_to_emit}})}\n\n"
+ continue
+
+ # Reasoning parser path
+ previous_text = accumulated_text
+ accumulated_text += filtered
+ delta_msg = _reasoning_parser.extract_reasoning_streaming(
+ previous_text, accumulated_text, filtered
)
- for event in events:
- yield event
- # Close final content block
- if current_block_type is not None:
- yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n"
- block_index += 1
+ if delta_msg is None:
+ continue
+
+ if delta_msg.reasoning:
+ if not thinking_block_started:
+ yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': thinking_index, 'content_block': {'type': 'thinking', 'thinking': ''}})}\n\n"
+ thinking_block_started = True
+ yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': thinking_index, 'delta': {'type': 'thinking_delta', 'thinking': delta_msg.reasoning}})}\n\n"
+
+ if delta_msg.content:
+ content_to_emit = delta_msg.content
+
+ # Filter tool call markup during streaming
+ if tool_parser and content_to_emit:
+ if not tool_markup_possible and "<" not in content_to_emit:
+ tool_accumulated_text += content_to_emit
+ else:
+ if not tool_markup_possible:
+ tool_markup_possible = True
+ tool_previous = tool_accumulated_text
+ tool_accumulated_text += content_to_emit
+ tool_result = tool_parser.extract_tool_calls_streaming(
+ tool_previous, tool_accumulated_text, content_to_emit
+ )
+ if tool_result is None or "tool_calls" in tool_result:
+ # Inside tool markup or tool calls detected — suppress
+ continue
+ content_to_emit = tool_result.get("content", "")
+ if content_to_emit:
+ content_to_emit = _TOOL_MARKUP_PATTERN.sub("", content_to_emit)
+ if not content_to_emit:
+ continue
+
+ if thinking_block_started and not text_block_started:
+ # Close thinking block, open text block
+ yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': thinking_index})}\n\n"
+ yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': text_index, 'content_block': {'type': 'text', 'text': ''}})}\n\n"
+ text_block_started = True
+ elif not text_block_started:
+ # No thinking was emitted, start text block at index 0
+ text_index = 0
+ yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': text_index, 'content_block': {'type': 'text', 'text': ''}})}\n\n"
+ text_block_started = True
+ yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': text_index, 'delta': {'type': 'text_delta', 'text': content_to_emit}})}\n\n"
+
+ # Close any open thinking block that was never followed by text
+ if thinking_block_started and not text_block_started:
+ yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': thinking_index})}\n\n"
+ # Emit empty text block so response always has text content
+ text_index = thinking_index + 1
+ yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': text_index, 'content_block': {'type': 'text', 'text': ''}})}\n\n"
+ text_block_started = True
# Check for tool calls in accumulated text
_, tool_calls = _parse_tool_calls_with_parser(accumulated_text, openai_request)
+ # Close text block
+ if text_block_started:
+ yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': text_index})}\n\n"
+
# If there are tool calls, emit tool_use blocks
+ next_index = (text_index + 1) if text_block_started else 0
if tool_calls:
for i, tc in enumerate(tool_calls):
- tool_index = block_index + i
+ tool_index = next_index + i
try:
tool_input = json.loads(tc.function.arguments)
except (json.JSONDecodeError, AttributeError):
tool_input = {}
- # content_block_start for tool_use
- tool_block_start = {
- "type": "content_block_start",
- "index": tool_index,
- "content_block": {
- "type": "tool_use",
- "id": tc.id,
- "name": tc.function.name,
- "input": {},
- },
- }
- yield f"event: content_block_start\ndata: {json.dumps(tool_block_start)}\n\n"
-
- # Send input as a single delta
- input_json = json.dumps(tool_input)
- input_delta = {
- "type": "content_block_delta",
- "index": tool_index,
- "delta": {"type": "input_json_delta", "partial_json": input_json},
- }
- yield f"event: content_block_delta\ndata: {json.dumps(input_delta)}\n\n"
-
- # content_block_stop
+ yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': tool_index, 'content_block': {'type': 'tool_use', 'id': tc.id, 'name': tc.function.name, 'input': {}}})}\n\n"
+ yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': tool_index, 'delta': {'type': 'input_json_delta', 'partial_json': json.dumps(tool_input)}})}\n\n"
yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': tool_index})}\n\n"
# Determine stop reason
@@ -1954,7 +2266,7 @@ async def _stream_anthropic_messages(
message_delta = {
"type": "message_delta",
"delta": {"stop_reason": stop_reason, "stop_sequence": None},
- "usage": {"input_tokens": prompt_tokens, "output_tokens": completion_tokens},
+ "usage": {"output_tokens": completion_tokens},
}
yield f"event: message_delta\ndata: {json.dumps(message_delta)}\n\n"
@@ -1962,7 +2274,7 @@ async def _stream_anthropic_messages(
elapsed = time.perf_counter() - start_time
tokens_per_sec = completion_tokens / elapsed if elapsed > 0 else 0
logger.info(
- f"Anthropic messages (stream): prompt={prompt_tokens} + completion={completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)"
+ f"Anthropic messages (stream): {completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)"
)
# Emit message_stop
@@ -1978,15 +2290,27 @@ async def stream_completion(
engine: BaseEngine,
prompt: str,
request: CompletionRequest,
+ repetition_penalty: float | None = None,
) -> AsyncIterator[str]:
"""Stream completion response."""
- async for output in engine.stream_generate(
- prompt=prompt,
- max_tokens=request.max_tokens or _default_max_tokens,
- temperature=_resolve_temperature(request.temperature),
- top_p=_resolve_top_p(request.top_p),
- stop=request.stop,
- ):
+ generate_kwargs = {
+ "prompt": prompt,
+ "max_tokens": request.max_tokens or _default_max_tokens,
+ "temperature": _resolve_temperature(request.temperature),
+ "top_p": _resolve_top_p(request.top_p),
+ "top_k": request.top_k or 0,
+ "min_p": request.min_p or 0.0,
+ "presence_penalty": request.presence_penalty or 0.0,
+ "stop": request.stop,
+ }
+ if repetition_penalty is not None:
+ generate_kwargs["repetition_penalty"] = repetition_penalty
+ if request.specprefill is not None:
+ generate_kwargs["specprefill"] = request.specprefill
+ if request.specprefill_keep_pct is not None:
+ generate_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct
+
+ async for output in engine.stream_generate(**generate_kwargs):
data = {
"id": f"cmpl-{uuid.uuid4().hex[:8]}",
"object": "text_completion",
@@ -2057,7 +2381,8 @@ async def stream_chat_completion(
tool_accumulated_text = ""
tool_calls_detected = False
tool_markup_possible = False # Fast path: skip parsing until '<' seen
- if _enable_auto_tool_choice and _tool_call_parser:
+ tool_choice = getattr(request, "tool_choice", None)
+ if _enable_auto_tool_choice and _tool_call_parser and tool_choice != "none":
# Initialize parser if needed (same as _parse_tool_calls_with_parser)
if _tool_parser_instance is None:
try:
@@ -2084,8 +2409,8 @@ async def stream_chat_completion(
if hasattr(output, "completion_tokens") and output.completion_tokens:
completion_tokens = output.completion_tokens
- # Use reasoning parser if enabled
- if _reasoning_parser and delta_text:
+ # Use reasoning parser if enabled (skip when enable_thinking=False)
+ if _reasoning_parser and delta_text and request.enable_thinking is not False:
previous_text = accumulated_text
accumulated_text += delta_text
delta_msg = _reasoning_parser.extract_reasoning_streaming(
@@ -2096,16 +2421,115 @@ async def stream_chat_completion(
# Skip this chunk (e.g., token itself)
continue
+ content = delta_msg.content
+ reasoning = delta_msg.reasoning
+
+ # Some models (e.g. MiniMax) wrap tool calls in
+ # blocks, so reasoning parser captures tool call XML as
+ # reasoning while content stays None. Redirect reasoning
+ # to the content stream so the tool parser can handle it.
+ if tool_parser and reasoning and not content:
+ _check = tool_accumulated_text + reasoning
+ if (
+ "" in _check
+ or "" in _check
+ or ' never arrived - incomplete tool call)
+ # (e.g., never arrived, or " in tool_accumulated_text
+ and (
+ "" in tool_accumulated_text
+ or "<|tool_call>" in tool_accumulated_text
+ or " 0:
+ if max_rotating_size > 0 and M > max_rotating_size:
tail_start = max(0, M - max_rotating_size)
tail_indices = set(range(tail_start, M))
existing = set(selected_indices.tolist())
diff --git a/vllm_mlx/text_model_from_vlm.py b/vllm_mlx/text_model_from_vlm.py
index b1130fdc5..082ccf43b 100644
--- a/vllm_mlx/text_model_from_vlm.py
+++ b/vllm_mlx/text_model_from_vlm.py
@@ -94,15 +94,27 @@ def _class_predicate(path, module):
else:
logger.warning("No MTP weights found in %s", model_path.name)
- # Verify MTP is functional
+ # Inject MTP if TextModel doesn't have native MTP support.
+ # mlx_lm's qwen3_5.TextModel strips MTP weights in sanitize(),
+ # so we inject MTP module + methods at runtime.
+ if not hasattr(text_model, "mtp") or text_model.mtp is None:
+ num_mtp = text_config.get("mtp_num_hidden_layers", 0)
+ if num_mtp == 0:
+ num_mtp = text_config.get("num_nextn_predict_layers", 0)
+ if num_mtp > 0:
+ from .patches.qwen3_5_mtp import inject_mtp_support
+
+ inject_mtp_support(text_model, model_path, config)
+
if hasattr(text_model, "mtp") and text_model.mtp is not None:
mx.eval(text_model.mtp.parameters())
- logger.info(
- "TextModel built with MTP support (%d layers)",
- args.mtp_num_hidden_layers,
+ num_mtp = text_config.get(
+ "mtp_num_hidden_layers",
+ text_config.get("num_nextn_predict_layers", 0),
)
+ logger.info("TextModel built with MTP support (%d layers)", num_mtp)
else:
- logger.info("TextModel built without MTP (mtp_num_hidden_layers=0)")
+ logger.info("TextModel built without MTP")
return text_model
diff --git a/vllm_mlx/tool_parsers/__init__.py b/vllm_mlx/tool_parsers/__init__.py
index 16f744080..cd76ad418 100644
--- a/vllm_mlx/tool_parsers/__init__.py
+++ b/vllm_mlx/tool_parsers/__init__.py
@@ -10,6 +10,7 @@
- mistral: Mistral models ([TOOL_CALLS] format)
- qwen/qwen3: Qwen models ( and [Calling tool:] formats)
- llama/llama3/llama4: Llama models ( format)
+- gemma4/gemma_4: Google Gemma 4 models (<|tool_call>call:name{} format)
- hermes/nous: Hermes/NousResearch models
- deepseek/deepseek_v3/deepseek_r1: DeepSeek models (unicode tokens)
- kimi/kimi_k2/moonshot: Kimi/Moonshot models
@@ -19,6 +20,7 @@
- functionary/meetkai: MeetKai Functionary models
- glm47/glm4: GLM-4.7 and GLM-4.7-Flash models
- harmony/gpt-oss: GPT-OSS models (Harmony format with channels)
+- minimax: MiniMax-M2 models
Usage:
from vllm_mlx.tool_parsers import ToolParserManager
@@ -47,6 +49,7 @@
from .auto_tool_parser import AutoToolParser
from .deepseek_tool_parser import DeepSeekToolParser
from .functionary_tool_parser import FunctionaryToolParser
+from .gemma4_tool_parser import Gemma4ToolParser
from .granite_tool_parser import GraniteToolParser
from .hermes_tool_parser import HermesToolParser
from .kimi_tool_parser import KimiToolParser
@@ -57,6 +60,7 @@
from .xlam_tool_parser import xLAMToolParser
from .glm47_tool_parser import Glm47ToolParser
from .harmony_tool_parser import HarmonyToolParser
+from .minimax_tool_parser import MiniMaxToolParser
__all__ = [
# Base classes
@@ -65,6 +69,7 @@
"ExtractedToolCallInformation",
# Specific parsers
"AutoToolParser",
+ "Gemma4ToolParser",
"MistralToolParser",
"QwenToolParser",
"LlamaToolParser",
@@ -77,4 +82,5 @@
"FunctionaryToolParser",
"Glm47ToolParser",
"HarmonyToolParser",
+ "MiniMaxToolParser",
]
diff --git a/vllm_mlx/tool_parsers/auto_tool_parser.py b/vllm_mlx/tool_parsers/auto_tool_parser.py
index fc02d8fc6..37ab10d74 100644
--- a/vllm_mlx/tool_parsers/auto_tool_parser.py
+++ b/vllm_mlx/tool_parsers/auto_tool_parser.py
@@ -16,6 +16,7 @@
ToolParser,
ToolParserManager,
)
+from .gemma4_tool_parser import Gemma4ToolParser
def generate_tool_id() -> str:
@@ -29,12 +30,13 @@ class AutoToolParser(ToolParser):
Auto-detecting tool call parser.
Tries multiple formats in order:
- 1. Mistral: [TOOL_CALLS] ...
- 2. Qwen bracket: [Calling tool: func_name({...})]
- 3. Qwen/Hermes XML: {"name": "...", "arguments": {...}}
- 4. Llama: {"arg": "value"}
- 5. Nemotron: ...
- 6. Raw JSON: {"name": "...", "arguments": {...}}
+ 1. Gemma 4: <|tool_call>call:name{...}
+ 2. Mistral: [TOOL_CALLS] ...
+ 3. Qwen bracket: [Calling tool: func_name({...})]
+ 4. Qwen/Hermes XML: {"name": "...", "arguments": {...}}
+ 5. Llama: {"arg": "value"}
+ 6. Nemotron: ...
+ 7. Raw JSON: {"name": "...", "arguments": {...}}
This is the default parser when no specific parser is selected.
"""
@@ -63,7 +65,14 @@ def extract_tool_calls(
tool_calls: list[dict[str, Any]] = []
cleaned_text = model_output
- # 1. Try Mistral format
+ # 1. Try Gemma 4 format (most distinctive marker)
+ if "<|tool_call>" in model_output:
+ gemma_parser = Gemma4ToolParser()
+ result = gemma_parser.extract_tool_calls(model_output, request)
+ if result.tools_called:
+ return result
+
+ # 2. Try Mistral format
if self.MISTRAL_TOKEN in model_output:
parts = model_output.split(self.MISTRAL_TOKEN)
content = parts[0].strip()
@@ -113,7 +122,7 @@ def extract_tool_calls(
content=content if content else None,
)
- # 2. Try Qwen bracket pattern
+ # 3. Try Qwen bracket pattern
bracket_matches = self.QWEN_BRACKET_PATTERN.findall(model_output)
for name, args_str in bracket_matches:
try:
@@ -141,7 +150,7 @@ def extract_tool_calls(
if bracket_matches:
cleaned_text = self.QWEN_BRACKET_PATTERN.sub("", cleaned_text).strip()
- # 3. Try Nemotron pattern (before Qwen XML as it's more specific)
+ # 4. Try Nemotron pattern (before Qwen XML as it's more specific)
nemotron_matches = self.NEMOTRON_PATTERN.findall(cleaned_text)
for name, params_block in nemotron_matches:
params = self.NEMOTRON_PARAM_PATTERN.findall(params_block)
@@ -157,7 +166,7 @@ def extract_tool_calls(
if nemotron_matches:
cleaned_text = self.NEMOTRON_PATTERN.sub("", cleaned_text).strip()
- # 4. Try Qwen/Hermes XML pattern
+ # 5. Try Qwen/Hermes XML pattern
xml_matches = self.QWEN_XML_PATTERN.findall(cleaned_text)
for match in xml_matches:
try:
@@ -182,7 +191,7 @@ def extract_tool_calls(
if xml_matches:
cleaned_text = self.QWEN_XML_PATTERN.sub("", cleaned_text).strip()
- # 5. Try Llama pattern
+ # 6. Try Llama pattern
llama_matches = self.LLAMA_PATTERN.findall(cleaned_text)
for name, args_str in llama_matches:
try:
@@ -210,7 +219,7 @@ def extract_tool_calls(
if llama_matches:
cleaned_text = self.LLAMA_PATTERN.sub("", cleaned_text).strip()
- # 6. Fallback: Try raw JSON
+ # 7. Fallback: Try raw JSON
if not tool_calls:
raw_calls = self._parse_raw_json_tool_calls(cleaned_text)
if raw_calls:
@@ -327,6 +336,7 @@ def extract_tool_calls_streaming(
"""
# Check for any tool call markers
markers = [
+ "<|tool_call>",
self.MISTRAL_TOKEN,
"[Calling tool:",
"",
@@ -339,7 +349,7 @@ def extract_tool_calls_streaming(
return {"content": delta_text}
# Check for completion markers
- end_markers = ["", "", ")]"]
+ end_markers = ["", "", "", ")]"]
if any(m in delta_text for m in end_markers):
result = self.extract_tool_calls(current_text)
if result.tools_called:
diff --git a/vllm_mlx/tool_parsers/gemma4_tool_parser.py b/vllm_mlx/tool_parsers/gemma4_tool_parser.py
new file mode 100644
index 000000000..a32fd90cf
--- /dev/null
+++ b/vllm_mlx/tool_parsers/gemma4_tool_parser.py
@@ -0,0 +1,237 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Gemma 4 tool call parser for vllm-mlx.
+
+Handles Gemma 4's native tool call format:
+ <|tool_call>call:func_name{<|"|>key<|"|>: <|"|>value<|"|>, num: 42}
+
+Gemma 4 uses special tokens instead of JSON:
+- <|tool_call> / delimit tool call blocks
+- <|"|> replaces " for string values
+- Keys are unquoted bare identifiers
+- Multiple call:name{...} can appear in a single block
+
+Reference: mlx-lm PR #1105, vllm PR #38837
+"""
+
+import json
+import logging
+import re
+import uuid
+from collections.abc import Sequence
+from typing import Any
+
+from .abstract_tool_parser import (
+ ExtractedToolCallInformation,
+ ToolParser,
+ ToolParserManager,
+)
+
+logger = logging.getLogger(__name__)
+
+# Delimiters
+TOOL_CALL_START = "<|tool_call>"
+TOOL_CALL_END = ""
+
+# Placeholder token used during <|"|> extraction. Matches \x00 + digits + \x00.
+_PLACEHOLDER_RE = re.compile(r"\x00(\d+)\x00")
+
+# Pattern to extract <|"|>-delimited strings (non-greedy, supports multiline)
+_STRING_DELIM_RE = re.compile(r'<\|"\|>(.*?)<\|"\|>', re.DOTALL)
+
+# Pattern to match call:name followed by a { (we extract balanced braces manually)
+_CALL_PREFIX = re.compile(r"call:(\w+)\s*\{")
+
+# Pattern to quote bare keys: word followed by : at start or after , or {
+_BARE_KEY = re.compile(r"(?<=[{,])\s*(\w+)\s*:")
+
+# Max arg block length to prevent runaway parsing on malformed input (1 MB)
+_MAX_ARG_BLOCK_LEN = 1_048_576
+
+
+def _find_balanced_brace(text: str, start: int) -> int:
+ """Find the index of the closing } that balances the { at `start`.
+
+ Before counting braces, <|"|>-delimited strings are conceptually opaque --
+ we skip over <|"|>...<|"|> regions so that braces inside string values
+ (e.g. code snippets) don't affect depth counting.
+
+ Args:
+ text: The string to search (may contain <|"|> tokens)
+ start: Index of the opening {
+
+ Returns:
+ Index of the matching } in the ORIGINAL text, or -1 if not found
+ """
+ if len(text) - start > _MAX_ARG_BLOCK_LEN:
+ return -1
+
+ depth = 0
+ i = start
+ in_string = False
+ while i < len(text):
+ if text.startswith('<|"|>', i):
+ in_string = not in_string
+ i += 5
+ continue
+ if not in_string:
+ if text[i] == "{":
+ depth += 1
+ elif text[i] == "}":
+ depth -= 1
+ if depth == 0:
+ return i
+ i += 1
+ return -1
+
+
+def _gemma4_args_to_json(text: str) -> str:
+ """Convert Gemma 4 tool call args to valid JSON.
+
+ Three-step conversion (ORDER MATTERS):
+ 1. Extract <|"|>-delimited strings into numbered \\x00N\\x00 placeholders.
+ This protects string contents from step 2's bare-key quoting -- without
+ this, a string value like "key: value" would be corrupted.
+ 2. Quote bare keys (word: -> "word":) now that strings are safe.
+ 3. Restore placeholders as properly JSON-escaped strings via json.dumps().
+ Uses a single re.sub pass (O(len(text))) instead of per-placeholder replace.
+ """
+ strings: list[str] = []
+
+ def _capture(m: re.Match) -> str:
+ strings.append(m.group(1))
+ return f"\x00{len(strings) - 1}\x00"
+
+ # Step 1: Extract <|"|>-delimited strings
+ text = _STRING_DELIM_RE.sub(_capture, text)
+
+ # Step 2: Quote bare keys
+ text = _BARE_KEY.sub(r'"\1":', text)
+
+ # Step 3: Restore captured strings as properly escaped JSON strings
+ def _restore(m: re.Match) -> str:
+ idx = int(m.group(1))
+ return json.dumps(strings[idx]) if idx < len(strings) else m.group(0)
+
+ text = _PLACEHOLDER_RE.sub(_restore, text)
+
+ return text
+
+
+def generate_tool_id() -> str:
+ """Generate a unique tool call ID."""
+ return f"call_{uuid.uuid4().hex[:8]}"
+
+
+@ToolParserManager.register_module("gemma4")
+class Gemma4ToolParser(ToolParser):
+ """
+ Tool call parser for Gemma 4 models.
+
+ Parses: <|tool_call>call:func{<|"|>key<|"|>: <|"|>val<|"|>}
+
+ Used when --enable-auto-tool-choice --tool-call-parser gemma4 are set.
+ """
+
+ def extract_tool_calls(
+ self, model_output: str, request: dict[str, Any] | None = None
+ ) -> ExtractedToolCallInformation:
+ """Extract tool calls from a complete Gemma 4 model response."""
+ cleaned = self.strip_think_tags(model_output)
+
+ start_idx = cleaned.find(TOOL_CALL_START)
+ if start_idx == -1:
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+
+ content_before = cleaned[:start_idx].strip() or None
+
+ block_start = start_idx + len(TOOL_CALL_START)
+ end_idx = cleaned.find(TOOL_CALL_END, block_start)
+ if end_idx == -1:
+ block = cleaned[block_start:]
+ else:
+ block = cleaned[block_start:end_idx]
+
+ tool_calls: list[dict[str, Any]] = []
+
+ pos = 0
+ while pos < len(block):
+ m = _CALL_PREFIX.search(block, pos)
+ if not m:
+ break
+
+ func_name = m.group(1)
+ brace_start = m.end() - 1
+
+ brace_end = _find_balanced_brace(block, brace_start)
+ if brace_end == -1:
+ pos = m.end()
+ continue
+
+ args_raw = block[brace_start : brace_end + 1]
+ try:
+ args_json = _gemma4_args_to_json(args_raw)
+ json.loads(args_json)
+ tool_calls.append(
+ {
+ "id": generate_tool_id(),
+ "name": func_name,
+ "arguments": args_json,
+ }
+ )
+ except (json.JSONDecodeError, ValueError) as e:
+ logger.warning(
+ f"Gemma 4 tool parser: failed to parse args for "
+ f"call:{func_name}: {e}"
+ )
+
+ pos = brace_end + 1
+
+ if tool_calls:
+ return ExtractedToolCallInformation(
+ tools_called=True,
+ tool_calls=tool_calls,
+ content=content_before,
+ )
+ else:
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int] | None = None,
+ current_token_ids: Sequence[int] | None = None,
+ delta_token_ids: Sequence[int] | None = None,
+ request: dict[str, Any] | None = None,
+ ) -> dict[str, Any] | None:
+ """Extract tool calls from streaming Gemma 4 model output."""
+ has_start = TOOL_CALL_START in current_text
+
+ if not has_start:
+ return {"content": delta_text}
+
+ if TOOL_CALL_END in delta_text:
+ result = self.extract_tool_calls(current_text)
+ if result.tools_called:
+ return {
+ "tool_calls": [
+ {
+ "index": i,
+ "id": tc["id"],
+ "type": "function",
+ "function": {
+ "name": tc["name"],
+ "arguments": tc["arguments"],
+ },
+ }
+ for i, tc in enumerate(result.tool_calls)
+ ]
+ }
+
+ return None
diff --git a/vllm_mlx/tool_parsers/minimax_tool_parser.py b/vllm_mlx/tool_parsers/minimax_tool_parser.py
new file mode 100644
index 000000000..7459fe97f
--- /dev/null
+++ b/vllm_mlx/tool_parsers/minimax_tool_parser.py
@@ -0,0 +1,172 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+MiniMax tool call parser for vllm-mlx.
+
+Parses the MiniMax-M2 native XML tool call format:
+
+
+param-value
+
+
+"""
+
+import json
+import re
+import uuid
+from collections.abc import Sequence
+from typing import Any
+
+from .abstract_tool_parser import (
+ ExtractedToolCallInformation,
+ ToolParser,
+ ToolParserManager,
+)
+
+
+def generate_tool_id() -> str:
+ return f"call_{uuid.uuid4().hex[:8]}"
+
+
+@ToolParserManager.register_module(["minimax", "minimax_m2"])
+class MiniMaxToolParser(ToolParser):
+ """
+ Parser for MiniMax-M2 tool call format.
+
+ Format:
+
+
+ value
+
+
+ """
+
+ TOOL_CALL_BLOCK = re.compile(
+ r"(.*?)", re.DOTALL
+ )
+ INVOKE_PATTERN = re.compile(r'(.*?)', re.DOTALL)
+ PARAM_PATTERN = re.compile(
+ r'(.*?)', re.DOTALL
+ )
+ THINK_PATTERN = re.compile(r".*?", re.DOTALL)
+
+ def _extract_invokes(self, text: str) -> list[dict[str, Any]]:
+ """Extract tool calls from invoke elements, with or without wrapper."""
+ tool_calls: list[dict[str, Any]] = []
+ invokes = self.INVOKE_PATTERN.findall(text)
+ for func_name, params_block in invokes:
+ params = self.PARAM_PATTERN.findall(params_block)
+ # Skip bare tags without parameters (hallucinated junk)
+ if not params:
+ continue
+ arguments = {}
+ for p_name, p_value in params:
+ p_value = p_value.strip()
+ try:
+ arguments[p_name] = json.loads(p_value)
+ except (json.JSONDecodeError, ValueError):
+ arguments[p_name] = p_value
+
+ tool_calls.append(
+ {
+ "id": generate_tool_id(),
+ "name": func_name.strip(),
+ "arguments": json.dumps(arguments, ensure_ascii=False),
+ }
+ )
+ return tool_calls
+
+ def extract_tool_calls(
+ self, model_output: str, request: dict[str, Any] | None = None
+ ) -> ExtractedToolCallInformation:
+ # Try wrapped format first: ......
+ blocks = self.TOOL_CALL_BLOCK.findall(model_output)
+ if blocks:
+ tool_calls: list[dict[str, Any]] = []
+ for block in blocks:
+ tool_calls.extend(self._extract_invokes(block))
+
+ cleaned = self.TOOL_CALL_BLOCK.sub("", model_output).strip()
+ cleaned = self.THINK_PATTERN.sub("", cleaned).strip()
+ cleaned = re.sub(r"\[e~\[.*$", "", cleaned).strip()
+
+ return ExtractedToolCallInformation(
+ tools_called=bool(tool_calls),
+ tool_calls=tool_calls,
+ content=cleaned if cleaned else None,
+ )
+
+ # Fallback: bare without wrapper
+ # (model sometimes emits tool calls inside without wrapper)
+ tool_calls = self._extract_invokes(model_output)
+ if tool_calls:
+ # Strip matched invoke blocks and thinking from content
+ cleaned = self.INVOKE_PATTERN.sub("", model_output).strip()
+ cleaned = self.THINK_PATTERN.sub("", cleaned).strip()
+ cleaned = re.sub(r"\[e~\[.*$", "", cleaned).strip()
+ # Remove leftover closing tags
+ cleaned = cleaned.replace("", "").strip()
+
+ return ExtractedToolCallInformation(
+ tools_called=True,
+ tool_calls=tool_calls,
+ content=cleaned if cleaned else None,
+ )
+
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+
+ def _has_tool_start(self, text: str) -> bool:
+ """Check if text contains the start of a tool call block."""
+ return "" in text or (
+ '" in current:
+ return (
+ "" in current
+ and "" not in previous
+ )
+ # Bare invoke: just appeared
+ if "" in current and "" not in previous:
+ return True
+ return False
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int] | None = None,
+ current_token_ids: Sequence[int] | None = None,
+ delta_token_ids: Sequence[int] | None = None,
+ request: dict[str, Any] | None = None,
+ ) -> dict[str, Any] | None:
+ # Not inside a tool call block yet — pass content through
+ if not self._has_tool_start(current_text):
+ return {"content": delta_text}
+
+ # Tool call block just completed
+ if self._has_tool_end(current_text, previous_text):
+ result = self.extract_tool_calls(current_text)
+ if result.tools_called:
+ return {
+ "tool_calls": [
+ {
+ "index": i,
+ "id": tc["id"],
+ "type": "function",
+ "function": {
+ "name": tc["name"],
+ "arguments": tc["arguments"],
+ },
+ }
+ for i, tc in enumerate(result.tool_calls)
+ ]
+ }
+
+ # Inside tool call block but not yet complete — suppress output
+ return None
diff --git a/vllm_mlx/tool_parsers/qwen_tool_parser.py b/vllm_mlx/tool_parsers/qwen_tool_parser.py
index fd69b96c0..e235a3c7d 100644
--- a/vllm_mlx/tool_parsers/qwen_tool_parser.py
+++ b/vllm_mlx/tool_parsers/qwen_tool_parser.py
@@ -5,8 +5,10 @@
Handles Qwen's tool calling formats:
- XML style:
{"name": "func", "arguments": {...}}
- Bracket style: [Calling tool: func_name({"arg": "value"})]
+- Function style:
value
"""
+import ast
import json
import re
import uuid
@@ -20,6 +22,24 @@
)
+def _parse_param_value(val: str) -> Any:
+ """Parse a parameter value, handling JSON literals and plain strings."""
+ try:
+ return json.loads(val)
+ except (json.JSONDecodeError, ValueError):
+ pass
+ try:
+ python_val = ast.literal_eval(val)
+ if isinstance(python_val, set):
+ python_val = sorted(python_val, key=str)
+ if isinstance(python_val, (complex, bytes)):
+ return val
+ json.dumps(python_val)
+ return python_val
+ except (ValueError, SyntaxError, TypeError):
+ return val
+
+
def generate_tool_id() -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:8]}"
@@ -33,6 +53,7 @@ class QwenToolParser(ToolParser):
Supports multiple Qwen tool call formats:
- XML:
{"name": "func", "arguments": {...}}
- Bracket: [Calling tool: func_name({"arg": "value"})]
+ - Function:
value
Used when --enable-auto-tool-choice --tool-call-parser qwen are set.
"""
@@ -43,6 +64,12 @@ class QwenToolParser(ToolParser):
# Pattern for bracket-style: [Calling tool: func_name({...})]
BRACKET_PATTERN = re.compile(r"\[Calling tool:\s*(\w+)\((\{.*?\})\)\]", re.DOTALL)
+ # Pattern for function-style:
...
+ FUNCTION_PATTERN = re.compile(r"
]+)>(.*?)", re.DOTALL)
+
+ # Pattern for parameter extraction:
value
+ PARAM_PATTERN = re.compile(r"
]+)>\s*(.*?)\s*", re.DOTALL)
+
def extract_tool_calls(
self, model_output: str, request: dict[str, Any] | None = None
) -> ExtractedToolCallInformation:
@@ -101,6 +128,41 @@ def extract_tool_calls(
if xml_matches:
cleaned_text = self.XML_PATTERN.sub("", cleaned_text).strip()
+ # Try function-style:
value
+ # Qwen3.5 generates this format natively.
+ if not tool_calls:
+ func_matches = self.FUNCTION_PATTERN.findall(cleaned_text)
+ for name, params_block in func_matches:
+ # Try JSON arguments first (e.g.
{"key": "val"})
+ params_block_stripped = params_block.strip()
+ if params_block_stripped.startswith("{"):
+ try:
+ arguments = json.loads(params_block_stripped)
+ tool_calls.append(
+ {
+ "id": generate_tool_id(),
+ "name": name.strip(),
+ "arguments": json.dumps(arguments, ensure_ascii=False),
+ }
+ )
+ continue
+ except json.JSONDecodeError:
+ pass
+ # Parse
value tags
+ params = self.PARAM_PATTERN.findall(params_block)
+ arguments = {}
+ for p_name, p_value in params:
+ arguments[p_name.strip()] = _parse_param_value(p_value.strip())
+ tool_calls.append(
+ {
+ "id": generate_tool_id(),
+ "name": name.strip(),
+ "arguments": json.dumps(arguments, ensure_ascii=False),
+ }
+ )
+ if func_matches:
+ cleaned_text = self.FUNCTION_PATTERN.sub("", cleaned_text).strip()
+
if tool_calls:
return ExtractedToolCallInformation(
tools_called=True,
@@ -112,6 +174,30 @@ def extract_tool_calls(
tools_called=False, tool_calls=[], content=model_output
)
+ # Partial marker prefixes — when current_text ends with one of these,
+ # we suppress output until the next token confirms or denies a tool call.
+ # These are long enough to avoid false positives on normal text.
+ _PARTIAL_MARKERS = ("
bool:
+ """Check if text ends with an incomplete tool call marker prefix."""
+ return self._get_partial_marker_len(text) > 0
+
+ def _get_partial_marker_len(self, text: str) -> int:
+ """Return the length of a partial tool call marker suffix at end of text."""
+ tail = text[-20:]
+ best = 0
+ for marker in self._PARTIAL_MARKERS:
+ for length in range(len(marker), 0, -1):
+ if tail.endswith(marker[:length]) and length > best:
+ best = length
+ break
+ return best
+
+ def _was_buffering(self, previous_text: str) -> bool:
+ """Check if the previous call was buffering a partial marker."""
+ return self._has_partial_marker(previous_text)
+
def extract_tool_calls_streaming(
self,
previous_text: str,
@@ -125,14 +211,67 @@ def extract_tool_calls_streaming(
"""
Extract tool calls from streaming Qwen model output.
"""
- # Check for tool call markers
+ # Check for complete tool call markers
has_tool_marker = (
- "" in current_text or "[Calling tool:" in current_text
+ "" in current_text
+ or "[Calling tool:" in current_text
+ or " 0:
+ return {"content": delta_text[:safe_chars]}
+ return None
+ # If we were buffering before but the marker didn't complete,
+ # emit the buffered marker prefix together with the new delta.
+ if self._was_buffering(previous_text):
+ for marker in self._PARTIAL_MARKERS:
+ for length in range(len(marker), 0, -1):
+ prefix = marker[:length]
+ if previous_text.endswith(prefix):
+ return {"content": prefix + delta_text}
+ return {"content": delta_text}
return {"content": delta_text}
+ # Handle ... (Qwen3.5 native format)
+ if "")
+ prev_func_close = previous_text.count("")
+
+ if current_text.count(" func_close_count:
+ # Inside an incomplete function block, suppress output
+ return None
+
+ if func_close_count > prev_func_close:
+ # New function block(s) completed
+ result = self.extract_tool_calls(current_text)
+ if result.tools_called:
+ new_calls = result.tool_calls[prev_func_close:]
+ if new_calls:
+ return {
+ "tool_calls": [
+ {
+ "index": prev_func_close + i,
+ "id": tc["id"],
+ "type": "function",
+ "function": {
+ "name": tc["name"],
+ "arguments": tc["arguments"],
+ },
+ }
+ for i, tc in enumerate(new_calls)
+ ]
+ }
+
+ return None
+
# If we're in a tool call, accumulate and parse at the end
# For simplicity, return None during accumulation
if "" in delta_text or ")]" in delta_text:
diff --git a/vllm_mlx/utils/__init__.py b/vllm_mlx/utils/__init__.py
index e808515ad..14d5de5c8 100644
--- a/vllm_mlx/utils/__init__.py
+++ b/vllm_mlx/utils/__init__.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
"""Utility modules for vllm-mlx."""
+from .download import DownloadConfig, ensure_model_downloaded
from .tokenizer import load_model_with_fallback
-__all__ = ["load_model_with_fallback"]
+__all__ = ["DownloadConfig", "ensure_model_downloaded", "load_model_with_fallback"]
diff --git a/vllm_mlx/utils/download.py b/vllm_mlx/utils/download.py
new file mode 100644
index 000000000..39941c7af
--- /dev/null
+++ b/vllm_mlx/utils/download.py
@@ -0,0 +1,144 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Resumable model download with retry/timeout support.
+
+Pre-downloads models via huggingface_hub.snapshot_download() with
+configurable timeout and retry logic before passing to mlx-lm/mlx-vlm.
+"""
+
+import logging
+import os
+import time
+from dataclasses import dataclass
+from pathlib import Path
+
+from huggingface_hub import snapshot_download
+
+logger = logging.getLogger(__name__)
+
+# Mirrors mlx_lm.utils._download() default allow_patterns
+LLM_ALLOW_PATTERNS = [
+ "*.json",
+ "model*.safetensors",
+ "*.py",
+ "tokenizer.model",
+ "*.tiktoken",
+ "tiktoken.model",
+ "*.txt",
+ "*.jsonl",
+ "*.jinja",
+]
+
+# Mirrors mlx_vlm.utils.get_model_path() allow_patterns
+MLLM_ALLOW_PATTERNS = [
+ "*.json",
+ "*.safetensors",
+ "*.py",
+ "*.model",
+ "*.tiktoken",
+ "*.txt",
+ "*.jinja",
+]
+
+
+@dataclass
+class DownloadConfig:
+ """Configuration for model download behavior."""
+
+ download_timeout: int = 300
+ max_retries: int = 3
+ retry_backoff_base: float = 2.0
+ offline: bool = False
+
+
+def ensure_model_downloaded(
+ model_name: str,
+ config: DownloadConfig | None = None,
+ is_mllm: bool = False,
+) -> Path:
+ """
+ Ensure a model is available locally, downloading with retry if needed.
+
+ Args:
+ model_name: HuggingFace model name or local path.
+ config: Download configuration. Uses defaults if None.
+ is_mllm: If True, use MLLM download patterns (broader file set).
+
+ Returns:
+ Path to the local model directory.
+
+ Raises:
+ RuntimeError: If download fails after all retries.
+ KeyboardInterrupt: Propagated immediately without retry.
+ """
+ if config is None:
+ config = DownloadConfig()
+
+ model_path = Path(model_name)
+ if model_path.exists():
+ logger.info(f"Model found at local path: {model_path}")
+ return model_path
+
+ if config.offline:
+ logger.info(f"Offline mode: looking for cached {model_name}")
+ try:
+ result = Path(snapshot_download(model_name, local_files_only=True))
+ logger.info(f"Found cached model at {result}")
+ return result
+ except Exception as e:
+ raise RuntimeError(
+ f"Model '{model_name}' not found in local cache. "
+ f"Download it first without --offline flag."
+ ) from e
+
+ allow_patterns = MLLM_ALLOW_PATTERNS if is_mllm else LLM_ALLOW_PATTERNS
+
+ # Set HF download timeout via environment variable
+ old_timeout = os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT")
+ os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = str(config.download_timeout)
+
+ last_error = None
+ try:
+ for attempt in range(1, config.max_retries + 1):
+ try:
+ logger.info(
+ f"Downloading model {model_name} "
+ f"(attempt {attempt}/{config.max_retries}, "
+ f"timeout={config.download_timeout}s)"
+ )
+ result = Path(
+ snapshot_download(
+ model_name,
+ allow_patterns=allow_patterns,
+ )
+ )
+ logger.info(f"Model downloaded successfully to {result}")
+ return result
+ except KeyboardInterrupt:
+ logger.warning("Download interrupted by user.")
+ raise
+ except Exception as e:
+ last_error = e
+ if attempt < config.max_retries:
+ wait = config.retry_backoff_base**attempt
+ logger.warning(
+ f"Download attempt {attempt} failed: {e}. "
+ f"Retrying in {wait:.0f}s..."
+ )
+ time.sleep(wait)
+ else:
+ logger.error(
+ f"Download failed after {config.max_retries} attempts."
+ )
+
+ raise RuntimeError(
+ f"Failed to download '{model_name}' after {config.max_retries} "
+ f"attempts. Last error: {last_error}\n"
+ f"Run the same command again to resume the download."
+ )
+ finally:
+ # Restore original env var
+ if old_timeout is None:
+ os.environ.pop("HF_HUB_DOWNLOAD_TIMEOUT", None)
+ else:
+ os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = old_timeout
diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py
index a50883951..9d200ab9f 100644
--- a/vllm_mlx/utils/tokenizer.py
+++ b/vllm_mlx/utils/tokenizer.py
@@ -28,6 +28,27 @@ def _needs_tokenizer_fallback(model_name: str) -> bool:
return any(pattern.lower() in model_lower for pattern in FALLBACK_MODELS)
+def _needs_strict_false(model_name: str) -> bool:
+ """Check if model needs strict=False loading (VLM models with extra weights).
+
+ VLM models (e.g., Qwen3.5) have vision_tower weights that don't match
+ the text-only model class. Loading with strict=True fails and wastes
+ memory by loading all weights (~100 GB) before raising ValueError.
+ Detect these models up-front to avoid the double-load penalty.
+ """
+ from mlx_lm.utils import _download, load_config
+
+ try:
+ model_path = _download(model_name)
+ config = load_config(model_path)
+ except Exception:
+ return False
+ # VLM models have vision_config or text_config with a separate model_type
+ if "vision_config" in config and "text_config" in config:
+ return True
+ return False
+
+
def load_model_with_fallback(model_name: str, tokenizer_config: dict = None):
"""
Load model and tokenizer with fallback for non-standard tokenizers.
@@ -50,6 +71,15 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None):
)
return _load_with_tokenizer_fallback(model_name)
+ # VLM models (e.g., Qwen3.5) have extra vision weights that cause
+ # strict=True to fail. Skip the first load attempt to avoid loading
+ # ~100 GB of weights twice (which can cause OOM on 256 GB systems).
+ if _needs_strict_false(model_name):
+ logger.info(
+ f"Model {model_name} detected as VLM, loading directly with strict=False"
+ )
+ return _load_strict_false(model_name, tokenizer_config)
+
try:
model, tokenizer = load(model_name, tokenizer_config=tokenizer_config)
except ValueError as e:
@@ -59,42 +89,89 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None):
return _load_with_tokenizer_fallback(model_name)
# Fallback for models with extra weights (e.g., vision tower, MTP layers).
# Retry with strict=False to discard extra weights.
- if "parameters not in model" in str(e):
+ elif "parameters not in model" in str(e):
logger.warning(
f"Extra parameters found (e.g., vision tower / MTP weights), "
f"retrying with strict=False: {e}"
)
+ # Clear traceback references to free memory from the failed first load.
+ # Without this, large models (200GB+) cause OOM during retry because
+ # the traceback holds references to the first load's weight tensors.
+ e.__traceback__ = None
+ del e
+ import gc
+
+ gc.collect()
return _load_strict_false(model_name, tokenizer_config)
- raise
+ else:
+ raise
+ # After successful load, check if MTP weights exist but were stripped by sanitize()
+ _try_inject_mtp_post_load(model, model_name)
+ return model, tokenizer
-def _load_strict_false(model_name: str, tokenizer_config: dict = None):
- """Load model with strict=False to discard extra weights (e.g., vision tower, MTP)."""
- from mlx_lm.utils import load_model, load_tokenizer
- local_path = Path(model_name)
- if local_path.is_dir():
- model_path = local_path
- else:
- from huggingface_hub import snapshot_download
+def _load_strict_false(model_name: str, tokenizer_config: dict = None):
+ """Load model with strict=False to discard extra weights.
- model_path = Path(snapshot_download(model_name))
+ Handles models with extra parameters that the text-only model class
+ doesn't define (e.g., vision tower weights in VLM models like Qwen3.5,
+ or MTP layers). The model's own sanitize() handles key remapping
+ (e.g., language_model.* prefix), and strict=False silently drops
+ unmatched keys.
+ """
+ import mlx.core as mx
+ from mlx_lm.utils import _download, load_model, load_tokenizer
+ model_path = _download(model_name)
model, config = load_model(model_path, strict=False)
+
+ # Verify weights loaded correctly
+ from mlx.utils import tree_flatten
+
+ params = tree_flatten(model.parameters())
+ total_params = len(params)
+ zero_params = sum(1 for _, v in params if mx.all(v == 0).item())
+ logger.info(
+ f"[strict=False] Loaded {total_params} parameters, "
+ f"{zero_params} all-zero tensors"
+ )
+ # Spot-check embedding weights
+ if hasattr(model, "language_model"):
+ emb = model.language_model.model.embed_tokens.weight
+ logger.info(
+ f"[strict=False] embed_tokens: shape={emb.shape}, "
+ f"dtype={emb.dtype}, mean={mx.mean(emb.astype(mx.float32)).item():.4f}"
+ )
+
tokenizer = load_tokenizer(
model_path,
tokenizer_config or {},
eos_token_ids=config.get("eos_token_id", None),
)
- # Inject MTP support if model has MTP config + weights
_try_inject_mtp(model, model_path, config)
return model, tokenizer
def _try_inject_mtp(model, model_path, config):
"""Inject MTP support if model has MTP config + weights."""
+ # Qwen3-Next: flat num_nextn_predict_layers
if config.get("num_nextn_predict_layers", 0) > 0:
- from ..patches.qwen3_next_mtp import inject_mtp_support
+ # Detect Qwen3.5 vs Qwen3-Next by checking text_config or model_type
+ text_config = config.get("text_config", config)
+ model_type = text_config.get("model_type", config.get("model_type", ""))
+ if "qwen3_5" in model_type:
+ from ..patches.qwen3_5_mtp import inject_mtp_support
+ else:
+ from ..patches.qwen3_next_mtp import inject_mtp_support
+ inject_mtp_support(model, model_path, config)
+ return
+
+ # Qwen3.5: mtp_num_hidden_layers in text_config
+ text_config = config.get("text_config", config)
+ num_mtp = text_config.get("mtp_num_hidden_layers", 0)
+ if num_mtp > 0:
+ from ..patches.qwen3_5_mtp import inject_mtp_support
inject_mtp_support(model, model_path, config)
@@ -111,13 +188,21 @@ def _try_inject_mtp_post_load(model, model_name):
return
with open(config_path) as f:
config = json.load(f)
- # Also check text_config for nested configs
+ # Check for MTP in flat config and nested text_config
+ text_config = config.get("text_config", {})
num_mtp = config.get("num_nextn_predict_layers", 0)
if num_mtp == 0:
- text_config = config.get("text_config", {})
num_mtp = text_config.get("num_nextn_predict_layers", 0)
- if num_mtp > 0 and getattr(model, "mtp", None) is None:
- mtp_file = Path(model_path) / "model-mtp.safetensors"
+ if num_mtp == 0:
+ num_mtp = text_config.get("mtp_num_hidden_layers", 0)
+ # Also check mtp attribute on language_model for VLM wrappers
+ check_model = model
+ if hasattr(model, "language_model"):
+ check_model = model.language_model
+ if num_mtp > 0 and getattr(check_model, "mtp", None) is None:
+ mtp_file = Path(model_path) / "mtp" / "weights.safetensors"
+ if not mtp_file.exists():
+ mtp_file = Path(model_path) / "model-mtp.safetensors"
if mtp_file.exists():
logger.info(
f"[MTP] Found MTP config (layers={num_mtp}) and weights, injecting..."
@@ -126,7 +211,7 @@ def _try_inject_mtp_post_load(model, model_name):
else:
logger.info(
f"[MTP] Config has num_nextn_predict_layers={num_mtp} "
- "but model-mtp.safetensors not found, skipping MTP."
+ "but MTP weights not found, skipping MTP."
)
@@ -134,16 +219,12 @@ def _load_with_tokenizer_fallback(model_name: str):
"""Load model with fallback tokenizer for non-standard models like Nemotron."""
from mlx_lm.utils import load_model
- logger.info("Loading with tokenizer fallback...")
+ from .download import ensure_model_downloaded
- # Get model path - use local path if it exists, otherwise download from Hub
- local_path = Path(model_name)
- if local_path.is_dir():
- model_path = local_path
- else:
- from huggingface_hub import snapshot_download
+ logger.info("Loading with tokenizer fallback...")
- model_path = Path(snapshot_download(model_name))
+ # Get model path (with retry/timeout support)
+ model_path = ensure_model_downloaded(model_name, is_mllm=False)
# Load model
model, _ = load_model(model_path)