diff --git a/tests/entrypoints/openai/chat_completion/test_serving_chat.py b/tests/entrypoints/openai/chat_completion/test_serving_chat.py index df4d5ad47ca2..c44e07a4c10d 100644 --- a/tests/entrypoints/openai/chat_completion/test_serving_chat.py +++ b/tests/entrypoints/openai/chat_completion/test_serving_chat.py @@ -807,6 +807,57 @@ async def test_serving_chat_should_set_correct_max_tokens(): assert mock_engine.generate.call_args.args[1].max_tokens == 5 +@pytest.mark.asyncio +async def test_serving_chat_truncate_prompt_tokens_max_token_accounting(): + """When truncate_prompt_tokens is set, max_tokens must be calculated using + the truncated prompt length, not the original prompt length. + + Regression: without the fix, get_max_tokens received the untruncated prompt + length, causing the output budget to be underestimated. + """ + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.input_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) + + serving_chat = _build_serving_chat(mock_engine) + + # "what is 1+1?" tokenizes to 7 tokens with the test chat template + # (max_model_len=100 -> max_tokens = 93 without truncation, confirmed by + # test_serving_chat_should_set_correct_max_tokens above). + messages = [{"role": "user", "content": "what is 1+1?"}] + + # Baseline: no truncation -> max_tokens = 100 - 7 = 93. + req = ChatCompletionRequest(model=MODEL_NAME, messages=messages) + with suppress(Exception): + await serving_chat.create_chat_completion(req) + assert mock_engine.generate.call_args.args[1].max_tokens == 93 + + # With truncate_prompt_tokens=5 (less than 7): the effective prompt length + # is 5, so max_tokens should be 100 - 5 = 95, not 93. + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=messages, + truncate_prompt_tokens=5, + ) + with suppress(Exception): + await serving_chat.create_chat_completion(req) + assert mock_engine.generate.call_args.args[1].max_tokens == 95 + + # With truncate_prompt_tokens=-1 (meaning use full max_model_len as the + # truncation limit, i.e., no practical truncation vs the window): effective + # length = min(7, 100) = 7 -> max_tokens = 93 again. + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=messages, + truncate_prompt_tokens=-1, + ) + with suppress(Exception): + await serving_chat.create_chat_completion(req) + assert mock_engine.generate.call_args.args[1].max_tokens == 93 + + @pytest.mark.asyncio async def test_serving_chat_mistral_token_ids_prompt_is_validated(): """Regression test: when the Mistral tokenizer path returns token IDs diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 694ff80047c7..b50db5d84484 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -289,6 +289,7 @@ async def create_chat_completion( self._extract_prompt_len(engine_input), self.default_sampling_params, self.override_max_tokens, + truncate_prompt_tokens=request.truncate_prompt_tokens, ) sampling_params: SamplingParams | BeamSearchParams diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index 454b170a5fa5..816c62163992 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -151,6 +151,7 @@ async def create_completion( self._extract_prompt_len(engine_input), self.default_sampling_params, self.override_max_tokens, + truncate_prompt_tokens=request.truncate_prompt_tokens, ) sampling_params: SamplingParams | BeamSearchParams diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index a6a9355aea88..9c4dc48589ff 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -416,6 +416,9 @@ async def create_responses( self._extract_prompt_len(engine_input), self.default_sampling_params, self.override_max_tokens, + truncate_prompt_tokens=( + -1 if request.truncation != "disabled" else None + ), ) sampling_params = request.to_sampling_params( @@ -700,6 +703,9 @@ async def _generate_with_builtin_tools( self._extract_prompt_len(engine_input), self.default_sampling_params, # type: ignore self.override_max_tokens, # type: ignore + truncate_prompt_tokens=( + -1 if context.request.truncation != "disabled" else None + ), ) # OPTIMIZATION diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index 967899229ada..782b2eaea24b 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -164,6 +164,7 @@ async def render_chat_request( input_length, self.default_sampling_params, self.override_max_tokens, + truncate_prompt_tokens=request.truncate_prompt_tokens, ) params = request.to_sampling_params(max_tokens, self.default_sampling_params) @@ -298,6 +299,7 @@ async def render_completion_request( input_length, self.default_sampling_params, self.override_max_tokens, + truncate_prompt_tokens=request.truncate_prompt_tokens, ) params = request.to_sampling_params( max_tokens, self.default_sampling_params diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index e3682280ec50..cd1010457d98 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -177,7 +177,14 @@ def get_max_tokens( input_length: int, default_sampling_params: dict, override_max_tokens: int | None = None, + truncate_prompt_tokens: int | None = None, ) -> int: + if truncate_prompt_tokens is not None: + limit = truncate_prompt_tokens + input_length = min( + input_length, + max_model_len if limit == -1 else limit, + ) if max_model_len < input_length: raise ValueError( f"Input length ({input_length}) exceeds model's maximum "