Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 133 additions & 15 deletions vllm_mlx/mllm_batch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import mlx.core as mx
import mlx.nn as nn

from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig, _trim_cache_offset
from .multimodal_processor import MultimodalProcessor
from .vision_embedding_cache import VisionEmbeddingCache

Expand Down Expand Up @@ -59,6 +60,9 @@ class MLLMBatchRequest:
num_tokens: int = 0 # Tokens generated so far
output_tokens: List[int] = field(default_factory=list)

# Whether the request is text-only (no images/videos) — eligible for prefix cache
is_text_only: bool = False

# Vision state (populated after initial VLM forward pass)
vision_encoded: bool = False
cross_attention_states: Optional[Any] = None # For models that use cross-attention
Expand Down Expand Up @@ -289,6 +293,7 @@ def __init__(
prefill_step_size: int = 1024,
enable_vision_cache: bool = True,
vision_cache_size: int = 100,
prefix_cache_config: Optional[MemoryCacheConfig] = None,
):
"""
Initialize MLLM batch generator.
Expand Down Expand Up @@ -351,6 +356,14 @@ def __init__(
f"MLLMBatchGenerator: Vision cache enabled (size={vision_cache_size})"
)

# KV prefix cache for text-only requests
self.prefix_cache: Optional[MemoryAwarePrefixCache] = None
if prefix_cache_config is not None:
self.prefix_cache = MemoryAwarePrefixCache(
model=self.language_model,
config=prefix_cache_config,
)

# Generation stream
if MLLMBatchGenerator._stream is None:
MLLMBatchGenerator._stream = mx.new_stream(mx.default_device())
Expand Down Expand Up @@ -545,6 +558,9 @@ def _preprocess_request(self, request: MLLMBatchRequest) -> None:
self._stats.num_images_processed += len(all_images)
self._stats.vision_encoding_time += processing_time

# Mark text-only requests as eligible for prefix cache
request.is_text_only = not bool(all_images)

logger.debug(
f"Preprocessed request {request.request_id}: "
f"{len(all_images)} images, {request.input_ids.size if request.input_ids is not None else 0} tokens "
Expand Down Expand Up @@ -641,26 +657,80 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch:
per_request_caches = []

for req in requests:
# Create a fresh KVCache for this request's language model prefill
request_cache = make_prompt_cache(self.language_model)
# Try prefix cache for text-only requests
cached_kv = None
remaining_ids = None
if (
self.prefix_cache is not None
and req.is_text_only
and req.input_ids is not None
):
input_ids_list = req.input_ids.reshape(-1).tolist()
cached_kv, remaining_ids = self.prefix_cache.fetch(input_ids_list)

if cached_kv is not None and remaining_ids:
# Prefix/LCP match — run language model on remaining tokens only
request_cache = cached_kv
remaining = mx.array(remaining_ids)[None, :]

with mx.stream(MLLMBatchGenerator._stream):
logits = self.language_model(remaining, cache=request_cache)
if hasattr(logits, "logits"):
logits = logits.logits
last_logits = logits[:, -1, :]
logprobs = last_logits - mx.logsumexp(
last_logits, axis=-1, keepdims=True
)
sampled = self.sampler(logprobs)
mx.eval(sampled, logprobs)
first_tokens.append(sampled.item())
all_logprobs.append(logprobs.squeeze(0))

per_request_caches.append(request_cache)

elif cached_kv is not None and not remaining_ids:
# Exact/supersequence match — run on last token to get logits
request_cache = cached_kv
last_token = req.input_ids[:, -1:]
if last_token.ndim == 1:
last_token = last_token[None, :]

with mx.stream(MLLMBatchGenerator._stream):
logits = self.language_model(last_token, cache=request_cache)
if hasattr(logits, "logits"):
logits = logits.logits
last_logits = logits[:, -1, :]
logprobs = last_logits - mx.logsumexp(
last_logits, axis=-1, keepdims=True
)
sampled = self.sampler(logprobs)
mx.eval(sampled, logprobs)
first_tokens.append(sampled.item())
all_logprobs.append(logprobs.squeeze(0))

with mx.stream(MLLMBatchGenerator._stream):
# Run VLM forward pass — cache= flows through to language_model
logits = self._run_vision_encoding(req, cache=request_cache)
per_request_caches.append(request_cache)

# Extract last token logits and sample
last_logits = logits[:, -1, :]
logprobs = last_logits - mx.logsumexp(
last_logits, axis=-1, keepdims=True
)
sampled = self.sampler(logprobs)
else:
# Cache miss — full VLM forward pass
request_cache = make_prompt_cache(self.language_model)

with mx.stream(MLLMBatchGenerator._stream):
# Run VLM forward pass — cache= flows through to language_model
logits = self._run_vision_encoding(req, cache=request_cache)

mx.eval(sampled, logprobs)
# Extract last token logits and sample
last_logits = logits[:, -1, :]
logprobs = last_logits - mx.logsumexp(
last_logits, axis=-1, keepdims=True
)
sampled = self.sampler(logprobs)

mx.eval(sampled, logprobs)

first_tokens.append(sampled.item())
all_logprobs.append(logprobs.squeeze(0))
first_tokens.append(sampled.item())
all_logprobs.append(logprobs.squeeze(0))

per_request_caches.append(request_cache)
per_request_caches.append(request_cache)

# Merge per-request KVCaches into a single BatchKVCache.
# KVCache.merge() creates a BatchKVCache with proper left-padding
Expand Down Expand Up @@ -739,6 +809,35 @@ def _step(

return sampled, list(logprobs)

def _maybe_store_prefix_cache(
self, batch: MLLMBatch, end_indices: List[int]
) -> None:
"""
Extract and store KV caches for finished text-only requests.

