diff --git a/tests/entrypoints/openai/test_chat_with_tool_reasoning.py b/tests/entrypoints/openai/test_chat_with_tool_reasoning.py index e452b578ba22..4c59e6e8b9f6 100644 --- a/tests/entrypoints/openai/test_chat_with_tool_reasoning.py +++ b/tests/entrypoints/openai/test_chat_with_tool_reasoning.py @@ -139,3 +139,32 @@ async def test_chat_full_of_tool_and_reasoning(client: openai.AsyncOpenAI): assert len(tool_calls.choices[0].message.reasoning_content) > 0 assert tool_calls.choices[0].message.tool_calls[0].function.name == FUNC_NAME assert tool_calls.choices[0].message.tool_calls[0].function.arguments == FUNC_ARGS + +@pytest.mark.asyncio +async def test_stop_str_with_reasoning(client: openai.AsyncOpenAI): + # check that the response is correctly stopped at "9.8" + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "9.11 and 9.8, which is greater?" + }], + temperature=1.0, + stop="9.8", + ) + + assert response.choices[0].message.reasoning_content.find("9.8") != -1 + assert response.choices[0].message.content.find("9.8") == -1 + + # check no stop string + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "9.11 and 9.8, which is greater?" + }], + temperature=1.0, + ) + assert response.choices[0].message.reasoning_content.find("9.8") != -1 + # check that the response is not stopped at "9.8" + assert response.choices[0].message.content.find("9.8") != -1 diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 14dcab7707d4..c75f651e07f3 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -7,6 +7,7 @@ import pytest from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from vllm.config import VllmConfig from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.v1.engine import EngineCoreRequest @@ -60,25 +61,28 @@ def _run_incremental_decode( skip_special_tokens=skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, ) - request = EngineCoreRequest( - request_id="", - prompt_token_ids=prompt_token_ids, - mm_features=None, - sampling_params=params, - pooling_params=None, - eos_token_id=None, - arrival_time=0.0, - lora_request=None, - cache_salt=None, - data_parallel_rank=None, - ) - + request = EngineCoreRequest(request_id="", + prompt_token_ids=prompt_token_ids, + mm_features=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None) + vllm_config = VllmConfig() if fast is None: - detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request) + detokenizer = IncrementalDetokenizer.from_new_request( + vllm_config=vllm_config, tokenizer=tokenizer, request=request) elif fast: - detokenizer = FastIncrementalDetokenizer(tokenizer, request) + detokenizer = FastIncrementalDetokenizer(vllm_config=vllm_config, + tokenizer=tokenizer, + request=request) else: - detokenizer = SlowIncrementalDetokenizer(tokenizer, request) + detokenizer = SlowIncrementalDetokenizer(vllm_config=vllm_config, + tokenizer=tokenizer, + request=request) output_text = "" for i, token_id in enumerate(all_input_ids[starting_index:]): diff --git a/tests/v1/engine/test_fast_incdec_prefix_err.py b/tests/v1/engine/test_fast_incdec_prefix_err.py index 77e67d54e587..336b0ce9361e 100644 --- a/tests/v1/engine/test_fast_incdec_prefix_err.py +++ b/tests/v1/engine/test_fast_incdec_prefix_err.py @@ -3,6 +3,7 @@ from transformers import AutoTokenizer +from vllm.config import VllmConfig from vllm.sampling_params import SamplingParams from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.detokenizer import IncrementalDetokenizer @@ -21,7 +22,7 @@ def test_fast_inc_detok_invalid_utf8_err_case(): https://gist.github.com/fpaupier/0ed1375bd7633c5be6c894b1c7ac1be3. """ tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") - + vllm_config = VllmConfig() # Create a test request prompt_token_ids = [107, 4606, 236787, 107] params = SamplingParams(skip_special_tokens=True) @@ -38,7 +39,8 @@ def test_fast_inc_detok_invalid_utf8_err_case(): data_parallel_rank=None, ) - detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request) + detokenizer = IncrementalDetokenizer.from_new_request( + vllm_config, tokenizer, request) assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", ( "Should use FastIncrementalDetokenizer by default" diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 9ebf7f09503e..57b758653f9e 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -15,6 +15,8 @@ MockEngineCore, ) from vllm import PoolingParams +from vllm.config import VllmConfig, StructuredOutputsConfig +from vllm.config import DecodingConfig, VllmConfig from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -41,13 +43,17 @@ def _ref_convert_id_to_token( @pytest.mark.parametrize( - "request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] -) -def test_incremental_detokenization( - request_output_kind: RequestOutputKind, dummy_test_vectors -): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) - engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens) + "request_output_kind", + [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +def test_incremental_detokenization(request_output_kind: RequestOutputKind, + dummy_test_vectors): + vllm_config = VllmConfig( + structured_outputs_config=StructuredOutputsConfig()) + output_processor = OutputProcessor(vllm_config=vllm_config, + tokenizer=dummy_test_vectors.tokenizer, + log_stats=False) + engine_core = MockEngineCore( + tokens_list=dummy_test_vectors.generation_tokens) # Make N requests. requests = [ @@ -407,17 +413,21 @@ def _validate_logprobs( @pytest.mark.parametrize( - "request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] -) -@pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) -@pytest.mark.parametrize("num_prompt_logprobs", [None, NUM_PROMPT_LOGPROBS_UNDER_TEST]) -def test_logprobs_processor( - request_output_kind: RequestOutputKind, - num_sample_logprobs: Optional[int], - num_prompt_logprobs: Optional[int], - dummy_test_vectors, -): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) + "request_output_kind", + [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.parametrize("num_sample_logprobs", + [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) +@pytest.mark.parametrize("num_prompt_logprobs", + [None, NUM_PROMPT_LOGPROBS_UNDER_TEST]) +def test_logprobs_processor(request_output_kind: RequestOutputKind, + num_sample_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], + dummy_test_vectors): + vllm_config = VllmConfig( + structured_outputs_config=StructuredOutputsConfig()) + output_processor = OutputProcessor(vllm_config=vllm_config, + tokenizer=dummy_test_vectors.tokenizer, + log_stats=False) engine_core = MockEngineCore( tokens_list=dummy_test_vectors.generation_tokens, generated_logprobs_raw=None @@ -588,8 +598,11 @@ def test_stop_token( dummy_test_vectors.tokenizer.eos_token_id if is_eos_test else None ) # '<|end_of_text|>' stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>' - - output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) + vllm_config = VllmConfig( + structured_outputs_config=StructuredOutputsConfig()) + output_processor = OutputProcessor(vllm_config=vllm_config, + tokenizer=dummy_test_vectors.tokenizer, + log_stats=False) # Dummy engine core outputs, with control tokens suffixed to test stops suffix_token = [eos_token_id] if is_eos_test else stop_token_ids assert suffix_token is not None and isinstance(suffix_token[0], int) @@ -693,13 +706,15 @@ def test_stop_token( @pytest.mark.parametrize("include_stop_str_in_output", [True, False]) -@pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) -def test_stop_string( - include_stop_str_in_output: bool, - num_sample_logprobs: Optional[int], - dummy_test_vectors, -): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) +@pytest.mark.parametrize("num_sample_logprobs", + [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) +def test_stop_string(include_stop_str_in_output: bool, + num_sample_logprobs: Optional[int], dummy_test_vectors): + vllm_config = VllmConfig( + structured_outputs_config=StructuredOutputsConfig()) + output_processor = OutputProcessor(vllm_config=vllm_config, + tokenizer=dummy_test_vectors.tokenizer, + log_stats=False) engine_core = MockEngineCore( tokens_list=dummy_test_vectors.generation_tokens, generated_logprobs_raw=dummy_test_vectors.generation_logprobs @@ -827,7 +842,11 @@ def test_stop_string( def test_iteration_stats(dummy_test_vectors): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True) + vllm_config = VllmConfig( + structured_outputs_config=StructuredOutputsConfig()) + output_processor = OutputProcessor(vllm_config=vllm_config, + tokenizer=dummy_test_vectors.tokenizer, + log_stats=True) engine_core = MockEngineCore(dummy_test_vectors.generation_tokens) engine_core_timestamp = time.monotonic() diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 112ec92b3af8..c6d004bdb13e 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -116,9 +116,9 @@ def __init__( ) # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). - self.output_processor = OutputProcessor( - self.tokenizer, log_stats=self.log_stats - ) + self.output_processor = OutputProcessor(self.vllm_config, + self.tokenizer, + log_stats=self.log_stats) if self.observability_config.otlp_traces_endpoint is not None: tracer = init_tracer( "vllm.llm_engine", self.observability_config.otlp_traces_endpoint diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 5efde9e2ff87..51752a754d8c 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -9,7 +9,9 @@ from tokenizers.decoders import DecodeStream from transformers import PreTrainedTokenizerFast +from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_prompt_ids_to_tokens, @@ -29,6 +31,7 @@ class IncrementalDetokenizer: + def __init__(self): self.token_ids: list[int] = [] @@ -46,6 +49,7 @@ def get_next_output_text(self, finished: bool, delta: bool) -> str: @classmethod def from_new_request( cls, + vllm_config: VllmConfig, tokenizer: Optional[AnyTokenizer], request: EngineCoreRequest, ) -> "IncrementalDetokenizer": @@ -57,14 +61,20 @@ def from_new_request( if USE_FAST_DETOKENIZER and isinstance(tokenizer, PreTrainedTokenizerFast): # Fast tokenizer => use tokenizers library DecodeStream. - return FastIncrementalDetokenizer(tokenizer, request) + return FastIncrementalDetokenizer(vllm_config=vllm_config, + tokenizer=tokenizer, + request=request) # Fall back to slow python-based incremental detokenization. - return SlowIncrementalDetokenizer(tokenizer, request) + return SlowIncrementalDetokenizer(vllm_config=vllm_config, + tokenizer=tokenizer, + request=request) class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): - def __init__(self, request: EngineCoreRequest): + + def __init__(self, vllm_config: VllmConfig, tokenizer: AnyTokenizer, + request: EngineCoreRequest): super().__init__() # Stop strings @@ -84,6 +94,7 @@ def __init__(self, request: EngineCoreRequest): # Generation data self.output_text = "" + self.stop_checker = StopChecker(vllm_config, tokenizer) def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[str]: """ @@ -123,7 +134,8 @@ def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[st # 2) Evaluate stop strings. stop_string = None if self.stop and len(self.output_token_ids) > self.min_tokens: - stop = check_stop_strings( + stop = self.stop_checker.check_stop_strings( + token_ids=self.token_ids, output_text=self.output_text, new_char_count=len(self.output_text) - stop_check_offset, stop=self.stop, @@ -161,8 +173,13 @@ def get_next_output_text(self, finished: bool, delta: bool) -> str: class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): - def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreRequest): - super().__init__(request) + + def __init__(self, vllm_config: VllmConfig, + tokenizer: PreTrainedTokenizerFast, + request: EngineCoreRequest): + super().__init__(vllm_config=vllm_config, + tokenizer=tokenizer, + request=request) sampling_params = request.sampling_params assert sampling_params is not None @@ -250,8 +267,12 @@ def _protected_step(self, next_token_id: int) -> Optional[str]: class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer): - def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest): - super().__init__(request) + + def __init__(self, vllm_config: VllmConfig, tokenizer: AnyTokenizer, + request: EngineCoreRequest): + super().__init__(vllm_config=vllm_config, + tokenizer=tokenizer, + request=request) self.tokenizer = tokenizer params = request.sampling_params @@ -307,39 +328,56 @@ def decode_next(self, next_token_id: int) -> str: return decoded_text -def check_stop_strings( - output_text: str, - new_char_count: int, - stop: list[str], - include_in_output: bool, -) -> Optional[tuple[str, int]]: - """Check if any stop strings are matched and truncate sequence - output text accordingly. - - Returns tuple (stop_string, offset) if matched or else None. - - Where stop_string is the matched stop string and offset is the - length to which output_text should be truncated, or -1 for no - truncation. - """ - if not new_char_count or not stop: +class StopChecker: + + def __init__(self, vllm_config: VllmConfig, tokenizer: AnyTokenizer): + self.reasoning_parser: Optional[ReasoningParser] = None + if vllm_config.structured_outputs_config.reasoning_parser: + reasoning_parser_cls = ReasoningParserManager.get_reasoning_parser( + vllm_config.structured_outputs_config.reasoning_parser) + self.reasoning_parser = reasoning_parser_cls(tokenizer) + self.reasoning_ended: bool = False + + def check_stop_strings( + self, + token_ids: list[int], + output_text: str, + new_char_count: int, + stop: list[str], + include_in_output: bool, + ) -> Optional[tuple[str, int]]: + """Check if any stop strings are matched and truncate sequence + output text accordingly. + + Returns tuple (stop_string, offset) if matched or else None. + + Where stop_string is the matched stop string and offset is the + length to which output_text should be truncated, or -1 for no + truncation. + """ + if (rp := self.reasoning_parser) is not None: + # Reasoning not ended => do not check stop strings. + if not self.reasoning_ended: + self.reasoning_ended = rp.is_reasoning_end(token_ids) + if not self.reasoning_ended: + return None + + for stop_str in stop: + stop_string_len = len(stop_str) + # Avoid searching already-searched text. + stop_index = output_text.find(stop_str, + 1 - new_char_count - stop_string_len) + if stop_index == -1: + continue + + if include_in_output: + # Truncate to end of stop string. + stop_index += stop_string_len + if stop_index >= len(output_text): + # No truncation required. + return stop_str, -1 + + # Truncate the output text to either the beginning + # or end of the stop string. + return stop_str, stop_index return None - - for stop_str in stop: - stop_string_len = len(stop_str) - # Avoid searching already-searched text. - stop_index = output_text.find(stop_str, 1 - new_char_count - stop_string_len) - if stop_index == -1: - continue - - if include_in_output: - # Truncate to end of stop string. - stop_index += stop_string_len - if stop_index >= len(output_text): - # No truncation required. - return stop_str, -1 - - # Truncate the output text to either the beginning - # or end of the stop string. - return stop_str, stop_index - return None diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index b2261855d125..68f78dd4e717 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -108,9 +108,9 @@ def __init__( ) # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). - self.output_processor = OutputProcessor( - self.tokenizer, log_stats=self.log_stats - ) + self.output_processor = OutputProcessor(vllm_config=self.vllm_config, + tokenizer=self.tokenizer, + log_stats=self.log_stats) if self.observability_config.otlp_traces_endpoint is not None: tracer = init_tracer( "vllm.llm_engine", self.observability_config.otlp_traces_endpoint diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index eb65b68969e3..d8eafea5237a 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -8,6 +8,7 @@ import torch +from vllm.config import VllmConfig from vllm.outputs import ( CompletionOutput, PoolingOutput, @@ -127,6 +128,7 @@ def __init__( @classmethod def from_new_request( cls, + vllm_config: VllmConfig, tokenizer: AnyTokenizer, request: EngineCoreRequest, prompt: Optional[str], @@ -144,6 +146,7 @@ def from_new_request( request=request, ) detokenizer = IncrementalDetokenizer.from_new_request( + vllm_config=vllm_config, tokenizer=tokenizer, request=request, ) @@ -302,13 +305,15 @@ def _new_pooling_output( class OutputProcessor: """Process EngineCoreOutputs into RequestOutputs.""" - def __init__(self, tokenizer: AnyTokenizer, log_stats: bool): + def __init__(self, vllm_config: VllmConfig, tokenizer: AnyTokenizer, + log_stats: bool): self.log_stats = log_stats self.tokenizer = tokenizer self.request_states: dict[str, RequestState] = {} self.parent_requests: dict[str, ParentRequest] = {} self.lora_states = LoRARequestStates() self.tracer: Optional[Tracer] = None + self.vllm_config = vllm_config def get_num_unfinished_requests(self): return len(self.request_states) @@ -369,15 +374,14 @@ def add_request( if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") - req_state = RequestState.from_new_request( - tokenizer=self.tokenizer, - request=request, - prompt=prompt, - parent_req=parent_req, - request_index=request_index, - queue=queue, - log_stats=self.log_stats, - ) + req_state = RequestState.from_new_request(vllm_config=self.vllm_config, + tokenizer=self.tokenizer, + request=request, + prompt=prompt, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) if parent_req: