Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions tests/entrypoints/openai/chat_completion/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,57 @@ async def test_serving_chat_should_set_correct_max_tokens():
assert mock_engine.generate.call_args.args[1].max_tokens == 5


@pytest.mark.asyncio
async def test_serving_chat_truncate_prompt_tokens_max_token_accounting():
"""When truncate_prompt_tokens is set, max_tokens must be calculated using
the truncated prompt length, not the original prompt length.

Regression: without the fix, get_max_tokens received the untruncated prompt
length, causing the output budget to be underestimated.
"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)

serving_chat = _build_serving_chat(mock_engine)

# "what is 1+1?" tokenizes to 7 tokens with the test chat template
# (max_model_len=100 -> max_tokens = 93 without truncation, confirmed by
# test_serving_chat_should_set_correct_max_tokens above).
messages = [{"role": "user", "content": "what is 1+1?"}]

# Baseline: no truncation -> max_tokens = 100 - 7 = 93.
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93

# With truncate_prompt_tokens=5 (less than 7): the effective prompt length
# is 5, so max_tokens should be 100 - 5 = 95, not 93.
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=messages,
truncate_prompt_tokens=5,
)
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 95

# With truncate_prompt_tokens=-1 (meaning use full max_model_len as the
# truncation limit, i.e., no practical truncation vs the window): effective
# length = min(7, 100) = 7 -> max_tokens = 93 again.
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=messages,
truncate_prompt_tokens=-1,
)
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93


@pytest.mark.asyncio
async def test_serving_chat_mistral_token_ids_prompt_is_validated():
"""Regression test: when the Mistral tokenizer path returns token IDs
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/chat_completion/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ async def create_chat_completion(
self._extract_prompt_len(engine_input),
self.default_sampling_params,
self.override_max_tokens,
truncate_prompt_tokens=request.truncate_prompt_tokens,
)

sampling_params: SamplingParams | BeamSearchParams
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/completion/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ async def create_completion(
self._extract_prompt_len(engine_input),
self.default_sampling_params,
self.override_max_tokens,
truncate_prompt_tokens=request.truncate_prompt_tokens,
)

sampling_params: SamplingParams | BeamSearchParams
Expand Down
6 changes: 6 additions & 0 deletions vllm/entrypoints/openai/responses/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,9 @@ async def create_responses(
self._extract_prompt_len(engine_input),
self.default_sampling_params,
self.override_max_tokens,
truncate_prompt_tokens=(
-1 if request.truncation != "disabled" else None
),
)

sampling_params = request.to_sampling_params(
Expand Down Expand Up @@ -700,6 +703,9 @@ async def _generate_with_builtin_tools(
self._extract_prompt_len(engine_input),
self.default_sampling_params, # type: ignore
self.override_max_tokens, # type: ignore
truncate_prompt_tokens=(
-1 if context.request.truncation != "disabled" else None
),
)

# OPTIMIZATION
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/serve/render/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ async def render_chat_request(
input_length,
self.default_sampling_params,
self.override_max_tokens,
truncate_prompt_tokens=request.truncate_prompt_tokens,
)
params = request.to_sampling_params(max_tokens, self.default_sampling_params)

Expand Down Expand Up @@ -298,6 +299,7 @@ async def render_completion_request(
input_length,
self.default_sampling_params,
self.override_max_tokens,
truncate_prompt_tokens=request.truncate_prompt_tokens,
)
params = request.to_sampling_params(
max_tokens, self.default_sampling_params
Expand Down
7 changes: 7 additions & 0 deletions vllm/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,14 @@ def get_max_tokens(
input_length: int,
default_sampling_params: dict,
override_max_tokens: int | None = None,
truncate_prompt_tokens: int | None = None,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also called in Responses API

) -> int:
if truncate_prompt_tokens is not None:
limit = truncate_prompt_tokens
input_length = min(
input_length,
max_model_len if limit == -1 else limit,
)
if max_model_len < input_length:
raise ValueError(
f"Input length ({input_length}) exceeds model's maximum "
Expand Down
Loading