diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 18e4c98f8829..da950c2a0810 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -19,9 +19,9 @@ logger = init_logger(__name__) -# Only tokenizers >= 0.21.1 supports DecodeStream used for -# FastIncrementalDetokenizer. -USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.21.1") +# Only tokenizers >= 0.22.0 supports DecodeStream with native prefill +# (ids parameter) used for FastIncrementalDetokenizer. +USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.22.0") # Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042 INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered" @@ -154,11 +154,10 @@ def get_next_output_text(self, finished: bool, delta: bool) -> str: # We return the full output text if the sequence is finished. buffer_length = 0 if finished else self.stop_buffer_length if not delta: - return ( - self.output_text[:-buffer_length] - if buffer_length - else (self.output_text) - ) + if not buffer_length: + return self.output_text + return self.output_text[:-buffer_length] + length = len(self.output_text) - buffer_length last_offset = self._last_output_text_offset if last_offset < length: @@ -176,24 +175,14 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreReques self.request_id = request.request_id self.skip_special_tokens = sampling_params.skip_special_tokens - self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens) self.tokenizer: Tokenizer = tokenizer._tokenizer - # Find a safe place to start. - prompt_token_ids = request.prompt_token_ids or [] - prompt_suffix = prompt_token_ids - prompt_len = len(prompt_suffix) - if prompt_len > 4: - for i in range(4, min(prompt_len + 1, 24)): - suffix = prompt_token_ids[-i:] - if "�" not in self.tokenizer.decode(suffix): - prompt_suffix = suffix - break - - # Prime the stream. - for tid in prompt_suffix: - self._protected_step(tid) + # Use native prefill to prime the decode stream with prompt tokens. + self.stream = DecodeStream( + ids=request.prompt_token_ids, + skip_special_tokens=self.skip_special_tokens, + ) self.spaces_between_special_tokens = ( sampling_params.skip_special_tokens @@ -203,9 +192,8 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreReques if not self.spaces_between_special_tokens: # Store dict of added token ids so that we can suppress # the spaces between them. - if ( - added_token_ids := getattr(self.tokenizer, "added_token_ids", None) - ) is None: + added_token_ids = getattr(self.tokenizer, "added_token_ids", None) + if added_token_ids is None: self.tokenizer.added_token_ids = added_token_ids = { tid: tok.content for tid, tok in self.tokenizer.get_added_tokens_decoder().items() @@ -290,11 +278,9 @@ def __init__(self, tokenizer: TokenizerLike, request: EngineCoreRequest): @property def output_token_ids(self) -> list[int]: - return ( - self.token_ids - if not self.prompt_len - else (self.token_ids[self.prompt_len :]) - ) + if self.prompt_len: + return self.token_ids[self.prompt_len :] + return self.token_ids def num_output_tokens(self) -> int: return len(self.token_ids) - self.prompt_len