Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def serve_command(args):

parser_cls = get_parser(args.reasoning_parser)
server._reasoning_parser = parser_cls()
server._reasoning_parser_name = args.reasoning_parser
logger.info(f"Reasoning parser enabled: {args.reasoning_parser}")
except KeyError as e:
print(f"Error: {e}")
Expand Down Expand Up @@ -216,6 +217,9 @@ def serve_command(args):
gpu_memory_utilization=args.gpu_memory_utilization,
draft_model=args.draft_model,
num_draft_tokens=args.num_draft_tokens,
prefill_step_size=args.prefill_step_size,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
)

# Start server
Expand Down Expand Up @@ -845,6 +849,25 @@ def main():
default=4,
help="Number of tokens to generate speculatively per step (default: 4)",
)
serve_parser.add_argument(
"--prefill-step-size",
type=int,
default=2048,
help="Tokens to process per prefill chunk in simple mode (default: 2048)",
)
serve_parser.add_argument(
"--kv-bits",
type=int,
default=None,
choices=[4, 8],
help="KV cache quantization bits for simple mode (4 or 8). Reduces memory for long contexts.",
)
serve_parser.add_argument(
"--kv-group-size",
type=int,
default=64,
help="Group size for KV cache quantization in simple mode (default: 64)",
)
# Reasoning parser options - choices loaded dynamically from registry
from .reasoning import list_parsers

Expand Down
21 changes: 14 additions & 7 deletions vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(
force_mllm: bool = False,
draft_model: str | None = None,
num_draft_tokens: int = 4,
prefill_step_size: int = 2048,
kv_bits: int | None = None,
kv_group_size: int = 64,
):
"""
Initialize the simple engine.
Expand All @@ -54,13 +57,19 @@ def __init__(
force_mllm: Force loading as MLLM even if not auto-detected
draft_model: Optional draft model path for speculative decoding
num_draft_tokens: Number of tokens to generate speculatively per step
prefill_step_size: Tokens to process per prefill chunk (default: 2048)
kv_bits: KV cache quantization bits (None=no quantization, 4 or 8)
kv_group_size: Group size for KV cache quantization (default: 64)
"""
self._model_name = model_name
self._trust_remote_code = trust_remote_code
self._enable_cache = enable_cache
self._is_mllm = force_mllm or is_mllm_model(model_name)
self._draft_model_name = draft_model
self._num_draft_tokens = num_draft_tokens
self._prefill_step_size = prefill_step_size
self._kv_bits = kv_bits
self._kv_group_size = kv_group_size

self._model = None
self._loaded = False
Expand Down Expand Up @@ -110,6 +119,9 @@ async def start(self) -> None:
trust_remote_code=self._trust_remote_code,
draft_model=self._draft_model_name,
num_draft_tokens=self._num_draft_tokens,
prefill_step_size=self._prefill_step_size,
kv_bits=self._kv_bits,
kv_group_size=self._kv_group_size,
)

