diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index 88d144cb7..38dbdde75 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -20,6 +20,7 @@ import mlx.core as mx from mlx_lm.generate import BatchGenerator from mlx_lm.sample_utils import make_sampler +from mlx_lm.tokenizer_utils import NaiveStreamingDetokenizer from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig from .paged_cache import PagedCacheManager @@ -1043,6 +1044,12 @@ def __init__( # CPython GIL guarantees set.add() and `x in set` are atomic. self._pending_abort_ids: Set[str] = set() + # Per-request streaming detokenizers for UTF-8-safe incremental decode. + # Raw tokenizer.decode([token]) splits multi-byte codepoints (emoji, + # CJK) into surrogate pairs. NaiveStreamingDetokenizer buffers + # incomplete byte sequences and only emits valid UTF-8 segments. + self._detokenizer_pool: Dict[str, Any] = {} + # Statistics self.num_requests_processed = 0 self.total_prompt_tokens = 0 @@ -1688,6 +1695,7 @@ def _do_abort_request(self, request_id: str) -> bool: if request is not None: request.set_finished(RequestStatus.FINISHED_ABORTED) self.finished_req_ids.add(request_id) + self._detokenizer_pool.pop(request_id, None) # Flush Metal encoders after removing arrays from batch mx.clear_cache() @@ -1848,11 +1856,23 @@ def _process_batch_responses( request.first_token_time = _time.time() - # Decode the new token (skip stop tokens — they are not content) + # Decode the new token using streaming detokenizer (UTF-8 safe). + # Raw tokenizer.decode([token]) produces surrogate pairs for + # multi-byte codepoints (emoji, CJK) in continuous batching. if response.finish_reason == "stop": new_text = "" else: - new_text = self._decode_tokens([response.token]) + if request_id not in self._detokenizer_pool: + tokenizer = self.tokenizer + if hasattr(tokenizer, "detokenizer"): + detok = tokenizer.detokenizer + else: + detok = NaiveStreamingDetokenizer(tokenizer) + detok.reset() + self._detokenizer_pool[request_id] = detok + detok = self._detokenizer_pool[request_id] + detok.add_token(response.token) + new_text = detok.last_segment # Create output output = RequestOutput( @@ -1875,8 +1895,13 @@ def _process_batch_responses( output.finish_reason = response.finish_reason finished_ids.add(request_id) - # Decode full output - output.output_text = self._decode_tokens(request.output_token_ids) + # Finalize streaming detokenizer for full output + detok = self._detokenizer_pool.pop(request_id, None) + if detok is not None: + detok.finalize() + output.output_text = detok.text + else: + output.output_text = self._decode_tokens(request.output_token_ids) request.output_text = output.output_text # Extract cache for future reuse (critical for agentic multi-turn) @@ -1920,6 +1945,8 @@ def _process_batch_responses( def _cleanup_finished(self, finished_ids: Set[str]) -> None: """Clean up finished requests and store caches for reuse.""" for request_id in finished_ids: + # Safety-net detokenizer cleanup (normally popped in _process_batch_responses) + self._detokenizer_pool.pop(request_id, None) request = self.running.get(request_id) # Store cache for future reuse @@ -2399,6 +2426,7 @@ def reset(self) -> None: self.finished_req_ids.clear() self.request_id_to_uid.clear() self.uid_to_request_id.clear() + self._detokenizer_pool.clear() self._close_batch_generator() self._current_sampler_params = None