Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions vllm_mlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import mlx.core as mx
from mlx_lm.generate import BatchGenerator
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import NaiveStreamingDetokenizer

from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig
from .paged_cache import PagedCacheManager
Expand Down Expand Up @@ -1043,6 +1044,12 @@ def __init__(
# CPython GIL guarantees set.add() and `x in set` are atomic.
self._pending_abort_ids: Set[str] = set()

# Per-request streaming detokenizers for UTF-8-safe incremental decode.
# Raw tokenizer.decode([token]) splits multi-byte codepoints (emoji,
# CJK) into surrogate pairs. NaiveStreamingDetokenizer buffers
# incomplete byte sequences and only emits valid UTF-8 segments.
self._detokenizer_pool: Dict[str, Any] = {}

# Statistics
self.num_requests_processed = 0
self.total_prompt_tokens = 0
Expand Down Expand Up @@ -1688,6 +1695,7 @@ def _do_abort_request(self, request_id: str) -> bool:
if request is not None:
request.set_finished(RequestStatus.FINISHED_ABORTED)
self.finished_req_ids.add(request_id)
self._detokenizer_pool.pop(request_id, None)

# Flush Metal encoders after removing arrays from batch
mx.clear_cache()
Expand Down Expand Up @@ -1848,11 +1856,23 @@ def _process_batch_responses(

request.first_token_time = _time.time()

# Decode the new token (skip stop tokens — they are not content)
# Decode the new token using streaming detokenizer (UTF-8 safe).
# Raw tokenizer.decode([token]) produces surrogate pairs for
# multi-byte codepoints (emoji, CJK) in continuous batching.
if response.finish_reason == "stop":
new_text = ""
else:
new_text = self._decode_tokens([response.token])
if request_id not in self._detokenizer_pool:
tokenizer = self.tokenizer
if hasattr(tokenizer, "detokenizer"):
detok = tokenizer.detokenizer
else:
detok = NaiveStreamingDetokenizer(tokenizer)
detok.reset()
self._detokenizer_pool[request_id] = detok
detok = self._detokenizer_pool[request_id]
detok.add_token(response.token)
new_text = detok.last_segment

# Create output
output = RequestOutput(
Expand All @@ -1875,8 +1895,13 @@ def _process_batch_responses(
output.finish_reason = response.finish_reason
finished_ids.add(request_id)

# Decode full output
output.output_text = self._decode_tokens(request.output_token_ids)
# Finalize streaming detokenizer for full output
detok = self._detokenizer_pool.pop(request_id, None)
if detok is not None:
detok.finalize()
output.output_text = detok.text
else:
output.output_text = self._decode_tokens(request.output_token_ids)
request.output_text = output.output_text

# Extract cache for future reuse (critical for agentic multi-turn)
Expand Down Expand Up @@ -1920,6 +1945,8 @@ def _process_batch_responses(
def _cleanup_finished(self, finished_ids: Set[str]) -> None:
"""Clean up finished requests and store caches for reuse."""
for request_id in finished_ids:
# Safety-net detokenizer cleanup (normally popped in _process_batch_responses)
self._detokenizer_pool.pop(request_id, None)
request = self.running.get(request_id)

# Store cache for future reuse
Expand Down Expand Up @@ -2399,6 +2426,7 @@ def reset(self) -> None:
self.finished_req_ids.clear()
self.request_id_to_uid.clear()
self.uid_to_request_id.clear()
self._detokenizer_pool.clear()
self._close_batch_generator()
self._current_sampler_params = None

Expand Down
Loading