diff --git a/examples/online_serving/prompt_embed_inference_with_openai_client.py b/examples/online_serving/prompt_embed_inference_with_openai_client.py index 889be6820e70..ab616facff99 100644 --- a/examples/online_serving/prompt_embed_inference_with_openai_client.py +++ b/examples/online_serving/prompt_embed_inference_with_openai_client.py @@ -60,9 +60,7 @@ def main(): completion = client.completions.create( model=model_name, - # NOTE: The OpenAI client does not allow `None` as an input to - # `prompt`. Use an empty string if you have no text prompts. - prompt="", + prompt=None, max_tokens=5, temperature=0.0, # NOTE: The OpenAI client allows passing in extra JSON body via the diff --git a/tests/engine/test_short_mm_context.py b/tests/engine/test_short_mm_context.py index 54a88586d8ed..23489c213332 100644 --- a/tests/engine/test_short_mm_context.py +++ b/tests/engine/test_short_mm_context.py @@ -22,7 +22,11 @@ def test_context_length_too_short(vllm_runner, image_assets, model): with pytest.raises(ValueError, match="longer than the maximum model length"): vllm_model = vllm_runner( model, - max_model_len=128, # LLaVA has a feature size of 576 + # LLaVA has a feature size of 576 + # For the HF processor to execute successfully but still + # failing the overall context length check, we need the + # max_model_len to at least contain all image tokens + max_model_len=579, enforce_eager=True, load_format="dummy", ) diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index a9698632b82e..dc72ffa0e81e 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -205,7 +205,7 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test): valid_msg, ] sampling_params = SamplingParams(temperature=0, max_tokens=10) - with pytest.raises(ValueError, match="longer than the maximum model length"): + with pytest.raises(ValueError, match="context length is only"): llm.chat(batch_1, sampling_params=sampling_params) outputs_2 = llm.chat(batch_2, sampling_params=sampling_params) assert len(outputs_2) == len(batch_2) diff --git a/tests/entrypoints/openai/test_chat_error.py b/tests/entrypoints/openai/test_chat_error.py index d42ae25573ed..7b15421fb550 100644 --- a/tests/entrypoints/openai/test_chat_error.py +++ b/tests/entrypoints/openai/test_chat_error.py @@ -15,7 +15,8 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.outputs import CompletionOutput, RequestOutput -from vllm.tokenizers import get_tokenizer +from vllm.renderers.hf import HfRenderer +from vllm.tokenizers.registry import tokenizer_args_from_config from vllm.v1.engine.async_llm import AsyncLLM MODEL_NAME = "openai-community/gpt2" @@ -57,6 +58,15 @@ def get_diff_sampling_param(self): return self.diff_sampling_param or {} +def _build_renderer(model_config: MockModelConfig): + _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config) + + return HfRenderer( + model_config, + tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name}, + ) + + def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: models = OpenAIServingModels( engine_client=engine, @@ -71,18 +81,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: chat_template_content_format="auto", ) - async def _fake_process_inputs( - request_id, - engine_prompt, - sampling_params, - *, - lora_request, - trace_headers, - priority, - data_parallel_rank, - ): - return dict(engine_prompt), {} - async def _fake_preprocess_chat(*args, **kwargs): # return conversation, engine_prompts return ( @@ -90,7 +88,6 @@ async def _fake_preprocess_chat(*args, **kwargs): [{"prompt_token_ids": [1, 2, 3]}], ) - serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs) serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat) return serving_chat @@ -99,11 +96,11 @@ async def _fake_preprocess_chat(*args, **kwargs): async def test_chat_error_non_stream(): """test finish_reason='error' returns 500 InternalServerError (non-streaming)""" mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_chat = _build_serving_chat(mock_engine) @@ -153,11 +150,11 @@ async def mock_generate(*args, **kwargs): async def test_chat_error_stream(): """test finish_reason='error' returns 500 InternalServerError (streaming)""" mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_chat = _build_serving_chat(mock_engine) diff --git a/tests/entrypoints/openai/test_completion_error.py b/tests/entrypoints/openai/test_completion_error.py index 7dd0448dea6c..01c4e567c9fd 100644 --- a/tests/entrypoints/openai/test_completion_error.py +++ b/tests/entrypoints/openai/test_completion_error.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from http import HTTPStatus from typing import Any -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest @@ -15,7 +15,8 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.outputs import CompletionOutput, RequestOutput -from vllm.tokenizers import get_tokenizer +from vllm.renderers.hf import HfRenderer +from vllm.tokenizers.registry import tokenizer_args_from_config from vllm.v1.engine.async_llm import AsyncLLM MODEL_NAME = "openai-community/gpt2" @@ -61,37 +62,31 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion: engine_client=engine, base_model_paths=BASE_MODEL_PATHS, ) - serving_completion = OpenAIServingCompletion( + return OpenAIServingCompletion( engine, models, request_logger=None, ) - async def _fake_process_inputs( - request_id, - engine_prompt, - sampling_params, - *, - lora_request, - trace_headers, - priority, - data_parallel_rank, - ): - return dict(engine_prompt), {} - serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs) - return serving_completion +def _build_renderer(model_config: MockModelConfig): + _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config) + + return HfRenderer( + model_config, + tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name}, + ) @pytest.mark.asyncio async def test_completion_error_non_stream(): """test finish_reason='error' returns 500 InternalServerError (non-streaming)""" mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_completion = _build_serving_completion(mock_engine) @@ -141,11 +136,11 @@ async def mock_generate(*args, **kwargs): async def test_completion_error_stream(): """test finish_reason='error' returns 500 InternalServerError (streaming)""" mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_completion = _build_serving_completion(mock_engine) diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py index 0a057b1848ad..f8a19e40b539 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -110,7 +110,7 @@ async def test_completions_with_prompt_embeds( # Test case: Single prompt embeds input completion = await client_with_prompt_embeds.completions.create( model=model_name, - prompt="", # Add empty prompt as required parameter + prompt=None, max_tokens=5, temperature=0.0, extra_body={"prompt_embeds": encoded_embeds}, @@ -121,7 +121,7 @@ async def test_completions_with_prompt_embeds( # Test case: batch completion with prompt_embeds completion = await client_with_prompt_embeds.completions.create( model=model_name, - prompt="", # Add empty prompt as required parameter + prompt=None, max_tokens=5, temperature=0.0, extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}, @@ -133,7 +133,7 @@ async def test_completions_with_prompt_embeds( # Test case: streaming with prompt_embeds single_completion = await client_with_prompt_embeds.completions.create( model=model_name, - prompt="", # Add empty prompt as required parameter + prompt=None, max_tokens=5, temperature=0.0, extra_body={"prompt_embeds": encoded_embeds}, @@ -142,7 +142,7 @@ async def test_completions_with_prompt_embeds( stream = await client_with_prompt_embeds.completions.create( model=model_name, - prompt="", # Add empty prompt as required parameter + prompt=None, max_tokens=5, temperature=0.0, stream=True, @@ -162,7 +162,7 @@ async def test_completions_with_prompt_embeds( # Test case: batch streaming with prompt_embeds stream = await client_with_prompt_embeds.completions.create( model=model_name, - prompt="", # Add empty prompt as required parameter + prompt=None, max_tokens=5, temperature=0.0, stream=True, @@ -197,7 +197,7 @@ async def test_completions_with_prompt_embeds( ) completion_embeds_only = await client_with_prompt_embeds.completions.create( model=model_name, - prompt="", + prompt=None, max_tokens=5, temperature=0.0, extra_body={"prompt_embeds": encoded_embeds}, @@ -215,7 +215,7 @@ async def test_completions_errors_with_prompt_embeds( # Test error case: invalid prompt_embeds with pytest.raises(BadRequestError): await client_with_prompt_embeds.completions.create( - prompt="", + prompt=None, model=model_name, max_tokens=5, temperature=0.0, @@ -237,7 +237,7 @@ async def test_completions_with_logprobs_and_prompt_embeds( # Test case: Logprobs using prompt_embeds completion = await client_with_prompt_embeds.completions.create( model=model_name, - prompt="", # Add empty prompt as required parameter + prompt=None, max_tokens=5, temperature=0.0, echo=False, @@ -257,7 +257,7 @@ async def test_completions_with_logprobs_and_prompt_embeds( # Test case: Log probs with batch completion and prompt_embeds completion = await client_with_prompt_embeds.completions.create( model=model_name, - prompt="", # Add empty prompt as required parameter + prompt=None, max_tokens=5, temperature=0.0, echo=False, @@ -287,7 +287,7 @@ async def test_prompt_logprobs_raises_error( with pytest.raises(BadRequestError, match="not compatible"): await client_with_prompt_embeds.completions.create( model=MODEL_NAME, - prompt="", + prompt=None, max_tokens=5, temperature=0.0, extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True}, diff --git a/tests/entrypoints/openai/test_embedding_shape_validation.py b/tests/entrypoints/openai/test_embedding_shape_validation.py index 27060e0be5ae..fcac5be5b5a7 100644 --- a/tests/entrypoints/openai/test_embedding_shape_validation.py +++ b/tests/entrypoints/openai/test_embedding_shape_validation.py @@ -7,7 +7,7 @@ are rejected before they can cause crashes during model inference. Validation is performed by the parser (MultiModalDataParser) and EmbeddingItems -classes, not by CompletionRenderer or MediaIO classes. +classes, not by MediaIO classes. """ import pytest diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index bfadf51e43af..b43577406b84 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -16,7 +16,8 @@ from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.lora.request import LoRARequest from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry -from vllm.tokenizers import get_tokenizer +from vllm.renderers.hf import HfRenderer +from vllm.tokenizers.registry import tokenizer_args_from_config from vllm.v1.engine.async_llm import AsyncLLM MODEL_NAME = "openai-community/gpt2" @@ -35,6 +36,7 @@ class MockModelConfig: """Minimal mock ModelConfig for testing.""" model: str = MODEL_NAME + runner_type = "generate" tokenizer: str = MODEL_NAME trust_remote_code: bool = False tokenizer_mode: str = "auto" @@ -85,15 +87,21 @@ def register_mock_resolver(): del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME] +def _build_renderer(model_config: MockModelConfig): + _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config) + + return HfRenderer( + model_config, + tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name}, + ) + + @pytest.fixture def mock_serving_setup(): """Provides a mocked engine and serving completion instance.""" mock_engine = MagicMock(spec=AsyncLLM) mock_engine.errored = False - tokenizer = get_tokenizer(MODEL_NAME) - mock_engine.get_tokenizer = AsyncMock(return_value=tokenizer) - async def mock_add_lora_side_effect(lora_request: LoRARequest): """Simulate engine behavior when adding LoRAs.""" if lora_request.lora_name == "test-lora": @@ -118,6 +126,7 @@ async def mock_generate(*args, **kwargs): mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) models = OpenAIServingModels( engine_client=mock_engine, @@ -128,10 +137,6 @@ async def mock_generate(*args, **kwargs): mock_engine, models, request_logger=None ) - serving_completion._process_inputs = AsyncMock( - return_value=(MagicMock(name="engine_request"), {}) - ) - return mock_engine, serving_completion diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index cd5661e5739f..5aff3b3c7bd9 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -12,7 +12,7 @@ import torch from vllm.config import ModelConfig -from vllm.entrypoints.renderer import CompletionRenderer +from vllm.renderers.embed_utils import safe_load_prompt_embeds from ...utils import RemoteOpenAIServer @@ -30,7 +30,7 @@ async def test_empty_prompt(): ): await client.completions.create( model=model_name, - prompt="", + prompt=None, max_tokens=5, temperature=0.0, extra_body={"prompt_embeds": []}, @@ -63,7 +63,6 @@ def test_load_prompt_embeds( ): model_config = Mock(spec=ModelConfig) model_config.enable_prompt_embeds = True - renderer = CompletionRenderer(model_config, tokenizer=None) # construct arbitrary tensors of various dtypes, layouts, and sizes. # We need to check against different layouts to make sure that if a user @@ -89,9 +88,7 @@ def test_load_prompt_embeds( buffer.seek(0) encoded_tensor = pybase64.b64encode(buffer.getvalue()) - loaded_prompt_embeds = renderer.load_prompt_embeds(encoded_tensor) - assert len(loaded_prompt_embeds) == 1 - loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"] + loaded_tensor = safe_load_prompt_embeds(model_config, encoded_tensor) assert loaded_tensor.device.type == "cpu" assert loaded_tensor.layout == torch.strided torch.testing.assert_close( @@ -105,7 +102,6 @@ def test_load_prompt_embeds( def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: int): model_config = Mock(spec=ModelConfig) model_config.enable_prompt_embeds = False - renderer = CompletionRenderer(model_config, tokenizer=None) tensor = torch.randn((seq_len, hidden_size), dtype=dtype) @@ -115,4 +111,4 @@ def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: in encoded_tensor = pybase64.b64encode(buffer.getvalue()) with pytest.raises(ValueError, match="--enable-prompt-embeds"): - renderer.load_prompt_embeds(encoded_tensor) + safe_load_prompt_embeds(model_config, encoded_tensor) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index fa29b31be365..b966e7dd7e31 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -556,19 +556,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: request_logger=None, ) - async def _fake_process_inputs( - request_id, - engine_prompt, - sampling_params, - *, - lora_request, - trace_headers, - priority, - data_parallel_rank, - ): - return dict(engine_prompt), {} - - serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs) return serving_chat @@ -784,7 +771,7 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated(): resp = await serving_chat.create_chat_completion(req) assert isinstance(resp, ErrorResponse) - assert "max_tokens" in resp.error.message + assert "context length is only" in resp.error.message @pytest.mark.asyncio @@ -824,7 +811,7 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected(): resp = await serving_chat.create_chat_completion(req) assert isinstance(resp, ErrorResponse) - assert "maximum context length" in resp.error.message + assert "context length is only" in resp.error.message @pytest.mark.asyncio @@ -890,6 +877,20 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): serving_chat = _build_serving_chat(mock_engine) + orig_render_chat_request = serving_chat.render_chat_request + captured_prompts = [] + + async def render_chat_request(request): + result = await orig_render_chat_request(request) + + assert isinstance(result, tuple) + conversation, engine_prompts = result + captured_prompts.extend(engine_prompts) + + return result + + serving_chat.render_chat_request = render_chat_request + # Test cache_salt req = ChatCompletionRequest( model=MODEL_NAME, @@ -899,15 +900,19 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): # By default, cache_salt in the engine prompt is not set with suppress(Exception): await serving_chat.create_chat_completion(req) - engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1] - assert "cache_salt" not in engine_prompt + + assert len(captured_prompts) == 1 + assert "cache_salt" not in captured_prompts[0] + + captured_prompts.clear() # Test with certain cache_salt req.cache_salt = "test_salt" with suppress(Exception): await serving_chat.create_chat_completion(req) - engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1] - assert engine_prompt.get("cache_salt") == "test_salt" + + assert len(captured_prompts) == 1 + assert captured_prompts[0]["cache_salt"] == "test_salt" @pytest.mark.asyncio @@ -1007,11 +1012,11 @@ def stream(self, request) -> bool: @pytest.fixture() def mock_engine(self) -> AsyncLLM: mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) return mock_engine @pytest.fixture() @@ -1618,11 +1623,11 @@ async def test_tool_choice_validation_without_parser(): """Test that tool_choice='required' or named tool without tool_parser returns an appropriate error message.""" mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) models = OpenAIServingModels( engine_client=mock_engine, diff --git a/tests/entrypoints/pooling/basic/test_truncation.py b/tests/entrypoints/pooling/basic/test_truncation.py index 5d099dd1f439..fcaead0e254c 100644 --- a/tests/entrypoints/pooling/basic/test_truncation.py +++ b/tests/entrypoints/pooling/basic/test_truncation.py @@ -67,20 +67,6 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI): assert response["usage"]["prompt_tokens"] == truncation_size -@pytest.mark.asyncio -async def test_zero_truncation_size(client: openai.AsyncOpenAI): - truncation_size = 0 - kwargs: dict[str, Any] = { - "model": MODEL_NAME, - "input": input, - "truncate_prompt_tokens": truncation_size, - } - - response = await client.post(path="embeddings", cast_to=object, body={**kwargs}) - - assert response["usage"]["prompt_tokens"] == truncation_size - - @pytest.mark.asyncio async def test_bigger_truncation_size(client: openai.AsyncOpenAI): truncation_size = max_model_len + 1 diff --git a/tests/entrypoints/pooling/classify/test_online.py b/tests/entrypoints/pooling/classify/test_online.py index 84b0173933c8..45712c8425bf 100644 --- a/tests/entrypoints/pooling/classify/test_online.py +++ b/tests/entrypoints/pooling/classify/test_online.py @@ -128,12 +128,10 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str): server.url_for("classify"), json={"model": model_name, "input": []}, ) - classification_response.raise_for_status() - output = ClassificationResponse.model_validate(classification_response.json()) - assert output.object == "list" - assert isinstance(output.data, list) - assert len(output.data) == 0 + error = classification_response.json() + assert classification_response.status_code == 400 + assert "error" in error @pytest.mark.parametrize("model_name", [MODEL_NAME]) diff --git a/tests/entrypoints/pooling/score/test_online_score.py b/tests/entrypoints/pooling/score/test_online_score.py index 1c74cf297f6a..e1cc074e885d 100644 --- a/tests/entrypoints/pooling/score/test_online_score.py +++ b/tests/entrypoints/pooling/score/test_online_score.py @@ -247,7 +247,7 @@ def test_score_max_model_len( }, ) assert score_response.status_code == 400 - assert "Please, select a smaller truncation size." in score_response.text + assert "Please request a smaller truncation size." in score_response.text def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, Any]): queries = "What is the capital of France?" diff --git a/tests/entrypoints/test_renderer.py b/tests/entrypoints/test_renderer.py deleted file mode 100644 index b0ef3dd045bd..000000000000 --- a/tests/entrypoints/test_renderer.py +++ /dev/null @@ -1,325 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import io -from dataclasses import dataclass -from unittest.mock import AsyncMock, MagicMock - -import pybase64 -import pytest -import torch - -from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig -from vllm.inputs.data import is_embeds_prompt - - -@dataclass -class MockModelConfig: - max_model_len: int = 100 - encoder_config: dict | None = None - enable_prompt_embeds: bool = True - - -class MockTokenizerResult: - def __init__(self, input_ids): - self.input_ids = input_ids - - -@pytest.fixture -def mock_model_config(): - return MockModelConfig() - - -@pytest.fixture -def mock_tokenizer(): - tokenizer = MagicMock() - return tokenizer - - -@pytest.fixture -def mock_async_tokenizer(): - async_tokenizer = AsyncMock() - return async_tokenizer - - -@pytest.fixture -def renderer(mock_model_config, mock_tokenizer): - return CompletionRenderer( - model_config=mock_model_config, - tokenizer=mock_tokenizer, - async_tokenizer_pool={}, - ) - - -class TestRenderPrompt: - """Test Category A: Basic Functionality Tests""" - - @pytest.mark.asyncio - async def test_token_input(self, renderer): - tokens = [101, 7592, 2088] - results = await renderer.render_prompt( - prompt_or_prompts=tokens, config=RenderConfig(max_length=100) - ) - - assert len(results) == 1 - assert results[0]["prompt_token_ids"] == tokens - - @pytest.mark.asyncio - async def test_token_list_input(self, renderer): - token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]] - results = await renderer.render_prompt( - prompt_or_prompts=token_lists, config=RenderConfig(max_length=100) - ) - - assert len(results) == 3 - assert results[0]["prompt_token_ids"] == [101, 7592, 2088] - assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012] - assert results[2]["prompt_token_ids"] == [103, 4567] - - @pytest.mark.asyncio - async def test_text_input(self, renderer, mock_async_tokenizer): - mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088]) - renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer - - results = await renderer.render_prompt( - prompt_or_prompts="Hello world", config=RenderConfig(max_length=100) - ) - - assert len(results) == 1 - assert results[0]["prompt_token_ids"] == [101, 7592, 2088] - mock_async_tokenizer.assert_called_once() - - @pytest.mark.asyncio - async def test_text_list_input(self, renderer, mock_async_tokenizer): - mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088]) - renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer - - text_list_input = ["Hello world", "How are you?", "Good morning"] - results = await renderer.render_prompt( - prompt_or_prompts=text_list_input, config=RenderConfig(max_length=100) - ) - - assert len(results) == 3 - for result in results: - assert result["prompt_token_ids"] == [101, 7592, 2088] - assert mock_async_tokenizer.call_count == 3 - - @pytest.mark.asyncio - async def test_no_truncation(self, renderer, mock_async_tokenizer): - mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088]) - renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer - - results = await renderer.render_prompt( - prompt_or_prompts="Hello world", config=RenderConfig(max_length=100) - ) - - assert len(results) == 1 - call_args = mock_async_tokenizer.call_args - assert ( - "truncation" not in call_args.kwargs - or call_args.kwargs["truncation"] is False - ) - - @pytest.mark.asyncio - async def test_truncation_positive(self, renderer, mock_async_tokenizer): - mock_async_tokenizer.return_value = MockTokenizerResult( - [101, 7592, 2088] - ) # Truncated - renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer - - results = await renderer.render_prompt( - prompt_or_prompts="Hello world", - config=RenderConfig(max_length=100, truncate_prompt_tokens=50), - ) - - assert len(results) == 1 - call_args = mock_async_tokenizer.call_args - assert call_args.kwargs["truncation"] is True - assert call_args.kwargs["max_length"] == 50 - - @pytest.mark.asyncio - async def test_truncation_negative(self, renderer, mock_async_tokenizer): - # Test that negative truncation uses model's max_model_len - mock_async_tokenizer.return_value = MockTokenizerResult( - [101, 7592, 2088] - ) # Truncated to max_model_len - renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer - - results = await renderer.render_prompt( - prompt_or_prompts="Hello world", - config=RenderConfig(max_length=200, truncate_prompt_tokens=-1), - ) - - assert len(results) == 1 - call_args = mock_async_tokenizer.call_args - assert call_args.kwargs["truncation"] is True - assert call_args.kwargs["max_length"] == 100 # model's max_model_len - - @pytest.mark.asyncio - async def test_token_truncation_last_elements(self, renderer): - # Test that token truncation keeps the last N elements - long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens - results = await renderer.render_prompt( - prompt_or_prompts=long_tokens, - config=RenderConfig(max_length=100, truncate_prompt_tokens=5), - ) - - assert len(results) == 1 - # Should keep the last 5 tokens: [105, 106, 107, 108, 109] - assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109] - - @pytest.mark.asyncio - async def test_max_length_exceeded(self, renderer): - long_tokens = list(range(150)) # Exceeds max_model_len=100 - - with pytest.raises(ValueError, match="maximum context length"): - await renderer.render_prompt( - prompt_or_prompts=long_tokens, config=RenderConfig(max_length=100) - ) - - @pytest.mark.asyncio - async def test_no_tokenizer_for_text(self, mock_model_config): - renderer_no_tokenizer = CompletionRenderer( - model_config=mock_model_config, tokenizer=None, async_tokenizer_pool={} - ) - - with pytest.raises(ValueError, match="No tokenizer available"): - await renderer_no_tokenizer.render_prompt( - prompt_or_prompts="Hello world", config=RenderConfig(max_length=100) - ) - - @pytest.mark.asyncio - async def test_token_input_with_needs_detokenization( - self, renderer, mock_async_tokenizer - ): - # When needs_detokenization=True for token inputs, renderer should - # use the async tokenizer to decode and include the original text - # in the returned prompt object. - mock_async_tokenizer.decode = AsyncMock(return_value="decoded text") - renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer - - tokens = [1, 2, 3, 4] - results = await renderer.render_prompt( - prompt_or_prompts=tokens, - config=RenderConfig(needs_detokenization=True), - ) - - assert len(results) == 1 - assert results[0]["prompt_token_ids"] == tokens - assert results[0]["prompt"] == "decoded text" - mock_async_tokenizer.decode.assert_awaited_once() - - -class TestRenderEmbedPrompt: - def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes: - """Helper to create base64-encoded tensor bytes""" - buffer = io.BytesIO() - torch.save(tensor, buffer) - buffer.seek(0) - return pybase64.b64encode(buffer.read()) - - @pytest.mark.asyncio - async def test_single_prompt_embed(self, renderer): - # Create a test tensor - test_tensor = torch.randn(10, 768, dtype=torch.float32) - embed_bytes = self._create_test_embed_bytes(test_tensor) - - results = await renderer.render_prompt_and_embeds( - prompt_embeds=embed_bytes, - config=RenderConfig(cache_salt="test_salt"), - ) - - assert len(results) == 1 - assert is_embeds_prompt(results[0]) - assert torch.allclose(results[0]["prompt_embeds"], test_tensor) - assert results[0]["cache_salt"] == "test_salt" - - @pytest.mark.asyncio - async def test_multiple_prompt_embeds(self, renderer): - # Create multiple test tensors - test_tensors = [ - torch.randn(8, 512, dtype=torch.float32), - torch.randn(12, 512, dtype=torch.float32), - ] - embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors] - - results = await renderer.render_prompt_and_embeds( - prompt_embeds=embed_bytes_list, - config=RenderConfig(), - ) - - assert len(results) == 2 - for i, result in enumerate(results): - assert is_embeds_prompt(result) - assert torch.allclose(result["prompt_embeds"], test_tensors[i]) - - @pytest.mark.asyncio - async def test_prompt_embed_truncation(self, renderer): - # Create tensor with more tokens than truncation limit - test_tensor = torch.randn(20, 768, dtype=torch.float32) - embed_bytes = self._create_test_embed_bytes(test_tensor) - - results = await renderer.render_prompt_and_embeds( - prompt_embeds=embed_bytes, - config=RenderConfig(truncate_prompt_tokens=10), - ) - - assert len(results) == 1 - # Should keep last 10 tokens - expected = test_tensor[-10:] - assert torch.allclose(results[0]["prompt_embeds"], expected) - - @pytest.mark.asyncio - async def test_prompt_embed_different_dtypes(self, renderer): - # Test different supported dtypes - dtypes = [torch.float32, torch.float16, torch.bfloat16] - - for dtype in dtypes: - test_tensor = torch.randn(5, 256, dtype=dtype) - embed_bytes = self._create_test_embed_bytes(test_tensor) - - results = await renderer.render_prompt_and_embeds( - prompt_embeds=embed_bytes, - config=RenderConfig(), - ) - - assert len(results) == 1 - assert results[0]["prompt_embeds"].dtype == dtype - - @pytest.mark.asyncio - async def test_prompt_embed_squeeze_batch_dim(self, renderer): - # Test tensor with batch dimension gets squeezed - test_tensor = torch.randn(1, 10, 768, dtype=torch.float32) - embed_bytes = self._create_test_embed_bytes(test_tensor) - - results = await renderer.render_prompt_and_embeds( - prompt_embeds=embed_bytes, - config=RenderConfig(), - ) - - assert len(results) == 1 - # Should be squeezed to 2D - assert results[0]["prompt_embeds"].shape == (10, 768) - - @pytest.mark.asyncio - async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer): - # Set up text tokenization - mock_async_tokenizer.return_value = MockTokenizerResult([101, 102, 103]) - renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer - - # Create embed - test_tensor = torch.randn(5, 256, dtype=torch.float32) - embed_bytes = self._create_test_embed_bytes(test_tensor) - - results = await renderer.render_prompt_and_embeds( - prompt_or_prompts="Hello world", - prompt_embeds=embed_bytes, - config=RenderConfig(), - ) - - assert len(results) == 2 - # First should be embed prompt - assert is_embeds_prompt(results[0]) - # Second should be tokens prompt - assert "prompt_token_ids" in results[1] - assert results[1]["prompt_token_ids"] == [101, 102, 103] diff --git a/tests/models/language/pooling/test_mm_classifier_conversion.py b/tests/models/language/pooling/test_mm_classifier_conversion.py index 631fd394f719..78448de5945f 100644 --- a/tests/models/language/pooling/test_mm_classifier_conversion.py +++ b/tests/models/language/pooling/test_mm_classifier_conversion.py @@ -96,7 +96,7 @@ def test_gemma_multimodal( dtype="bfloat16", ) as vllm_model: llm = vllm_model.get_llm() - prompts = llm.preprocess_chat(messages) + prompts = llm._preprocess_chat([messages]) result = llm.classify(prompts) assert result[0].outputs.probs[0] > 0.95 diff --git a/tests/models/language/pooling/test_truncation_control.py b/tests/models/language/pooling/test_truncation_control.py index f1870ddbee51..d41a3379dc0f 100644 --- a/tests/models/language/pooling/test_truncation_control.py +++ b/tests/models/language/pooling/test_truncation_control.py @@ -29,7 +29,8 @@ def test_smaller_truncation_size( model_name, runner="pooling", max_model_len=max_model_len ) as vllm_model: vllm_output = vllm_model.llm.embed( - input_str, truncate_prompt_tokens=truncate_prompt_tokens + input_str, + tokenization_kwargs=dict(truncate_prompt_tokens=truncate_prompt_tokens), ) prompt_tokens = vllm_output[0].prompt_token_ids @@ -44,7 +45,8 @@ def test_max_truncation_size(vllm_runner, model_name=MODEL_NAME, input_str=input model_name, runner="pooling", max_model_len=max_model_len ) as vllm_model: vllm_output = vllm_model.llm.embed( - input_str, truncate_prompt_tokens=truncate_prompt_tokens + input_str, + tokenization_kwargs=dict(truncate_prompt_tokens=truncate_prompt_tokens), ) prompt_tokens = vllm_output[0].prompt_token_ids @@ -64,7 +66,8 @@ def test_bigger_truncation_size( ) as vllm_model, ): llm_output = vllm_model.llm.embed( - input_str, truncate_prompt_tokens=truncate_prompt_tokens + input_str, + tokenization_kwargs=dict(truncate_prompt_tokens=truncate_prompt_tokens), ) assert ( diff --git a/tests/models/language/pooling_mteb_test/mteb_embed_utils.py b/tests/models/language/pooling_mteb_test/mteb_embed_utils.py index a736b991d4d5..da0b16449a6e 100644 --- a/tests/models/language/pooling_mteb_test/mteb_embed_utils.py +++ b/tests/models/language/pooling_mteb_test/mteb_embed_utils.py @@ -187,7 +187,10 @@ def mteb_test_embed_models( head_dtype = model_config.head_dtype # Test embedding_size, isnan and whether to use normalize - vllm_outputs = vllm_model.embed(example_prompts, truncate_prompt_tokens=-1) + vllm_outputs = vllm_model.embed( + example_prompts, + tokenization_kwargs=dict(truncate_prompt_tokens=-1), + ) outputs_tensor = torch.tensor(vllm_outputs) assert not torch.any(torch.isnan(outputs_tensor)) embedding_size = model_config.embedding_size diff --git a/tests/models/language/pooling_mteb_test/mteb_score_utils.py b/tests/models/language/pooling_mteb_test/mteb_score_utils.py index d9e3521d936f..ad32880390e9 100644 --- a/tests/models/language/pooling_mteb_test/mteb_score_utils.py +++ b/tests/models/language/pooling_mteb_test/mteb_score_utils.py @@ -79,9 +79,9 @@ def predict( outputs = self.llm.score( queries, corpus, - truncate_prompt_tokens=-1, use_tqdm=False, chat_template=self.chat_template, + tokenization_kwargs={"truncate_prompt_tokens": -1}, ) scores = np.array(outputs) scores = scores[np.argsort(r)] diff --git a/tests/renderers/test_completions.py b/tests/renderers/test_completions.py new file mode 100644 index 000000000000..7e33a1f9fbb9 --- /dev/null +++ b/tests/renderers/test_completions.py @@ -0,0 +1,426 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import io +from dataclasses import dataclass +from typing import Any +from unittest.mock import AsyncMock + +import pybase64 +import pytest +import torch + +from vllm.inputs.data import is_embeds_prompt +from vllm.renderers import TokenizeParams +from vllm.renderers.hf import HfRenderer +from vllm.tokenizers.registry import tokenizer_args_from_config + +MODEL_NAME = "openai-community/gpt2" + + +@dataclass +class MockHFConfig: + model_type: str = "any" + + +@dataclass +class MockModelConfig: + runner_type = "generate" + model: str = MODEL_NAME + tokenizer: str = MODEL_NAME + trust_remote_code: bool = False + max_model_len: int = 100 + tokenizer_revision = None + tokenizer_mode = "auto" + hf_config = MockHFConfig() + encoder_config: dict[str, Any] | None = None + enable_prompt_embeds: bool = True + skip_tokenizer_init: bool = False + + +@pytest.fixture +def mock_model_config(): + return MockModelConfig() + + +@pytest.fixture +def mock_async_tokenizer(): + return AsyncMock() + + +@pytest.fixture +def renderer(mock_model_config): + _, tokenizer_name, _, kwargs = tokenizer_args_from_config(mock_model_config) + + return HfRenderer( + mock_model_config, + tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name}, + ) + + +class TestValidatePrompt: + STRING_INPUTS = [ + "", + "foo", + "foo bar", + "foo baz bar", + "foo bar qux baz", + ] + + TOKEN_INPUTS = [ + [-1], + [1], + [1, 2], + [1, 3, 4], + [1, 2, 4, 3], + ] + + INPUTS_SLICES = [ + slice(None, None, -1), + slice(None, None, 2), + slice(None, None, -2), + ] + + # Test that a nested mixed-type list of lists raises a TypeError. + def test_empty_input(self, renderer): + with pytest.raises(ValueError, match="at least one prompt"): + renderer.render_completions([]) + + def test_invalid_type(self, renderer): + with pytest.raises(TypeError, match="string or an array of tokens"): + renderer.render_completions([[1, 2], ["foo", "bar"]]) + + @pytest.mark.parametrize("string_input", STRING_INPUTS) + def test_string_consistent(self, renderer, string_input: str): + assert renderer.render_completions(string_input) == renderer.render_completions( + [string_input] + ) + + @pytest.mark.parametrize("token_input", TOKEN_INPUTS) + def test_token_consistent(self, renderer, token_input: list[int]): + assert renderer.render_completions(token_input) == renderer.render_completions( + [token_input] + ) + + @pytest.mark.parametrize("inputs_slice", INPUTS_SLICES) + def test_string_slice(self, renderer, inputs_slice: slice): + assert renderer.render_completions(self.STRING_INPUTS)[ + inputs_slice + ] == renderer.render_completions(self.STRING_INPUTS[inputs_slice]) + + +class TestRenderPrompt: + @pytest.mark.asyncio + async def test_token_input(self, renderer): + tokens = [101, 7592, 2088] + prompts = await renderer.render_completions_async(tokens) + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams(max_total_tokens=100), + ) + + assert len(results) == 1 + assert results[0]["prompt_token_ids"] == tokens + + @pytest.mark.asyncio + async def test_token_list_input(self, renderer): + token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]] + prompts = await renderer.render_completions_async(token_lists) + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams(max_total_tokens=100), + ) + + assert len(results) == 3 + assert results[0]["prompt_token_ids"] == [101, 7592, 2088] + assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012] + assert results[2]["prompt_token_ids"] == [103, 4567] + + @pytest.mark.asyncio + async def test_text_input(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.encode.return_value = [101, 7592, 2088] + renderer._async_tokenizer = mock_async_tokenizer + + prompts = await renderer.render_completions_async("Hello world") + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams(max_total_tokens=100), + ) + + assert len(results) == 1 + assert results[0]["prompt_token_ids"] == [101, 7592, 2088] + mock_async_tokenizer.encode.assert_called_once() + + @pytest.mark.asyncio + async def test_text_list_input(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.encode.return_value = [101, 7592, 2088] + renderer._async_tokenizer = mock_async_tokenizer + + text_list_input = ["Hello world", "How are you?", "Good morning"] + prompts = await renderer.render_completions_async(text_list_input) + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams(max_total_tokens=100), + ) + + assert len(results) == 3 + for result in results: + assert result["prompt_token_ids"] == [101, 7592, 2088] + assert mock_async_tokenizer.encode.call_count == 3 + + @pytest.mark.asyncio + async def test_no_truncation(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.encode.return_value = [101, 7592, 2088] + renderer._async_tokenizer = mock_async_tokenizer + + prompts = await renderer.render_completions_async("Hello world") + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams(max_total_tokens=100), + ) + + assert len(results) == 1 + call_args = mock_async_tokenizer.encode.call_args + assert ( + "truncation" not in call_args.kwargs + or call_args.kwargs["truncation"] is False + ) + + @pytest.mark.asyncio + async def test_truncation_positive(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.encode.return_value = [101, 7592, 2088] # Truncated + renderer._async_tokenizer = mock_async_tokenizer + + prompts = await renderer.render_completions_async("Hello world") + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams( + max_total_tokens=200, + truncate_prompt_tokens=50, + ), + ) + + assert len(results) == 1 + call_args = mock_async_tokenizer.encode.call_args + assert call_args.kwargs["truncation"] is True + assert call_args.kwargs["max_length"] == 50 + + @pytest.mark.asyncio + async def test_truncation_negative(self, renderer, mock_async_tokenizer): + # Test that negative truncation uses model's max_model_len + mock_async_tokenizer.encode.return_value = [ + 101, + 7592, + 2088, + ] # Truncated to max_model_len + renderer._async_tokenizer = mock_async_tokenizer + + prompts = await renderer.render_completions_async("Hello world") + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams( + max_total_tokens=200, + truncate_prompt_tokens=-1, + ), + ) + + assert len(results) == 1 + call_args = mock_async_tokenizer.encode.call_args + assert call_args.kwargs["truncation"] is True + assert call_args.kwargs["max_length"] == 200 + + @pytest.mark.asyncio + async def test_token_truncation_last_elements(self, renderer): + # Test that token truncation keeps the last N elements + long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens + prompts = await renderer.render_completions_async(long_tokens) + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams( + max_total_tokens=100, + truncate_prompt_tokens=5, + ), + ) + + assert len(results) == 1 + # Should keep the last 5 tokens: [105, 106, 107, 108, 109] + assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109] + + @pytest.mark.asyncio + async def test_max_length_exceeded(self, renderer): + long_tokens = list(range(150)) # Exceeds max_model_len=100 + + prompts = await renderer.render_completions_async(long_tokens) + + with pytest.raises(ValueError, match="context length is only"): + await renderer.tokenize_prompts_async( + prompts, + TokenizeParams(max_total_tokens=100), + ) + + @pytest.mark.asyncio + async def test_no_tokenizer_for_text(self, renderer): + renderer_no_tokenizer = HfRenderer.from_config( + MockModelConfig(skip_tokenizer_init=True), + tokenizer_kwargs={}, + ) + + prompts = await renderer_no_tokenizer.render_completions_async("Hello world") + + with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"): + await renderer_no_tokenizer.tokenize_prompts_async( + prompts, + TokenizeParams(max_total_tokens=100), + ) + + @pytest.mark.asyncio + async def test_token_input_with_needs_detokenization( + self, renderer, mock_async_tokenizer + ): + # When needs_detokenization=True for token inputs, renderer should + # use the async tokenizer to decode and include the original text + # in the returned prompt object. + mock_async_tokenizer.decode = AsyncMock(return_value="decoded text") + renderer._async_tokenizer = mock_async_tokenizer + + tokens = [1, 2, 3, 4] + prompts = await renderer.render_completions_async(tokens) + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams( + max_total_tokens=renderer.config.max_model_len, + needs_detokenization=True, + ), + ) + + assert len(results) == 1 + assert results[0]["prompt_token_ids"] == tokens + assert results[0]["prompt"] == "decoded text" + mock_async_tokenizer.decode.assert_awaited_once() + + +class TestRenderEmbedPrompt: + def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes: + """Helper to create base64-encoded tensor bytes""" + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + return pybase64.b64encode(buffer.read()) + + @pytest.mark.asyncio + async def test_single_prompt_embed(self, renderer): + # Create a test tensor + test_tensor = torch.randn(10, 768, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes) + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams(max_total_tokens=renderer.config.max_model_len), + ) + + assert len(results) == 1 + assert is_embeds_prompt(results[0]) + assert torch.allclose(results[0]["prompt_embeds"], test_tensor) + + @pytest.mark.asyncio + async def test_multiple_prompt_embeds(self, renderer): + # Create multiple test tensors + test_tensors = [ + torch.randn(8, 512, dtype=torch.float32), + torch.randn(12, 512, dtype=torch.float32), + ] + embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors] + + prompts = await renderer.render_completions_async( + prompt_embeds=embed_bytes_list + ) + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams(max_total_tokens=renderer.config.max_model_len), + ) + + assert len(results) == 2 + for i, result in enumerate(results): + assert is_embeds_prompt(result) + assert torch.allclose(result["prompt_embeds"], test_tensors[i]) + + @pytest.mark.asyncio + async def test_prompt_embed_truncation(self, renderer): + # Create tensor with more tokens than truncation limit + test_tensor = torch.randn(20, 768, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes) + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams( + max_total_tokens=renderer.config.max_model_len, + truncate_prompt_tokens=10, + ), + ) + + assert len(results) == 1 + # Should keep last 10 tokens + expected = test_tensor[-10:] + assert torch.allclose(results[0]["prompt_embeds"], expected) + + @pytest.mark.asyncio + async def test_prompt_embed_different_dtypes(self, renderer): + # Test different supported dtypes + dtypes = [torch.float32, torch.float16, torch.bfloat16] + + for dtype in dtypes: + test_tensor = torch.randn(5, 256, dtype=dtype) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes) + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams(max_total_tokens=renderer.config.max_model_len), + ) + + assert len(results) == 1 + assert results[0]["prompt_embeds"].dtype == dtype + + @pytest.mark.asyncio + async def test_prompt_embed_squeeze_batch_dim(self, renderer): + # Test tensor with batch dimension gets squeezed + test_tensor = torch.randn(1, 10, 768, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes) + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams(max_total_tokens=renderer.config.max_model_len), + ) + + assert len(results) == 1 + # Should be squeezed to 2D + assert results[0]["prompt_embeds"].shape == (10, 768) + + @pytest.mark.asyncio + async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer): + # Set up text tokenization + mock_async_tokenizer.encode.return_value = [101, 102, 103] + renderer._async_tokenizer = mock_async_tokenizer + + # Create embed + test_tensor = torch.randn(5, 256, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + prompts = await renderer.render_completions_async( + "Hello world", + prompt_embeds=embed_bytes, + ) + results = await renderer.tokenize_prompts_async( + prompts, + TokenizeParams(max_total_tokens=renderer.config.max_model_len), + ) + + assert len(results) == 2 + # First should be embed prompt + assert is_embeds_prompt(results[0]) + # Second should be tokens prompt + assert "prompt_token_ids" in results[1] + assert results[1]["prompt_token_ids"] == [101, 102, 103] diff --git a/tests/renderers/test_mistral.py b/tests/renderers/test_mistral.py index 0dc214ae939b..9346582bf75a 100644 --- a/tests/renderers/test_mistral.py +++ b/tests/renderers/test_mistral.py @@ -9,6 +9,7 @@ from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from vllm.config import ModelConfig +from vllm.renderers import ChatParams from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template from vllm.tokenizers.mistral import MistralTokenizer @@ -27,7 +28,7 @@ def mocked_apply_chat_template(*_args, **_kwargs): mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={}) mock_renderer._tokenizer = mock_tokenizer - task = mock_renderer.render_messages_async([]) + task = mock_renderer.render_messages_async([], ChatParams()) # Ensure the event loop is not blocked blocked_count = 0 diff --git a/tests/entrypoints/openai/test_sparse_tensor_validation.py b/tests/renderers/test_sparse_tensor_validation.py similarity index 83% rename from tests/entrypoints/openai/test_sparse_tensor_validation.py rename to tests/renderers/test_sparse_tensor_validation.py index c9bd156272ff..a90eac4782f7 100644 --- a/tests/entrypoints/openai/test_sparse_tensor_validation.py +++ b/tests/renderers/test_sparse_tensor_validation.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Sparse tensor validation in embedding APIs. - Tests verify that malicious sparse tensors are rejected before they can trigger out-of-bounds memory writes during to_dense() operations. """ @@ -13,8 +11,24 @@ import pytest import torch -from vllm.entrypoints.renderer import CompletionRenderer from vllm.multimodal.media import AudioEmbeddingMediaIO, ImageEmbeddingMediaIO +from vllm.renderers.embed_utils import safe_load_prompt_embeds + + +@pytest.fixture +def model_config(): + """Mock ModelConfig for testing.""" + from vllm.config import ModelConfig + + return ModelConfig( + model="facebook/opt-125m", + tokenizer="facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float32", + seed=0, + enable_prompt_embeds=True, # Required for prompt embeds tests + ) def _encode_tensor(tensor: torch.Tensor) -> bytes: @@ -63,15 +77,12 @@ class TestPromptEmbedsValidation: def test_valid_dense_tensor_accepted(self, model_config): """Baseline: Valid dense tensors should work normally.""" - renderer = CompletionRenderer(model_config) - valid_tensor = _create_valid_dense_tensor() encoded = _encode_tensor(valid_tensor) # Should not raise any exception - result = renderer.load_prompt_embeds(encoded) - assert len(result) == 1 - assert result[0]["prompt_embeds"].shape == valid_tensor.shape + result = safe_load_prompt_embeds(model_config, encoded) + assert result.shape == valid_tensor.shape def test_valid_sparse_tensor_accepted(self): """Baseline: Valid sparse tensors should load successfully.""" @@ -86,14 +97,12 @@ def test_valid_sparse_tensor_accepted(self): def test_malicious_sparse_tensor_rejected(self, model_config): """Security: Malicious sparse tensors should be rejected.""" - renderer = CompletionRenderer(model_config) - malicious_tensor = _create_malicious_sparse_tensor() encoded = _encode_tensor(malicious_tensor) # Should raise RuntimeError due to invalid sparse tensor with pytest.raises((RuntimeError, ValueError)) as exc_info: - renderer.load_prompt_embeds(encoded) + safe_load_prompt_embeds(model_config, encoded) # Error should indicate sparse tensor validation failure error_msg = str(exc_info.value).lower() @@ -101,8 +110,6 @@ def test_malicious_sparse_tensor_rejected(self, model_config): def test_extremely_large_indices_rejected(self, model_config): """Security: Sparse tensors with extremely large indices should be rejected.""" - renderer = CompletionRenderer(model_config) - # Create tensor with indices far beyond reasonable bounds indices = torch.tensor([[999999], [999999]]) values = torch.tensor([1.0]) @@ -114,12 +121,10 @@ def test_extremely_large_indices_rejected(self, model_config): encoded = _encode_tensor(malicious_tensor) with pytest.raises((RuntimeError, ValueError)): - renderer.load_prompt_embeds(encoded) + safe_load_prompt_embeds(model_config, encoded) def test_negative_indices_rejected(self, model_config): """Security: Sparse tensors with negative indices should be rejected.""" - renderer = CompletionRenderer(model_config) - # Create tensor with negative indices indices = torch.tensor([[-1], [-1]]) values = torch.tensor([1.0]) @@ -131,7 +136,7 @@ def test_negative_indices_rejected(self, model_config): encoded = _encode_tensor(malicious_tensor) with pytest.raises((RuntimeError, ValueError)): - renderer.load_prompt_embeds(encoded) + safe_load_prompt_embeds(model_config, encoded) class TestImageEmbedsValidation: @@ -253,14 +258,12 @@ def test_attack_scenario_completions_api(self, model_config): 3. Sends to /v1/completions with prompt_embeds parameter 4. Server should reject before memory corruption occurs """ - renderer = CompletionRenderer(model_config) - # Step 1-2: Attacker creates malicious payload attack_payload = _encode_tensor(_create_malicious_sparse_tensor()) # Step 3-4: Server processes and should reject with pytest.raises((RuntimeError, ValueError)): - renderer.load_prompt_embeds(attack_payload) + safe_load_prompt_embeds(model_config, attack_payload) def test_attack_scenario_chat_api_image(self): """ @@ -285,57 +288,3 @@ def test_attack_scenario_chat_api_audio(self): with pytest.raises((RuntimeError, ValueError)): io_handler.load_base64("", attack_payload.decode("utf-8")) - - def test_multiple_valid_embeddings_in_batch(self, model_config): - """ - Regression test: Multiple valid embeddings should still work. - - Ensures the fix doesn't break legitimate batch processing. - """ - renderer = CompletionRenderer(model_config) - - valid_tensors = [ - _encode_tensor(_create_valid_dense_tensor()), - _encode_tensor(_create_valid_dense_tensor()), - _encode_tensor(_create_valid_dense_tensor()), - ] - - # Should process all without error - result = renderer.load_prompt_embeds(valid_tensors) - assert len(result) == 3 - - def test_mixed_valid_and_malicious_rejected(self, model_config): - """ - Security: Batch with one malicious tensor should be rejected. - - Even if most tensors are valid, a single malicious one should - cause rejection of the entire batch. - """ - renderer = CompletionRenderer(model_config) - - mixed_batch = [ - _encode_tensor(_create_valid_dense_tensor()), - _encode_tensor(_create_malicious_sparse_tensor()), # Malicious - _encode_tensor(_create_valid_dense_tensor()), - ] - - # Should fail on the malicious tensor - with pytest.raises((RuntimeError, ValueError)): - renderer.load_prompt_embeds(mixed_batch) - - -# Pytest fixtures -@pytest.fixture -def model_config(): - """Mock ModelConfig for testing.""" - from vllm.config import ModelConfig - - return ModelConfig( - model="facebook/opt-125m", - tokenizer="facebook/opt-125m", - tokenizer_mode="auto", - trust_remote_code=False, - dtype="float32", - seed=0, - enable_prompt_embeds=True, # Required for prompt embeds tests - ) diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 6ea4f465cdff..a051bc54b818 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -5,65 +5,10 @@ from vllm.config import ModelConfig from vllm.inputs import zip_enc_dec_prompts -from vllm.inputs.parse import parse_raw_prompts from vllm.inputs.preprocess import InputPreprocessor pytestmark = pytest.mark.cpu_test -STRING_INPUTS = [ - "", - "foo", - "foo bar", - "foo baz bar", - "foo bar qux baz", -] - -TOKEN_INPUTS = [ - [-1], - [1], - [1, 2], - [1, 3, 4], - [1, 2, 4, 3], -] - -INPUTS_SLICES = [ - slice(None, None, -1), - slice(None, None, 2), - slice(None, None, -2), -] - - -# Test that a nested mixed-type list of lists raises a TypeError. -@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]]) -def test_invalid_input_raise_type_error(invalid_input): - with pytest.raises(TypeError): - parse_raw_prompts(invalid_input) - - -def test_parse_raw_single_batch_empty(): - with pytest.raises(ValueError, match="at least one prompt"): - parse_raw_prompts([]) - - with pytest.raises(ValueError, match="at least one prompt"): - parse_raw_prompts([[]]) - - -@pytest.mark.parametrize("string_input", STRING_INPUTS) -def test_parse_raw_single_batch_string_consistent(string_input: str): - assert parse_raw_prompts(string_input) == parse_raw_prompts([string_input]) - - -@pytest.mark.parametrize("token_input", TOKEN_INPUTS) -def test_parse_raw_single_batch_token_consistent(token_input: list[int]): - assert parse_raw_prompts(token_input) == parse_raw_prompts([token_input]) - - -@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES) -def test_parse_raw_single_batch_string_slice(inputs_slice: slice): - assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] == parse_raw_prompts( - STRING_INPUTS[inputs_slice] - ) - @pytest.mark.parametrize( "mm_processor_kwargs,expected_mm_kwargs", diff --git a/vllm/config/model.py b/vllm/config/model.py index 527af4c54c7f..b0640f1b42f7 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -767,7 +767,7 @@ def maybe_pull_model_tokenizer_for_runai(self, model: str, tokenizer: str) -> No ) self.tokenizer = object_storage_tokenizer.dir - def _get_encoder_config(self): + def _get_encoder_config(self) -> dict[str, Any] | None: model = self.model if is_remote_gguf(model): model, _ = split_remote_gguf(model) @@ -1916,7 +1916,7 @@ def _get_and_verify_max_len( disable_sliding_window: bool, sliding_window: int | None, spec_target_max_model_len: int | None = None, - encoder_config: Any | None = None, + encoder_config: dict[str, Any] | None = None, ) -> int: """Get and verify the model's maximum length.""" (derived_max_model_len, max_len_key) = ( diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 205efd1d582e..282dbf9910b9 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -72,14 +72,9 @@ def encode( lora_request: LoRARequest | None = None, trace_headers: Mapping[str, str] | None = None, priority: int = 0, - truncate_prompt_tokens: int | None = None, tokenization_kwargs: dict[str, Any] | None = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: - """Generate outputs for a request from a pooling model. - - NOTE: truncate_prompt_tokens is deprecated in v0.14. - TODO: Remove this argument in v0.15. - """ + """Generate outputs for a request from a pooling model.""" ... @abstractmethod diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9ef7af09843b..5ee86ee72f6c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools +import warnings from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, TypeAlias, cast import cloudpickle import torch.nn as nn @@ -46,15 +47,17 @@ compress_token_type_ids, get_score_prompt, ) -from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args +from vllm.entrypoints.utils import log_non_default_args from vllm.inputs import ( DataPrompt, + EmbedsPrompt, + ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt, TextPrompt, TokensPrompt, ) -from vllm.inputs.parse import get_prompt_components +from vllm.inputs.parse import get_prompt_components, is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.quantization import QuantizationMethods @@ -67,6 +70,7 @@ ) from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams +from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.tasks import PoolingTask from vllm.tokenizers import TokenizerLike @@ -74,7 +78,6 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils.collection_utils import as_iter, is_list_of from vllm.utils.counter import Counter -from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.sample.logits_processor import LogitsProcessor @@ -85,6 +88,9 @@ _R = TypeVar("_R", default=Any) +EnginePrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt +EngineEncDecPrompt: TypeAlias = ExplicitEncoderDecoderPrompt[EnginePrompt, EnginePrompt] + class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -372,6 +378,7 @@ def generate( use_tqdm: bool | Callable[..., tqdm] = True, lora_request: list[LoRARequest] | LoRARequest | None = None, priority: list[int] | None = None, + tokenization_kwargs: dict[str, Any] | None = None, ) -> list[RequestOutput]: """Generates the completions for the input prompts. @@ -398,15 +405,11 @@ def generate( If provided, must be a list of integers matching the length of `prompts`, where each priority value corresponds to the prompt at the same index. + tokenization_kwargs: Overrides for `tokenizer.encode`. Returns: A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. - - Note: - Using `prompts` and `prompt_token_ids` as keyword parameters is - considered legacy and may be deprecated in the future. You should - instead pass them via the `inputs` parameter. """ model_config = self.model_config runner_type = model_config.runner_type @@ -418,17 +421,14 @@ def generate( ) if sampling_params is None: - # Use default sampling params. sampling_params = self.get_default_sampling_params() - # Add any modality specific loras to the corresponding prompts - lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request) - self._validate_and_add_requests( prompts=prompts, params=sampling_params, use_tqdm=use_tqdm, - lora_request=lora_request, + lora_request=self._get_modality_specific_lora_reqs(prompts, lora_request), + tokenization_kwargs=tokenization_kwargs, priority=priority, ) @@ -771,65 +771,169 @@ def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt: return outputs - def preprocess_chat( + def _get_cmpl_tok_params(self, tokenization_kwargs: dict[str, Any] | None): + model_config = self.model_config + encoder_config = model_config.encoder_config or {} + + return TokenizeParams( + max_total_tokens=model_config.max_model_len, + do_lower_case=encoder_config.get("do_lower_case", False), + # For Whisper, special tokens should be provided by the user based + # on the task and language of their request. Also needed to avoid + # appending an EOS token to the prompt which disrupts generation. + add_special_tokens=not model_config.is_encoder_decoder, + ).with_kwargs(tokenization_kwargs) + + def _normalize_prompts( self, - messages: list[ChatCompletionMessageParam] + prompts: PromptType | Sequence[PromptType], + ) -> list[EnginePrompt | EngineEncDecPrompt]: + if isinstance(prompts, str): + prompts = TextPrompt(prompt=prompts) + + return prompts if isinstance(prompts, Sequence) else [prompts] # type: ignore[return-value] + + def _preprocess_cmpl_singleton( + self, + prompt: SingletonPrompt, + tok_params: TokenizeParams, + *, + tokenize: bool, + ) -> EnginePrompt: + renderer = self.llm_engine.renderer + + if not isinstance(prompt, dict): + prompt = renderer.render_completion(prompt) + + return renderer.tokenize_prompt(prompt, tok_params) if tokenize else prompt + + def _preprocess_cmpl_enc_dec( + self, + prompt: ExplicitEncoderDecoderPrompt, + tok_params: TokenizeParams, + ) -> EngineEncDecPrompt: + enc_prompt = prompt["encoder_prompt"] + dec_prompt = prompt["decoder_prompt"] + + return EngineEncDecPrompt( + encoder_prompt=self._preprocess_cmpl_singleton( + enc_prompt, + tok_params, + # TODO: Move multi-modal processor into tokenization + tokenize=not self.model_config.is_multimodal_model, + ), + decoder_prompt=( + None + if dec_prompt is None + else self._preprocess_cmpl_singleton( + dec_prompt, + tok_params, + # TODO: Move multi-modal processor into tokenization + tokenize=not self.model_config.is_multimodal_model, + ) + ), + ) + + def _preprocess_completion( + self, + prompts: PromptType | Sequence[PromptType], + tokenization_kwargs: dict[str, Any] | None = None, + ) -> list[EnginePrompt | EngineEncDecPrompt]: + """ + Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into + a format that can be passed to `_add_request`. + + Refer to [LLM.generate][] for a complete description of the arguments. + + Returns: + A list of `TokensPrompts` objects containing the tokenized prompt + after chat template interpolation, and the raw multi-modal inputs. + """ + tok_params = self._get_cmpl_tok_params(tokenization_kwargs) + + engine_prompts = list[EnginePrompt | EngineEncDecPrompt]() + for prompt in self._normalize_prompts(prompts): + if is_explicit_encoder_decoder_prompt(prompt): + engine_prompts.append(self._preprocess_cmpl_enc_dec(prompt, tok_params)) + else: + # Some MM models have non-default `add_special_tokens` + # TODO: Move multi-modal processor into tokenization + engine_prompts.append( + self._preprocess_cmpl_singleton( + prompt, + tok_params, + tokenize=not self.model_config.is_multimodal_model, + ) + ) + + return engine_prompts + + def _normalize_conversations( + self, + conversations: list[ChatCompletionMessageParam] + | list[list[ChatCompletionMessageParam]], + ) -> list[list[ChatCompletionMessageParam]]: + return conversations if is_list_of(conversations, list) else [conversations] # type: ignore[list-item,return-value] + + def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None): + model_config = self.model_config + encoder_config = model_config.encoder_config or {} + + return TokenizeParams( + max_total_tokens=model_config.max_model_len, + do_lower_case=encoder_config.get("do_lower_case", False), + add_special_tokens=False, + ).with_kwargs(tokenization_kwargs) + + def _preprocess_chat( + self, + conversations: list[ChatCompletionMessageParam] | list[list[ChatCompletionMessageParam]], chat_template: str | None = None, chat_template_content_format: ChatTemplateContentFormatOption = "auto", + chat_template_kwargs: dict[str, Any] | None = None, add_generation_prompt: bool = True, continue_final_message: bool = False, tools: list[dict[str, Any]] | None = None, - chat_template_kwargs: dict[str, Any] | None = None, + tokenization_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None, - ) -> list[TextPrompt | TokensPrompt]: + ) -> list[EnginePrompt]: """ - Generate prompt for a chat conversation. The pre-processed - prompt can then be used as input for the other LLM methods. + Convert a list of conversations into prompts so that they can then + be used as input for other LLM APIs. + + Refer to [LLM.chat][] for a complete description of the arguments. - Refer to `chat` for a complete description of the arguments. Returns: - A list of `TokensPrompts` objects containing the tokenized - prompt after chat template interpolation, and the - pre-processed multi-modal inputs. + A list of `TokensPrompts` objects containing the tokenized prompt + after chat template interpolation, and the raw multi-modal inputs. """ - list_of_messages: list[list[ChatCompletionMessageParam]] - - # Handle multi and single conversations - if is_list_of(messages, list): - # messages is list[list[...]] - list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages) - else: - # messages is list[...] - list_of_messages = [cast(list[ChatCompletionMessageParam], messages)] - renderer = self.llm_engine.renderer - chat_template_kwargs = { - "chat_template": chat_template, - "add_generation_prompt": add_generation_prompt, - "continue_final_message": continue_final_message, - "tools": tools, - **(chat_template_kwargs or {}), - } - - prompts = list[TextPrompt | TokensPrompt]() - - for msgs in list_of_messages: - # NOTE: renderer.render_messages() currently doesn't - # handle mm_processor_kwargs, since there is no implementation in - # the chat message parsing for it. - _, prompt = renderer.render_messages( - msgs, - chat_template_content_format=chat_template_content_format, - **chat_template_kwargs, - ) + chat_params = ChatParams( + chat_template=chat_template, + chat_template_content_format=chat_template_content_format, + chat_template_kwargs=merge_kwargs( + chat_template_kwargs, + dict( + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tools, + tokenize=isinstance(renderer.tokenizer, MistralTokenizer), + ), + ), + ) + tok_params = self._get_chat_tok_params(tokenization_kwargs) + + engine_prompts = list[EnginePrompt]() + for conversation in self._normalize_conversations(conversations): + _, in_prompt = renderer.render_messages(conversation, chat_params) if mm_processor_kwargs is not None: - prompt["mm_processor_kwargs"] = mm_processor_kwargs + in_prompt["mm_processor_kwargs"] = mm_processor_kwargs - prompts.append(prompt) + engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params)) - return prompts + return engine_prompts def chat( self, @@ -844,6 +948,7 @@ def chat( continue_final_message: bool = False, tools: list[dict[str, Any]] | None = None, chat_template_kwargs: dict[str, Any] | None = None, + tokenization_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None, ) -> list[RequestOutput]: """ @@ -889,22 +994,22 @@ def chat( `True` if `add_generation_prompt` is also `True`. chat_template_kwargs: Additional kwargs to pass to the chat template. - mm_processor_kwargs: Multimodal processor kwarg overrides for this - chat request. Only used for offline requests. + tokenization_kwargs: Overrides for `tokenizer.encode`. + mm_processor_kwargs: Overrides for `processor.__call__`. Returns: A list of `RequestOutput` objects containing the generated responses in the same order as the input messages. """ - - prompts = self.preprocess_chat( - messages=messages, + prompts = self._preprocess_chat( + messages, chat_template=chat_template, chat_template_content_format=chat_template_content_format, + chat_template_kwargs=chat_template_kwargs, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, tools=tools, - chat_template_kwargs=chat_template_kwargs, + tokenization_kwargs=tokenization_kwargs, mm_processor_kwargs=mm_processor_kwargs, ) @@ -913,6 +1018,7 @@ def chat( sampling_params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, ) def encode( @@ -945,37 +1051,29 @@ def encode( If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. pooling_task: Override the pooling task to use. - tokenization_kwargs: overrides tokenization_kwargs set in - pooling_params + tokenization_kwargs: Overrides for `tokenizer.encode`. Returns: A list of `PoolingRequestOutput` objects containing the pooled hidden states in the same order as the input prompts. - - Note: - Using `prompts` and `prompt_token_ids` as keyword parameters is - considered legacy and may be deprecated in the future. You should - instead pass them via the `inputs` parameter. """ - error_str = ( - "pooling_task required for `LLM.encode`\n" - "Please use one of the more specific methods or set the " - "pooling_task when using `LLM.encode`:\n" - " - For embeddings, use `LLM.embed(...)` " - 'or `pooling_task="embed"`.\n' - " - For classification logits, use `LLM.classify(...)` " - 'or `pooling_task="classify"`.\n' - " - For similarity scores, use `LLM.score(...)`.\n" - " - For rewards, use `LLM.reward(...)` " - 'or `pooling_task="token_classify"`\n' - " - For token classification, " - 'use `pooling_task="token_classify"`\n' - ' - For multi-vector retrieval, use `pooling_task="token_embed"`' - ) - if pooling_task is None: - raise ValueError(error_str) + raise ValueError( + "pooling_task required for `LLM.encode`\n" + "Please use one of the more specific methods or set the " + "pooling_task when using `LLM.encode`:\n" + " - For embeddings, use `LLM.embed(...)` " + 'or `pooling_task="embed"`.\n' + " - For classification logits, use `LLM.classify(...)` " + 'or `pooling_task="classify"`.\n' + " - For similarity scores, use `LLM.score(...)`.\n" + " - For rewards, use `LLM.reward(...)` " + 'or `pooling_task="token_classify"`\n' + " - For token classification, " + 'use `pooling_task="token_classify"`\n' + ' - For multi-vector retrieval, use `pooling_task="token_embed"`' + ) model_config = self.model_config runner_type = model_config.runner_type @@ -986,6 +1084,20 @@ def encode( "pooling model." ) + if truncate_prompt_tokens is not None: + warnings.warn( + "The `truncate_prompt_tokens` parameter in `LLM.encode()` " + "is deprecated and will be removed in v0.16. " + "Please pass it via `tokenization_kwargs` instead.", + DeprecationWarning, + stacklevel=2, + ) + + tokenization_kwargs = merge_kwargs( + tokenization_kwargs, + dict(truncate_prompt_tokens=truncate_prompt_tokens), + ) + io_processor_prompt = False if isinstance(prompts, dict) and "data" in prompts: io_processor_prompt = True @@ -1017,19 +1129,16 @@ def encode( pooling_params = self.io_processor.validate_or_generate_params( pooling_params ) - else: - if pooling_params is None: - # Use default pooling params. - pooling_params = PoolingParams() + + if pooling_params is None: + # Use default pooling params. + pooling_params = PoolingParams() if pooling_task not in self.supported_tasks: raise ValueError(f"pooling_task must be one of {self.supported_tasks}.") for param in as_iter(pooling_params): param.verify(pooling_task, model_config) - # for backwards compatibility - if truncate_prompt_tokens is not None: - param.truncate_prompt_tokens = truncate_prompt_tokens self._validate_and_add_requests( prompts=prompts, @@ -1094,6 +1203,7 @@ def embed( it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. + tokenization_kwargs: Overrides for `tokenizer.encode`. Returns: A list of `EmbeddingRequestOutput` objects containing the @@ -1105,9 +1215,14 @@ def embed( "Try converting the model using `--convert embed`." ) + if truncate_prompt_tokens is not None: + tokenization_kwargs = merge_kwargs( + tokenization_kwargs, + dict(truncate_prompt_tokens=truncate_prompt_tokens), + ) + items = self.encode( prompts, - truncate_prompt_tokens=truncate_prompt_tokens, use_tqdm=use_tqdm, pooling_params=pooling_params, lora_request=lora_request, @@ -1121,8 +1236,8 @@ def classify( self, prompts: PromptType | Sequence[PromptType], *, - use_tqdm: bool | Callable[..., tqdm] = True, pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, + use_tqdm: bool | Callable[..., tqdm] = True, lora_request: list[LoRARequest] | LoRARequest | None = None, tokenization_kwargs: dict[str, Any] | None = None, ) -> list[ClassificationRequestOutput]: @@ -1137,13 +1252,15 @@ def classify( prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] for more details about the format of each prompt. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. use_tqdm: If `True`, shows a tqdm progress bar. If a callable (e.g., `functools.partial(tqdm, leave=False)`), it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. - pooling_params: The pooling parameters for pooling. If None, we - use the default pooling parameters. + tokenization_kwargs: Overrides for `tokenizer.encode`. + Returns: A list of `ClassificationRequestOutput` objects containing the embedding vectors in the same order as the input prompts. @@ -1170,9 +1287,9 @@ def reward( prompts: PromptType | Sequence[PromptType], /, *, + pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, truncate_prompt_tokens: int | None = None, use_tqdm: bool | Callable[..., tqdm] = True, - pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None, tokenization_kwargs: dict[str, Any] | None = None, ) -> list[PoolingRequestOutput]: @@ -1183,13 +1300,15 @@ def reward( prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] for more details about the format of each prompt. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. use_tqdm: If `True`, shows a tqdm progress bar. If a callable (e.g., `functools.partial(tqdm, leave=False)`), it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. - pooling_params: The pooling parameters for pooling. If None, we - use the default pooling parameters. + tokenization_kwargs: Overrides for `tokenizer.encode`. + Returns: A list of `PoolingRequestOutput` objects containing the pooled hidden states in the same order as the input prompts. @@ -1207,18 +1326,18 @@ def reward( def _embedding_score( self, - tokenizer: TokenizerLike, - text_1: list[str | TextPrompt | TokensPrompt], - text_2: list[str | TextPrompt | TokensPrompt], - truncate_prompt_tokens: int | None = None, - use_tqdm: bool | Callable[..., tqdm] = True, - pooling_params: PoolingParams | None = None, - lora_request: list[LoRARequest] | LoRARequest | None = None, - tokenization_kwargs: dict[str, Any] | None = None, + text_1: list[SingletonPrompt], + text_2: list[SingletonPrompt], + *, + use_tqdm: bool | Callable[..., tqdm], + pooling_params: PoolingParams | None, + lora_request: list[LoRARequest] | LoRARequest | None, + tokenization_kwargs: dict[str, Any], ) -> list[ScoringRequestOutput]: - encoded_output: list[PoolingRequestOutput] = self.encode( + tokenizer = self.get_tokenizer() + + encoded_output = self.encode( text_1 + text_2, - truncate_prompt_tokens=truncate_prompt_tokens, use_tqdm=use_tqdm, lora_request=lora_request, pooling_params=pooling_params, @@ -1226,14 +1345,16 @@ def _embedding_score( tokenization_kwargs=tokenization_kwargs, ) - encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)] - encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :] + encoded_output_1 = encoded_output[0 : len(text_1)] + encoded_output_2 = encoded_output[len(text_1) :] if len(encoded_output_1) == 1: encoded_output_1 = encoded_output_1 * len(encoded_output_2) scores = _cosine_similarity( - tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2 + tokenizer=tokenizer, + embed_1=encoded_output_1, + embed_2=encoded_output_2, ) items = self.engine_class.validate_outputs(scores, PoolingRequestOutput) @@ -1241,17 +1362,17 @@ def _embedding_score( def _cross_encoding_score( self, - tokenizer: TokenizerLike, data_1: list[str] | list[ScoreContentPartParam], data_2: list[str] | list[ScoreContentPartParam], - truncate_prompt_tokens: int | None = None, - use_tqdm: bool | Callable[..., tqdm] = True, - pooling_params: PoolingParams | None = None, - lora_request: list[LoRARequest] | LoRARequest | None = None, - tokenization_kwargs: dict[str, Any] | None = None, - score_template: str | None = None, + *, + use_tqdm: bool | Callable[..., tqdm], + pooling_params: PoolingParams | None, + lora_request: list[LoRARequest] | LoRARequest | None, + tokenization_kwargs: dict[str, Any], + score_template: str | None, ) -> list[ScoringRequestOutput]: model_config = self.model_config + tokenizer = self.get_tokenizer() if isinstance(tokenizer, MistralTokenizer): raise ValueError("Score API is not supported for Mistral tokenizer") @@ -1265,13 +1386,6 @@ def _cross_encoding_score( pooling_params.verify("score", model_config) pooling_params_list = list[PoolingParams]() - local_kwargs = tokenization_kwargs or {} - tokenization_kwargs = local_kwargs.copy() - - _validate_truncation_size( - model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs - ) - prompts = list[PromptType]() input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] @@ -1314,10 +1428,10 @@ def score( data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam, /, *, - truncate_prompt_tokens: int | None = None, use_tqdm: bool | Callable[..., tqdm] = True, pooling_params: PoolingParams | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, chat_template: str | None = None, ) -> list[ScoringRequestOutput]: """Generate similarity scores for all pairs `` or @@ -1344,20 +1458,22 @@ def score( the LLM. Can be text or multi-modal data. See [PromptType] [vllm.inputs.PromptType] for more details about the format of each prompt. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. use_tqdm: If `True`, shows a tqdm progress bar. If a callable (e.g., `functools.partial(tqdm, leave=False)`), it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. - pooling_params: The pooling parameters for pooling. If None, we - use the default pooling parameters. chat_template: The chat template to use for the scoring. If None, we use the model's default chat template. + tokenization_kwargs: Overrides for `tokenizer.encode`. Returns: A list of `ScoringRequestOutput` objects containing the generated scores in the same order as the input prompts. """ model_config = self.model_config + runner_type = model_config.runner_type if runner_type != "pooling": raise ValueError( @@ -1445,26 +1561,27 @@ def ensure_str(prompt: SingletonPrompt): _validate_score_input_lens(data_1, data_2) # type: ignore[arg-type] + tok_params = self._get_cmpl_tok_params(tokenization_kwargs) + encode_kwargs = tok_params.get_encode_kwargs() + if model_config.is_cross_encoder: return self._cross_encoding_score( - tokenizer, data_1, # type: ignore[arg-type] data_2, # type: ignore[arg-type] - truncate_prompt_tokens, - use_tqdm, - pooling_params, - lora_request, + use_tqdm=use_tqdm, + pooling_params=pooling_params, + lora_request=lora_request, + tokenization_kwargs=encode_kwargs, score_template=chat_template, ) else: return self._embedding_score( - tokenizer, data_1, # type: ignore[arg-type] data_2, # type: ignore[arg-type] - truncate_prompt_tokens, - use_tqdm, - pooling_params, - lora_request, + use_tqdm=use_tqdm, + pooling_params=pooling_params, + lora_request=lora_request, + tokenization_kwargs=encode_kwargs, ) def start_profile(self) -> None: @@ -1530,42 +1647,79 @@ def get_metrics(self) -> list["Metric"]: def _validate_and_add_requests( self, - prompts: PromptType | Sequence[PromptType] | DataPrompt, + prompts: PromptType | Sequence[PromptType], params: SamplingParams | Sequence[SamplingParams] | PoolingParams | Sequence[PoolingParams], *, use_tqdm: bool | Callable[..., tqdm] = True, - lora_request: Sequence[LoRARequest] | LoRARequest | None, - priority: list[int] | None = None, + lora_request: Sequence[LoRARequest | None] | LoRARequest | None, tokenization_kwargs: dict[str, Any] | None = None, + priority: list[int] | None = None, ) -> None: - if isinstance(prompts, (str, dict)): - # Convert a single prompt to a list. - prompts = [prompts] # type: ignore[list-item] - - num_requests = len(prompts) - if isinstance(params, Sequence) and len(params) != num_requests: - raise ValueError("The lengths of prompts and params must be the same.") - if isinstance(lora_request, Sequence) and len(lora_request) != num_requests: - raise ValueError( - "The lengths of prompts and lora_request must be the same." - ) - if priority is not None and len(priority) != num_requests: - raise ValueError( - "The lengths of prompts " - f"({num_requests}) and priority ({len(priority)}) " - "must be the same." + in_prompts = self._normalize_prompts(prompts) + num_requests = len(in_prompts) + + if isinstance(params, Sequence): + if len(params) != num_requests: + raise ValueError( + f"The lengths of prompts ({params}) " + f"and lora_request ({len(params)}) must be the same." + ) + + engine_params = params + else: + engine_params = [params] * num_requests + + if isinstance(lora_request, Sequence): + if len(lora_request) != num_requests: + raise ValueError( + f"The lengths of prompts ({num_requests}) " + f"and lora_request ({len(lora_request)}) must be the same." + ) + + engine_lora_requests: Sequence[LoRARequest | None] = lora_request + else: + engine_lora_requests = [lora_request] * num_requests + + if priority is not None: + if len(priority) != num_requests: + raise ValueError( + f"The lengths of prompts ({num_requests}) " + f"and priority ({len(priority)}) must be the same." + ) + else: + priority = [0] * num_requests + + if any(param.truncate_prompt_tokens is not None for param in engine_params): + # TODO: Remove this after deprecating `param.truncate_prompt_tokens` + # Then, move the code from the `else` block to the top and let + # `self._preprocess_completion` handle prompt normalization + engine_prompts = [ + engine_prompt + for in_prompt, param in zip(in_prompts, engine_params) + for engine_prompt in self._preprocess_completion( + [in_prompt], + tokenization_kwargs=merge_kwargs( + tokenization_kwargs, + dict(truncate_prompt_tokens=param.truncate_prompt_tokens), + ), + ) + ] + else: + engine_prompts = self._preprocess_completion( + in_prompts, + tokenization_kwargs=tokenization_kwargs, ) - for sp in params if isinstance(params, Sequence) else (params,): + for sp in engine_params: if isinstance(sp, SamplingParams): # We only care about the final output sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. - it = prompts + it = engine_prompts if use_tqdm: tqdm_func = use_tqdm if callable(use_tqdm) else tqdm it = tqdm_func(it, desc="Adding requests") @@ -1576,12 +1730,10 @@ def _validate_and_add_requests( for i, prompt in enumerate(it): request_id = self._add_request( prompt, - params[i] if isinstance(params, Sequence) else params, - lora_request=lora_request[i] - if isinstance(lora_request, Sequence) - else lora_request, - priority=priority[i] if priority else 0, + engine_params[i], + lora_request=engine_lora_requests[i], tokenization_kwargs=tokenization_kwargs, + priority=priority[i], ) added_request_ids.append(request_id) except Exception as e: @@ -1589,54 +1741,42 @@ def _validate_and_add_requests( self.llm_engine.abort_request(added_request_ids, internal=True) raise e - def _process_inputs( - self, - request_id: str, - engine_prompt: PromptType, - params: SamplingParams | PoolingParams, - *, - lora_request: LoRARequest | None, - priority: int, - tokenization_kwargs: dict[str, Any] | None = None, - ) -> tuple[EngineCoreRequest, dict[str, Any]]: - """Use the Processor to process inputs for LLMEngine.""" - - local_kwargs = tokenization_kwargs or {} - tokenization_kwargs = local_kwargs.copy() - _validate_truncation_size( - self.model_config.max_model_len, - params.truncate_prompt_tokens, - tokenization_kwargs, - ) - - engine_request = self.input_processor.process_inputs( - request_id, - engine_prompt, - params, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - priority=priority, - ) - return engine_request, tokenization_kwargs - def _add_request( self, prompt: PromptType, params: SamplingParams | PoolingParams, lora_request: LoRARequest | None = None, - priority: int = 0, tokenization_kwargs: dict[str, Any] | None = None, + priority: int = 0, ) -> str: prompt_text, _, _ = get_prompt_components(prompt) request_id = str(next(self.request_counter)) - engine_request, tokenization_kwargs = self._process_inputs( + if params.truncate_prompt_tokens is not None: + params_type = type(params).__name__ + warnings.warn( + f"The `truncate_prompt_tokens` parameter in `{params_type}` " + "is deprecated and will be removed in v0.16. " + "Please pass it via `tokenization_kwargs` instead.", + DeprecationWarning, + stacklevel=2, + ) + + tokenization_kwargs = merge_kwargs( + tokenization_kwargs, + dict(truncate_prompt_tokens=params.truncate_prompt_tokens), + ) + + tok_params = self._get_cmpl_tok_params(tokenization_kwargs) + + tokenization_kwargs = tok_params.get_encode_kwargs() + engine_request = self.input_processor.process_inputs( request_id, prompt, params, lora_request=lora_request, - priority=priority, tokenization_kwargs=tokenization_kwargs, + priority=priority, ) self.llm_engine.add_request( diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index a76dc73d9ba3..e5a821edd1d9 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -13,12 +13,13 @@ ChatCompletionAudio as OpenAIChatCompletionAudio, ) from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation -from pydantic import ( - Field, - model_validator, -) +from pydantic import Field, model_validator -from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormatOption, +) from vllm.entrypoints.openai.engine.protocol import ( AnyResponseFormat, DeltaMessage, @@ -36,6 +37,7 @@ from vllm.exceptions import VLLMValidationError from vllm.logger import init_logger from vllm.logprobs import Logprob +from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs from vllm.sampling_params import ( BeamSearchParams, RequestOutputKind, @@ -356,6 +358,43 @@ class ChatCompletionRequest(OpenAIBaseModel): # --8<-- [end:chat-completion-extra-params] + def build_chat_params( + self, + default_template: str | None, + default_template_content_format: ChatTemplateContentFormatOption, + ) -> ChatParams: + return ChatParams( + chat_template=self.chat_template or default_template, + chat_template_content_format=default_template_content_format, + chat_template_kwargs=merge_kwargs( + self.chat_template_kwargs, + dict( + add_generation_prompt=self.add_generation_prompt, + continue_final_message=self.continue_final_message, + documents=self.documents, + reasoning_effort=self.reasoning_effort, + ), + ), + ) + + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + if self.max_completion_tokens is not None: + max_output_tokens: int | None = self.max_completion_tokens + max_output_tokens_param = "max_completion_tokens" + else: + max_output_tokens = self.max_tokens + max_output_tokens_param = "max_tokens" + + return TokenizeParams( + max_total_tokens=model_config.max_model_len, + max_output_tokens=max_output_tokens or 0, + truncate_prompt_tokens=self.truncate_prompt_tokens, + add_special_tokens=self.add_special_tokens, + needs_detokenization=bool(self.echo and not self.return_token_ids), + max_total_tokens_param="max_model_len", + max_output_tokens_param=max_output_tokens_param, + ) + # Default sampling parameters for chat completion requests _DEFAULT_SAMPLING_PARAMS: dict = { "repetition_penalty": 1.0, diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index c5e6c5c6af6f..c200bb76bd92 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -67,7 +67,7 @@ ) from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.utils import get_max_tokens, should_include_usage -from vllm.inputs.data import TokensPrompt +from vllm.inputs.data import EmbedsPrompt, TokensPrompt from vllm.inputs.parse import get_prompt_components from vllm.logger import init_logger from vllm.logprobs import Logprob @@ -185,8 +185,6 @@ async def warmup(self) -> None: start_time = time.perf_counter() try: - renderer = self.engine_client.renderer - # Create a minimal dummy request dummy_request = ChatCompletionRequest( messages=[{"role": "user", "content": "warmup"}], @@ -201,18 +199,10 @@ async def warmup(self) -> None: # 3. Tokenizer initialization for chat await self._preprocess_chat( dummy_request, - renderer, dummy_request.messages, - chat_template=self.chat_template, - chat_template_content_format=self.chat_template_content_format, - add_generation_prompt=True, - continue_final_message=False, - tool_dicts=None, - documents=None, - chat_template_kwargs=None, - default_chat_template_kwargs=self.default_chat_template_kwargs, - tool_parser=None, - add_special_tokens=False, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=self.default_chat_template_kwargs, ) elapsed = (time.perf_counter() - start_time) * 1000 @@ -225,7 +215,10 @@ async def warmup(self) -> None: async def render_chat_request( self, request: ChatCompletionRequest, - ) -> tuple[list[ConversationMessage], list[Any]] | ErrorResponse: + ) -> ( + tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]] + | ErrorResponse + ): """ render chat request by validating and preprocessing inputs. @@ -302,23 +295,14 @@ async def render_chat_request( if error_check_ret is not None: return error_check_ret - chat_template_kwargs = request.chat_template_kwargs or {} - chat_template_kwargs.update(reasoning_effort=request.reasoning_effort) - conversation, engine_prompts = await self._preprocess_chat( request, - renderer, request.messages, - chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self.chat_template_content_format, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=self.default_chat_template_kwargs, tool_dicts=tool_dicts, - documents=request.documents, - chat_template_kwargs=chat_template_kwargs, - default_chat_template_kwargs=self.default_chat_template_kwargs, tool_parser=tool_parser, - add_special_tokens=request.add_special_tokens, ) else: # For GPT-OSS. @@ -428,11 +412,15 @@ async def create_chat_completion( trace_headers=trace_headers, ) else: - engine_request, tokenization_kwargs = await self._process_inputs( + tok_params = request.build_tok_params(self.model_config) + tokenization_kwargs = tok_params.get_encode_kwargs() + + engine_request = self.input_processor.process_inputs( sub_request_id, engine_prompt, sampling_params, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, priority=request.priority, data_parallel_rank=data_parallel_rank, diff --git a/vllm/entrypoints/openai/completion/protocol.py b/vllm/entrypoints/openai/completion/protocol.py index fc773c402ede..aab733082558 100644 --- a/vllm/entrypoints/openai/completion/protocol.py +++ b/vllm/entrypoints/openai/completion/protocol.py @@ -9,11 +9,9 @@ from typing import Annotated, Any, Literal import torch -from pydantic import ( - Field, - model_validator, -) +from pydantic import Field, model_validator +from vllm.config import ModelConfig from vllm.entrypoints.openai.engine.protocol import ( AnyResponseFormat, LegacyStructuralTagResponseFormat, @@ -27,6 +25,7 @@ from vllm.exceptions import VLLMValidationError from vllm.logger import init_logger from vllm.logprobs import Logprob +from vllm.renderers import TokenizeParams from vllm.sampling_params import ( BeamSearchParams, RequestOutputKind, @@ -178,6 +177,17 @@ class CompletionRequest(OpenAIBaseModel): # --8<-- [end:completion-extra-params] + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + return TokenizeParams( + max_total_tokens=model_config.max_model_len, + max_output_tokens=self.max_tokens or 0, + truncate_prompt_tokens=self.truncate_prompt_tokens, + add_special_tokens=self.add_special_tokens, + needs_detokenization=bool(self.echo and not self.return_token_ids), + max_total_tokens_param="max_model_len", + max_output_tokens_param="max_tokens", + ) + # Default sampling parameters for completion requests _DEFAULT_SAMPLING_PARAMS: dict = { "repetition_penalty": 1.0, diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index 24cf486a61fe..dc59d5248e34 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -32,7 +32,6 @@ clamp_prompt_logprobs, ) from vllm.entrypoints.openai.models.serving import OpenAIServingModels -from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.exceptions import VLLMValidationError from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt @@ -111,11 +110,10 @@ async def render_completion_request( ) try: - renderer = self._get_completion_renderer() - engine_prompts = await renderer.render_prompt_and_embeds( - prompt_or_prompts=request.prompt, + engine_prompts = await self._preprocess_completion( + request, + prompt_input=request.prompt, prompt_embeds=request.prompt_embeds, - config=self._build_render_config(request), ) except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") @@ -203,10 +201,6 @@ async def create_completion( else await self._get_trace_headers(raw_request.headers) ) - # Mypy inconsistently requires this second cast in different - # environments. It shouldn't be necessary (redundant from above) - # but pre-commit in CI fails without it. - engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt) if isinstance(sampling_params, BeamSearchParams): generator = self.beam_search( prompt=engine_prompt, @@ -216,11 +210,15 @@ async def create_completion( trace_headers=trace_headers, ) else: - engine_request, tokenization_kwargs = await self._process_inputs( + tok_params = request.build_tok_params(self.model_config) + tokenization_kwargs = tok_params.get_encode_kwargs() + + engine_request = self.input_processor.process_inputs( request_id_item, engine_prompt, sampling_params, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, priority=request.priority, data_parallel_rank=data_parallel_rank, @@ -709,26 +707,3 @@ def _create_completion_logprobs( tokens=out_tokens, top_logprobs=out_top_logprobs, ) - - def _build_render_config( - self, - request: CompletionRequest, - max_input_length: int | None = None, - ) -> RenderConfig: - # Validate max_tokens before using it - if request.max_tokens is not None and request.max_tokens > self.max_model_len: - raise VLLMValidationError( - f"'max_tokens' ({request.max_tokens}) cannot be greater than " - f"the model's maximum context length ({self.max_model_len}).", - parameter="max_tokens", - value=request.max_tokens, - ) - - max_input_tokens_len = self.max_model_len - (request.max_tokens or 0) - return RenderConfig( - max_length=max_input_tokens_len, - truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - cache_salt=request.cache_salt, - needs_detokenization=bool(request.echo and not request.return_token_ids), - ) diff --git a/vllm/entrypoints/openai/engine/protocol.py b/vllm/entrypoints/openai/engine/protocol.py index e491f9399545..4fe5c3cc723c 100644 --- a/vllm/entrypoints/openai/engine/protocol.py +++ b/vllm/entrypoints/openai/engine/protocol.py @@ -16,9 +16,7 @@ from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.logger import init_logger -from vllm.sampling_params import ( - SamplingParams, -) +from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid from vllm.utils.import_utils import resolve_obj_by_qualname diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 8c3c028d05ed..7e2f65d26e01 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -5,10 +5,10 @@ import sys import time import traceback -from collections.abc import AsyncGenerator, Callable, Iterable, Mapping +from collections.abc import AsyncGenerator, Callable, Mapping from dataclasses import dataclass, field from http import HTTPStatus -from typing import Any, ClassVar, Generic, TypeAlias, TypeVar, cast +from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar import numpy as np from fastapi import Request @@ -20,6 +20,7 @@ import vllm.envs as envs from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function +from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, @@ -86,7 +87,6 @@ ScoreResponse, ScoreTextRequest, ) -from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.serve.tokenize.protocol import ( DetokenizeRequest, @@ -94,13 +94,9 @@ TokenizeCompletionRequest, TokenizeResponse, ) -from vllm.entrypoints.utils import ( - _validate_truncation_size, - get_max_tokens, - sanitize_message, -) +from vllm.entrypoints.utils import get_max_tokens, sanitize_message from vllm.exceptions import VLLMValidationError -from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs.data import EmbedsPrompt, PromptType, TokensPrompt from vllm.inputs.parse import ( get_prompt_components, is_explicit_encoder_decoder_prompt, @@ -112,7 +108,7 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.reasoning import ReasoningParser, ReasoningParserManager -from vllm.renderers import RendererLike +from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tokenizers import TokenizerLike from vllm.tool_parsers import ToolParser, ToolParserManager @@ -123,11 +119,9 @@ ) from vllm.utils import random_uuid from vllm.utils.async_utils import ( - AsyncMicrobatchTokenizer, collect_from_async_generator, merge_async_iterators, ) -from vllm.v1.engine import EngineCoreRequest class GenerationError(Exception): @@ -140,6 +134,21 @@ def __init__(self, message: str = "Internal server error"): logger = init_logger(__name__) + +class RendererRequest(Protocol): + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + raise NotImplementedError + + +class RendererChatRequest(RendererRequest, Protocol): + def build_chat_params( + self, + default_template: str | None, + default_template_content_format: ChatTemplateContentFormatOption, + ) -> ChatParams: + raise NotImplementedError + + CompletionLikeRequest: TypeAlias = ( CompletionRequest | TokenizeCompletionRequest @@ -158,7 +167,9 @@ def __init__(self, message: str = "Internal server error"): | ClassificationChatRequest | PoolingChatRequest ) + SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest + AnyRequest: TypeAlias = ( CompletionLikeRequest | ChatLikeRequest @@ -193,7 +204,7 @@ class ServeContext(Generic[RequestT]): request_id: str created_time: int = field(default_factory=lambda: int(time.time())) lora_request: LoRARequest | None = None - engine_prompts: list[TokensPrompt] | None = None + engine_prompts: list[TokensPrompt | EmbedsPrompt] | None = None result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = ( None @@ -227,7 +238,6 @@ def __init__( self.request_logger = request_logger self.return_tokens_as_token_ids = return_tokens_as_token_ids - self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {} self.log_error_stack = log_error_stack self.input_processor = self.models.input_processor @@ -519,41 +529,6 @@ async def beam_search( prompt_logprobs=None, ) - def _get_completion_renderer(self) -> BaseRenderer: - """ - Get a Renderer instance with the provided tokenizer. - Uses shared async tokenizer pool for efficiency. - """ - return CompletionRenderer( - model_config=self.model_config, - tokenizer=self.renderer.tokenizer, - async_tokenizer_pool=self._async_tokenizer_pool, - ) - - def _build_render_config( - self, - request: Any, - ) -> RenderConfig: - """ - Build and return a `RenderConfig` for an endpoint. - - Used by the renderer to control how prompts are prepared - (e.g., tokenization and length handling). Endpoints should - implement this with logic appropriate to their request type. - """ - raise NotImplementedError - - def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: - """ - Return (and cache) an `AsyncMicrobatchTokenizer` bound to the - given tokenizer. - """ - async_tokenizer = self._async_tokenizer_pool.get(tokenizer) - if async_tokenizer is None: - async_tokenizer = AsyncMicrobatchTokenizer(tokenizer) - self._async_tokenizer_pool[tokenizer] = async_tokenizer - return async_tokenizer - async def _preprocess( self, ctx: ServeContext, @@ -912,71 +887,6 @@ def _get_message_types(self, request: AnyRequest) -> set[str]: message_types.add(content_dict["type"].split("_")[0]) return message_types - async def _normalize_prompt_text_to_input( - self, - request: AnyRequest, - prompt: str, - tokenizer: TokenizerLike, - add_special_tokens: bool, - ) -> TokensPrompt: - async_tokenizer = self._get_async_tokenizer(tokenizer) - - if ( - self.model_config.encoder_config is not None - and self.model_config.encoder_config.get("do_lower_case", False) - ): - prompt = prompt.lower() - - truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) - - if truncate_prompt_tokens is None: - encoded = await async_tokenizer( - prompt, add_special_tokens=add_special_tokens - ) - elif truncate_prompt_tokens < 0: - # Negative means we cap at the model's max length - encoded = await async_tokenizer( - prompt, - add_special_tokens=add_special_tokens, - truncation=True, - max_length=self.max_model_len, - ) - else: - encoded = await async_tokenizer( - prompt, - add_special_tokens=add_special_tokens, - truncation=True, - max_length=truncate_prompt_tokens, - ) - - input_ids = encoded.input_ids - input_text = prompt - - return self._validate_input(request, input_ids, input_text) - - async def _normalize_prompt_tokens_to_input( - self, - request: AnyRequest, - prompt_ids: list[int], - tokenizer: TokenizerLike | None, - ) -> TokensPrompt: - truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) - - if truncate_prompt_tokens is None: - input_ids = prompt_ids - elif truncate_prompt_tokens < 0: - input_ids = prompt_ids[-self.max_model_len :] - else: - input_ids = prompt_ids[-truncate_prompt_tokens:] - - if tokenizer is None: - input_text = "" - else: - async_tokenizer = self._get_async_tokenizer(tokenizer) - input_text = await async_tokenizer.decode(input_ids) - - return self._validate_input(request, input_ids, input_text) - def _validate_input( self, request: object, @@ -1061,50 +971,6 @@ def _validate_input( return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) - async def _tokenize_prompt_input_async( - self, - request: AnyRequest, - tokenizer: TokenizerLike, - prompt_input: str | list[int], - add_special_tokens: bool = True, - ) -> TokensPrompt: - """ - A simpler implementation that tokenizes a single prompt input. - """ - async for result in self._tokenize_prompt_inputs_async( - request, - tokenizer, - [prompt_input], - add_special_tokens=add_special_tokens, - ): - return result - raise ValueError("No results yielded from tokenization") - - async def _tokenize_prompt_inputs_async( - self, - request: AnyRequest, - tokenizer: TokenizerLike, - prompt_inputs: Iterable[str | list[int]], - add_special_tokens: bool = True, - ) -> AsyncGenerator[TokensPrompt, None]: - """ - A simpler implementation that tokenizes multiple prompt inputs. - """ - for prompt in prompt_inputs: - if isinstance(prompt, str): - yield await self._normalize_prompt_text_to_input( - request, - prompt=prompt, - tokenizer=tokenizer, - add_special_tokens=add_special_tokens, - ) - else: - yield await self._normalize_prompt_tokens_to_input( - request, - prompt_ids=prompt, - tokenizer=tokenizer, - ) - def _validate_chat_template( self, request_chat_template: str | None, @@ -1137,131 +1003,94 @@ def _prepare_extra_chat_template_kwargs( # Apply server defaults first, then request kwargs override. return default_chat_template_kwargs | request_chat_template_kwargs + async def _preprocess_completion( + self, + request: RendererRequest, + prompt_input: str | list[str] | list[int] | list[list[int]] | None, + prompt_embeds: bytes | list[bytes] | None, + ) -> list[TokensPrompt | EmbedsPrompt]: + renderer = self.renderer + tok_params = request.build_tok_params(self.model_config) + + in_prompts = await renderer.render_completions_async( + prompt_input, prompt_embeds + ) + engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params) + + extra_items = { + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + } + for prompt in engine_prompts: + prompt.update(extra_items) # type: ignore + + return engine_prompts + async def _preprocess_chat( self, - request: ChatLikeRequest | ResponsesRequest, - renderer: RendererLike, + request: RendererChatRequest, messages: list[ChatCompletionMessageParam], - chat_template: str | None, - chat_template_content_format: ChatTemplateContentFormatOption, - add_generation_prompt: bool = True, - continue_final_message: bool = False, + default_template: str | None, + default_template_content_format: ChatTemplateContentFormatOption, + default_template_kwargs: dict[str, Any] | None, tool_dicts: list[dict[str, Any]] | None = None, - documents: list[dict[str, str]] | None = None, - chat_template_kwargs: dict[str, Any] | None = None, - default_chat_template_kwargs: dict[str, Any] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, - add_special_tokens: bool = False, - ) -> tuple[list[ConversationMessage], list[TokensPrompt]]: - chat_template_kwargs = { - "chat_template": chat_template, - "add_generation_prompt": add_generation_prompt, - "continue_final_message": continue_final_message, - "tools": tool_dicts, - "documents": documents, - **(chat_template_kwargs or {}), - } - chat_template_kwargs = self._prepare_extra_chat_template_kwargs( - chat_template_kwargs, - default_chat_template_kwargs, - ) - - # Use the async tokenizer in `OpenAIServing` if possible. - # Later we can move it into the renderer so that we can return both - # text and token IDs in the same prompt from `render_messages_async` - # which is used for logging and `enable_response_messages`. + ) -> tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]: from vllm.tokenizers.mistral import MistralTokenizer - conversation, engine_prompt = await renderer.render_messages_async( - messages, - chat_template_content_format=chat_template_content_format, - tokenize=( - chat_template_kwargs.pop("tokenize", False) - or isinstance(renderer.tokenizer, MistralTokenizer) + renderer = self.renderer + + default_template_kwargs = merge_kwargs( + default_template_kwargs, + dict( + tools=tool_dicts, + tokenize=isinstance(renderer.tokenizer, MistralTokenizer), ), - **chat_template_kwargs, ) - if "prompt_token_ids" not in engine_prompt: - extra_data = engine_prompt - engine_prompt = await self._tokenize_prompt_input_async( - request, - renderer.get_tokenizer(), - engine_prompt["prompt"], - add_special_tokens=add_special_tokens, - ) - # Fill in other keys like MM data - engine_prompt.update(extra_data) # type: ignore - else: - self._validate_input( - request=request, - input_ids=engine_prompt["prompt_token_ids"], # type: ignore - input_text="", - ) + tok_params = request.build_tok_params(self.model_config) + chat_params = request.build_chat_params( + default_template, default_template_content_format + ).with_defaults(default_template_kwargs) - engine_prompt = cast(TokensPrompt, engine_prompt) + conversation, prompt = await renderer.render_messages_async( + messages, chat_params + ) + engine_prompt = await renderer.tokenize_prompt_async(prompt, tok_params) - if request.mm_processor_kwargs is not None: - engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs - if (cache_salt := getattr(request, "cache_salt", None)) is not None: - engine_prompt["cache_salt"] = cache_salt + extra_items = { + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + } + engine_prompt.update(extra_items) # type: ignore # tool parsing is done only if a tool_parser has been set and if # tool_choice is not "none" (if tool_choice is "none" but a tool_parser # is set, we want to prevent parsing a tool_call hallucinated by the LLM - should_parse_tools = tool_parser is not None and ( - hasattr(request, "tool_choice") and request.tool_choice != "none" - ) - - if should_parse_tools: - if not isinstance(request, ChatCompletionRequest | ResponsesRequest): - msg = ( - "Tool usage is only supported for Chat Completions API " - "or Responses API requests." - ) - raise NotImplementedError(msg) + if tool_parser is not None: + tool_choice = getattr(request, "tool_choice", "none") + if tool_choice != "none": + if not isinstance(request, ChatCompletionRequest | ResponsesRequest): + msg = ( + "Tool usage is only supported for Chat Completions API " + "or Responses API requests." + ) + raise NotImplementedError(msg) - tokenizer = renderer.get_tokenizer() - request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore + # TODO: Update adjust_request to accept ResponsesRequest + tokenizer = renderer.get_tokenizer() + request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type] return conversation, [engine_prompt] - async def _process_inputs( - self, - request_id: str, - engine_prompt: PromptType, - params: SamplingParams | PoolingParams, - *, - lora_request: LoRARequest | None, - trace_headers: Mapping[str, str] | None, - priority: int, - data_parallel_rank: int | None = None, - ) -> tuple[EngineCoreRequest, dict[str, Any]]: - """Use the Processor to process inputs for AsyncLLM.""" - tokenization_kwargs: dict[str, Any] = {} - _validate_truncation_size( - self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs - ) - - engine_request = self.input_processor.process_inputs( - request_id, - engine_prompt, - params, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - trace_headers=trace_headers, - priority=priority, - data_parallel_rank=data_parallel_rank, - ) - return engine_request, tokenization_kwargs - async def _render_next_turn( self, request: ResponsesRequest, - renderer: RendererLike, messages: list[ResponseInputOutputItem], tool_dicts: list[dict[str, Any]] | None, - tool_parser, + tool_parser: Callable[[TokenizerLike], ToolParser] | None, chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, ): @@ -1271,24 +1100,25 @@ async def _render_next_turn( _, engine_prompts = await self._preprocess_chat( request, - renderer, new_messages, + default_template=chat_template, + default_template_content_format=chat_template_content_format, + default_template_kwargs=None, tool_dicts=tool_dicts, tool_parser=tool_parser, - chat_template=chat_template, - chat_template_content_format=chat_template_content_format, ) return engine_prompts async def _generate_with_builtin_tools( self, request_id: str, - engine_prompt: TokensPrompt, + engine_prompt: TokensPrompt | EmbedsPrompt, sampling_params: SamplingParams, + tok_params: TokenizeParams, context: ConversationContext, lora_request: LoRARequest | None = None, priority: int = 0, - **kwargs, + trace_headers: Mapping[str, str] | None = None, ): prompt_text, _, _ = get_prompt_components(engine_prompt) @@ -1297,18 +1127,21 @@ async def _generate_with_builtin_tools( while True: # Ensure that each sub-request has a unique request id. sub_request_id = f"{request_id}_{sub_request}" + self._log_inputs( sub_request_id, engine_prompt, params=sampling_params, lora_request=lora_request, ) - trace_headers = kwargs.get("trace_headers") - engine_request, tokenization_kwargs = await self._process_inputs( + + tokenization_kwargs = tok_params.get_encode_kwargs() + engine_request = self.input_processor.process_inputs( sub_request_id, engine_prompt, sampling_params, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, priority=priority, ) @@ -1318,10 +1151,10 @@ async def _generate_with_builtin_tools( sampling_params, sub_request_id, lora_request=lora_request, + trace_headers=trace_headers, priority=priority, prompt_text=prompt_text, tokenization_kwargs=tokenization_kwargs, - **kwargs, ) async for res in generator: @@ -1350,7 +1183,6 @@ async def _generate_with_builtin_tools( elif isinstance(context, ParsableContext): engine_prompts = await self._render_next_turn( context.request, - context.renderer, context.parser.response_messages, context.tool_dicts, context.tool_parser_cls, diff --git a/vllm/entrypoints/openai/responses/context.py b/vllm/entrypoints/openai/responses/context.py index b3ac24881aa4..a10567e40136 100644 --- a/vllm/entrypoints/openai/responses/context.py +++ b/vllm/entrypoints/openai/responses/context.py @@ -43,7 +43,6 @@ from vllm.entrypoints.openai.responses.utils import construct_tool_dicts from vllm.outputs import RequestOutput from vllm.reasoning.abs_reasoning_parsers import ReasoningParser -from vllm.renderers import RendererLike from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ToolParser from vllm.utils import random_uuid @@ -261,7 +260,7 @@ def __init__( self, *, response_messages: list[ResponseInputOutputItem], - renderer: RendererLike, + tokenizer: TokenizerLike, reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None, request: ResponsesRequest, available_tools: list[str] | None, @@ -280,7 +279,6 @@ def __init__( if reasoning_parser_cls is None: raise ValueError("reasoning_parser_cls must be provided.") - tokenizer = renderer.get_tokenizer() self.parser = get_responses_parser_for_simple_context( tokenizer=tokenizer, reasoning_parser_cls=reasoning_parser_cls, @@ -290,8 +288,6 @@ def __init__( ) self.tool_parser_cls = tool_parser_cls self.request = request - self.renderer = renderer - self.tokenizer = tokenizer self.available_tools = available_tools or [] self._tool_sessions: dict[str, ClientSession | Tool] = {} diff --git a/vllm/entrypoints/openai/responses/protocol.py b/vllm/entrypoints/openai/responses/protocol.py index 1109d78a5362..2a6cf1bc5484 100644 --- a/vllm/entrypoints/openai/responses/protocol.py +++ b/vllm/entrypoints/openai/responses/protocol.py @@ -59,12 +59,15 @@ model_validator, ) -from vllm.entrypoints.chat_utils import ChatCompletionMessageParam -from vllm.entrypoints.openai.engine.protocol import ( - OpenAIBaseModel, +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormatOption, ) +from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel from vllm.exceptions import VLLMValidationError from vllm.logger import init_logger +from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs from vllm.sampling_params import ( RequestOutputKind, SamplingParams, @@ -230,6 +233,42 @@ class ResponsesRequest(OpenAIBaseModel): previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None # --8<-- [end:responses-extra-params] + def build_chat_params( + self, + default_template: str | None, + default_template_content_format: ChatTemplateContentFormatOption, + ) -> ChatParams: + from .utils import should_continue_final_message + + # Check if we should continue the final message (partial completion) + # This enables Anthropic-style partial message completion where the + # user provides an incomplete assistant message to continue from. + continue_final = should_continue_final_message(self.input) + + reasoning = self.reasoning + + return ChatParams( + chat_template=default_template, + chat_template_content_format=default_template_content_format, + chat_template_kwargs=merge_kwargs( # To remove unset values + {}, + dict( + add_generation_prompt=not continue_final, + continue_final_message=continue_final, + reasoning_effort=None if reasoning is None else reasoning.effort, + ), + ), + ) + + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + return TokenizeParams( + max_total_tokens=model_config.max_model_len, + max_output_tokens=self.max_output_tokens or 0, + truncate_prompt_tokens=-1 if self.truncation != "disabled" else None, + max_total_tokens_param="max_model_len", + max_output_tokens_param="max_output_tokens", + ) + _DEFAULT_SAMPLING_PARAMS = { "temperature": 1.0, "top_p": 1.0, diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 582e41bdf85c..cd6aa48c30d9 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -114,16 +114,15 @@ construct_input_messages, construct_tool_dicts, extract_tool_types, - should_continue_final_message, ) from vllm.entrypoints.utils import get_max_tokens from vllm.exceptions import VLLMValidationError -from vllm.inputs.data import TokensPrompt +from vllm.inputs.data import EmbedsPrompt, TokensPrompt +from vllm.inputs.parse import get_prompt_len from vllm.logger import init_logger from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import SampleLogprobs from vllm.outputs import CompletionOutput -from vllm.renderers import RendererLike from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.tokenizers import TokenizerLike from vllm.utils import random_uuid @@ -291,13 +290,14 @@ def __init__( self.tool_server = tool_server def _validate_generator_input( - self, engine_prompt: TokensPrompt + self, + engine_prompt: TokensPrompt | EmbedsPrompt, ) -> ErrorResponse | None: """Add validations to the input to the generator here.""" - if self.max_model_len <= len(engine_prompt["prompt_token_ids"]): + prompt_len = get_prompt_len(engine_prompt) + if self.max_model_len <= prompt_len: error_message = ( - "The engine prompt length" - f" {len(engine_prompt['prompt_token_ids'])} " + f"The engine prompt length {prompt_len} " f"exceeds the max_model_len {self.max_model_len}. " "Please reduce prompt." ) @@ -307,6 +307,7 @@ def _validate_generator_input( status_code=HTTPStatus.BAD_REQUEST, param="input", ) + return None def _validate_create_responses_input( @@ -387,8 +388,6 @@ async def create_responses( try: lora_request = self._maybe_get_adapters(request) model_name = self.models.model_name(lora_request) - renderer = self.engine_client.renderer - tokenizer = renderer.get_tokenizer() if self.use_harmony: messages, engine_prompts = self._make_request_with_harmony( @@ -396,7 +395,7 @@ async def create_responses( ) else: messages, engine_prompts = await self._make_request( - request, prev_response, renderer + request, prev_response ) except ( @@ -431,6 +430,9 @@ async def create_responses( assert len(builtin_tool_list) == 0 available_tools = [] try: + renderer = self.engine_client.renderer + tokenizer = renderer.get_tokenizer() + for engine_prompt in engine_prompts: maybe_error = self._validate_generator_input(engine_prompt) if maybe_error is not None: @@ -446,6 +448,7 @@ async def create_responses( sampling_params = request.to_sampling_params( default_max_tokens, self.default_sampling_params ) + tok_params = request.build_tok_params(self.model_config) trace_headers = ( None @@ -465,7 +468,7 @@ async def create_responses( # tokens during generation instead of at the end context = ParsableContext( response_messages=messages, - renderer=renderer, + tokenizer=tokenizer, reasoning_parser_cls=self.reasoning_parser, request=request, tool_parser_cls=self.tool_parser, @@ -495,6 +498,7 @@ async def create_responses( request_id=request.request_id, engine_prompt=engine_prompt, sampling_params=sampling_params, + tok_params=tok_params, context=context, lora_request=lora_request, priority=request.priority, @@ -596,7 +600,6 @@ async def _make_request( self, request: ResponsesRequest, prev_response: ResponsesResponse | None, - renderer: RendererLike, ): tool_dicts = construct_tool_dicts(request.tools, request.tool_choice) # Construct the input messages. @@ -606,30 +609,15 @@ async def _make_request( prev_msg=self.msg_store.get(prev_response.id) if prev_response else None, prev_response_output=prev_response.output if prev_response else None, ) - # Check if we should continue the final message (partial completion) - # This enables Anthropic-style partial message completion where the - # user provides an incomplete assistant message to continue from. - continue_final = should_continue_final_message(request.input) - chat_template_kwargs = dict( - reasoning_effort=None - if request.reasoning is None - else request.reasoning.effort - ) _, engine_prompts = await self._preprocess_chat( request, - renderer, messages, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=None, tool_dicts=tool_dicts, tool_parser=self.tool_parser, - chat_template=self.chat_template, - chat_template_content_format=self.chat_template_content_format, - # When continuing a partial message, we set continue_final_message=True - # and add_generation_prompt=False so the model continues the message - # rather than starting a new one. - add_generation_prompt=not continue_final, - continue_final_message=continue_final, - chat_template_kwargs=chat_template_kwargs, ) return messages, engine_prompts diff --git a/vllm/entrypoints/pooling/base/protocol.py b/vllm/entrypoints/pooling/base/protocol.py index dd185e574386..19a44a3615d5 100644 --- a/vllm/entrypoints/pooling/base/protocol.py +++ b/vllm/entrypoints/pooling/base/protocol.py @@ -8,8 +8,12 @@ from vllm import PoolingParams from vllm.config.pooler import get_use_activation -from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormatOption, +) from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel +from vllm.renderers import ChatParams, merge_kwargs from vllm.utils import random_uuid from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness @@ -119,6 +123,23 @@ def check_generation_prompt(cls, data): ) return data + def build_chat_params( + self, + default_template: str | None, + default_template_content_format: ChatTemplateContentFormatOption, + ) -> ChatParams: + return ChatParams( + chat_template=self.chat_template or default_template, + chat_template_content_format=default_template_content_format, + chat_template_kwargs=merge_kwargs( + self.chat_template_kwargs, + dict( + add_generation_prompt=self.add_generation_prompt, + continue_final_message=self.continue_final_message, + ), + ), + ) + class EncodingRequestMixin(OpenAIBaseModel): # --8<-- [start:encoding-params] diff --git a/vllm/entrypoints/pooling/classify/protocol.py b/vllm/entrypoints/pooling/classify/protocol.py index a94c7b49e589..33a25335d1f9 100644 --- a/vllm/entrypoints/pooling/classify/protocol.py +++ b/vllm/entrypoints/pooling/classify/protocol.py @@ -4,10 +4,9 @@ import time from typing import Any, TypeAlias -from pydantic import ( - Field, -) +from pydantic import Field +from vllm.config import ModelConfig from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.pooling.base.protocol import ( ChatRequestMixin, @@ -15,13 +14,24 @@ CompletionRequestMixin, PoolingBasicRequestMixin, ) +from vllm.renderers import TokenizeParams from vllm.utils import random_uuid class ClassificationCompletionRequest( PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin ): - pass + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + encoder_config = model_config.encoder_config or {} + + return TokenizeParams( + max_total_tokens=model_config.max_model_len, + max_output_tokens=0, + truncate_prompt_tokens=self.truncate_prompt_tokens, + do_lower_case=encoder_config.get("do_lower_case", False), + add_special_tokens=self.add_special_tokens, + max_total_tokens_param="max_model_len", + ) class ClassificationChatRequest( @@ -33,6 +43,18 @@ class ClassificationChatRequest( description=("Additional kwargs to pass to the HF processor."), ) + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + encoder_config = model_config.encoder_config or {} + + return TokenizeParams( + max_total_tokens=model_config.max_model_len, + max_output_tokens=0, + truncate_prompt_tokens=self.truncate_prompt_tokens, + do_lower_case=encoder_config.get("do_lower_case", False), + add_special_tokens=self.add_special_tokens, + max_total_tokens_param="max_model_len", + ) + ClassificationRequest: TypeAlias = ( ClassificationCompletionRequest | ClassificationChatRequest diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py index 578e02ca3557..d9f7db953812 100644 --- a/vllm/entrypoints/pooling/classify/serving.py +++ b/vllm/entrypoints/pooling/classify/serving.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from http import HTTPStatus -from typing import Final, cast +from typing import Final, TypeAlias import jinja2 import numpy as np @@ -21,15 +20,14 @@ ClassificationRequest, ClassificationResponse, ) -from vllm.entrypoints.renderer import RenderConfig from vllm.logger import init_logger -from vllm.outputs import ClassificationOutput, PoolingRequestOutput +from vllm.outputs import ClassificationOutput from vllm.pooling_params import PoolingParams logger = init_logger(__name__) -ClassificationServeContext = ServeContext[ClassificationRequest] +ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest] class ServingClassification(OpenAIServing): @@ -77,34 +75,18 @@ async def _preprocess( if error_check_ret: return error_check_ret - _, engine_prompts = await self._preprocess_chat( + _, ctx.engine_prompts = await self._preprocess_chat( ctx.request, - self.renderer, ctx.request.messages, - chat_template=ctx.request.chat_template or self.chat_template, - chat_template_content_format=self.chat_template_content_format, - add_generation_prompt=ctx.request.add_generation_prompt, - continue_final_message=ctx.request.continue_final_message, - add_special_tokens=ctx.request.add_special_tokens, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=None, ) - ctx.engine_prompts = engine_prompts - elif isinstance(ctx.request, ClassificationCompletionRequest): - input_data = ctx.request.input - if input_data in (None, ""): - return self.create_error_response( - "Input or messages must be provided", - status_code=HTTPStatus.BAD_REQUEST, - ) - if isinstance(input_data, list) and not input_data: - ctx.engine_prompts = [] - return None - - renderer = self._get_completion_renderer() - prompt_input = cast(str | list[str], input_data) - ctx.engine_prompts = await renderer.render_prompt( - prompt_or_prompts=prompt_input, - config=self._build_render_config(ctx.request), + ctx.engine_prompts = await self._preprocess_completion( + ctx.request, + prompt_input=ctx.request.input, + prompt_embeds=None, ) else: return self.create_error_response("Invalid classification request type") @@ -128,7 +110,7 @@ def _build_response( items: list[ClassificationData] = [] num_prompt_tokens = 0 - final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch) + final_res_batch_checked = ctx.final_res_batch for idx, final_res in enumerate(final_res_batch_checked): classify_res = ClassificationOutput.from_base(final_res.outputs) @@ -161,13 +143,6 @@ def _build_response( usage=usage, ) - def _build_render_config(self, request: ClassificationRequest) -> RenderConfig: - return RenderConfig( - max_length=self.max_model_len, - truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) - async def create_classify( self, request: ClassificationRequest, diff --git a/vllm/entrypoints/pooling/embed/protocol.py b/vllm/entrypoints/pooling/embed/protocol.py index 6cebe046deed..1ab6097e7921 100644 --- a/vllm/entrypoints/pooling/embed/protocol.py +++ b/vllm/entrypoints/pooling/embed/protocol.py @@ -3,10 +3,9 @@ import time from typing import Any, TypeAlias -from pydantic import ( - Field, -) +from pydantic import Field +from vllm.config import ModelConfig from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.pooling.base.protocol import ( ChatRequestMixin, @@ -14,15 +13,47 @@ EmbedRequestMixin, PoolingBasicRequestMixin, ) +from vllm.renderers import TokenizeParams from vllm.utils import random_uuid +def _get_max_total_output_tokens( + model_config: ModelConfig, +) -> tuple[int | None, int]: + max_total_tokens = model_config.max_model_len + pooler_config = model_config.pooler_config + + if pooler_config is None: + return max_total_tokens, 0 + + if pooler_config.enable_chunked_processing: + return None, 0 + + max_embed_len = pooler_config.max_embed_len or max_total_tokens + max_output_tokens = max_total_tokens - max_embed_len + return max_total_tokens, max_output_tokens + + class EmbeddingCompletionRequest( PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin ): - # Ordered by official OpenAI API documentation - # https://platform.openai.com/docs/api-reference/embeddings - pass + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + encoder_config = model_config.encoder_config or {} + + ( + max_total_tokens, + max_output_tokens, + ) = _get_max_total_output_tokens(model_config) + + return TokenizeParams( + max_total_tokens=max_total_tokens, + max_output_tokens=max_output_tokens, + truncate_prompt_tokens=self.truncate_prompt_tokens, + do_lower_case=encoder_config.get("do_lower_case", False), + add_special_tokens=self.add_special_tokens, + max_total_tokens_param="max_model_len", + max_output_tokens_param="max_model_len - max_embed_len", + ) class EmbeddingChatRequest( @@ -33,6 +64,24 @@ class EmbeddingChatRequest( description=("Additional kwargs to pass to the HF processor."), ) + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + encoder_config = model_config.encoder_config or {} + + ( + max_total_tokens, + max_output_tokens, + ) = _get_max_total_output_tokens(model_config) + + return TokenizeParams( + max_total_tokens=max_total_tokens, + max_output_tokens=max_output_tokens, + truncate_prompt_tokens=self.truncate_prompt_tokens, + do_lower_case=encoder_config.get("do_lower_case", False), + add_special_tokens=self.add_special_tokens, + max_total_tokens_param="max_model_len", + max_output_tokens_param="max_model_len - max_embed_len", + ) + EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index 22a6188878c8..7c9e840ead43 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json from collections.abc import AsyncGenerator, Mapping -from typing import Any, Final, cast +from typing import Any, Final, TypeAlias import torch from fastapi import Request @@ -22,8 +22,7 @@ EmbeddingResponse, EmbeddingResponseData, ) -from vllm.entrypoints.renderer import RenderConfig -from vllm.inputs.data import TokensPrompt +from vllm.inputs.data import EmbedsPrompt, TokensPrompt from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.pooling_params import PoolingParams @@ -37,7 +36,7 @@ logger = init_logger(__name__) -EmbeddingServeContext = ServeContext[EmbeddingRequest] +EmbeddingServeContext: TypeAlias = ServeContext[EmbeddingRequest] class OpenAIServingEmbedding(OpenAIServing): @@ -95,19 +94,16 @@ async def _preprocess( _, ctx.engine_prompts = await self._preprocess_chat( ctx.request, - self.renderer, ctx.request.messages, - chat_template=ctx.request.chat_template or self.chat_template, - chat_template_content_format=self.chat_template_content_format, - add_generation_prompt=ctx.request.add_generation_prompt, - continue_final_message=ctx.request.continue_final_message, - add_special_tokens=ctx.request.add_special_tokens, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=None, ) elif isinstance(ctx.request, EmbeddingCompletionRequest): - renderer = self._get_completion_renderer() - ctx.engine_prompts = await renderer.render_prompt( - prompt_or_prompts=ctx.request.input, - config=self._build_render_config(ctx.request), + ctx.engine_prompts = await self._preprocess_completion( + ctx.request, + prompt_input=ctx.request.input, + prompt_embeds=None, ) else: return self.create_error_response("Invalid classification request type") @@ -117,19 +113,6 @@ async def _preprocess( logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig: - # Set max_length based on chunked processing capability - if self._should_use_chunked_processing(request): - max_length = None - else: - max_length = self.max_embed_len or self.max_model_len - - return RenderConfig( - max_length=max_length, - truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) - def _build_response( self, ctx: EmbeddingServeContext, @@ -246,14 +229,18 @@ async def _process_chunked_request( lora_request=ctx.lora_request, ) + tok_params = ctx.request.build_tok_params(self.model_config) + tokenization_kwargs = tok_params.get_encode_kwargs() + # Create generator for this chunk and wrap it to return indices original_generator = self.engine_client.encode( chunk_engine_prompt, pooling_params, chunk_request_id, lora_request=ctx.lora_request, + tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, - priority=getattr(ctx.request, "priority", 0), + priority=ctx.request.priority, ) generators.append(original_generator) @@ -338,7 +325,7 @@ def _validate_input( async def _create_single_prompt_generator( self, ctx: EmbeddingServeContext, - engine_prompt: TokensPrompt, + engine_prompt: TokensPrompt | EmbedsPrompt, pooling_params: PoolingParams, trace_headers: Mapping[str, str] | None, prompt_index: int, @@ -353,23 +340,25 @@ async def _create_single_prompt_generator( lora_request=ctx.lora_request, ) + tok_params = ctx.request.build_tok_params(self.model_config) + tokenization_kwargs = tok_params.get_encode_kwargs() + # Return the original generator without wrapping return self.engine_client.encode( engine_prompt, pooling_params, request_id_item, lora_request=ctx.lora_request, + tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, - priority=getattr(ctx.request, "priority", 0), + priority=ctx.request.priority, ) async def _prepare_generators( self, - ctx: ServeContext, + ctx: EmbeddingServeContext, ) -> ErrorResponse | None: """Override to support chunked processing.""" - ctx = cast(EmbeddingServeContext, ctx) - # Check if we should use chunked processing use_chunked = self._should_use_chunked_processing(ctx.request) @@ -405,7 +394,8 @@ async def _prepare_generators( for i, engine_prompt in enumerate(ctx.engine_prompts): # Check if this specific prompt needs chunked processing if "prompt_token_ids" in engine_prompt: - prompt_token_ids = engine_prompt["prompt_token_ids"] + prompt_token_ids = engine_prompt["prompt_token_ids"] # type: ignore[typeddict-item] + if len(prompt_token_ids) > max_pos_embeddings: # Use chunked processing for this prompt chunk_generators = await self._process_chunked_request( @@ -573,7 +563,7 @@ async def _collect_batch( "token IDs" ) - original_token_ids = original_prompt["prompt_token_ids"] + original_token_ids = original_prompt["prompt_token_ids"] # type: ignore[typeddict-item] pooling_request_output = PoolingRequestOutput( request_id=aggregator["request_id"], diff --git a/vllm/entrypoints/pooling/pooling/protocol.py b/vllm/entrypoints/pooling/pooling/protocol.py index f3b043ca0dbd..633d0bb85632 100644 --- a/vllm/entrypoints/pooling/pooling/protocol.py +++ b/vllm/entrypoints/pooling/pooling/protocol.py @@ -3,11 +3,10 @@ import time from typing import Any, Generic, TypeAlias, TypeVar -from pydantic import ( - Field, -) +from pydantic import Field from vllm import PoolingParams +from vllm.config import ModelConfig from vllm.config.pooler import get_use_activation from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.pooling.base.protocol import ( @@ -18,6 +17,7 @@ EncodingRequestMixin, PoolingBasicRequestMixin, ) +from vllm.renderers import TokenizeParams from vllm.tasks import PoolingTask from vllm.utils import random_uuid @@ -30,6 +30,18 @@ class PoolingCompletionRequest( ): task: PoolingTask | None = None + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + encoder_config = model_config.encoder_config or {} + + return TokenizeParams( + max_total_tokens=model_config.max_model_len, + max_output_tokens=0, + truncate_prompt_tokens=self.truncate_prompt_tokens, + do_lower_case=encoder_config.get("do_lower_case", False), + add_special_tokens=self.add_special_tokens, + max_total_tokens_param="max_model_len", + ) + def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, @@ -48,6 +60,18 @@ class PoolingChatRequest( description=("Additional kwargs to pass to the HF processor."), ) + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + encoder_config = model_config.encoder_config or {} + + return TokenizeParams( + max_total_tokens=model_config.max_model_len, + max_output_tokens=0, + truncate_prompt_tokens=self.truncate_prompt_tokens, + do_lower_case=encoder_config.get("do_lower_case", False), + add_special_tokens=self.add_special_tokens, + max_total_tokens_param="max_model_len", + ) + def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index 1900e446dbb9..4efc7572b276 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -5,7 +5,7 @@ import json import time from collections.abc import AsyncGenerator, Sequence -from typing import Final, cast +from typing import Any, Final, cast import jinja2 from fastapi import Request @@ -14,10 +14,7 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.engine.protocol import ( - ErrorResponse, - UsageInfo, -) +from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.pooling.protocol import ( @@ -30,8 +27,6 @@ PoolingResponse, PoolingResponseData, ) -from vllm.entrypoints.renderer import RenderConfig -from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput from vllm.tasks import PoolingTask, SupportedTask @@ -99,11 +94,6 @@ async def create_pooling( "dimensions is currently not supported" ) - truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) - truncate_prompt_tokens = _validate_truncation_size( - self.max_model_len, truncate_prompt_tokens - ) - if is_io_processor_request: if self.io_processor is None: raise ValueError( @@ -134,19 +124,16 @@ async def create_pooling( _, engine_prompts = await self._preprocess_chat( request, - self.renderer, request.messages, - chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self.chat_template_content_format, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, - add_special_tokens=request.add_special_tokens, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=None, ) elif isinstance(request, PoolingCompletionRequest): - renderer = self._get_completion_renderer() - engine_prompts = await renderer.render_prompt( - prompt_or_prompts=request.input, - config=self._build_render_config(request), + engine_prompts = await self._preprocess_completion( + request, + prompt_input=request.input, + prompt_embeds=None, ) else: raise ValueError(f"Unsupported request of type {type(request)}") @@ -207,11 +194,18 @@ async def create_pooling( else await self._get_trace_headers(raw_request.headers) ) + if is_io_processor_request: + tokenization_kwargs: dict[str, Any] = {} + else: + tok_params = request.build_tok_params(self.model_config) # type: ignore + tokenization_kwargs = tok_params.get_encode_kwargs() + generator = self.engine_client.encode( engine_prompt, pooling_params, request_id_item, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, priority=request.priority, ) @@ -338,10 +332,3 @@ def encode_bytes(bytes_only: bool) -> PoolingBytesResponse: return encode_bytes(bytes_only=encoding_format == "bytes_only") else: assert_never(encoding_format) - - def _build_render_config(self, request: PoolingCompletionRequest) -> RenderConfig: - return RenderConfig( - max_length=self.max_model_len, - truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) diff --git a/vllm/entrypoints/pooling/score/protocol.py b/vllm/entrypoints/pooling/score/protocol.py index 2af43c4a8115..e080ffd67ff0 100644 --- a/vllm/entrypoints/pooling/score/protocol.py +++ b/vllm/entrypoints/pooling/score/protocol.py @@ -3,12 +3,10 @@ import time from typing import Any, TypeAlias -from pydantic import ( - BaseModel, - Field, -) +from pydantic import BaseModel, Field from vllm import PoolingParams +from vllm.config import ModelConfig from vllm.config.pooler import get_use_activation from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.pooling.base.protocol import ( @@ -19,6 +17,7 @@ ScoreContentPartParam, ScoreMultiModalParam, ) +from vllm.renderers import TokenizeParams from vllm.utils import random_uuid @@ -30,6 +29,17 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin): ) # --8<-- [end:score-extra-params] + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + encoder_config = model_config.encoder_config or {} + + return TokenizeParams( + max_total_tokens=model_config.max_model_len, + max_output_tokens=0, + truncate_prompt_tokens=self.truncate_prompt_tokens, + do_lower_case=encoder_config.get("do_lower_case", False), + max_total_tokens_param="max_model_len", + ) + def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, @@ -85,6 +95,17 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin): ) # --8<-- [end:rerank-extra-params] + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + encoder_config = model_config.encoder_config or {} + + return TokenizeParams( + max_total_tokens=model_config.max_model_len, + max_output_tokens=0, + truncate_prompt_tokens=self.truncate_prompt_tokens, + do_lower_case=encoder_config.get("do_lower_case", False), + max_total_tokens_param="max_model_len", + ) + class RerankDocument(BaseModel): text: str | None = None diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index 85c74e5a26c2..1bd28f8bdccb 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -34,7 +34,6 @@ compress_token_type_ids, get_score_prompt, ) -from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -68,31 +67,31 @@ def __init__( async def _embedding_score( self, - tokenizer: TokenizerLike, data_1: list[str], data_2: list[str], request: RerankRequest | ScoreRequest, request_id: str, - tokenization_kwargs: dict[str, Any] | None = None, lora_request: LoRARequest | None | None = None, trace_headers: Mapping[str, str] | None = None, ) -> list[PoolingRequestOutput] | ErrorResponse: - input_texts = data_1 + data_2 + model_config = self.model_config + tokenizer = self.renderer.get_tokenizer() - engine_prompts: list[TokensPrompt] = [] - tokenize_async = make_async( - tokenizer.__call__, executor=self._tokenizer_executor + encode_async = make_async( + tokenizer.encode, + executor=self._tokenizer_executor, ) - tokenization_kwargs = tokenization_kwargs or {} + input_texts = data_1 + data_2 + + tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs() tokenized_prompts = await asyncio.gather( - *(tokenize_async(t, **tokenization_kwargs) for t in input_texts) + *(encode_async(t, **tokenization_kwargs) for t in input_texts) ) + engine_prompts: list[TokensPrompt] = [] for tok_result, input_text in zip(tokenized_prompts, input_texts): - text_token_prompt = self._validate_input( - request, tok_result["input_ids"], input_text - ) + text_token_prompt = self._validate_input(request, tok_result, input_text) engine_prompts.append( TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"]) @@ -184,15 +183,16 @@ def _preprocess_score( async def _cross_encoding_score( self, - tokenizer: TokenizerLike, data_1: list[str] | list[ScoreContentPartParam], data_2: list[str] | list[ScoreContentPartParam], request: RerankRequest | ScoreRequest, request_id: str, - tokenization_kwargs: dict[str, Any] | None = None, lora_request: LoRARequest | None | None = None, trace_headers: Mapping[str, str] | None = None, ) -> list[PoolingRequestOutput] | ErrorResponse: + model_config = self.model_config + tokenizer = self.renderer.get_tokenizer() + request_prompts: list[str] = [] engine_prompts: list[TokensPrompt] = [] @@ -202,12 +202,13 @@ async def _cross_encoding_score( if isinstance(tokenizer, MistralTokenizer): raise ValueError("MistralTokenizer not supported for cross-encoding") - tokenization_kwargs = tokenization_kwargs or {} + tok_kwargs = request.build_tok_params(model_config).get_encode_kwargs() input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] preprocess_async = make_async( - self._preprocess_score, executor=self._tokenizer_executor + self._preprocess_score, + executor=self._tokenizer_executor, ) preprocessed_prompts = await asyncio.gather( @@ -215,7 +216,7 @@ async def _cross_encoding_score( preprocess_async( request=request, tokenizer=tokenizer, - tokenization_kwargs=tokenization_kwargs, + tokenization_kwargs=tok_kwargs, data_1=t1, data_2=t2, ) @@ -286,14 +287,6 @@ async def _run_scoring( raw_request: Request | None = None, ) -> list[PoolingRequestOutput] | ErrorResponse: lora_request = self._maybe_get_adapters(request) - tokenizer = self.renderer.get_tokenizer() - - truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) - - tokenization_kwargs: dict[str, Any] = {} - _validate_truncation_size( - self.max_model_len, truncate_prompt_tokens, tokenization_kwargs - ) trace_headers = ( None @@ -322,24 +315,20 @@ async def _run_scoring( if self.model_config.is_cross_encoder: return await self._cross_encoding_score( - tokenizer=tokenizer, data_1=data_1, # type: ignore[arg-type] data_2=data_2, # type: ignore[arg-type] request=request, request_id=request_id, - tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, trace_headers=trace_headers, ) else: return await self._embedding_score( - tokenizer=tokenizer, data_1=data_1, # type: ignore[arg-type] data_2=data_2, # type: ignore[arg-type] request=request, request_id=request_id, - tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, trace_headers=trace_headers, ) diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py deleted file mode 100644 index 8a88eff430d9..000000000000 --- a/vllm/entrypoints/renderer.py +++ /dev/null @@ -1,411 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import io -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Annotated - -import pybase64 -import torch -from pydantic import Field - -from vllm.config import ModelConfig -from vllm.exceptions import VLLMValidationError -from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt -from vllm.inputs.parse import get_prompt_components, parse_raw_prompts -from vllm.tokenizers import TokenizerLike -from vllm.utils.async_utils import AsyncMicrobatchTokenizer - - -@dataclass(frozen=True) -class RenderConfig: - """Configuration to control how prompts are prepared.""" - - max_length: int | None = None - """Maximum allowable total input token length. If provided, - token inputs longer than this raise `ValueError`.""" - - truncate_prompt_tokens: int | None = None - """Number of tokens to keep. `None` means no truncation. - `0` yields an empty list (and skips embeds). - `-1` maps to `model_config.max_model_len`.""" - - add_special_tokens: bool = True - """Whether to add model-specific special tokens during tokenization.""" - - cache_salt: str | None = None - """String to disambiguate prefix cache entries.""" - - needs_detokenization: bool | None = False - """If True, detokenize IDs back to text for inclusion in outputs.""" - - def verify_truncate_prompt_tokens(self, model_config: ModelConfig) -> int | None: - """Validate and normalize `truncate_prompt_tokens` parameter.""" - truncate_prompt_tokens = self.truncate_prompt_tokens - if truncate_prompt_tokens is None or truncate_prompt_tokens == 0: - return truncate_prompt_tokens - - if truncate_prompt_tokens < 0: - truncate_prompt_tokens = model_config.max_model_len - - max_length = self.max_length - if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator] - raise ValueError( - f"{truncate_prompt_tokens=} cannot be greater than " - f"{max_length=}. Please select a smaller truncation size." - ) - - return truncate_prompt_tokens - - -class BaseRenderer(ABC): - """ - Base class for unified input processing and rendering. - - The Renderer serves as a unified input processor that consolidates - tokenization, chat template formatting, and multimodal input handling - into a single component. - It converts high-level API requests (OpenAI-style JSON) into token IDs and - multimodal features ready for engine consumption. - - Key responsibilities: - - Convert text prompts to token sequences with proper special tokens - - Apply chat templates and format conversations - - Handle multimodal inputs (images, audio, etc.) when applicable - - Manage prompt truncation and length validation - - Provide clean separation between API layer and engine core - """ - - def __init__( - self, - model_config: ModelConfig, - tokenizer: TokenizerLike | None = None, - ): - super().__init__() - self.model_config = model_config - self.tokenizer = tokenizer - - @abstractmethod - async def render_prompt( - self, - *, - prompt_or_prompts: str | list[str] | list[int] | list[list[int]], - config: RenderConfig, - ) -> list[TokensPrompt]: - """ - Convert text or token inputs into engine-ready TokensPrompt objects. - - This method accepts text or token inputs and produces a - list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects - for the engine. - - Args: - prompt_or_prompts: One of: - - `str`: Single text prompt. - - `list[str]`: Batch of text prompts. - - `list[int]`: Single pre-tokenized sequence. - - `list[list[int]]`: Batch of pre-tokenized sequences. - config: Render configuration controlling how prompts are prepared - (e.g., tokenization and length handling). - - Returns: - list[TokensPrompt]: Engine-ready token prompts. - - Raises: - ValueError: If input formats are invalid or length limits exceeded. - """ - raise NotImplementedError - - @abstractmethod - async def render_prompt_and_embeds( - self, - *, - prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None, - prompt_embeds: bytes | list[bytes] | None = None, - config: RenderConfig, - ) -> list[TokensPrompt | EmbedsPrompt]: - """ - Convert text/token and/or base64-encoded embeddings inputs into - engine-ready prompt objects using a unified RenderConfig. - - At least one of `prompt_or_prompts` or `prompt_embeds` must be - provided and non-empty. If both are omitted or empty (e.g., empty - string and empty list), a `ValueError` is raised. - - Args: - prompt_or_prompts: Text or token inputs to include. - prompt_embeds: Base64-encoded bytes (or list thereof) containing a - torch-saved tensor to be used as prompt embeddings. - config: Render configuration controlling how prompts are prepared - (e.g., tokenization and length handling). - - Returns: - list[Union[TokensPrompt, EmbedsPrompt]]: - Engine-ready prompt objects. - - Raises: - ValueError: If both `prompt_or_prompts` and `prompt_embeds` - are omitted or empty (decoder prompt cannot be empty), or if - length limits are exceeded. - """ - raise NotImplementedError - - def load_prompt_embeds( - self, - prompt_embeds: bytes | list[bytes], - truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None, - cache_salt: str | None = None, - ) -> list[EmbedsPrompt]: - """Load and validate base64-encoded embeddings into prompt objects.""" - if not self.model_config.enable_prompt_embeds: - raise VLLMValidationError( - "You must set `--enable-prompt-embeds` to input `prompt_embeds`.", - parameter="prompt_embeds", - ) - - def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: - # Enable sparse tensor integrity checks to prevent out-of-bounds - # writes from maliciously crafted tensors - with torch.sparse.check_sparse_tensor_invariants(): - tensor = torch.load( - io.BytesIO(pybase64.b64decode(embed, validate=True)), - weights_only=True, - map_location=torch.device("cpu"), - ) - assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( - torch.float32, - torch.bfloat16, - torch.float16, - ) - tensor = tensor.to_dense() - if tensor.dim() > 2: - tensor = tensor.squeeze(0) - assert tensor.dim() == 2 - if truncate_prompt_tokens is not None: - tensor = tensor[-truncate_prompt_tokens:] - embeds_prompt = EmbedsPrompt(prompt_embeds=tensor) - if cache_salt is not None: - embeds_prompt["cache_salt"] = cache_salt - return embeds_prompt - - if isinstance(prompt_embeds, list): - return [_load_and_validate_embed(embed) for embed in prompt_embeds] - - return [_load_and_validate_embed(prompt_embeds)] - - -class CompletionRenderer(BaseRenderer): - def __init__( - self, - model_config: ModelConfig, - tokenizer: TokenizerLike | None = None, - async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] - | None = None, - ): - super().__init__(model_config, tokenizer) - self.async_tokenizer_pool = async_tokenizer_pool - self.async_tokenizer: AsyncMicrobatchTokenizer | None = None - - async def render_prompt( - self, - *, - prompt_or_prompts: str | list[str] | list[int] | list[list[int]], - config: RenderConfig, - ) -> list[TokensPrompt]: - """Implementation of prompt rendering for completion-style requests. - - Uses async tokenizer pooling for improved performance. See base class - for detailed parameter documentation. - """ - truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config) - if truncate_prompt_tokens == 0: - return [] - - tasks = ( - self._create_prompt( - prompt_input, - config=config, - truncate_prompt_tokens=truncate_prompt_tokens, - ) - for prompt_input in parse_raw_prompts(prompt_or_prompts) - ) - - return await asyncio.gather(*tasks) - - async def render_prompt_and_embeds( - self, - *, - prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None, - prompt_embeds: bytes | list[bytes] | None = None, - config: RenderConfig, - ) -> list[TokensPrompt | EmbedsPrompt]: - """ - Render text/token prompts and/or precomputed embedding prompts. At - least one of `prompt_or_prompts` or `prompt_embeds` must be provided. - """ - truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config) - if truncate_prompt_tokens == 0: - return [] - - rendered: list[TokensPrompt | EmbedsPrompt] = [] - - if prompt_embeds is not None: - rendered.extend( - self.load_prompt_embeds( - prompt_embeds, truncate_prompt_tokens, config.cache_salt - ) - ) - if prompt_or_prompts is None or prompt_or_prompts == "": - return rendered - - token_prompts = await self.render_prompt( - prompt_or_prompts=prompt_or_prompts, - config=config, - ) - rendered.extend(token_prompts) - - return rendered - - def _maybe_apply_truncation( - self, token_ids: list[int], truncate_prompt_tokens: int | None - ) -> list[int]: - """Apply truncation to token sequence.""" - if truncate_prompt_tokens is None: - return token_ids - if truncate_prompt_tokens >= len(token_ids): - return token_ids - - return token_ids[-truncate_prompt_tokens:] - - async def _create_prompt( - self, - prompt_input: TextPrompt | TokensPrompt, - config: RenderConfig, - truncate_prompt_tokens: int | None, - ) -> TokensPrompt: - prompt, prompt_token_ids, _ = get_prompt_components(prompt_input) - - if prompt_token_ids is not None: - # NOTE: detokenization is needed when echo is enabled, - # where the input token IDs are decoded back to text. - return await self._create_prompt_from_token_ids( - prompt_token_ids, - config.max_length, - truncate_prompt_tokens, - config.cache_salt, - config.needs_detokenization, - ) - - if prompt is not None: - return await self._create_prompt_from_text( - prompt, - config.max_length, - truncate_prompt_tokens, - config.add_special_tokens, - config.cache_salt, - ) - - # TODO: Also handle embeds prompt using this method - raise NotImplementedError - - async def _create_prompt_from_text( - self, - text: str, - max_length: int | None, - truncate_prompt_tokens: int | None, - add_special_tokens: bool, - cache_salt: str | None, - ) -> TokensPrompt: - """Tokenize text input asynchronously.""" - async_tokenizer = self._get_async_tokenizer() - - # Handle encoder-specific preprocessing - if ( - self.model_config.encoder_config is not None - and self.model_config.encoder_config.get("do_lower_case", False) - ): - text = text.lower() - - # Tokenize texts - if truncate_prompt_tokens is None: - encoded = await async_tokenizer(text, add_special_tokens=add_special_tokens) - else: - encoded = await async_tokenizer( - text, - add_special_tokens=add_special_tokens, - truncation=True, - max_length=truncate_prompt_tokens, - ) - - return self._create_tokens_prompt( - encoded.input_ids, max_length, cache_salt, text - ) - - async def _create_prompt_from_token_ids( - self, - token_ids: list[int], - max_length: int | None, - truncate_prompt_tokens: int | None, - cache_salt: str | None, - needs_detokenization: bool | None = False, - ) -> TokensPrompt: - """Optionally detokenize token IDs and build a tokens prompt.""" - token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens) - - prompt = None - if needs_detokenization: - async_tokenizer = self._get_async_tokenizer() - prompt = await async_tokenizer.decode(token_ids) - - return self._create_tokens_prompt( - token_ids=token_ids, - max_length=max_length, - cache_salt=cache_salt, - prompt=prompt, - ) - - def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer: - """Get or create async tokenizer using shared pool.""" - async_tokenizer = self.async_tokenizer - if async_tokenizer is not None: - return async_tokenizer - - tokenizer = self.tokenizer - if tokenizer is None: - raise ValueError("No tokenizer available for text input processing") - - if self.async_tokenizer_pool is None: - async_tokenizer = AsyncMicrobatchTokenizer(tokenizer) - else: - async_tokenizer = self.async_tokenizer_pool.get(tokenizer) - if async_tokenizer is None: - async_tokenizer = AsyncMicrobatchTokenizer(tokenizer) - self.async_tokenizer_pool[tokenizer] = async_tokenizer - self.async_tokenizer = async_tokenizer - return async_tokenizer - - def _create_tokens_prompt( - self, - token_ids: list[int], - max_length: int | None = None, - cache_salt: str | None = None, - prompt: str | None = None, - ) -> TokensPrompt: - """Create validated TokensPrompt.""" - if max_length is not None and len(token_ids) > max_length: - raise VLLMValidationError( - f"This model's maximum context length is {max_length} tokens. " - f"However, your request has {len(token_ids)} input tokens. " - "Please reduce the length of the input messages.", - parameter="input_tokens", - value=len(token_ids), - ) - - tokens_prompt = TokensPrompt(prompt_token_ids=token_ids) - if cache_salt is not None: - tokens_prompt["cache_salt"] = cache_salt - if prompt is not None: - tokens_prompt["prompt"] = prompt - return tokens_prompt diff --git a/vllm/entrypoints/serve/disagg/protocol.py b/vllm/entrypoints/serve/disagg/protocol.py index 659cf9e344ba..da13ea0cd476 100644 --- a/vllm/entrypoints/serve/disagg/protocol.py +++ b/vllm/entrypoints/serve/disagg/protocol.py @@ -4,12 +4,14 @@ from pydantic import BaseModel, Field +from vllm.config import ModelConfig from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionLogProbs from vllm.entrypoints.openai.engine.protocol import ( SamplingParams, StreamOptions, ) from vllm.logprobs import Logprob +from vllm.renderers import TokenizeParams from vllm.utils import random_uuid @@ -62,6 +64,12 @@ class GenerateRequest(BaseModel): description="KVTransfer parameters used for disaggregated serving.", ) + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + return TokenizeParams( + max_total_tokens=None, + max_output_tokens=0, + ) + class GenerateResponseChoice(BaseModel): index: int diff --git a/vllm/entrypoints/serve/disagg/serving.py b/vllm/entrypoints/serve/disagg/serving.py index 5253040c5fb0..b74b50611f6e 100644 --- a/vllm/entrypoints/serve/disagg/serving.py +++ b/vllm/entrypoints/serve/disagg/serving.py @@ -101,12 +101,13 @@ async def serve_tokens( # TODO(NickLucche): Change to EngineCoreRequest once Renderer work is # completed - engine_prompt = TokensPrompt(prompt_token_ids=request.token_ids) - if request.features is not None: - engine_prompt["multi_modal_data"] = None - - if hasattr(request, "cache_salt") and request.cache_salt is not None: - engine_prompt["cache_salt"] = request.cache_salt + engine_prompts = await self._preprocess_completion( + request, + prompt_input=request.token_ids, + prompt_embeds=None, + ) + assert len(engine_prompts) == 1 + engine_prompt = engine_prompts[0] # Schedule the request and get the result generator. result_generator: AsyncGenerator[RequestOutput, None] | None = None @@ -128,11 +129,15 @@ async def serve_tokens( else await self._get_trace_headers(raw_request.headers) ) + tok_params = request.build_tok_params(self.model_config) + tokenization_kwargs = tok_params.get_encode_kwargs() + result_generator = self.engine_client.generate( engine_prompt, sampling_params, request_id, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, priority=request.priority, ) diff --git a/vllm/entrypoints/serve/tokenize/protocol.py b/vllm/entrypoints/serve/tokenize/protocol.py index 49d737ef0a89..39b181aa7ea5 100644 --- a/vllm/entrypoints/serve/tokenize/protocol.py +++ b/vllm/entrypoints/serve/tokenize/protocol.py @@ -6,8 +6,10 @@ from pydantic import ConfigDict, Field, model_validator +from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, + ChatTemplateContentFormatOption, ) from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionToolsParam, @@ -15,6 +17,7 @@ from vllm.entrypoints.openai.engine.protocol import ( OpenAIBaseModel, ) +from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs class TokenizeCompletionRequest(OpenAIBaseModel): @@ -35,6 +38,13 @@ class TokenizeCompletionRequest(OpenAIBaseModel): ), ) + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + return TokenizeParams( + max_total_tokens=None, + max_output_tokens=0, + add_special_tokens=self.add_special_tokens, + ) + class TokenizeChatRequest(OpenAIBaseModel): model: str | None = None @@ -109,6 +119,30 @@ def check_generation_prompt(cls, data): ) return data + def build_chat_params( + self, + default_template: str | None, + default_template_content_format: ChatTemplateContentFormatOption, + ) -> ChatParams: + return ChatParams( + chat_template=self.chat_template or default_template, + chat_template_content_format=default_template_content_format, + chat_template_kwargs=merge_kwargs( + self.chat_template_kwargs, + dict( + add_generation_prompt=self.add_generation_prompt, + continue_final_message=self.continue_final_message, + ), + ), + ) + + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + return TokenizeParams( + max_total_tokens=None, + max_output_tokens=0, + add_special_tokens=self.add_special_tokens, + ) + TokenizeRequest: TypeAlias = TokenizeCompletionRequest | TokenizeChatRequest @@ -124,6 +158,13 @@ class DetokenizeRequest(OpenAIBaseModel): model: str | None = None tokens: list[int] + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + return TokenizeParams( + max_total_tokens=None, + max_output_tokens=0, + needs_detokenization=True, + ) + class DetokenizeResponse(OpenAIBaseModel): prompt: str diff --git a/vllm/entrypoints/serve/tokenize/serving.py b/vllm/entrypoints/serve/tokenize/serving.py index f0cfb8af1748..64a2741acdf6 100644 --- a/vllm/entrypoints/serve/tokenize/serving.py +++ b/vllm/entrypoints/serve/tokenize/serving.py @@ -9,12 +9,9 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.engine.protocol import ( - ErrorResponse, -) +from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.models.serving import OpenAIServingModels -from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.serve.tokenize.protocol import ( DetokenizeRequest, DetokenizeResponse, @@ -83,21 +80,17 @@ async def create_tokenize( _, engine_prompts = await self._preprocess_chat( request, - self.renderer, request.messages, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=None, tool_dicts=tool_dicts, - chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self.chat_template_content_format, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, - chat_template_kwargs=request.chat_template_kwargs, - add_special_tokens=request.add_special_tokens, ) else: - renderer = self._get_completion_renderer() - engine_prompts = await renderer.render_prompt( - prompt_or_prompts=request.prompt, - config=self._build_render_config(request), + engine_prompts = await self._preprocess_completion( + request, + prompt_input=request.prompt, + prompt_embeds=None, ) except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") @@ -106,11 +99,14 @@ async def create_tokenize( input_ids: list[int] = [] for engine_prompt in engine_prompts: self._log_inputs( - request_id, engine_prompt, params=None, lora_request=lora_request + request_id, + engine_prompt, + params=None, + lora_request=lora_request, ) - if isinstance(engine_prompt, dict) and "prompt_token_ids" in engine_prompt: - input_ids.extend(engine_prompt["prompt_token_ids"]) + if "prompt_token_ids" in engine_prompt: + input_ids.extend(engine_prompt["prompt_token_ids"]) # type: ignore[typeddict-item] token_strs = None if request.return_token_strs: @@ -136,7 +132,6 @@ async def create_detokenize( request_id = f"tokenize-{self._base_request_id(raw_request)}" lora_request = self._maybe_get_adapters(request) - tokenizer = self.renderer.get_tokenizer() self._log_inputs( request_id, @@ -145,14 +140,13 @@ async def create_detokenize( lora_request=lora_request, ) - prompt_input = await self._tokenize_prompt_input_async( - request, - tokenizer, - request.tokens, + engine_prompt = await self.renderer.tokenize_prompt_async( + TokensPrompt(prompt_token_ids=request.tokens), + request.build_tok_params(self.model_config), ) - input_text = prompt_input["prompt"] + prompt_text = engine_prompt["prompt"] # type: ignore[typeddict-item] - return DetokenizeResponse(prompt=input_text) + return DetokenizeResponse(prompt=prompt_text) async def get_tokenizer_info( self, @@ -165,9 +159,6 @@ async def get_tokenizer_info( except Exception as e: return self.create_error_response(f"Failed to get tokenizer info: {str(e)}") - def _build_render_config(self, request: TokenizeRequest) -> RenderConfig: - return RenderConfig(add_special_tokens=request.add_special_tokens) - @dataclass class TokenizerInfo: diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index c329e7a19149..290f8fd89af2 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -8,7 +8,7 @@ from argparse import Namespace from logging import Logger from string import Template -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import regex as re from fastapi import Request @@ -18,9 +18,9 @@ from vllm import envs from vllm.engine.arg_utils import EngineArgs from vllm.inputs import EmbedsPrompt, TokensPrompt +from vllm.inputs.parse import get_prompt_len from vllm.logger import current_formatter_type, init_logger from vllm.platforms import current_platform -from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils.argparse_utils import FlexibleArgumentParser if TYPE_CHECKING: @@ -34,9 +34,7 @@ StreamOptions, ) from vllm.entrypoints.openai.models.protocol import LoRAModulePath - from vllm.entrypoints.openai.responses.protocol import ( - ResponsesRequest, - ) + from vllm.entrypoints.openai.responses.protocol import ResponsesRequest else: ChatCompletionRequest = object CompletionRequest = object @@ -188,33 +186,6 @@ def cli_env_setup(): os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" -def _validate_truncation_size( - max_model_len: int, - truncate_prompt_tokens: int | None, - tokenization_kwargs: dict[str, Any] | None = None, -) -> int | None: - if truncate_prompt_tokens is not None: - if truncate_prompt_tokens <= -1: - truncate_prompt_tokens = max_model_len - - if truncate_prompt_tokens > max_model_len: - raise ValueError( - f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " - f"is greater than max_model_len ({max_model_len})." - f" Please, select a smaller truncation size." - ) - - if tokenization_kwargs is not None: - tokenization_kwargs["truncation"] = True - tokenization_kwargs["max_length"] = truncate_prompt_tokens - - else: - if tokenization_kwargs is not None: - tokenization_kwargs["truncation"] = False - - return truncate_prompt_tokens - - def get_max_tokens( max_model_len: int, request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest", @@ -233,10 +204,7 @@ def get_max_tokens( # CompletionRequest (also a fallback for ChatCompletionRequest) max_tokens = getattr(request, "max_tokens", None) - input_length = length_from_prompt_token_ids_or_embeds( - prompt.get("prompt_token_ids"), # type: ignore[arg-type] - prompt.get("prompt_embeds"), # type: ignore[arg-type] - ) + input_length = get_prompt_len(prompt) default_max_tokens = max_model_len - input_length max_output_tokens = current_platform.get_max_output_tokens(input_length) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 1f138a72d084..82832dcaa0b5 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -18,12 +18,7 @@ MultiModalUUIDDict = object -class TextPrompt(TypedDict): - """Schema for a text prompt.""" - - prompt: str - """The input text to be tokenized before passing to the model.""" - +class _CommonKeys(TypedDict): multi_modal_data: NotRequired[MultiModalDataDict | None] """ Optional multi-modal data to pass to the model, @@ -53,7 +48,14 @@ class TextPrompt(TypedDict): """ -class TokensPrompt(TypedDict): +class TextPrompt(_CommonKeys): + """Schema for a text prompt.""" + + prompt: str + """The input text to be tokenized before passing to the model.""" + + +class TokensPrompt(_CommonKeys): """Schema for a tokenized prompt.""" prompt_token_ids: list[int] @@ -65,47 +67,15 @@ class TokensPrompt(TypedDict): token_type_ids: NotRequired[list[int]] """A list of token type IDs to pass to the cross encoder model.""" - multi_modal_data: NotRequired[MultiModalDataDict | None] - """ - Optional multi-modal data to pass to the model, - if the model supports it. - """ - - mm_processor_kwargs: NotRequired[dict[str, Any] | None] - """ - Optional multi-modal processor kwargs to be forwarded to the - multimodal input mapper & processor. Note that if multiple modalities - have registered mappers etc for the model being considered, we attempt - to pass the mm_processor_kwargs to each of them. - """ - - multi_modal_uuids: NotRequired[MultiModalUUIDDict] - """ - Optional user-specified UUIDs for multimodal items, mapped by modality. - Lists must match the number of items per modality and may contain `None`. - For `None` entries, the hasher will compute IDs automatically; non-None - entries override the default hashes for caching. - """ - cache_salt: NotRequired[str] - """ - Optional cache salt to be used for prefix caching. - """ - - -class EmbedsPrompt(TypedDict): +class EmbedsPrompt(_CommonKeys): """Schema for a prompt provided via token embeddings.""" prompt_embeds: torch.Tensor """The embeddings of the prompt.""" - cache_salt: NotRequired[str] - """ - Optional cache salt to be used for prefix caching. - """ - -class DataPrompt(TypedDict): +class DataPrompt(_CommonKeys): """Represents generic inputs handled by IO processor plugins.""" data: Any @@ -194,7 +164,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): mm_processor_kwargs: NotRequired[dict[str, Any]] -PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt +PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt[Any, Any] """ Set of possible schemas for an LLM input, including both decoder-only and encoder/decoder input types: diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 5e7795a14072..5f832afdbf4c 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence -from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict, cast +from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict from typing_extensions import TypeIs -from vllm.utils.collection_utils import is_list_of +from vllm.utils import length_from_prompt_token_ids_or_embeds from .data import ( EmbedsPrompt, @@ -22,50 +21,6 @@ import torch -def parse_raw_prompts( - prompt: str | list[str] | list[int] | list[list[int]], -) -> Sequence[TextPrompt] | Sequence[TokensPrompt]: - if isinstance(prompt, str): - # case 1: a string - return [TextPrompt(prompt=prompt)] - - if isinstance(prompt, list): - if len(prompt) == 0: - raise ValueError("please provide at least one prompt") - - # case 2: array of strings - if is_list_of(prompt, str): - prompt = cast(list[str], prompt) - return [TextPrompt(prompt=elem) for elem in prompt] - - # case 3: array of tokens - if is_list_of(prompt, int): - prompt = cast(list[int], prompt) - return [TokensPrompt(prompt_token_ids=prompt)] - - # case 4: array of token arrays - if is_list_of(prompt, list): - if len(prompt) == 1 and isinstance(prompt[0], list) and len(prompt[0]) == 0: - raise ValueError("please provide at least one prompt") - for elem in prompt: - if not isinstance(elem, list): - raise TypeError( - "prompt must be a list of lists, but found a non-list element." - ) - if not is_list_of(elem, int): - raise TypeError( - "Nested lists of tokens must contain only integers." - ) - - prompt = cast(list[list[int]], prompt) - return [TokensPrompt(prompt_token_ids=elem) for elem in prompt] - - raise TypeError( - "prompt must be a string, array of strings, " - "array of tokens, or array of token arrays" - ) - - class ParsedStrPrompt(TypedDict): type: Literal["str"] content: str @@ -145,3 +100,10 @@ def get_prompt_components(prompt: PromptType) -> PromptComponents: token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type] embeds=prompt.get("prompt_embeds"), ) + + +def get_prompt_len(prompt: TokensPrompt | EmbedsPrompt): + return length_from_prompt_token_ids_or_embeds( + prompt.get("prompt_token_ids"), # type: ignore[arg-type] + prompt.get("prompt_embeds"), # type: ignore[arg-type] + ) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 124d657613d9..944dc5e12130 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -209,6 +209,7 @@ def _call_hf_processor( item_processor_data = dict(**mm_data, audios=audios) # some tokenizer kwargs are incompatible with UltravoxProcessor + tok_kwargs.pop("add_special_tokens", None) tok_kwargs.pop("padding", None) tok_kwargs.pop("truncation", None) diff --git a/vllm/renderers/__init__.py b/vllm/renderers/__init__.py index cd6a11dcc833..d8f76dd87fcb 100644 --- a/vllm/renderers/__init__.py +++ b/vllm/renderers/__init__.py @@ -1,7 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .params import ChatParams, TokenizeParams, merge_kwargs from .protocol import RendererLike from .registry import RendererRegistry, renderer_from_config -__all__ = ["RendererLike", "RendererRegistry", "renderer_from_config"] +__all__ = [ + "RendererLike", + "RendererRegistry", + "renderer_from_config", + "ChatParams", + "TokenizeParams", + "merge_kwargs", +] diff --git a/vllm/renderers/deepseek_v32.py b/vllm/renderers/deepseek_v32.py index 123911654d8c..91b95db06f41 100644 --- a/vllm/renderers/deepseek_v32.py +++ b/vllm/renderers/deepseek_v32.py @@ -9,11 +9,12 @@ parse_chat_messages, parse_chat_messages_async, ) -from vllm.inputs import TextPrompt, TokensPrompt +from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.logger import init_logger from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer +from .params import ChatParams from .protocol import RendererLike logger = init_logger(__name__) @@ -61,8 +62,8 @@ def get_tokenizer(self) -> DeepseekV32Tokenizer: def render_messages( self, messages: list[ChatCompletionMessageParam], - **kwargs, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + params: ChatParams, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: tokenizer = self.get_tokenizer() conversation, mm_data, mm_uuids = parse_chat_messages( messages, @@ -73,26 +74,22 @@ def render_messages( prompt_raw = tokenizer.apply_chat_template( conversation=conversation, messages=messages, - **kwargs, + **params.get_apply_chat_template_kwargs(), ) - prompt = ( - TextPrompt(prompt=prompt_raw) - if isinstance(prompt_raw, str) - else TokensPrompt(prompt_token_ids=prompt_raw) - ) + prompt = self.render_completion(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: prompt["multi_modal_uuids"] = mm_uuids - return conversation, prompt # type: ignore[return-value] + return conversation, prompt async def render_messages_async( self, messages: list[ChatCompletionMessageParam], - **kwargs, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + params: ChatParams, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: tokenizer = self.get_tokenizer() conversation, mm_data, mm_uuids = await parse_chat_messages_async( messages, @@ -103,17 +100,13 @@ async def render_messages_async( prompt_raw = tokenizer.apply_chat_template( conversation=conversation, messages=messages, - **kwargs, + **params.get_apply_chat_template_kwargs(), ) - prompt = ( - TextPrompt(prompt=prompt_raw) - if isinstance(prompt_raw, str) - else TokensPrompt(prompt_token_ids=prompt_raw) - ) + prompt = self.render_completion(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: prompt["multi_modal_uuids"] = mm_uuids - return conversation, prompt # type: ignore[return-value] + return conversation, prompt diff --git a/vllm/renderers/embed_utils.py b/vllm/renderers/embed_utils.py new file mode 100644 index 000000000000..a51fc53a24ad --- /dev/null +++ b/vllm/renderers/embed_utils.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from io import BytesIO +from typing import TYPE_CHECKING + +import pybase64 +import torch + +from vllm.exceptions import VLLMValidationError + +if TYPE_CHECKING: + from vllm.config import ModelConfig + + +def safe_load_prompt_embeds( + model_config: "ModelConfig", + embed: bytes, +) -> torch.Tensor: + if not model_config.enable_prompt_embeds: + raise VLLMValidationError( + "You must set `--enable-prompt-embeds` to input `prompt_embeds`.", + parameter="prompt_embeds", + ) + + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load( + BytesIO(pybase64.b64decode(embed, validate=True)), + weights_only=True, + map_location=torch.device("cpu"), + ) + assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( + torch.float32, + torch.bfloat16, + torch.float16, + ) + tensor = tensor.to_dense() + + if tensor.dim() > 2: + tensor = tensor.squeeze(0) + assert tensor.dim() == 2 + + return tensor diff --git a/vllm/renderers/grok2.py b/vllm/renderers/grok2.py index 06de760f8f90..feefe8f0b4e8 100644 --- a/vllm/renderers/grok2.py +++ b/vllm/renderers/grok2.py @@ -9,11 +9,12 @@ parse_chat_messages, parse_chat_messages_async, ) -from vllm.inputs import TextPrompt, TokensPrompt +from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.logger import init_logger from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers.grok2 import Grok2Tokenizer +from .params import ChatParams from .protocol import RendererLike logger = init_logger(__name__) @@ -61,8 +62,8 @@ def get_tokenizer(self) -> Grok2Tokenizer: def render_messages( self, messages: list[ChatCompletionMessageParam], - **kwargs, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + params: ChatParams, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: tokenizer = self.get_tokenizer() conversation, mm_data, mm_uuids = parse_chat_messages( messages, @@ -73,26 +74,22 @@ def render_messages( prompt_raw = tokenizer.apply_chat_template( conversation=conversation, messages=messages, - **kwargs, + **params.get_apply_chat_template_kwargs(), ) - prompt = ( - TextPrompt(prompt=prompt_raw) - if isinstance(prompt_raw, str) - else TokensPrompt(prompt_token_ids=prompt_raw) - ) + prompt = self.render_completion(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: prompt["multi_modal_uuids"] = mm_uuids - return conversation, prompt # type: ignore[return-value] + return conversation, prompt async def render_messages_async( self, messages: list[ChatCompletionMessageParam], - **kwargs, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + params: ChatParams, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: tokenizer = self.get_tokenizer() conversation, mm_data, mm_uuids = await parse_chat_messages_async( messages, @@ -103,17 +100,13 @@ async def render_messages_async( prompt_raw = tokenizer.apply_chat_template( conversation=conversation, messages=messages, - **kwargs, + **params.get_apply_chat_template_kwargs(), ) - prompt = ( - TextPrompt(prompt=prompt_raw) - if isinstance(prompt_raw, str) - else TokensPrompt(prompt_token_ids=prompt_raw) - ) + prompt = self.render_completion(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: prompt["multi_modal_uuids"] = mm_uuids - return conversation, prompt # type: ignore[return-value] + return conversation, prompt diff --git a/vllm/renderers/hf.py b/vllm/renderers/hf.py index 252e6e753c9a..eb90c9e6d3a6 100644 --- a/vllm/renderers/hf.py +++ b/vllm/renderers/hf.py @@ -25,7 +25,7 @@ parse_chat_messages, parse_chat_messages_async, ) -from vllm.inputs import TextPrompt, TokensPrompt +from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.logger import init_logger from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer @@ -33,6 +33,7 @@ from vllm.transformers_utils.processor import cached_get_processor from vllm.utils.func_utils import supports_kw +from .params import ChatParams from .protocol import RendererLike if TYPE_CHECKING: @@ -631,9 +632,8 @@ def get_tokenizer(self) -> HfTokenizer: def render_messages( self, messages: list[ChatCompletionMessageParam], - chat_template_content_format: ChatTemplateContentFormatOption = "auto", - **kwargs, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + params: ChatParams, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: model_config = self.config tokenizer = self.get_tokenizer() @@ -641,9 +641,9 @@ def render_messages( messages, model_config, content_format=resolve_chat_template_content_format( - chat_template=kwargs.get("chat_template"), - tools=kwargs.get("tools"), - given_format=chat_template_content_format, + chat_template=params.chat_template, + tools=params.chat_template_kwargs.get("tools"), + given_format=params.chat_template_content_format, tokenizer=tokenizer, model_config=model_config, ), @@ -653,7 +653,7 @@ def render_messages( model_config, tokenizer, conversation, - **kwargs, + **params.get_apply_chat_template_kwargs(), ) # NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5 @@ -665,7 +665,7 @@ def render_messages( ): mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data) - # get video placehoder, replace it with runtime video-chunk prompts + # get video placeholder, replace it with runtime video-chunk prompts video_placeholder = getattr( model_config.hf_config, "video_placeholder", None ) @@ -675,24 +675,19 @@ def render_messages( video_placeholder, ) - prompt = ( - TextPrompt(prompt=prompt_raw) - if isinstance(prompt_raw, str) - else TokensPrompt(prompt_token_ids=prompt_raw) - ) + prompt = self.render_completion(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: prompt["multi_modal_uuids"] = mm_uuids - return conversation, prompt # type: ignore[return-value] + return conversation, prompt async def render_messages_async( self, messages: list[ChatCompletionMessageParam], - chat_template_content_format: ChatTemplateContentFormatOption = "auto", - **kwargs, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + params: ChatParams, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: model_config = self.config tokenizer = self.get_tokenizer() @@ -700,9 +695,9 @@ async def render_messages_async( messages, model_config, content_format=resolve_chat_template_content_format( - chat_template=kwargs.get("chat_template"), - tools=kwargs.get("tools"), - given_format=chat_template_content_format, + chat_template=params.chat_template, + tools=params.chat_template_kwargs.get("tools"), + given_format=params.chat_template_content_format, tokenizer=tokenizer, model_config=model_config, ), @@ -712,7 +707,7 @@ async def render_messages_async( model_config, tokenizer, conversation, - **kwargs, + **params.get_apply_chat_template_kwargs(), ) # NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5 @@ -722,9 +717,7 @@ async def render_messages_async( and mm_uuids is not None and mm_data is not None ): - mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data) - - # get video placehoder, replace it with runtime video-chunk prompts + # get video placeholder, replace it with runtime video-chunk prompts video_placeholder = getattr( model_config.hf_config, "video_placeholder", None ) @@ -734,14 +727,10 @@ async def render_messages_async( video_placeholder, ) - prompt = ( - TextPrompt(prompt=prompt_raw) - if isinstance(prompt_raw, str) - else TokensPrompt(prompt_token_ids=prompt_raw) - ) + prompt = self.render_completion(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: prompt["multi_modal_uuids"] = mm_uuids - return conversation, prompt # type: ignore[return-value] + return conversation, prompt diff --git a/vllm/renderers/mistral.py b/vllm/renderers/mistral.py index c45fb1f77ed8..109690dd7423 100644 --- a/vllm/renderers/mistral.py +++ b/vllm/renderers/mistral.py @@ -10,12 +10,13 @@ parse_chat_messages, parse_chat_messages_async, ) -from vllm.inputs import TextPrompt, TokensPrompt +from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.logger import init_logger from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.async_utils import make_async +from .params import ChatParams from .protocol import RendererLike logger = init_logger(__name__) @@ -95,8 +96,8 @@ def get_tokenizer(self) -> MistralTokenizer: def render_messages( self, messages: list[ChatCompletionMessageParam], - **kwargs, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + params: ChatParams, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: tokenizer = self.get_tokenizer() conversation, mm_data, mm_uuids = parse_chat_messages( messages, @@ -104,25 +105,25 @@ def render_messages( content_format="string", ) - prompt_raw = safe_apply_chat_template(tokenizer, messages, **kwargs) - - prompt = ( - TextPrompt(prompt=prompt_raw) - if isinstance(prompt_raw, str) - else TokensPrompt(prompt_token_ids=prompt_raw) + prompt_raw = safe_apply_chat_template( + tokenizer, + messages, + **params.get_apply_chat_template_kwargs(), ) + + prompt = self.render_completion(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: prompt["multi_modal_uuids"] = mm_uuids - return conversation, prompt # type: ignore[return-value] + return conversation, prompt async def render_messages_async( self, messages: list[ChatCompletionMessageParam], - **kwargs, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + params: ChatParams, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: tokenizer = self.get_tokenizer() conversation, mm_data, mm_uuids = await parse_chat_messages_async( messages, @@ -131,17 +132,15 @@ async def render_messages_async( ) prompt_raw = await self._apply_chat_template_async( - tokenizer, messages, **kwargs + tokenizer, + messages, + **params.get_apply_chat_template_kwargs(), ) - prompt = ( - TextPrompt(prompt=prompt_raw) - if isinstance(prompt_raw, str) - else TokensPrompt(prompt_token_ids=prompt_raw) - ) + prompt = self.render_completion(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: prompt["multi_modal_uuids"] = mm_uuids - return conversation, prompt # type: ignore[return-value] + return conversation, prompt diff --git a/vllm/renderers/params.py b/vllm/renderers/params.py new file mode 100644 index 000000000000..19555bba04b1 --- /dev/null +++ b/vllm/renderers/params.py @@ -0,0 +1,351 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, TypeVar + +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.exceptions import VLLMValidationError +from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.utils.import_utils import LazyLoader + +if TYPE_CHECKING: + import torch +else: + torch = LazyLoader("torch", globals(), "torch") + +logger = init_logger(__name__) + + +_S = TypeVar("_S", list[int], "torch.Tensor") + + +def merge_kwargs( + defaults: dict[str, Any] | None, + overrides: dict[str, Any] | None, + /, + *, + unset_values: tuple[object, ...] = (None, "auto"), +) -> dict[str, Any]: + if defaults is None: + defaults = {} + if overrides is None: + overrides = {} + + return defaults | {k: v for k, v in overrides.items() if v not in unset_values} + + +@dataclass(frozen=True) +class ChatParams: + """Configuration to control how to parse chat messages.""" + + chat_template: str | None = None + """The chat template to apply.""" + + chat_template_content_format: ChatTemplateContentFormatOption = "auto" + """The format of the chat template.""" + + chat_template_kwargs: dict[str, Any] = field(default_factory=dict) + """The kwargs to pass to the chat template.""" + + def with_defaults(self, default_chat_template_kwargs: dict[str, Any] | None): + if not default_chat_template_kwargs: + return self + + return ChatParams( + chat_template=self.chat_template, + chat_template_content_format=self.chat_template_content_format, + chat_template_kwargs=merge_kwargs( + default_chat_template_kwargs, + self.chat_template_kwargs, + ), + ) + + def get_apply_chat_template_kwargs(self) -> dict[str, Any]: + """The arguments to pass to `tokenizer.apply_chat_template`.""" + return merge_kwargs( + self.chat_template_kwargs, + dict(chat_template=self.chat_template), + ) + + +@dataclass(frozen=True) +class TokenizeParams: + """Configuration to control how prompts are tokenized.""" + + max_total_tokens: int | None + """ + Maximum allowed number of input + output tokens. + + Usually, this refers to the model's context length. + """ + + max_output_tokens: int = 0 + """Maximum requested number of output tokens.""" + + pad_prompt_tokens: int | None = None + """ + Number of tokens to pad to: + - `None` means no padding. + - `-1` maps to `max_input_tokens`. + """ + + truncate_prompt_tokens: int | None = None + """ + Number of tokens to keep: + - `None` means no truncation. + - `-1` maps to `max_input_tokens`. + """ + + do_lower_case: bool = False + """Whether to normalize text to lower case before tokenization.""" + + add_special_tokens: bool = True + """Whether to add special tokens.""" + + needs_detokenization: bool = False + """ + Whether the tokenized prompt needs to contain the original text. + + Not to be confused with `SamplingParams.detokenize` which deals + with the output generated by the model. + """ + + max_total_tokens_param: str = "max_total_tokens" + """Override this to edit the message for validation errors.""" + + max_output_tokens_param: str = "max_output_tokens" + """Override this to edit the message for validation errors.""" + + truncate_prompt_tokens_param: str = "truncate_prompt_tokens" + """Override this to edit the message for validation errors.""" + + @property + def max_input_tokens(self) -> int | None: + """Maximum allowed number of input tokens.""" + if self.max_total_tokens is None: + return None + + return self.max_total_tokens - self.max_output_tokens + + def __post_init__(self) -> None: + max_total_tokens = self.max_total_tokens + max_output_tokens = self.max_output_tokens + max_input_tokens = self.max_input_tokens + truncate_prompt_tokens = self.truncate_prompt_tokens + + if ( + max_output_tokens is not None + and max_total_tokens is not None + and max_output_tokens > max_total_tokens + ): + raise VLLMValidationError( + f"{self.max_output_tokens_param}={max_output_tokens}" + f"cannot be greater than " + f"{self.max_total_tokens_param}={max_total_tokens=}. " + f"Please request fewer output tokens.", + parameter=self.max_output_tokens_param, + value=max_output_tokens, + ) + + if ( + max_input_tokens is not None + and truncate_prompt_tokens is not None + and truncate_prompt_tokens > max_input_tokens + ): + raise VLLMValidationError( + f"{self.truncate_prompt_tokens_param}={truncate_prompt_tokens} " + f"cannot be greater than {self.max_total_tokens_param} - " + f"{self.max_output_tokens_param} = {max_input_tokens}. " + f"Please request a smaller truncation size.", + parameter=self.truncate_prompt_tokens_param, + value=truncate_prompt_tokens, + ) + + def with_kwargs(self, tokenization_kwargs: dict[str, Any] | None): + if tokenization_kwargs is None: + tokenization_kwargs = {} + + max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens) + pad_prompt_tokens = tokenization_kwargs.pop( + "pad_prompt_tokens", self.pad_prompt_tokens + ) + truncate_prompt_tokens = tokenization_kwargs.pop( + "truncate_prompt_tokens", self.truncate_prompt_tokens + ) + do_lower_case = tokenization_kwargs.pop("do_lower_case", self.do_lower_case) + add_special_tokens = tokenization_kwargs.pop( + "add_special_tokens", self.add_special_tokens + ) + needs_detokenization = tokenization_kwargs.pop( + "needs_detokenization", self.needs_detokenization + ) + + # https://huggingface.co/docs/transformers/en/pad_truncation + if padding := tokenization_kwargs.pop("padding", None): + if padding == "max_length": + pad_prompt_tokens = max_length + elif padding in (False, "do_not_pad"): + pad_prompt_tokens = None + else: + # To emit the below warning + tokenization_kwargs["padding"] = padding + + if truncation := tokenization_kwargs.pop("truncation", None): + if truncation in (True, "longest_first"): + truncate_prompt_tokens = max_length + elif truncation in (False, "do_not_truncate"): + truncate_prompt_tokens = None + else: + # To emit the below warning + tokenization_kwargs["truncation"] = truncation + + if tokenization_kwargs: + logger.warning( + "The following tokenization arguments are not supported " + "by vLLM Renderer and will be ignored: %s", + tokenization_kwargs, + ) + + max_total_tokens = self.max_total_tokens + + return TokenizeParams( + max_total_tokens=max_total_tokens, + max_output_tokens=( + 0 + if max_total_tokens is None or max_length is None + else max_total_tokens - max_length + ), + pad_prompt_tokens=pad_prompt_tokens, + truncate_prompt_tokens=truncate_prompt_tokens, + do_lower_case=do_lower_case, + add_special_tokens=add_special_tokens, + needs_detokenization=needs_detokenization, + ) + + def get_encode_kwargs(self) -> dict[str, Any]: + """The arguments to pass to `tokenizer.encode`.""" + max_length = self.truncate_prompt_tokens + if max_length is not None and max_length < 0: + max_length = self.max_input_tokens + + return dict( + truncation=self.truncate_prompt_tokens is not None, + max_length=max_length, + add_special_tokens=self.add_special_tokens, + ) + + def _apply_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str: + if self.do_lower_case: + text = text.lower() + + return text + + def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str: + """Apply all validators to prompt text.""" + # TODO: Implement https://github.com/vllm-project/vllm/pull/31366 + for validator in (self._apply_lowercase,): + text = validator(tokenizer, text) + + return text + + def apply_pre_tokenization( + self, + tokenizer: TokenizerLike | None, + prompt: TextPrompt, + ) -> TextPrompt: + """ + Ensure that the prompt meets the requirements set out by this config. + If that is not possible, raise a `VLLMValidationError`. + + This method is run before tokenization occurs. + """ + prompt["prompt"] = self._validate_text(tokenizer, prompt["prompt"]) + + return prompt + + def _apply_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: + """Apply padding to a token sequence.""" + pad_length = self.pad_prompt_tokens + if pad_length is not None and pad_length < 0: + pad_length = self.max_input_tokens + + if pad_length is None or pad_length <= len(tokens): + return tokens + + if tokenizer is None: + raise ValueError("Cannot pad tokens when `skip_tokenizer_init=True`") + if not isinstance(tokens, list): + raise ValueError("Cannot pad tokens for embedding inputs") + + return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens)) + + def _apply_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: + """Apply truncation to a token sequence.""" + max_length = self.truncate_prompt_tokens + if max_length is not None and max_length < 0: + max_length = self.max_input_tokens + + if max_length is None or max_length >= len(tokens): + return tokens + if max_length == 0: + return tokens[:0] + + if getattr(tokenizer, "truncation_side", "left") == "left": + return tokens[-max_length:] + + return tokens[:max_length] + + def _apply_length_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: + """Apply length checks to a token sequence.""" + max_input_tokens = self.max_input_tokens + + if max_input_tokens is not None and len(tokens) > max_input_tokens: + raise VLLMValidationError( + f"You passed {len(tokens)} input tokens and " + f"requested {self.max_output_tokens} output tokens. " + f"However, the model's context length is only " + f"{self.max_total_tokens}, resulting in a maximum " + f"input length of {max_input_tokens}. " + f"Please reduce the length of the input messages.", + parameter="input_tokens", + value=len(tokens), + ) + + return tokens + + def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: + """Apply all validators to a token sequence.""" + for validator in ( + self._apply_padding, + self._apply_truncation, + self._apply_length_check, + ): + tokens = validator(tokenizer, tokens) + + return tokens + + def apply_post_tokenization( + self, + tokenizer: TokenizerLike | None, + prompt: TokensPrompt | EmbedsPrompt, + ) -> TokensPrompt | EmbedsPrompt: + """ + Ensure that the prompt meets the requirements set out by this config. + If that is not possible, raise a `VLLMValidationError`. + + This method is run after tokenization occurs. + """ + if "prompt_token_ids" in prompt: + prompt["prompt_token_ids"] = self._validate_tokens( # type: ignore[typeddict-unknown-key] + tokenizer, + prompt["prompt_token_ids"], # type: ignore[typeddict-item] + ) + if "prompt_embeds" in prompt: + prompt["prompt_embeds"] = self._validate_tokens( # type: ignore[typeddict-unknown-key] + tokenizer, + prompt["prompt_embeds"], # type: ignore[typeddict-item] + ) + + return prompt diff --git a/vllm/renderers/protocol.py b/vllm/renderers/protocol.py index e788f431b0f8..71d30c3b8c88 100644 --- a/vllm/renderers/protocol.py +++ b/vllm/renderers/protocol.py @@ -1,9 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio from typing import TYPE_CHECKING, Any, Protocol -from vllm.inputs import TextPrompt, TokensPrompt +from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.tokenizers import TokenizerLike +from vllm.utils.async_utils import AsyncMicrobatchTokenizer +from vllm.utils.collection_utils import is_list_of + +from .embed_utils import safe_load_prompt_embeds +from .params import ChatParams, TokenizeParams if TYPE_CHECKING: from vllm.config import ModelConfig @@ -14,6 +20,9 @@ class RendererLike(Protocol): + config: "ModelConfig" + _async_tokenizer: AsyncMicrobatchTokenizer + @classmethod def from_config( cls, @@ -33,16 +42,147 @@ def get_tokenizer(self) -> TokenizerLike: return tokenizer + def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer: + # Lazy initialization since offline LLM doesn't use async + if not hasattr(self, "_async_tokenizer"): + self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer()) + + return self._async_tokenizer + + # Step 1: Convert raw inputs to prompts + def render_completion( + self, + prompt_raw: str | list[int] | bytes, + ) -> TextPrompt | TokensPrompt | EmbedsPrompt: + error_msg = "Each prompt must be a string or an array of tokens" + + if isinstance(prompt_raw, str): + return TextPrompt(prompt=prompt_raw) + + if isinstance(prompt_raw, list): + if not is_list_of(prompt_raw, int): + raise TypeError(error_msg) + + return TokensPrompt(prompt_token_ids=prompt_raw) + + if isinstance(prompt_raw, bytes): + embeds = safe_load_prompt_embeds(self.config, prompt_raw) + return EmbedsPrompt(prompt_embeds=embeds) + + raise TypeError(error_msg) + + def render_completions( + self, + prompt_input: str | list[str] | list[int] | list[list[int]] | None = None, + prompt_embeds: bytes | list[bytes] | None = None, + ) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]: + prompts_raw = list[str | list[int] | bytes]() + + if prompt_embeds is not None: # embeds take higher priority + if isinstance(prompt_embeds, bytes): + prompts_raw.append(prompt_embeds) + else: + prompts_raw.extend(prompt_embeds) + + if prompt_input is not None: + if isinstance(prompt_input, str) or ( + len(prompt_input) > 0 and is_list_of(prompt_input, int) + ): + prompts_raw.append(prompt_input) # type: ignore[arg-type] + else: + prompts_raw.extend(prompt_input) # type: ignore[arg-type] + + if len(prompts_raw) == 0: + raise ValueError("You must pass at least one prompt") + + return [self.render_completion(prompt) for prompt in prompts_raw] + + async def render_completions_async( + self, + prompt_input: str | list[str] | list[int] | list[list[int]] | None = None, + prompt_embeds: bytes | list[bytes] | None = None, + ) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]: + return self.render_completions(prompt_input, prompt_embeds) + def render_messages( self, messages: list["ChatCompletionMessageParam"], - **kwargs, - ) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]: + params: ChatParams, + ) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]: raise NotImplementedError async def render_messages_async( self, messages: list["ChatCompletionMessageParam"], - **kwargs, - ) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]: - return self.render_messages(messages, **kwargs) + params: ChatParams, + ) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]: + return self.render_messages(messages, params) + + # Step 2: Tokenize prompts if necessary + def tokenize_prompt( + self, + prompt: TextPrompt | TokensPrompt | EmbedsPrompt, + params: TokenizeParams, + ) -> TokensPrompt | EmbedsPrompt: + if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt: + prompt = params.apply_pre_tokenization(self.tokenizer, prompt) + + tokenizer = self.get_tokenizer() + prompt_token_ids = tokenizer.encode( + prompt["prompt"], + **params.get_encode_kwargs(), + ) + + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt) + + if params.needs_detokenization and "prompt" not in prompt: + if "prompt_token_ids" not in prompt: + raise RuntimeError("Cannot run detokenization on embeddings") + + tokenizer = self.get_tokenizer() + prompt_text = tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item] + prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key] + + return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type] + + def tokenize_prompts( + self, + prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt], + params: TokenizeParams, + ) -> list[TokensPrompt | EmbedsPrompt]: + return [self.tokenize_prompt(prompt, params) for prompt in prompts] + + async def tokenize_prompt_async( + self, + prompt: TextPrompt | TokensPrompt | EmbedsPrompt, + params: TokenizeParams, + ) -> TokensPrompt | EmbedsPrompt: + if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt: + prompt = params.apply_pre_tokenization(self.tokenizer, prompt) + + tokenizer = self.get_async_tokenizer() + prompt_token_ids = await tokenizer.encode( + prompt["prompt"], + **params.get_encode_kwargs(), + ) + + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt) + + if params.needs_detokenization and "prompt" not in prompt: + if "prompt_token_ids" not in prompt: + raise RuntimeError("Cannot run detokenization on embeddings") + + tokenizer = self.get_async_tokenizer() + prompt_text = await tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item] + prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key] + + return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type] + + async def tokenize_prompts_async( + self, + prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt], + params: TokenizeParams, + ) -> list[TokensPrompt | EmbedsPrompt]: + return await asyncio.gather( + *(self.tokenize_prompt_async(prompt, params) for prompt in prompts) + ) diff --git a/vllm/renderers/terratorch.py b/vllm/renderers/terratorch.py index fc41a94c85b2..23f8eecc2e37 100644 --- a/vllm/renderers/terratorch.py +++ b/vllm/renderers/terratorch.py @@ -9,10 +9,11 @@ parse_chat_messages, parse_chat_messages_async, ) -from vllm.inputs import TextPrompt, TokensPrompt +from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from .params import ChatParams from .protocol import RendererLike logger = init_logger(__name__) @@ -45,8 +46,8 @@ def get_tokenizer(self) -> TokenizerLike: def render_messages( self, messages: list[ChatCompletionMessageParam], - **kwargs, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + params: ChatParams, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: model_config = self.config conversation, mm_data, mm_uuids = parse_chat_messages( @@ -55,7 +56,7 @@ def render_messages( content_format="string", ) - prompt = TokensPrompt(prompt_token_ids=[1]) + prompt = self.render_completion([1]) # Dummy token IDs if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: @@ -66,8 +67,8 @@ def render_messages( async def render_messages_async( self, messages: list[ChatCompletionMessageParam], - **kwargs, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + params: ChatParams, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: model_config = self.config conversation, mm_data, mm_uuids = await parse_chat_messages_async( @@ -76,7 +77,7 @@ async def render_messages_async( content_format="string", ) - prompt = TokensPrompt(prompt_token_ids=[1]) # Dummy token IDs + prompt = self.render_completion([1]) # Dummy token IDs if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 3ec0cc5d09dd..a3f63786c139 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -830,7 +830,7 @@ def parse_pooling_type(pooling_name: str): @cache def get_sentence_transformer_tokenizer_config( model: str | Path, revision: str | None = "main" -): +) -> dict[str, Any] | None: """ Returns the tokenization configuration dictionary for a given Sentence Transformer BERT model. diff --git a/vllm/utils/async_utils.py b/vllm/utils/async_utils.py index 77234cbd0c8c..f0336dc8ed0a 100644 --- a/vllm/utils/async_utils.py +++ b/vllm/utils/async_utils.py @@ -50,14 +50,17 @@ def __init__( self._executor = ThreadPoolExecutor(max_workers=1) # === Public async API === - async def __call__(self, prompt, **kwargs): + async def __call__(self, prompt, **kwargs) -> BatchEncoding: result_future: Future = self._loop.create_future() key = self._queue_key("encode", kwargs) queue = self._get_queue(self._loop, key) await queue.put((prompt, kwargs, result_future)) return await result_future - async def decode(self, token_ids, **kwargs): + async def encode(self, prompt, **kwargs) -> list[int]: + return (await self(prompt, **kwargs)).input_ids + + async def decode(self, token_ids, **kwargs) -> str: result_future: Future = self._loop.create_future() key = self._queue_key("decode", kwargs) queue = self._get_queue(self._loop, key) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 9f40f41a10a5..71ecd2065fc8 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -17,7 +17,6 @@ from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient -from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -25,7 +24,7 @@ from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams -from vllm.renderers import RendererLike +from vllm.renderers import RendererLike, merge_kwargs from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.tasks import SupportedTask from vllm.tokenizers import TokenizerLike @@ -316,13 +315,20 @@ async def add_request( "prompt logprobs" ) - if tokenization_kwargs is None: - tokenization_kwargs = {} - _validate_truncation_size( - self.model_config.max_model_len, - params.truncate_prompt_tokens, - tokenization_kwargs, - ) + if params.truncate_prompt_tokens is not None: + params_type = type(params).__name__ + warnings.warn( + f"The `truncate_prompt_tokens` parameter in `{params_type}` " + "is deprecated and will be removed in v0.16. " + "Please pass it via `tokenization_kwargs` instead.", + DeprecationWarning, + stacklevel=2, + ) + + tokenization_kwargs = merge_kwargs( + tokenization_kwargs, + dict(truncate_prompt_tokens=params.truncate_prompt_tokens), + ) if isinstance(prompt, AsyncGenerator): # Streaming input case. @@ -356,12 +362,12 @@ async def add_request( request_id, prompt, params, - arrival_time, - lora_request, - tokenization_kwargs, - trace_headers, - priority, - data_parallel_rank, + arrival_time=arrival_time, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + trace_headers=trace_headers, + priority=priority, + data_parallel_rank=data_parallel_rank, ) prompt_text = get_prompt_text(prompt) @@ -769,7 +775,6 @@ async def encode( lora_request: LoRARequest | None = None, trace_headers: Mapping[str, str] | None = None, priority: int = 0, - truncate_prompt_tokens: int | None = None, tokenization_kwargs: dict[str, Any] | None = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """ @@ -784,22 +789,10 @@ async def encode( The caller of generate() iterates the returned AsyncGenerator, returning the RequestOutput back to the caller. - - NOTE: truncate_prompt_tokens is deprecated in v0.14. - TODO: Remove truncate_prompt_tokens in v0.15. """ q: RequestOutputCollector | None = None try: - if truncate_prompt_tokens is not None: - warnings.warn( - "The `truncate_prompt_tokens` parameter in `AsyncLLM.encode()` " - "is deprecated and will be removed in v0.15. " - "Please use `pooling_params.truncate_prompt_tokens` instead.", - DeprecationWarning, - stacklevel=2, - ) - q = await self.add_request( request_id, prompt,