self._model.load()
Expand Down Expand Up @@ -207,8 +219,7 @@ async def stream_generate(

async with self._generation_lock:
accumulated_text = ""
# Compute prompt tokens upfront since StreamingOutput doesn't carry them
prompt_tokens = len(self._model.tokenizer.encode(prompt))
prompt_tokens = 0
completion_tokens = 0
finished = False

Expand All @@ -220,11 +231,7 @@ async def stream_generate(
stop=stop,
**kwargs,
):
prompt_tokens = (
chunk.prompt_tokens
if hasattr(chunk, "prompt_tokens")
else prompt_tokens
)
prompt_tokens = getattr(chunk, "prompt_tokens", 0) or prompt_tokens
completion_tokens += 1
new_text = chunk.text if hasattr(chunk, "text") else str(chunk)
accumulated_text += new_text
Expand Down
47 changes: 42 additions & 5 deletions vllm_mlx/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class StreamingOutput:
finished: bool = False
finish_reason: str | None = None
logprobs: Any = None # mx.array of shape [vocab_size] from mlx-lm
prompt_tokens: int = 0


class MLXLanguageModel:
Expand All @@ -57,6 +58,9 @@ def __init__(
trust_remote_code: bool = False,
draft_model: str | None = None,
num_draft_tokens: int = 4,
prefill_step_size: int = 2048,
kv_bits: int | None = None,
kv_group_size: int = 64,
):
"""
Initialize the MLX language model.
Expand All @@ -67,12 +71,18 @@ def __init__(
trust_remote_code: Whether to trust remote code
draft_model: Optional draft model path for speculative decoding
num_draft_tokens: Number of tokens to generate speculatively per step
prefill_step_size: Tokens to process per prefill chunk (default: 2048)
kv_bits: KV cache quantization bits (None=no quantization, 4 or 8)
kv_group_size: Group size for KV cache quantization (default: 64)
"""
self.model_name = model_name
self.tokenizer_name = tokenizer_name or model_name
self.trust_remote_code = trust_remote_code
self.draft_model_name = draft_model
self.num_draft_tokens = num_draft_tokens
self.prefill_step_size = prefill_step_size
self.kv_bits = kv_bits
self.kv_group_size = kv_group_size

self.model = None
self.tokenizer = None
Expand Down Expand Up @@ -320,8 +330,12 @@ def stream_generate(
if not self._loaded:
self.load()

import time as _time

from mlx_lm import stream_generate

t0 = _time.perf_counter()

# Tokenize the full prompt
add_special_tokens = (
self.tokenizer.bos_token is None
Expand All @@ -331,6 +345,8 @@ def stream_generate(
prompt, add_special_tokens=add_special_tokens
)

t_tokenize = _time.perf_counter()

# Prepare cache and get only the tokens that need processing
suffix_tokens = self._prepare_cache_for_prompt(full_token_ids)
prefix_len = len(full_token_ids) - len(suffix_tokens)
Expand All @@ -341,6 +357,10 @@ def stream_generate(
f"{len(suffix_tokens)} new tokens "
f"(saved {prefix_len} tokens of prefill)"
)
else:
logger.info(
f"Prompt cache miss: {len(full_token_ids)} tokens to prefill"
)

# Create sampler with parameters
sampler = self._create_sampler(temperature, top_p)
Expand All @@ -353,8 +373,14 @@ def stream_generate(
"max_tokens": max_tokens,
"sampler": sampler,
"prompt_cache": self._prompt_cache,
"prefill_step_size": self.prefill_step_size,
}

# KV cache quantization reduces memory pressure for long prompts
if self.kv_bits is not None:
gen_kwargs["kv_bits"] = self.kv_bits
gen_kwargs["kv_group_size"] = self.kv_group_size

# Add draft model for speculative decoding if available
if self.draft_model is not None:
gen_kwargs["draft_model"] = self.draft_model
Expand All @@ -373,13 +399,23 @@ def stream_generate(
else:
prompt_to_send = suffix_tokens

t_first_token = None
for response in stream_generate(
self.model,
self.tokenizer,
prompt=prompt_to_send,
**gen_kwargs,
):
token_count += 1
if token_count == 1:
t_first_token = _time.perf_counter()
logger.info(
f"TTFT breakdown: tokenize={t_tokenize - t0:.3f}s, "
f"prefill+decode={t_first_token - t_tokenize:.3f}s, "
f"total={t_first_token - t0:.3f}s "
f"(prompt={len(full_token_ids)} tokens, "
f"prefilled={len(prompt_to_send)} tokens)"
)
# response.text is the new token text (not accumulated)
new_text = response.text
accumulated_text += new_text
Expand All @@ -396,23 +432,24 @@ def stream_generate(
finish_reason = None
if finished:
finish_reason = "stop" if should_stop else "length"
# Save cache BEFORE yielding the finished chunk.
# The caller may break/abandon this generator after
# receiving the finished chunk, so code after yield
# would never execute.
self._save_cache_snapshot(full_token_ids)

yield StreamingOutput(
text=new_text,
token=response.token if hasattr(response, "token") else 0,
finished=finished,
finish_reason=finish_reason,
logprobs=getattr(response, "logprobs", None),
prompt_tokens=len(full_token_ids),
)

if finished:
break

# Save cache state: prompt tokens only (not generated tokens)
# The cache now has prompt + generated tokens; we save the prompt part
# so next request can match against it
self._save_cache_snapshot(full_token_ids)

def chat(
self,
messages: list[dict],
Expand Down
32 changes: 32 additions & 0 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,9 @@ def load_model(
gpu_memory_utilization: float = 0.90,
draft_model: str | None = None,
num_draft_tokens: int = 4,
prefill_step_size: int = 2048,
kv_bits: int | None = None,
kv_group_size: int = 64,
):
"""
Load a model (auto-detects MLLM vs LLM).
Expand All @@ -637,6 +640,9 @@ def load_model(
limit and emergency threshold (0.0-1.0, default 0.90)
draft_model: Optional draft model for speculative decoding
num_draft_tokens: Number of tokens to generate speculatively per step
prefill_step_size: Tokens to process per prefill chunk (default: 2048)
kv_bits: KV cache quantization bits (None=no quantization, 4 or 8)
kv_group_size: Group size for KV cache quantization (default: 64)
"""
global _engine, _model_name, _default_max_tokens, _tool_parser_instance

Expand Down Expand Up @@ -688,6 +694,9 @@ def load_model(
force_mllm=force_mllm,
draft_model=draft_model,
num_draft_tokens=num_draft_tokens,
prefill_step_size=prefill_step_size,
kv_bits=kv_bits,
kv_group_size=kv_group_size,
)
# Start SimpleEngine synchronously (no background loop)
# Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated)
Expand Down Expand Up @@ -2798,6 +2807,26 @@ def main():
default=4,
help="Number of tokens to generate speculatively per step (default: 4)",
)
parser.add_argument(
"--prefill-step-size",
type=int,
default=2048,
help="Tokens to process per prefill chunk (default: 2048). "
"Larger values may improve TTFT on Apple Silicon with sufficient memory.",
)
parser.add_argument(
"--kv-bits",
type=int,
default=None,
choices=[4, 8],
help="KV cache quantization bits (4 or 8). Reduces memory for long contexts.",
)
parser.add_argument(
"--kv-group-size",
type=int,
default=64,
help="Group size for KV cache quantization (default: 64)",
)

args = parser.parse_args()

Expand Down Expand Up @@ -2858,6 +2887,9 @@ def main():
force_mllm=args.mllm,
draft_model=args.draft_model,
num_draft_tokens=args.num_draft_tokens,
prefill_step_size=args.prefill_step_size,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
)

# Start server
Expand Down