Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions vllm_mlx/mllm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Dict, List, Optional, Set, Tuple

from mlx_lm.tokenizer_utils import NaiveStreamingDetokenizer

from .mllm_batch_generator import (
MLLMBatchGenerator,
MLLMBatchRequest,
MLLMBatchResponse,
)
from .mllm_cache import MLLMCacheManager
from .multimodal_processor import MultimodalProcessor
from .request import RequestOutput, RequestStatus, SamplingParams
from .mllm_cache import MLLMCacheManager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -198,6 +199,9 @@ def __init__(
self.request_id_to_uid: Dict[str, int] = {}
self.uid_to_request_id: Dict[int, str] = {}

# Per-request streaming detokenizers for UTF-8-safe incremental decode
self._detokenizer_pool: Dict[str, Any] = {}

# Output queues for async streaming
self.output_queues: Dict[str, asyncio.Queue] = {}

Expand Down Expand Up @@ -345,6 +349,8 @@ def abort_request(self, request_id: str) -> bool:
request.status = RequestStatus.FINISHED_ABORTED
self.finished_req_ids.add(request_id)

self._detokenizer_pool.pop(request_id, None)

# Signal output queue
if request_id in self.output_queues:
try:
Expand Down Expand Up @@ -446,8 +452,21 @@ def _process_batch_responses(
request.output_tokens.append(response.token)
request.num_output_tokens = len(request.output_tokens)

# Decode the new token
new_text = tokenizer.decode([response.token])
# Decode the new token using streaming detokenizer (UTF-8 safe).
# Skip stop tokens — they are not content.
if response.finish_reason == "stop":
new_text = ""
else:
if request_id not in self._detokenizer_pool:
if hasattr(tokenizer, "detokenizer"):
detok = tokenizer.detokenizer
else:
detok = NaiveStreamingDetokenizer(tokenizer)
detok.reset()
self._detokenizer_pool[request_id] = detok
detok = self._detokenizer_pool[request_id]
detok.add_token(response.token)
new_text = detok.last_segment

# Create output
output = RequestOutput(
Expand All @@ -470,10 +489,16 @@ def _process_batch_responses(
output.finish_reason = response.finish_reason
finished_ids.add(request_id)

# Decode full output
output.output_text = tokenizer.decode(request.output_tokens)
# Finalize streaming detokenizer and get full output
detok = self._detokenizer_pool.get(request_id)
if detok is not None:
detok.finalize()
output.output_text = detok.text
else:
output.output_text = tokenizer.decode(request.output_tokens)
request.output_text = output.output_text
request.finish_reason = response.finish_reason
self._detokenizer_pool.pop(request_id, None)

self.total_completion_tokens += request.num_output_tokens
self.num_requests_processed += 1
Expand Down Expand Up @@ -778,6 +803,7 @@ def reset(self) -> None:
self.finished_req_ids.clear()
self.request_id_to_uid.clear()
self.uid_to_request_id.clear()
self._detokenizer_pool.clear()

if self.batch_generator is not None:
self.batch_generator.close()
Expand Down
38 changes: 34 additions & 4 deletions vllm_mlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import mlx.core as mx
from mlx_lm.generate import BatchGenerator
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import NaiveStreamingDetokenizer

from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig
from .paged_cache import PagedCacheManager
Expand Down Expand Up @@ -976,6 +977,9 @@ def __init__(
# Detect if tokenizer is a processor (MLLM) and get the actual tokenizer
self._actual_tokenizer = self._get_actual_tokenizer(tokenizer)

# Per-request streaming detokenizers for UTF-8-safe incremental decode
self._detokenizer_pool: Dict[str, Any] = {}

# Request management - following vLLM's design
self.waiting: deque[Request] = deque() # Waiting queue (FCFS)
self.running: Dict[str, Request] = {} # Running requests by ID
Expand Down Expand Up @@ -1076,6 +1080,21 @@ def _decode_tokens(self, token_ids: List[int]) -> str:
"""
return self._actual_tokenizer.decode(token_ids)

def _get_detokenizer(self, request_id: str) -> Any:
"""Get or create a streaming detokenizer for a request."""
if request_id not in self._detokenizer_pool:
if hasattr(self.tokenizer, "detokenizer"):
detok = self.tokenizer.detokenizer
else:
detok = NaiveStreamingDetokenizer(self._actual_tokenizer)
detok.reset()
self._detokenizer_pool[request_id] = detok
return self._detokenizer_pool[request_id]

def _cleanup_detokenizer(self, request_id: str) -> None:
"""Remove the streaming detokenizer for a finished request."""
self._detokenizer_pool.pop(request_id, None)

def _get_stop_tokens(self) -> Set[int]:
"""Get stop token IDs from tokenizer or processor."""
stop_tokens = set()
Expand Down Expand Up @@ -1688,6 +1707,7 @@ def _do_abort_request(self, request_id: str) -> bool:
if request is not None:
request.set_finished(RequestStatus.FINISHED_ABORTED)
self.finished_req_ids.add(request_id)
self._cleanup_detokenizer(request_id)

# Flush Metal encoders after removing arrays from batch
mx.clear_cache()
Expand Down Expand Up @@ -1848,11 +1868,13 @@ def _process_batch_responses(

request.first_token_time = _time.time()

# Decode the new token (skip stop tokens — they are not content)
# Decode the new token using streaming detokenizer (UTF-8 safe)
if response.finish_reason == "stop":
new_text = ""
else:
new_text = self._decode_tokens([response.token])
detok = self._get_detokenizer(request_id)
detok.add_token(response.token)
new_text = detok.last_segment

# Create output
output = RequestOutput(
Expand All @@ -1875,9 +1897,15 @@ def _process_batch_responses(
output.finish_reason = response.finish_reason
finished_ids.add(request_id)

# Decode full output
output.output_text = self._decode_tokens(request.output_token_ids)
# Finalize streaming detokenizer and get full output
detok = self._detokenizer_pool.get(request_id)
if detok is not None:
detok.finalize()
output.output_text = detok.text
else:
output.output_text = self._decode_tokens(request.output_token_ids)
request.output_text = output.output_text
self._cleanup_detokenizer(request_id)

# Extract cache for future reuse (critical for agentic multi-turn)
if hasattr(response, "prompt_cache"):
Expand Down Expand Up @@ -2112,6 +2140,7 @@ def _recover_from_generation_error(self) -> Set[str]:
aborted_ids.add(request_id)
self.finished_req_ids.add(request_id)
self.running.clear()
self._detokenizer_pool.clear()

# Clear UID mappings (batch generator is gone)
self.request_id_to_uid.clear()
Expand Down Expand Up @@ -2383,6 +2412,7 @@ def reset(self) -> None:
self.finished_req_ids.clear()
self.request_id_to_uid.clear()
self.uid_to_request_id.clear()
self._detokenizer_pool.clear()
self._close_batch_generator()
self._current_sampler_params = None

Expand Down
Loading