Must be called BEFORE batch.filter() since indices reference
the current batch layout.

Args:
batch: The active batch
end_indices: Indices of finished requests in the batch
"""
if self.prefix_cache is None or not end_indices:
return
for i in end_indices:
req = batch.requests[i]
if req.is_text_only and req.input_ids is not None:
try:
extracted = batch.extract_cache(i)
input_ids_list = req.input_ids.reshape(-1).tolist()
# Trim output tokens + 1 so fetch always returns at least
# one remaining token (the last prompt token)
output_count = batch.num_tokens[i]
prompt_cache = _trim_cache_offset(extracted, output_count + 1)
self.prefix_cache.store(input_ids_list, prompt_cache)
except Exception as e:
logger.warning(f"[prefix_store] FAILED: {type(e).__name__}: {e}")

def _next(self) -> List[MLLMBatchResponse]:
"""
Internal next() implementation.
Expand Down Expand Up @@ -833,6 +932,9 @@ def _next(self) -> List[MLLMBatchResponse]:
)
)

# Store prefix caches BEFORE filtering (indices must still be valid)
self._maybe_store_prefix_cache(batch, end_idx)

# Remove finished requests from batch
if end_idx:
if keep_idx:
Expand Down Expand Up @@ -867,6 +969,22 @@ def get_vision_cache_stats(self) -> Dict[str, Any]:
"""Get vision cache statistics."""
return self.vision_cache.get_stats()

def get_prefix_cache_stats(self) -> Dict[str, Any]:
"""Get prefix cache statistics."""
if self.prefix_cache is not None:
return self.prefix_cache.get_stats()
return {
"hits": 0,
"misses": 0,
"hit_rate": 0.0,
"evictions": 0,
"tokens_saved": 0,
"current_memory_mb": 0.0,
"max_memory_mb": 0.0,
"memory_utilization": 0.0,
"entry_count": 0,
}

def has_pending(self) -> bool:
"""Check if there are pending or active requests."""
return bool(self.unprocessed_requests or self.active_batch)
16 changes: 16 additions & 0 deletions vllm_mlx/mllm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class MLLMSchedulerConfig:
default_video_fps: float = 2.0
# Maximum video frames
max_video_frames: int = 128
# Enable KV prefix cache for text-only requests
enable_prefix_cache: bool = True
# Maximum memory for prefix cache (MB). None = auto-detect.
prefix_cache_memory_mb: Optional[int] = None


@dataclass
Expand Down Expand Up @@ -246,9 +250,18 @@ def _ensure_batch_generator(self) -> None:
if self.batch_generator is None:
from mlx_lm.sample_utils import make_sampler

from .memory_cache import MemoryCacheConfig

# Default sampler (can be overridden per-request in future)
sampler = make_sampler(temp=0.7, top_p=0.9)

# Configure KV prefix cache for text-only requests
prefix_cache_config = None
if self.config.enable_prefix_cache:
prefix_cache_config = MemoryCacheConfig(
max_memory_mb=self.config.prefix_cache_memory_mb,
)

self.batch_generator = MLLMBatchGenerator(
model=self.model,
processor=self.processor,
Expand All @@ -259,6 +272,7 @@ def _ensure_batch_generator(self) -> None:
prefill_batch_size=self.config.prefill_batch_size,
completion_batch_size=self.config.completion_batch_size,
prefill_step_size=self.config.prefill_step_size,
prefix_cache_config=prefix_cache_config,
)

# ========== Sync API (step-based) ==========
Expand Down Expand Up @@ -796,6 +810,8 @@ def get_stats(self) -> Dict[str, Any]:
stats["vision_embedding_cache"] = (
self.batch_generator.get_vision_cache_stats()
)
# KV prefix cache stats (text-only request reuse)
stats["memory_aware_cache"] = self.batch_generator.get_prefix_cache_stats()

if self.vision_cache:
stats["vision_cache"] = self.vision_cache.get_stats()
Expand Down
Loading