diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 815c3e44..567d3596 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -28,15 +28,16 @@ from dataclasses import dataclass, field from typing import Any, AsyncIterator, Dict, List, Optional, Set, Tuple +from mlx_lm.tokenizer_utils import NaiveStreamingDetokenizer from .mllm_batch_generator import ( MLLMBatchGenerator, MLLMBatchRequest, MLLMBatchResponse, ) +from .mllm_cache import MLLMCacheManager from .multimodal_processor import MultimodalProcessor from .request import RequestOutput, RequestStatus, SamplingParams -from .mllm_cache import MLLMCacheManager logger = logging.getLogger(__name__) @@ -198,6 +199,9 @@ def __init__( self.request_id_to_uid: Dict[str, int] = {} self.uid_to_request_id: Dict[int, str] = {} + # Per-request streaming detokenizers for UTF-8-safe incremental decode + self._detokenizer_pool: Dict[str, Any] = {} + # Output queues for async streaming self.output_queues: Dict[str, asyncio.Queue] = {} @@ -345,6 +349,8 @@ def abort_request(self, request_id: str) -> bool: request.status = RequestStatus.FINISHED_ABORTED self.finished_req_ids.add(request_id) + self._detokenizer_pool.pop(request_id, None) + # Signal output queue if request_id in self.output_queues: try: @@ -446,8 +452,21 @@ def _process_batch_responses( request.output_tokens.append(response.token) request.num_output_tokens = len(request.output_tokens) - # Decode the new token - new_text = tokenizer.decode([response.token]) + # Decode the new token using streaming detokenizer (UTF-8 safe). + # Skip stop tokens — they are not content. + if response.finish_reason == "stop": + new_text = "" + else: + if request_id not in self._detokenizer_pool: + if hasattr(tokenizer, "detokenizer"): + detok = tokenizer.detokenizer + else: + detok = NaiveStreamingDetokenizer(tokenizer) + detok.reset() + self._detokenizer_pool[request_id] = detok + detok = self._detokenizer_pool[request_id] + detok.add_token(response.token) + new_text = detok.last_segment # Create output output = RequestOutput( @@ -470,10 +489,16 @@ def _process_batch_responses( output.finish_reason = response.finish_reason finished_ids.add(request_id) - # Decode full output - output.output_text = tokenizer.decode(request.output_tokens) + # Finalize streaming detokenizer and get full output + detok = self._detokenizer_pool.get(request_id) + if detok is not None: + detok.finalize() + output.output_text = detok.text + else: + output.output_text = tokenizer.decode(request.output_tokens) request.output_text = output.output_text request.finish_reason = response.finish_reason + self._detokenizer_pool.pop(request_id, None) self.total_completion_tokens += request.num_output_tokens self.num_requests_processed += 1 @@ -778,6 +803,7 @@ def reset(self) -> None: self.finished_req_ids.clear() self.request_id_to_uid.clear() self.uid_to_request_id.clear() + self._detokenizer_pool.clear() if self.batch_generator is not None: self.batch_generator.close() diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index 233b3109..81a2b72a 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -20,6 +20,7 @@ import mlx.core as mx from mlx_lm.generate import BatchGenerator from mlx_lm.sample_utils import make_sampler +from mlx_lm.tokenizer_utils import NaiveStreamingDetokenizer from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig from .paged_cache import PagedCacheManager @@ -976,6 +977,9 @@ def __init__( # Detect if tokenizer is a processor (MLLM) and get the actual tokenizer self._actual_tokenizer = self._get_actual_tokenizer(tokenizer) + # Per-request streaming detokenizers for UTF-8-safe incremental decode + self._detokenizer_pool: Dict[str, Any] = {} + # Request management - following vLLM's design self.waiting: deque[Request] = deque() # Waiting queue (FCFS) self.running: Dict[str, Request] = {} # Running requests by ID @@ -1076,6 +1080,21 @@ def _decode_tokens(self, token_ids: List[int]) -> str: """ return self._actual_tokenizer.decode(token_ids) + def _get_detokenizer(self, request_id: str) -> Any: + """Get or create a streaming detokenizer for a request.""" + if request_id not in self._detokenizer_pool: + if hasattr(self.tokenizer, "detokenizer"): + detok = self.tokenizer.detokenizer + else: + detok = NaiveStreamingDetokenizer(self._actual_tokenizer) + detok.reset() + self._detokenizer_pool[request_id] = detok + return self._detokenizer_pool[request_id] + + def _cleanup_detokenizer(self, request_id: str) -> None: + """Remove the streaming detokenizer for a finished request.""" + self._detokenizer_pool.pop(request_id, None) + def _get_stop_tokens(self) -> Set[int]: """Get stop token IDs from tokenizer or processor.""" stop_tokens = set() @@ -1688,6 +1707,7 @@ def _do_abort_request(self, request_id: str) -> bool: if request is not None: request.set_finished(RequestStatus.FINISHED_ABORTED) self.finished_req_ids.add(request_id) + self._cleanup_detokenizer(request_id) # Flush Metal encoders after removing arrays from batch mx.clear_cache() @@ -1848,11 +1868,13 @@ def _process_batch_responses( request.first_token_time = _time.time() - # Decode the new token (skip stop tokens — they are not content) + # Decode the new token using streaming detokenizer (UTF-8 safe) if response.finish_reason == "stop": new_text = "" else: - new_text = self._decode_tokens([response.token]) + detok = self._get_detokenizer(request_id) + detok.add_token(response.token) + new_text = detok.last_segment # Create output output = RequestOutput( @@ -1875,9 +1897,15 @@ def _process_batch_responses( output.finish_reason = response.finish_reason finished_ids.add(request_id) - # Decode full output - output.output_text = self._decode_tokens(request.output_token_ids) + # Finalize streaming detokenizer and get full output + detok = self._detokenizer_pool.get(request_id) + if detok is not None: + detok.finalize() + output.output_text = detok.text + else: + output.output_text = self._decode_tokens(request.output_token_ids) request.output_text = output.output_text + self._cleanup_detokenizer(request_id) # Extract cache for future reuse (critical for agentic multi-turn) if hasattr(response, "prompt_cache"): @@ -2112,6 +2140,7 @@ def _recover_from_generation_error(self) -> Set[str]: aborted_ids.add(request_id) self.finished_req_ids.add(request_id) self.running.clear() + self._detokenizer_pool.clear() # Clear UID mappings (batch generator is gone) self.request_id_to_uid.clear() @@ -2383,6 +2412,7 @@ def reset(self) -> None: self.finished_req_ids.clear() self.request_id_to_uid.clear() self.uid_to_request_id.clear() + self._detokenizer_pool.clear() self._close_batch_generator() self._current_sampler_params = None