diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index d69d74ca61f5..831b76b66e09 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -620,7 +620,7 @@ def get_tokenizer( kwargs["use_fast"] = False if tokenizer_mode == "mistral": try: - from vllm.tokenizers import MistralTokenizer + from vllm.tokenizers.mistral import MistralTokenizer except ImportError as e: raise ImportError( "MistralTokenizer requires vllm package.\n" diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index b050cfdb561c..73e27a768f75 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -3,8 +3,9 @@ import pytest -from vllm.entrypoints.chat_utils import apply_hf_chat_template, load_chat_template +from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.openai.protocol import ChatCompletionRequest +from vllm.renderers.hf import safe_apply_chat_template from vllm.tokenizers import get_tokenizer from ...models.registry import HF_EXAMPLE_MODELS @@ -125,14 +126,15 @@ def test_get_gen_prompt( ) # Call the function and get the result - result = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=mock_request.messages, - chat_template=mock_request.chat_template or template_content, - renderer_config=renderer_config, + result = safe_apply_chat_template( + renderer_config, + tokenizer, + mock_request.messages, tools=None, + chat_template=mock_request.chat_template or template_content, add_generation_prompt=mock_request.add_generation_prompt, continue_final_message=mock_request.continue_final_message, + tokenize=False, ) # Test assertion diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index 6ab0942b58da..679113096130 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -10,7 +10,7 @@ from vllm.config import ModelConfig, RendererConfig from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer @pytest.fixture() diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 7b296eae7c5a..0e938492033f 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -29,7 +29,8 @@ encode_image_base64, encode_video_base64, ) -from vllm.tokenizers import MistralTokenizer, get_tokenizer +from vllm.tokenizers import get_tokenizer +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.serial_utils import tensor2base64 from ..models.registry import HF_EXAMPLE_MODELS diff --git a/tests/models/language/generation/test_mistral.py b/tests/models/language/generation/test_mistral.py index e2d6271e2fae..bc8bb05c284e 100644 --- a/tests/models/language/generation/test_mistral.py +++ b/tests/models/language/generation/test_mistral.py @@ -10,7 +10,7 @@ MistralToolParser, ) from vllm.sampling_params import SamplingParams -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer from ...utils import check_logprobs_close diff --git a/tests/models/multimodal/generation/test_voxtral.py b/tests/models/multimodal/generation/test_voxtral.py index 9e9087cb0fc4..0eaef49e2395 100644 --- a/tests/models/multimodal/generation/test_voxtral.py +++ b/tests/models/multimodal/generation/test_voxtral.py @@ -9,7 +9,7 @@ from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk from mistral_common.protocol.instruct.messages import UserMessage -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer from ....conftest import AudioTestAssets from ....utils import RemoteOpenAIServer diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 9b2b29b75876..324853e15b07 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -22,10 +22,8 @@ from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext -from vllm.tokenizers import ( - MistralTokenizer, - TokenizerLike, -) +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer from ....multimodal.utils import random_audio, random_image, random_video from ...registry import ( diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 48fd076ab3c6..fd5235266bec 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -7,7 +7,7 @@ from vllm.inputs import zip_enc_dec_prompts from vllm.inputs.parse import parse_raw_prompts from vllm.inputs.preprocess import InputPreprocessor -from vllm.tokenizers import init_tokenizer_from_config +from vllm.renderers import RendererRegistry pytestmark = pytest.mark.cpu_test @@ -109,10 +109,11 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): def test_preprocessor_always_mm_code_path(model_id, prompt): model_config = ModelConfig(model=model_id) renderer_config = RendererConfig(model_config=model_config) - tokenizer = init_tokenizer_from_config(renderer_config) - input_preprocessor = InputPreprocessor(renderer_config, tokenizer) + renderer = RendererRegistry.get_renderer(renderer_config) + input_preprocessor = InputPreprocessor(renderer_config, renderer) # HF processor adds sep token + tokenizer = renderer.get_tokenizer() sep_token_id = tokenizer.vocab[tokenizer.sep_token] processed_inputs = input_preprocessor.preprocess(prompt) diff --git a/tests/tokenizers_/test_registry.py b/tests/tokenizers_/test_registry.py index 7e795350d64c..f2001bc19a37 100644 --- a/tests/tokenizers_/test_registry.py +++ b/tests/tokenizers_/test_registry.py @@ -43,7 +43,7 @@ def is_fast(self) -> bool: def test_customized_tokenizer(): TokenizerRegistry.register("test_tokenizer", __name__, TestTokenizer.__name__) - tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer", "abc") + tokenizer = TokenizerRegistry.init_tokenizer("test_tokenizer", "abc") assert isinstance(tokenizer, TestTokenizer) assert tokenizer.path_or_repo_id == "abc" assert tokenizer.bos_token_id == 0 diff --git a/tests/v1/engine/test_process_multi_modal_uuids.py b/tests/v1/engine/test_process_multi_modal_uuids.py index 85fab3a855fd..ae77f922f1b6 100644 --- a/tests/v1/engine/test_process_multi_modal_uuids.py +++ b/tests/v1/engine/test_process_multi_modal_uuids.py @@ -13,6 +13,7 @@ RendererConfig, VllmConfig, ) +from vllm.renderers.terratorch import TerratorchRenderer from vllm.sampling_params import SamplingParams from vllm.v1.engine import input_processor as input_processor_mod from vllm.v1.engine.input_processor import InputProcessor @@ -59,7 +60,6 @@ def _mock_input_processor( renderer_config = RendererConfig( model_config=model_config, - tokenizer="dummy", skip_tokenizer_init=True, ) @@ -70,7 +70,7 @@ def _mock_input_processor( device_config=DeviceConfig(device="cpu"), ) - return InputProcessor(vllm_config, tokenizer=None) + return InputProcessor(vllm_config, renderer=TerratorchRenderer(renderer_config)) def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 7b60e7f89861..dc79acbffdc9 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -11,6 +11,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import IOProcessor from vllm.pooling_params import PoolingParams +from vllm.renderers import RendererLike from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask from vllm.tokenizers import TokenizerLike @@ -27,6 +28,10 @@ class EngineClient(ABC): input_processor: InputProcessor io_processor: IOProcessor | None + @property + @abstractmethod + def renderer(self) -> RendererLike: ... + @property @abstractmethod def is_running(self) -> bool: ... diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 5ad256c2f3eb..9141d95e3374 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -2,22 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -import inspect import json +import warnings from abc import ABC, abstractmethod -from collections import Counter, defaultdict, deque +from collections import Counter, defaultdict from collections.abc import Awaitable, Callable, Iterable from functools import cached_property, lru_cache, partial from pathlib import Path -from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast - -import jinja2 -import jinja2.ext -import jinja2.meta -import jinja2.nodes -import jinja2.parser -import jinja2.sandbox -import transformers.utils.chat_template_utils as hf_chat_utils +from typing import Generic, Literal, TypeAlias, TypeVar, cast + from openai.types.chat import ( ChatCompletionAssistantMessageParam, ChatCompletionContentPartImageParam, @@ -38,7 +31,6 @@ from openai_harmony import Message as OpenAIHarmonyMessage from PIL import Image from pydantic import BaseModel, ConfigDict, TypeAdapter -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin # pydantic needs the TypedDict from typing_extensions from typing_extensions import Required, TypedDict @@ -49,14 +41,26 @@ from vllm.model_executor.models import SupportsMultiModal from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector -from vllm.tokenizers import MistralTokenizer, TokenizerLike -from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path -from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import random_uuid -from vllm.utils.func_utils import supports_kw logger = init_logger(__name__) + +def __getattr__(name: str): + if name == "resolve_hf_chat_template": + from vllm.renderers.hf import resolve_chat_template + + warnings.warn( + "`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to " + "`vllm.renderers.hf.resolve_chat_template`. " + "The old name will be removed in v0.14.", + DeprecationWarning, + stacklevel=2, + ) + + return resolve_chat_template + + MODALITY_PLACEHOLDERS_MAP = { "image": "<##IMAGE##>", "audio": "<##AUDIO##>", @@ -295,329 +299,8 @@ class ConversationMessage(TypedDict, total=False): # Passed in by user ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] -# Used internally -_ChatTemplateContentFormat = Literal["string", "openai"] - - -def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: - if isinstance(node, jinja2.nodes.Name): - return node.ctx == "load" and node.name == varname - - return False - - -def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: - if isinstance(node, jinja2.nodes.Getitem): - return ( - _is_var_access(node.node, varname) - and isinstance(node.arg, jinja2.nodes.Const) - and node.arg.value == key - ) - - if isinstance(node, jinja2.nodes.Getattr): - return _is_var_access(node.node, varname) and node.attr == key - - return False - - -def _is_var_or_elems_access( - node: jinja2.nodes.Node, - varname: str, - key: str | None = None, -) -> bool: - if isinstance(node, jinja2.nodes.Filter): - return node.node is not None and _is_var_or_elems_access( - node.node, varname, key - ) - if isinstance(node, jinja2.nodes.Test): - return _is_var_or_elems_access(node.node, varname, key) - - if isinstance(node, jinja2.nodes.Getitem) and isinstance( - node.arg, jinja2.nodes.Slice - ): - return _is_var_or_elems_access(node.node, varname, key) - - return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) - - -def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): - # Global variable that is implicitly defined at the root - yield root, varname - - # Iterative BFS - related_varnames = deque([varname]) - while related_varnames: - related_varname = related_varnames.popleft() - - for assign_ast in root.find_all(jinja2.nodes.Assign): - lhs = assign_ast.target - rhs = assign_ast.node - - if _is_var_or_elems_access(rhs, related_varname): - assert isinstance(lhs, jinja2.nodes.Name) - yield assign_ast, lhs.name - - # Avoid infinite looping for self-assignment - if lhs.name != related_varname: - related_varnames.append(lhs.name) - - -# NOTE: The proper way to handle this is to build a CFG so that we can handle -# the scope in which each variable is defined, but that is too complicated -def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): - messages_varnames = [ - varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") - ] - - # Search for {%- for message in messages -%} loops - for loop_ast in root.find_all(jinja2.nodes.For): - loop_iter = loop_ast.iter - loop_target = loop_ast.target - - for varname in messages_varnames: - if _is_var_or_elems_access(loop_iter, varname): - assert isinstance(loop_target, jinja2.nodes.Name) - yield loop_ast, loop_target.name - break - - -def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): - message_varnames = [ - varname for _, varname in _iter_nodes_assign_messages_item(root) - ] - - # Search for {%- for content in message['content'] -%} loops - for loop_ast in root.find_all(jinja2.nodes.For): - loop_iter = loop_ast.iter - loop_target = loop_ast.target - - for varname in message_varnames: - if _is_var_or_elems_access(loop_iter, varname, "content"): - assert isinstance(loop_target, jinja2.nodes.Name) - yield loop_ast, loop_target.name - break - - -def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None: - try: - jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) - return jinja_compiled.environment.parse(chat_template) - except Exception: - logger.exception("Error when compiling Jinja template") - return None - - -@lru_cache(maxsize=32) -def _detect_content_format( - chat_template: str, - *, - default: _ChatTemplateContentFormat, -) -> _ChatTemplateContentFormat: - jinja_ast = _try_extract_ast(chat_template) - if jinja_ast is None: - return default - - try: - next(_iter_nodes_assign_content_item(jinja_ast)) - except StopIteration: - return "string" - except Exception: - logger.exception("Error when parsing AST of Jinja template") - return default - else: - return "openai" - - -def resolve_mistral_chat_template( - chat_template: str | None, - **kwargs: Any, -) -> str | None: - if chat_template is not None or kwargs.get("chat_template_kwargs") is not None: - raise ValueError( - "'chat_template' or 'chat_template_kwargs' cannot be overridden " - "for mistral tokenizer." - ) - - return None - - -_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]() -""" -Used in `_try_get_processor_chat_template` to avoid calling -`cached_get_processor` again if the processor fails to be loaded. - -This is needed because `lru_cache` does not cache when an exception happens. -""" - - -def _try_get_processor_chat_template( - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - *, - trust_remote_code: bool, -) -> str | None: - cache_key = (tokenizer.name_or_path, trust_remote_code) - if cache_key in _PROCESSOR_CHAT_TEMPLATES: - return _PROCESSOR_CHAT_TEMPLATES[cache_key] - - try: - processor = cached_get_processor( - tokenizer.name_or_path, - processor_cls=( - PreTrainedTokenizer, - PreTrainedTokenizerFast, - ProcessorMixin, - ), - trust_remote_code=trust_remote_code, - ) - if ( - isinstance(processor, ProcessorMixin) - and hasattr(processor, "chat_template") - and (chat_template := processor.chat_template) is not None - ): - _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template - return chat_template - except Exception: - logger.debug( - "Failed to load AutoProcessor chat template for %s", - tokenizer.name_or_path, - exc_info=True, - ) - - _PROCESSOR_CHAT_TEMPLATES[cache_key] = None - return None - - -def resolve_hf_chat_template( - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - chat_template: str | None, - tools: list[dict[str, Any]] | None, - *, - model_config: ModelConfig, -) -> str | None: - # 1st priority: The given chat template - if chat_template is not None: - return chat_template - - # 2nd priority: AutoProcessor chat template, unless tool calling is enabled - if tools is None: - chat_template = _try_get_processor_chat_template( - tokenizer, - trust_remote_code=model_config.trust_remote_code, - ) - if chat_template is not None: - return chat_template - - # 3rd priority: AutoTokenizer chat template - try: - return tokenizer.get_chat_template(chat_template, tools=tools) - except Exception: - logger.debug( - "Failed to load AutoTokenizer chat template for %s", - tokenizer.name_or_path, - exc_info=True, - ) - - # 4th priority: Predefined fallbacks] - path = get_chat_template_fallback_path( - model_type=model_config.hf_config.model_type, - tokenizer_name_or_path=tokenizer.name_or_path, - ) - if path is not None: - logger.info_once( - "Loading chat template fallback for %s as there isn't one " - "defined on HF Hub.", - tokenizer.name_or_path, - ) - chat_template = load_chat_template(path) - else: - logger.debug_once( - "There is no chat template fallback for %s", tokenizer.name_or_path - ) - - return chat_template - - -def _resolve_chat_template_content_format( - chat_template: str | None, - tools: list[dict[str, Any]] | None, - tokenizer: TokenizerLike | None, - *, - renderer_config: RendererConfig, -) -> _ChatTemplateContentFormat: - if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): - hf_chat_template = resolve_hf_chat_template( - tokenizer, - chat_template=chat_template, - tools=tools, - model_config=renderer_config.model_config, - ) - else: - hf_chat_template = None - - jinja_text = ( - hf_chat_template - if isinstance(hf_chat_template, str) - else load_chat_template(chat_template, is_literal=True) - ) - - detected_format = ( - "string" - if jinja_text is None - else _detect_content_format(jinja_text, default="string") - ) - - return detected_format - - -@lru_cache -def _log_chat_template_content_format( - chat_template: str | None, - given_format: ChatTemplateContentFormatOption, - detected_format: ChatTemplateContentFormatOption, -): - logger.info( - "Detected the chat template content format to be '%s'. " - "You can set `--chat-template-content-format` to override this.", - detected_format, - ) - - if given_format != "auto" and given_format != detected_format: - logger.warning( - "You specified `--chat-template-content-format %s` " - "which is different from the detected format '%s'. " - "If our automatic detection is incorrect, please consider " - "opening a GitHub issue so that we can improve it: " - "https://github.com/vllm-project/vllm/issues/new/choose", - given_format, - detected_format, - ) - - -def resolve_chat_template_content_format( - chat_template: str | None, - tools: list[dict[str, Any]] | None, - given_format: ChatTemplateContentFormatOption, - tokenizer: TokenizerLike | None, - *, - renderer_config: RendererConfig, -) -> _ChatTemplateContentFormat: - if given_format != "auto": - return given_format - - detected_format = _resolve_chat_template_content_format( - chat_template, - tools, - tokenizer, - renderer_config=renderer_config, - ) - - _log_chat_template_content_format( - chat_template, - given_format=given_format, - detected_format=detected_format, - ) - - return detected_format +# After resolving "auto" +ChatTemplateContentFormat = Literal["string", "openai"] ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"] @@ -1543,7 +1226,7 @@ def _parse_chat_message_content_part( def _parse_chat_message_content( message: ChatCompletionMessageParam, mm_tracker: BaseMultiModalItemTracker, - content_format: _ChatTemplateContentFormat, + content_format: ChatTemplateContentFormat, interleave_strings: bool, ) -> list[ConversationMessage]: role = message["role"] @@ -1612,7 +1295,7 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: def parse_chat_messages( messages: list[ChatCompletionMessageParam], renderer_config: RendererConfig, - content_format: _ChatTemplateContentFormat, + content_format: ChatTemplateContentFormat, ) -> tuple[ list[ConversationMessage], MultiModalDataDict | None, @@ -1645,7 +1328,7 @@ def parse_chat_messages( def parse_chat_messages_futures( messages: list[ChatCompletionMessageParam], renderer_config: RendererConfig, - content_format: _ChatTemplateContentFormat, + content_format: ChatTemplateContentFormat, ) -> tuple[ list[ConversationMessage], Awaitable[MultiModalDataDict | None], @@ -1675,173 +1358,6 @@ def parse_chat_messages_futures( return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() -# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412 -# only preserve the parse function used to resolve chat template kwargs -class AssistantTracker(jinja2.ext.Extension): - tags = {"generation"} - - def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: - lineno = next(parser.stream).lineno - body = parser.parse_statements(["name:endgeneration"], drop_needle=True) - call = self.call_method("_generation_support") - call_block = jinja2.nodes.CallBlock(call, [], [], body) - return call_block.set_lineno(lineno) - - -def _resolve_chat_template_kwargs( - chat_template: str, -): - env = jinja2.sandbox.ImmutableSandboxedEnvironment( - trim_blocks=True, - lstrip_blocks=True, - extensions=[AssistantTracker, jinja2.ext.loopcontrols], - ) - parsed_content = env.parse(chat_template) - template_vars = jinja2.meta.find_undeclared_variables(parsed_content) - return template_vars - - -_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs) - - -@lru_cache -def _get_hf_base_chat_template_params() -> frozenset[str]: - # Get standard parameters from HuggingFace's base tokenizer class. - # This dynamically extracts parameters from PreTrainedTokenizer's - # apply_chat_template method, ensuring compatibility with tokenizers - # that use **kwargs to receive standard parameters. - - # Read signature from HF's base class - the single source of truth - base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template) - # Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders - return frozenset( - p.name - for p in base_sig.parameters.values() - if p.kind - not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) - ) - - -def resolve_chat_template_kwargs( - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - chat_template: str, - chat_template_kwargs: dict[str, Any], - raise_on_unexpected: bool = True, -) -> dict[str, Any]: - # We exclude chat_template from kwargs here, because - # chat template has been already resolved at this stage - unexpected_vars = {"chat_template", "tokenize"} - if raise_on_unexpected and ( - unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys() - ): - raise ValueError( - "Found unexpected chat template kwargs from request: " - f"{unexpected_in_kwargs}" - ) - - fn_kw = { - k - for k in chat_template_kwargs - if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False) - } - template_vars = _cached_resolve_chat_template_kwargs(chat_template) - - # Allow standard HF parameters even if tokenizer uses **kwargs to receive them - hf_base_params = _get_hf_base_chat_template_params() - - accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars - return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars} - - -def apply_hf_chat_template( - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - conversation: list[ConversationMessage], - chat_template: str | None, - tools: list[dict[str, Any]] | None, - *, - renderer_config: RendererConfig, - **kwargs: Any, -) -> str: - hf_chat_template = resolve_hf_chat_template( - tokenizer, - chat_template=chat_template, - tools=tools, - model_config=renderer_config.model_config, - ) - - if hf_chat_template is None: - raise ValueError( - "As of transformers v4.44, default chat template is no longer " - "allowed, so you must provide a chat template if the tokenizer " - "does not define one." - ) - - resolved_kwargs = resolve_chat_template_kwargs( - tokenizer=tokenizer, - chat_template=hf_chat_template, - chat_template_kwargs=kwargs, - ) - - try: - return tokenizer.apply_chat_template( - conversation=conversation, # type: ignore[arg-type] - tools=tools, # type: ignore[arg-type] - chat_template=hf_chat_template, - tokenize=False, - **resolved_kwargs, - ) - - # External library exceptions can sometimes occur despite the framework's - # internal exception management capabilities. - except Exception as e: - # Log and report any library-related exceptions for further - # investigation. - logger.exception( - "An error occurred in `transformers` while applying chat template" - ) - raise ValueError(str(e)) from e - - -def apply_mistral_chat_template( - tokenizer: MistralTokenizer, - messages: list[ChatCompletionMessageParam], - chat_template: str | None, - tools: list[dict[str, Any]] | None, - **kwargs: Any, -) -> list[int]: - from mistral_common.exceptions import MistralCommonException - - # The return value of resolve_mistral_chat_template is always None, - # and we won't use it. - resolve_mistral_chat_template( - chat_template=chat_template, - **kwargs, - ) - - try: - return tokenizer.apply_chat_template( - messages=messages, - tools=tools, - **kwargs, - ) - # mistral-common uses assert statements to stop processing of input - # if input does not comply with the expected format. - # We convert those assertion errors to ValueErrors so they can be - # properly caught in the preprocessing_input step - except (AssertionError, MistralCommonException) as e: - raise ValueError(str(e)) from e - - # External library exceptions can sometimes occur despite the framework's - # internal exception management capabilities. - except Exception as e: - # Log and report any library-related exceptions for further - # investigation. - logger.exception( - "An error occurred in `mistral_common` while applying chat template" - ) - raise ValueError(str(e)) from e - - def get_history_tool_calls_cnt(conversation: list[ConversationMessage]): idx = 0 for msg in conversation: diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index a484a437c853..25aca81d619b 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -39,8 +39,8 @@ from vllm.entrypoints.tool_server import ToolServer from vllm.outputs import RequestOutput from vllm.reasoning.abs_reasoning_parsers import ReasoningParser -from vllm.tokenizers.protocol import TokenizerLike -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.renderers import RendererLike +from vllm.tokenizers import TokenizerLike from vllm.utils import random_uuid if TYPE_CHECKING: @@ -228,8 +228,8 @@ def __init__( self, *, response_messages: list[ResponseInputOutputItem], - tokenizer: AnyTokenizer, - reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser] | None, + renderer: RendererLike, + reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None, request: ResponsesRequest, available_tools: list[str] | None, tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, @@ -248,7 +248,7 @@ def __init__( raise ValueError("reasoning_parser_cls must be provided.") self.parser = get_responses_parser_for_simple_context( - tokenizer=tokenizer, + tokenizer=renderer.get_tokenizer(), reasoning_parser_cls=reasoning_parser_cls, response_messages=response_messages, request=request, @@ -256,7 +256,8 @@ def __init__( ) self.tool_parser_cls = tool_parser_cls self.request = request - self.tokenizer = tokenizer + self.renderer = renderer + self.tokenizer = renderer.get_tokenizer() self.available_tools = available_tools or [] self._tool_sessions: dict[str, ClientSession | Tool] = {} diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 6b3cb26afb62..d1e90ca769ea 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -35,10 +35,6 @@ from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, ChatTemplateContentFormatOption, - apply_hf_chat_template, - apply_mistral_chat_template, - parse_chat_messages, - resolve_chat_template_content_format, ) from vllm.entrypoints.score_utils import ( ScoreContentPartParam, @@ -71,7 +67,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.tasks import PoolingTask -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike from vllm.tokenizers.hf import get_cached_tokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils.collection_utils import as_iter, is_list_of @@ -787,7 +783,7 @@ def preprocess_chat( tools: list[dict[str, Any]] | None = None, chat_template_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None, - ) -> list[TokensPrompt]: + ) -> list[TextPrompt | TokensPrompt]: """ Generate prompt for a chat conversation. The pre-processed prompt can then be used as input for the other LLM methods. @@ -808,63 +804,27 @@ def preprocess_chat( # messages is list[...] list_of_messages = [cast(list[ChatCompletionMessageParam], messages)] - tokenizer = self.get_tokenizer() - renderer_config = self.renderer_config - resolved_content_format = resolve_chat_template_content_format( - chat_template, - tools, - chat_template_content_format, - tokenizer, - renderer_config=renderer_config, - ) + renderer = self.llm_engine.renderer - _chat_template_kwargs: dict[str, Any] = dict( - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - continue_final_message=continue_final_message, - tools=tools, - ) - _chat_template_kwargs.update(chat_template_kwargs or {}) + chat_template_kwargs = { + "chat_template": chat_template, + "add_generation_prompt": add_generation_prompt, + "continue_final_message": continue_final_message, + "tools": tools, + **(chat_template_kwargs or {}), + } - prompts: list[TokensPrompt] = [] + prompts = list[TextPrompt | TokensPrompt]() for msgs in list_of_messages: - # NOTE: _parse_chat_message_content_parts() currently doesn't + # NOTE: parse_mesrender_messagessages() currently doesn't # handle mm_processor_kwargs, since there is no implementation in # the chat message parsing for it. - conversation, mm_data, mm_uuids = parse_chat_messages( + _, prompt = renderer.render_messages( msgs, - renderer_config, - content_format=resolved_content_format, + chat_template_content_format=chat_template_content_format, + **chat_template_kwargs, ) - - if isinstance(tokenizer, MistralTokenizer): - prompt_token_ids = apply_mistral_chat_template( - tokenizer, - messages=msgs, - **_chat_template_kwargs, - ) - else: - prompt_str = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - renderer_config=renderer_config, - **_chat_template_kwargs, - ) - # Special tokens are already included in chat templates so - # should not be added by the tokenizer in this case. - prompt_token_ids = tokenizer.encode( - prompt_str, add_special_tokens=False - ) - - prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) - - if mm_data is not None: - prompt["multi_modal_data"] = mm_data - - if mm_uuids is not None: - prompt["multi_modal_uuids"] = mm_uuids - if mm_processor_kwargs is not None: prompt["mm_processor_kwargs"] = mm_processor_kwargs @@ -1294,9 +1254,6 @@ def _cross_encoding_score( renderer_config = self.renderer_config model_config = self.model_config - if isinstance(tokenizer, MistralTokenizer): - raise ValueError("Score API is not supported for Mistral tokenizer") - if len(data_1) == 1: data_1 = data_1 * len(data_2) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c6333d170c66..1476aa8f38ce 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -188,7 +188,8 @@ async def create_chat_completion( model_name = self.models.model_name(lora_request) - tokenizer = await self.engine_client.get_tokenizer() + renderer = self.engine_client.renderer + tokenizer = renderer.get_tokenizer() tool_parser = self.tool_parser @@ -236,7 +237,7 @@ async def create_chat_completion( engine_prompts, ) = await self._preprocess_chat( request, - tokenizer, + renderer, request.messages, chat_template=request.chat_template or self.chat_template, chat_template_content_format=self.chat_template_content_format, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index d887cf48d89f..7fd4cb714e88 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -6,7 +6,6 @@ import time import traceback from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from http import HTTPStatus from typing import Any, ClassVar, Generic, TypeAlias, TypeVar @@ -49,7 +48,6 @@ ScoreRequest, ScoreResponse, ) -from vllm.transformers_utils.tokenizer import AnyTokenizer if sys.version_info >= (3, 12): from typing import TypedDict @@ -67,10 +65,6 @@ ChatCompletionMessageParam, ChatTemplateContentFormatOption, ConversationMessage, - apply_hf_chat_template, - apply_mistral_chat_template, - parse_chat_messages_futures, - resolve_chat_template_content_format, ) from vllm.entrypoints.context import ConversationContext from vllm.entrypoints.logger import RequestLogger @@ -99,7 +93,7 @@ ) from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.utils import _validate_truncation_size -from vllm.inputs.data import PromptType +from vllm.inputs.data import PromptType, SingletonPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.parse import ( PromptComponents, @@ -109,15 +103,13 @@ from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest -from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin - MultiModalDataDict, - MultiModalUUIDDict, -) +from vllm.multimodal import MultiModalDataDict from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.reasoning import ReasoningParser, ReasoningParserManager +from vllm.renderers import RendererLike from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.tokenizers import DeepseekV32Tokenizer, MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike from vllm.tracing import ( contains_trace_headers, extract_trace_headers, @@ -127,10 +119,8 @@ from vllm.utils.async_utils import ( AsyncMicrobatchTokenizer, collect_from_async_generator, - make_async, merge_async_iterators, ) -from vllm.utils.collection_utils import is_list_of from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) @@ -183,7 +173,7 @@ class EmbedsPrompt(TypedDict): prompt_embeds: torch.Tensor -RequestPrompt: TypeAlias = list[int] | str | TextTokensPrompt | EmbedsPrompt +RequestPrompt: TypeAlias = list[int] | TextTokensPrompt | EmbedsPrompt | SingletonPrompt def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]: @@ -235,7 +225,6 @@ class ResponseGenerationMixin: @dataclass(kw_only=True) class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]): - # Shared across all requests request: RequestT raw_request: Request | None = None model_name: str @@ -243,9 +232,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[Requ created_time: int = field(default_factory=lambda: int(time.time())) lora_request: LoRARequest | None = None - # Shared across most requests - tokenizer: TokenizerLike | None = None - @dataclass(kw_only=True) class ClassificationServeContext(ServeContext[ClassificationRequest]): @@ -281,16 +267,13 @@ def __init__( self.request_logger = request_logger self.return_tokens_as_token_ids = return_tokens_as_token_ids - self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) - self._apply_mistral_chat_template_async = make_async( - apply_mistral_chat_template, executor=self._tokenizer_executor - ) self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {} self.log_error_stack = log_error_stack self.input_processor = self.models.input_processor self.io_processor = self.models.io_processor + self.renderer = self.models.renderer self.renderer_config = self.models.renderer_config self.model_config = self.models.model_config self.max_model_len = self.model_config.max_model_len @@ -1085,7 +1068,7 @@ def _validate_chat_template( async def _preprocess_chat( self, request: ChatLikeRequest | ResponsesRequest, - tokenizer: TokenizerLike | None, + renderer: RendererLike, messages: list[ChatCompletionMessageParam], chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, @@ -1101,56 +1084,33 @@ async def _preprocess_chat( Sequence[RequestPrompt], list[EngineTokensPrompt], ]: - renderer_config = self.renderer_config - - resolved_content_format = resolve_chat_template_content_format( - chat_template, - tool_dicts, - chat_template_content_format, - tokenizer, - renderer_config=renderer_config, - ) - conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( + chat_template_kwargs = { + "chat_template": chat_template, + "add_generation_prompt": add_generation_prompt, + "continue_final_message": continue_final_message, + "tools": tool_dicts, + "documents": documents, + **(chat_template_kwargs or {}), + } + + conversation, engine_prompt = await renderer.render_messages_async( messages, - renderer_config, - content_format=resolved_content_format, - ) - - _chat_template_kwargs: dict[str, Any] = dict( - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - continue_final_message=continue_final_message, - tools=tool_dicts, - documents=documents, + chat_template_content_format=chat_template_content_format, + **chat_template_kwargs, ) - _chat_template_kwargs.update(chat_template_kwargs or {}) - - request_prompt: str | list[int] - if tokenizer is None: - request_prompt = "placeholder" - elif isinstance(tokenizer, MistralTokenizer): - request_prompt = await self._apply_mistral_chat_template_async( - tokenizer, - messages=messages, - **_chat_template_kwargs, - ) - elif isinstance(tokenizer, DeepseekV32Tokenizer): - request_prompt = tokenizer.apply_chat_template( - conversation=conversation, - messages=messages, - model_config=renderer_config.model_config, - **_chat_template_kwargs, - ) - else: - request_prompt = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - renderer_config=renderer_config, - **_chat_template_kwargs, + if "prompt_token_ids" not in engine_prompt: + engine_prompt = await self._tokenize_prompt_input_async( + request, + renderer.get_tokenizer(), + engine_prompt["prompt"], + add_special_tokens=add_special_tokens, ) - mm_data = await mm_data_future + if request.mm_processor_kwargs is not None: + engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs + if (cache_salt := getattr(request, "cache_salt", None)) is not None: + engine_prompt["cache_salt"] = cache_salt # tool parsing is done only if a tool_parser has been set and if # tool_choice is not "none" (if tool_choice is "none" but a tool_parser @@ -1166,49 +1126,11 @@ async def _preprocess_chat( "or Responses API requests." ) raise NotImplementedError(msg) - request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore - - if tokenizer is None: - assert isinstance(request_prompt, str), ( - "Prompt has to be a string", - "when the tokenizer is not initialised", - ) - prompt_inputs = TextTokensPrompt( - prompt=request_prompt, prompt_token_ids=[1] - ) - elif isinstance(request_prompt, str): - prompt_inputs = await self._tokenize_prompt_input_async( - request, - tokenizer, - request_prompt, - add_special_tokens=add_special_tokens, - ) - else: - # For MistralTokenizer - assert is_list_of(request_prompt, int), ( - "Prompt has to be either a string or a list of token ids" - ) - prompt_inputs = TextTokensPrompt( - prompt=tokenizer.decode(request_prompt), - prompt_token_ids=request_prompt, - ) - - engine_prompt = EngineTokensPrompt( - prompt_token_ids=prompt_inputs["prompt_token_ids"] - ) - if mm_data is not None: - engine_prompt["multi_modal_data"] = mm_data - - if mm_uuids is not None: - engine_prompt["multi_modal_uuids"] = mm_uuids - if request.mm_processor_kwargs is not None: - engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs - - if hasattr(request, "cache_salt") and request.cache_salt is not None: - engine_prompt["cache_salt"] = request.cache_salt + tokenizer = renderer.get_tokenizer() + request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore - return conversation, [request_prompt], [engine_prompt] + return conversation, [engine_prompt], [engine_prompt] async def _process_inputs( self, @@ -1240,7 +1162,7 @@ async def _process_inputs( async def _render_next_turn( self, request: ResponsesRequest, - tokenizer: AnyTokenizer, + renderer: RendererLike, messages: list[ResponseInputOutputItem], tool_dicts: list[dict[str, Any]] | None, tool_parser, @@ -1253,7 +1175,7 @@ async def _render_next_turn( _, request_prompts, engine_prompts = await self._preprocess_chat( request, - tokenizer, + renderer, new_messages, tool_dicts=tool_dicts, tool_parser=tool_parser, @@ -1331,7 +1253,7 @@ async def _generate_with_builtin_tools( elif isinstance(context, ParsableContext): request_prompts, engine_prompts = await self._render_next_turn( context.request, - context.tokenizer, + context.renderer, context.parser.response_messages, context.tool_dicts, context.tool_parser_cls, diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index ec65e659383d..dd35b3cd6c5c 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -71,6 +71,7 @@ def __init__( self.input_processor = self.engine_client.input_processor self.io_processor = self.engine_client.io_processor + self.renderer = self.engine_client.renderer self.renderer_config = self.engine_client.renderer_config self.model_config = self.engine_client.model_config self.max_model_len = self.model_config.max_model_len diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 91616a78e11d..1bb0c76f8e8e 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -108,6 +108,7 @@ from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import SampleLogprobs from vllm.outputs import CompletionOutput +from vllm.renderers import RendererLike from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.tokenizers import TokenizerLike from vllm.utils import random_uuid @@ -346,7 +347,8 @@ async def create_responses( try: lora_request = self._maybe_get_adapters(request) model_name = self.models.model_name(lora_request) - tokenizer = await self.engine_client.get_tokenizer() + renderer = self.engine_client.renderer + tokenizer = renderer.get_tokenizer() if self.use_harmony: messages, request_prompts, engine_prompts = ( @@ -354,7 +356,7 @@ async def create_responses( ) else: messages, request_prompts, engine_prompts = await self._make_request( - request, prev_response, tokenizer + request, prev_response, renderer ) except ( @@ -420,7 +422,7 @@ async def create_responses( # tokens during generation instead of at the end context = ParsableContext( response_messages=messages, - tokenizer=tokenizer, + renderer=renderer, reasoning_parser_cls=self.reasoning_parser, request=request, tool_parser_cls=self.tool_parser, @@ -548,7 +550,7 @@ async def _make_request( self, request: ResponsesRequest, prev_response: ResponsesResponse | None, - tokenizer: TokenizerLike, + renderer: RendererLike, ): tool_dicts = construct_tool_dicts(request.tools, request.tool_choice) # Construct the input messages. @@ -560,7 +562,7 @@ async def _make_request( ) _, request_prompts, engine_prompts = await self._preprocess_chat( request, - tokenizer, + renderer, messages, tool_dicts=tool_dicts, tool_parser=self.tool_parser, diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 19c1c83268ed..14cf2f38b70c 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -22,7 +22,8 @@ ToolParser, ) from vllm.logger import init_logger -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 4655da8dd454..92b09917c252 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -21,7 +21,8 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index aa5089ffe84d..54a90cb0f7b0 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -6,6 +6,7 @@ from enum import Enum, auto from random import choices from string import ascii_letters, digits +from typing import Any import ijson import regex as re @@ -24,7 +25,8 @@ ToolParser, ) from vllm.logger import init_logger -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer logger = init_logger(__name__) @@ -112,6 +114,8 @@ def __init__(self, tokenizer: TokenizerLike): "the tokenizer!" ) + self.prev_tool_call_arr: list[dict[str, Any]] + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: request = super().adjust_request(request) if ( diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py index d6d3825daf7b..377d037b2c10 100644 --- a/vllm/entrypoints/pooling/classify/serving.py +++ b/vllm/entrypoints/pooling/classify/serving.py @@ -52,8 +52,6 @@ async def _preprocess( """ ctx = cast(ClassificationServeContext, ctx) try: - ctx.tokenizer = await self.engine_client.get_tokenizer() - request_obj = ctx.request if isinstance(request_obj, ClassificationChatRequest): @@ -78,7 +76,7 @@ async def _preprocess( engine_prompts, ) = await self._preprocess_chat( cast(ChatCompletionRequest, chat_request), - ctx.tokenizer, + self.renderer, messages, chat_template=( chat_request.chat_template @@ -106,7 +104,7 @@ async def _preprocess( ctx.engine_prompts = [] return None - renderer = self._get_renderer(ctx.tokenizer) + renderer = self._get_renderer(self.renderer.tokenizer) prompt_input = cast(str | list[str], input_data) ctx.engine_prompts = await renderer.render_prompt( prompt_or_prompts=prompt_input, diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index 868a3cb017a6..6675b1fc8793 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -79,9 +79,6 @@ async def _preprocess( try: ctx.lora_request = self._maybe_get_adapters(ctx.request) - tokenizer = await self.engine_client.get_tokenizer() - renderer = self._get_renderer(tokenizer) - if isinstance(ctx.request, EmbeddingChatRequest): ( _, @@ -89,7 +86,7 @@ async def _preprocess( ctx.engine_prompts, ) = await self._preprocess_chat( ctx.request, - tokenizer, + self.renderer, ctx.request.messages, chat_template=ctx.request.chat_template or ctx.chat_template, chat_template_content_format=ctx.chat_template_content_format, @@ -98,6 +95,8 @@ async def _preprocess( add_special_tokens=ctx.request.add_special_tokens, ) else: + tokenizer = await self.engine_client.get_tokenizer() + renderer = self._get_renderer(tokenizer) ctx.engine_prompts = await renderer.render_prompt( prompt_or_prompts=ctx.request.input, config=self._build_render_config(ctx.request), diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index cd28ccba9ef9..004b1e9d93c5 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -94,12 +94,6 @@ async def create_pooling( try: lora_request = self._maybe_get_adapters(request) - if self.renderer_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = await self.engine_client.get_tokenizer() - renderer = self._get_renderer(tokenizer) - if getattr(request, "dimensions", None) is not None: return self.create_error_response( "dimensions is currently not supported" @@ -143,7 +137,7 @@ async def create_pooling( engine_prompts, ) = await self._preprocess_chat( request, - tokenizer, + self.renderer, request.messages, chat_template=request.chat_template or self.chat_template, chat_template_content_format=self.chat_template_content_format, @@ -154,6 +148,11 @@ async def create_pooling( add_special_tokens=request.add_special_tokens, ) elif isinstance(request, PoolingCompletionRequest): + if self.renderer_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = await self.engine_client.get_tokenizer() + renderer = self._get_renderer(tokenizer) engine_prompts = await renderer.render_prompt( prompt_or_prompts=request.input, config=self._build_render_config(request), diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index f657fcefd3a8..d08ec32f74b3 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -3,6 +3,7 @@ import asyncio import time from collections.abc import AsyncGenerator, Mapping +from concurrent.futures import ThreadPoolExecutor from typing import Any from fastapi import Request @@ -38,7 +39,8 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.async_utils import make_async, merge_async_iterators logger = init_logger(__name__) @@ -60,6 +62,8 @@ def __init__( log_error_stack=log_error_stack, ) + self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) + async def _embedding_score( self, tokenizer: TokenizerLike, diff --git a/vllm/entrypoints/serve/tokenize/serving.py b/vllm/entrypoints/serve/tokenize/serving.py index 979da02d1450..a1293c82c355 100644 --- a/vllm/entrypoints/serve/tokenize/serving.py +++ b/vllm/entrypoints/serve/tokenize/serving.py @@ -64,9 +64,6 @@ async def create_tokenize( try: lora_request = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer() - renderer = self._get_renderer(tokenizer) - if isinstance(request, TokenizeChatRequest): tool_dicts = ( None @@ -86,7 +83,7 @@ async def create_tokenize( engine_prompts, ) = await self._preprocess_chat( request, - tokenizer, + self.renderer, request.messages, tool_dicts=tool_dicts, chat_template=request.chat_template or self.chat_template, @@ -97,6 +94,8 @@ async def create_tokenize( add_special_tokens=request.add_special_tokens, ) else: + tokenizer = await self.engine_client.get_tokenizer() + renderer = self._get_renderer(tokenizer) engine_prompts = await renderer.render_prompt( prompt_or_prompts=request.prompt, config=self._build_render_config(request), diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index a81f73ac9e61..b4a697762560 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -30,7 +30,7 @@ from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.argparse_utils import FlexibleArgumentParser logger = init_logger(__name__) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index f534d102fc3b..7ab6fd832a6b 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -17,6 +17,7 @@ MultiModalUUIDDict, ) from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.renderers import RendererLike from vllm.tokenizers import TokenizerLike from vllm.utils.jsontree import json_iter_leaves from vllm.v1.metrics.stats import MultiModalCacheStats @@ -46,7 +47,7 @@ class InputPreprocessor: def __init__( self, renderer_config: RendererConfig, - tokenizer: TokenizerLike | None, + renderer: RendererLike, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_processor_cache: BaseMultiModalProcessorCache | None = None, ) -> None: @@ -54,19 +55,22 @@ def __init__( self.renderer_config = renderer_config self.model_config = renderer_config.model_config - self.tokenizer = tokenizer + self.renderer = renderer self.mm_registry = mm_registry self.mm_processor_cache = mm_processor_cache self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None - def get_tokenizer(self) -> TokenizerLike: - if self.tokenizer is None: - raise ValueError( - "You cannot pass text prompts when `skip_tokenizer_init=True`" - ) + @property + def tokenizer(self) -> TokenizerLike | None: + return self.renderer.tokenizer + + @tokenizer.setter + def tokenizer(self, tokenizer: TokenizerLike | None) -> None: + self.renderer.tokenizer = tokenizer - return self.tokenizer + def get_tokenizer(self) -> TokenizerLike: + return self.renderer.get_tokenizer() def get_bos_token_id(self) -> int | None: if self.tokenizer is None: diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index ebe743fa82a0..27acf67087a4 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -59,7 +59,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.tokenizers import MistralTokenizer, cached_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 0acd564e2e54..e9a564d04ade 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -51,7 +51,8 @@ ) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.tokenizers import MistralTokenizer, cached_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config +from vllm.tokenizers.mistral import MistralTokenizer from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription from .utils import init_vllm_registered_model, maybe_prefix diff --git a/vllm/reasoning/mistral_reasoning_parser.py b/vllm/reasoning/mistral_reasoning_parser.py index b61e50c188f8..1eb4ead69916 100644 --- a/vllm/reasoning/mistral_reasoning_parser.py +++ b/vllm/reasoning/mistral_reasoning_parser.py @@ -6,7 +6,7 @@ from vllm.logger import init_logger from vllm.reasoning import ReasoningParser from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer logger = init_logger(__name__) diff --git a/vllm/renderers/__init__.py b/vllm/renderers/__init__.py new file mode 100644 index 000000000000..cd6a11dcc833 --- /dev/null +++ b/vllm/renderers/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .protocol import RendererLike +from .registry import RendererRegistry, renderer_from_config + +__all__ = ["RendererLike", "RendererRegistry", "renderer_from_config"] diff --git a/vllm/renderers/deepseekv32.py b/vllm/renderers/deepseekv32.py new file mode 100644 index 000000000000..ba8123c0b70e --- /dev/null +++ b/vllm/renderers/deepseekv32.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +from vllm.config import RendererConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ConversationMessage, + parse_chat_messages, + parse_chat_messages_futures, +) +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.logger import init_logger +from vllm.tokenizers import get_tokenizer +from vllm.tokenizers.deepseekv32 import DeepseekV32Tokenizer + +from .protocol import RendererLike + +logger = init_logger(__name__) + + +class DeepseekV32Renderer(RendererLike): + @classmethod + def from_config( + cls, + config: RendererConfig, + tokenizer_kwargs: dict[str, Any], + ) -> "RendererLike": + return cls(config, tokenizer_kwargs) + + def __init__( + self, + config: RendererConfig, + tokenizer_kwargs: dict[str, Any], + ) -> None: + super().__init__() + + self.config = config + + if config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = get_tokenizer(DeepseekV32Tokenizer, **tokenizer_kwargs) + + self._tokenizer = tokenizer + + @property + def tokenizer(self) -> DeepseekV32Tokenizer | None: + return self._tokenizer + + # NOTE: Remove this once LLM.tokenizer.setter is removed + @tokenizer.setter + def tokenizer(self, tokenizer: DeepseekV32Tokenizer | None) -> None: + self._tokenizer = tokenizer + + def get_tokenizer(self) -> DeepseekV32Tokenizer: + tokenizer = self.tokenizer + if tokenizer is None: + raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`") + + return tokenizer + + def render_messages( + self, + messages: list[ChatCompletionMessageParam], + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + tokenizer = self.get_tokenizer() + conversation, mm_data, mm_uuids = parse_chat_messages( + messages, + self.config, + content_format="string", + ) + + prompt_raw = tokenizer.apply_chat_template( + conversation=conversation, + messages=messages, + **kwargs, + ) + if isinstance(prompt_raw, str): + prompt = TextPrompt(prompt=prompt_raw) + else: + prompt = TokensPrompt(prompt_token_ids=prompt_raw) + + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt + + async def render_messages_async( + self, + messages: list[ChatCompletionMessageParam], + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + tokenizer = self.get_tokenizer() + conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( + messages, + self.config, + content_format="string", + ) + + prompt_raw = tokenizer.apply_chat_template( + conversation=conversation, + messages=messages, + **kwargs, + ) + if isinstance(prompt_raw, str): + prompt = TextPrompt(prompt=prompt_raw) + else: + prompt = TokensPrompt(prompt_token_ids=prompt_raw) + + if mm_data_future is not None: + prompt["multi_modal_data"] = await mm_data_future + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt diff --git a/vllm/renderers/hf.py b/vllm/renderers/hf.py new file mode 100644 index 000000000000..d1990f346592 --- /dev/null +++ b/vllm/renderers/hf.py @@ -0,0 +1,598 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import inspect +from collections import deque +from collections.abc import Set +from functools import lru_cache +from typing import Any + +import jinja2 +import jinja2.ext +import jinja2.meta +import jinja2.nodes +import jinja2.parser +import jinja2.sandbox + +from vllm.config import ModelConfig, RendererConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormat, + ChatTemplateContentFormatOption, + ConversationMessage, + load_chat_template, + parse_chat_messages, + parse_chat_messages_futures, +) +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.logger import init_logger +from vllm.tokenizers import get_tokenizer +from vllm.tokenizers.hf import HfTokenizer +from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path +from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils.func_utils import supports_kw + +from .protocol import RendererLike + +logger = init_logger(__name__) + + +_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]() +""" +Used in `_try_get_processor_chat_template` to avoid calling +`cached_get_processor` again if the processor fails to be loaded. + +This is needed because `lru_cache` does not cache when an exception happens. +""" + + +def _try_get_processor_chat_template( + tokenizer: HfTokenizer, + *, + trust_remote_code: bool, +) -> str | None: + cache_key = (tokenizer.name_or_path, trust_remote_code) + if cache_key in _PROCESSOR_CHAT_TEMPLATES: + return _PROCESSOR_CHAT_TEMPLATES[cache_key] + + from transformers import ( + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, + ) + + try: + processor = cached_get_processor( + tokenizer.name_or_path, + processor_cls=( + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, + ), + trust_remote_code=trust_remote_code, + ) + if ( + isinstance(processor, ProcessorMixin) + and hasattr(processor, "chat_template") + and (chat_template := processor.chat_template) is not None + ): + _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template + return chat_template + except Exception: + logger.debug( + "Failed to load AutoProcessor chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) + + _PROCESSOR_CHAT_TEMPLATES[cache_key] = None + return None + + +def resolve_chat_template( + tokenizer: HfTokenizer, + chat_template: str | None, + tools: list[dict[str, Any]] | None, + *, + model_config: "ModelConfig", +) -> str | None: + # 1st priority: The given chat template + if chat_template is not None: + return chat_template + + # 2nd priority: AutoProcessor chat template, unless tool calling is enabled + if tools is None: + chat_template = _try_get_processor_chat_template( + tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + if chat_template is not None: + return chat_template + + # 3rd priority: AutoTokenizer chat template + try: + return tokenizer.get_chat_template(chat_template, tools=tools) + except Exception: + logger.debug( + "Failed to load AutoTokenizer chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) + + # 4th priority: Predefined fallbacks + path = get_chat_template_fallback_path( + model_type=model_config.hf_config.model_type, + tokenizer_name_or_path=tokenizer.name_or_path, + ) + if path is not None: + logger.info_once( + "Loading chat template fallback for %s as there isn't one " + "defined on HF Hub.", + tokenizer.name_or_path, + ) + chat_template = load_chat_template(path) + else: + logger.debug_once( + "There is no chat template fallback for %s", tokenizer.name_or_path + ) + + return chat_template + + +def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: + if isinstance(node, jinja2.nodes.Name): + return node.ctx == "load" and node.name == varname + + return False + + +def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: + if isinstance(node, jinja2.nodes.Getitem): + return ( + _is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key + ) + + if isinstance(node, jinja2.nodes.Getattr): + return _is_var_access(node.node, varname) and node.attr == key + + return False + + +def _is_var_or_elems_access( + node: jinja2.nodes.Node, + varname: str, + key: str | None = None, +) -> bool: + if isinstance(node, jinja2.nodes.Filter): + return node.node is not None and _is_var_or_elems_access( + node.node, varname, key + ) + if isinstance(node, jinja2.nodes.Test): + return _is_var_or_elems_access(node.node, varname, key) + + if isinstance(node, jinja2.nodes.Getitem) and isinstance( + node.arg, jinja2.nodes.Slice + ): + return _is_var_or_elems_access(node.node, varname, key) + + return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) + + +def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): + # Global variable that is implicitly defined at the root + yield root, varname + + # Iterative BFS + related_varnames = deque([varname]) + while related_varnames: + related_varname = related_varnames.popleft() + + for assign_ast in root.find_all(jinja2.nodes.Assign): + lhs = assign_ast.target + rhs = assign_ast.node + + if _is_var_or_elems_access(rhs, related_varname): + assert isinstance(lhs, jinja2.nodes.Name) + yield assign_ast, lhs.name + + # Avoid infinite looping for self-assignment + if lhs.name != related_varname: + related_varnames.append(lhs.name) + + +# NOTE: The proper way to handle this is to build a CFG so that we can handle +# the scope in which each variable is defined, but that is too complicated +def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): + messages_varnames = [ + varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") + ] + + # Search for {%- for message in messages -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in messages_varnames: + if _is_var_or_elems_access(loop_iter, varname): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): + message_varnames = [ + varname for _, varname in _iter_nodes_assign_messages_item(root) + ] + + # Search for {%- for content in message['content'] -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in message_varnames: + if _is_var_or_elems_access(loop_iter, varname, "content"): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None: + import transformers.utils.chat_template_utils as hf_chat_utils + + try: + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) + return jinja_compiled.environment.parse(chat_template) + except Exception: + logger.exception("Error when compiling Jinja template") + return None + + +@lru_cache(maxsize=32) +def _detect_content_format( + chat_template: str, + *, + default: ChatTemplateContentFormat, +) -> ChatTemplateContentFormat: + jinja_ast = _try_extract_ast(chat_template) + if jinja_ast is None: + return default + + try: + next(_iter_nodes_assign_content_item(jinja_ast)) + except StopIteration: + return "string" + except Exception: + logger.exception("Error when parsing AST of Jinja template") + return default + else: + return "openai" + + +def _resolve_chat_template_content_format( + chat_template: str | None, + tools: list[dict[str, Any]] | None, + tokenizer: HfTokenizer, + *, + model_config: "ModelConfig", +) -> ChatTemplateContentFormat: + resolved_chat_template = resolve_chat_template( + tokenizer, + chat_template=chat_template, + tools=tools, + model_config=model_config, + ) + + jinja_text = ( + resolved_chat_template + if isinstance(resolved_chat_template, str) + else load_chat_template(chat_template, is_literal=True) + ) + + detected_format = ( + "string" + if jinja_text is None + else _detect_content_format(jinja_text, default="string") + ) + + return detected_format + + +@lru_cache +def _log_chat_template_content_format( + chat_template: str | None, # For caching purposes + given_format: ChatTemplateContentFormatOption, + detected_format: ChatTemplateContentFormatOption, +): + logger.info( + "Detected the chat template content format to be '%s'. " + "You can set `--chat-template-content-format` to override this.", + detected_format, + ) + + if given_format != "auto" and given_format != detected_format: + logger.warning( + "You specified `--chat-template-content-format %s` " + "which is different from the detected format '%s'. " + "If our automatic detection is incorrect, please consider " + "opening a GitHub issue so that we can improve it: " + "https://github.com/vllm-project/vllm/issues/new/choose", + given_format, + detected_format, + ) + + +def resolve_chat_template_content_format( + chat_template: str | None, + tools: list[dict[str, Any]] | None, + given_format: ChatTemplateContentFormatOption, + tokenizer: HfTokenizer, + *, + model_config: "ModelConfig", +) -> ChatTemplateContentFormat: + if given_format != "auto": + return given_format + + detected_format = _resolve_chat_template_content_format( + chat_template, + tools, + tokenizer, + model_config=model_config, + ) + + _log_chat_template_content_format( + chat_template, + given_format=given_format, + detected_format=detected_format, + ) + + return detected_format + + +# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412 +# only preserve the parse function used to resolve chat template kwargs +class AssistantTracker(jinja2.ext.Extension): + tags = {"generation"} + + def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.Node: + lineno = next(parser.stream).lineno + body = parser.parse_statements(("name:endgeneration",), drop_needle=True) + call = self.call_method("_generation_support") + call_block = jinja2.nodes.CallBlock(call, [], [], body) + return call_block.set_lineno(lineno) + + +def _resolve_chat_template_kwargs(chat_template: str) -> Set[str]: + env = jinja2.sandbox.ImmutableSandboxedEnvironment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[AssistantTracker, jinja2.ext.loopcontrols], + ) + parsed_content = env.parse(chat_template) + template_vars = jinja2.meta.find_undeclared_variables(parsed_content) + return template_vars + + +_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs) + + +@lru_cache +def _get_hf_base_chat_template_params() -> frozenset[str]: + from transformers import PreTrainedTokenizer + + # Get standard parameters from HuggingFace's base tokenizer class. + # This dynamically extracts parameters from PreTrainedTokenizer's + # apply_chat_template method, ensuring compatibility with tokenizers + # that use **kwargs to receive standard parameters. + + # Read signature from HF's base class - the single source of truth + base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template) + + # Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders + return frozenset( + p.name + for p in base_sig.parameters.values() + if p.kind + not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) + ) + + +def resolve_chat_template_kwargs( + tokenizer: HfTokenizer, + chat_template: str, + chat_template_kwargs: dict[str, Any], + raise_on_unexpected: bool = True, +) -> dict[str, Any]: + # We exclude chat_template from kwargs here, because + # chat template has been already resolved at this stage + unexpected_vars = {"chat_template", "tokenize"} + if raise_on_unexpected and ( + unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys() + ): + raise ValueError( + "Found unexpected chat template kwargs from request: " + f"{unexpected_in_kwargs}" + ) + + fn_kw = { + k + for k in chat_template_kwargs + if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False) + } + template_vars = _cached_resolve_chat_template_kwargs(chat_template) + + # Allow standard HF parameters even if tokenizer uses **kwargs to receive them + hf_base_params = _get_hf_base_chat_template_params() + + accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars + return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars} + + +def safe_apply_chat_template( + model_config: "ModelConfig", + tokenizer: HfTokenizer, + conversation: list[ConversationMessage], + *, + tools: list[dict[str, Any]] | None = None, + chat_template: str | None = None, + tokenize: bool = True, + **kwargs, +) -> str | list[int]: + chat_template = resolve_chat_template( + tokenizer, + chat_template=chat_template, + tools=tools, + model_config=model_config, + ) + if chat_template is None: + raise ValueError( + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one." + ) + + resolved_kwargs = resolve_chat_template_kwargs( + tokenizer=tokenizer, + chat_template=chat_template, + chat_template_kwargs=kwargs, + ) + + try: + return tokenizer.apply_chat_template( + conversation=conversation, # type: ignore[arg-type] + tools=tools, # type: ignore[arg-type] + chat_template=chat_template, + tokenize=tokenize, + **resolved_kwargs, + ) + # External library exceptions can sometimes occur despite the framework's + # internal exception management capabilities. + except Exception as e: + # Log and report any library-related exceptions for further + # investigation. + logger.exception( + "An error occurred in `transformers` while applying chat template" + ) + raise ValueError(str(e)) from e + + +class HfRenderer(RendererLike): + @classmethod + def from_config( + cls, + config: RendererConfig, + tokenizer_kwargs: dict[str, Any], + ) -> "RendererLike": + return cls(config, tokenizer_kwargs) + + def __init__( + self, + config: RendererConfig, + tokenizer_kwargs: dict[str, Any], + ) -> None: + super().__init__() + + self.config = config + + if config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = get_tokenizer(HfTokenizer, **tokenizer_kwargs) + + self._tokenizer = tokenizer + + @property + def tokenizer(self) -> HfTokenizer | None: + return self._tokenizer + + # NOTE: Remove this once LLM.tokenizer.setter is removed + @tokenizer.setter + def tokenizer(self, tokenizer: HfTokenizer | None) -> None: + self._tokenizer = tokenizer + + def get_tokenizer(self) -> HfTokenizer: + tokenizer = self.tokenizer + if tokenizer is None: + raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`") + + return tokenizer + + def render_messages( + self, + messages: list[ChatCompletionMessageParam], + chat_template_content_format: ChatTemplateContentFormat, + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + renderer_config = self.config + model_config = renderer_config.model_config + tokenizer = self.get_tokenizer() + + conversation, mm_data, mm_uuids = parse_chat_messages( + messages, + renderer_config, + content_format=resolve_chat_template_content_format( + chat_template=kwargs.get("chat_template"), + tools=kwargs.get("tools"), + given_format=chat_template_content_format, + tokenizer=tokenizer, + model_config=model_config, + ), + ) + + prompt_raw = safe_apply_chat_template( + model_config, + tokenizer, + conversation, + **kwargs, + ) + if isinstance(prompt_raw, str): + prompt = TextPrompt(prompt=prompt_raw) + else: + prompt = TokensPrompt(prompt_token_ids=prompt_raw) + + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt + + async def render_messages_async( + self, + messages: list[ChatCompletionMessageParam], + chat_template_content_format: ChatTemplateContentFormat, + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + renderer_config = self.config + model_config = renderer_config.model_config + tokenizer = self.get_tokenizer() + + conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( + messages, + renderer_config, + content_format=resolve_chat_template_content_format( + chat_template=kwargs.get("chat_template"), + tools=kwargs.get("tools"), + given_format=chat_template_content_format, + tokenizer=tokenizer, + model_config=model_config, + ), + ) + + prompt_raw = safe_apply_chat_template( + model_config, + tokenizer, + conversation, + **kwargs, + ) + if isinstance(prompt_raw, str): + prompt = TextPrompt(prompt=prompt_raw) + else: + prompt = TokensPrompt(prompt_token_ids=prompt_raw) + + if mm_data_future is not None: + prompt["multi_modal_data"] = await mm_data_future + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt diff --git a/vllm/renderers/mistral.py b/vllm/renderers/mistral.py new file mode 100644 index 000000000000..3f91d76283fb --- /dev/null +++ b/vllm/renderers/mistral.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +from vllm.config import RendererConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ConversationMessage, + parse_chat_messages, + parse_chat_messages_futures, +) +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.logger import init_logger +from vllm.tokenizers import get_tokenizer +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.utils.async_utils import make_async + +from .protocol import RendererLike + +logger = init_logger(__name__) + + +def safe_apply_chat_template( + tokenizer: MistralTokenizer, + messages: list[ChatCompletionMessageParam], + **kwargs, +) -> str | list[int]: + from mistral_common.exceptions import MistralCommonException + + try: + return tokenizer.apply_chat_template(messages, **kwargs) + # mistral-common uses assert statements to stop processing of input + # if input does not comply with the expected format. + # We convert those assertion errors to ValueErrors so they can be + # properly caught in the preprocessing_input step + except (AssertionError, MistralCommonException) as e: + raise ValueError(str(e)) from e + + # External library exceptions can sometimes occur despite the framework's + # internal exception management capabilities. + except Exception as e: + # Log and report any library-related exceptions for further + # investigation. + logger.exception( + "An error occurred in `mistral_common` while applying chat template" + ) + raise ValueError(str(e)) from e + + +class MistralRenderer(RendererLike): + @classmethod + def from_config( + cls, + config: RendererConfig, + tokenizer_kwargs: dict[str, Any], + ) -> "RendererLike": + return cls(config, tokenizer_kwargs) + + def __init__( + self, + config: RendererConfig, + tokenizer_kwargs: dict[str, Any], + ) -> None: + super().__init__() + + self.config = config + + if config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = get_tokenizer(MistralTokenizer, **tokenizer_kwargs) + + self._tokenizer = tokenizer + + self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1) + self._apply_chat_template_async = make_async( + safe_apply_chat_template, executor=self._apply_chat_template_executor + ) + + @property + def tokenizer(self) -> MistralTokenizer | None: + return self._tokenizer + + # NOTE: Remove this once LLM.tokenizer.setter is removed + @tokenizer.setter + def tokenizer(self, tokenizer: MistralTokenizer | None) -> None: + self._tokenizer = tokenizer + + def get_tokenizer(self) -> MistralTokenizer: + tokenizer = self.tokenizer + if tokenizer is None: + raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`") + + return tokenizer + + def render_messages( + self, + messages: list[ChatCompletionMessageParam], + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + tokenizer = self.get_tokenizer() + conversation, mm_data, mm_uuids = parse_chat_messages( + messages, + self.config, + content_format="string", + ) + + prompt_raw = safe_apply_chat_template(tokenizer, messages, **kwargs) + if isinstance(prompt_raw, str): + prompt = TextPrompt(prompt=prompt_raw) + else: + prompt = TokensPrompt(prompt_token_ids=prompt_raw) + + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt + + async def render_messages_async( + self, + messages: list[ChatCompletionMessageParam], + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + tokenizer = self.get_tokenizer() + conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( + messages, + self.config, + content_format="string", + ) + + prompt_raw = await self._apply_chat_template_async( + tokenizer, messages, **kwargs + ) + if isinstance(prompt_raw, str): + prompt = TextPrompt(prompt=prompt_raw) + else: + prompt = TokensPrompt(prompt_token_ids=prompt_raw) + + if mm_data_future is not None: + prompt["multi_modal_data"] = await mm_data_future + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt diff --git a/vllm/renderers/protocol.py b/vllm/renderers/protocol.py new file mode 100644 index 000000000000..98e164526260 --- /dev/null +++ b/vllm/renderers/protocol.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Any, Protocol + +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.tokenizers import TokenizerLike + +if TYPE_CHECKING: + from vllm.config import RendererConfig + from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ConversationMessage, + ) + + +class RendererLike(Protocol): + @classmethod + def from_config( + cls, + config: "RendererConfig", + tokenizer_kwargs: dict[str, Any], + ) -> "RendererLike": + raise NotImplementedError + + @property + def tokenizer(self) -> TokenizerLike | None: + raise NotImplementedError + + # NOTE: Remove this once LLM.tokenizer.setter is removed + @tokenizer.setter + def tokenizer(self, tokenizer: TokenizerLike | None) -> None: + self._tokenizer = tokenizer + + def get_tokenizer(self) -> TokenizerLike: + tokenizer = self.tokenizer + if tokenizer is None: + raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`") + + return tokenizer + + def render_messages( + self, + messages: list["ChatCompletionMessageParam"], + **kwargs, + ) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]: + raise NotImplementedError + + async def render_messages_async( + self, + messages: list["ChatCompletionMessageParam"], + **kwargs, + ) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]: + return self.render_messages(messages, **kwargs) diff --git a/vllm/renderers/registry.py b/vllm/renderers/registry.py new file mode 100644 index 000000000000..e7021e1fc548 --- /dev/null +++ b/vllm/renderers/registry.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib.util +from typing import TYPE_CHECKING, Any, TypeVar + +from typing_extensions import assert_never + +from vllm.logger import init_logger +from vllm.transformers_utils.repo_utils import list_filtered_repo_files +from vllm.utils.import_utils import resolve_obj_by_qualname + +from .protocol import RendererLike + +if TYPE_CHECKING: + from vllm.config import RendererConfig + +logger = init_logger(__name__) + +_T = TypeVar("_T", bound=type[RendererLike]) + + +class RendererRegistry: + # Renderer name -> (renderer module, renderer class) + REGISTRY: dict[str, tuple[str, str]] = { + "deepseekv32": ("vllm.renderers.deepseekv32", "DeepseekV32Renderer"), + "hf": ("vllm.renderers.hf", "HfRenderer"), + "mistral": ("vllm.renderers.mistral", "MistralRenderer"), + "terratorch": ("vllm.renderers.terratorch", "TerratorchRenderer"), + } + + @staticmethod + def register(renderer_mode: str, module: str, class_name: str) -> None: + if renderer_mode in RendererRegistry.REGISTRY: + logger.warning( + "%s.%s is already registered for renderer_mode=%r. " + "It is overwritten by the new one.", + module, + class_name, + renderer_mode, + ) + + RendererRegistry.REGISTRY[renderer_mode] = (module, class_name) + + return None + + @staticmethod + def init_renderer( + renderer_mode: str, + renderer_config: "RendererConfig", + tokenizer_kwargs: dict[str, Any], + ) -> RendererLike: + if renderer_mode not in RendererRegistry.REGISTRY: + raise ValueError(f"No renderer registered for {renderer_mode=!r}.") + + module, class_name = RendererRegistry.REGISTRY[renderer_mode] + logger.debug_once(f"Loading {class_name} for {renderer_mode=!r}") + + cls_: type[RendererLike] = resolve_obj_by_qualname(f"{module}.{class_name}") + return cls_.from_config(renderer_config, tokenizer_kwargs) + + +def renderer_from_config(renderer_config: "RendererConfig"): + tokenizer_name = renderer_config.tokenizer + tokenizer_mode = renderer_config.tokenizer_mode + tokenizer_revision = renderer_config.tokenizer_revision + trust_remote_code = renderer_config.trust_remote_code + tokenizer_kwargs = dict[str, Any]() + + runner_type = renderer_config.model_config.runner_type + if runner_type == "generate" or runner_type == "draft": + tokenizer_kwargs["truncation_side"] = "left" + elif runner_type == "pooling": + tokenizer_kwargs["truncation_side"] = "right" + else: + assert_never(runner_type) + + tokenizer_mode = renderer_config.tokenizer_mode + if tokenizer_mode == "slow": + if tokenizer_kwargs.get("use_fast", False): + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") + + tokenizer_mode = "hf" + tokenizer_kwargs["use_fast"] = False + + # Try to use official Mistral tokenizer if possible + if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"): + allow_patterns = ["tekken.json", "tokenizer.model.v*"] + files_list = list_filtered_repo_files( + model_name_or_path=str(tokenizer_name), + allow_patterns=allow_patterns, + revision=tokenizer_revision, + ) + if len(files_list) > 0: + tokenizer_mode = "mistral" + + # Fallback to HF tokenizer + if tokenizer_mode == "auto": + tokenizer_mode = "hf" + + tokenizer_kwargs = dict[str, Any]( + trust_remote_code=trust_remote_code, + revision=tokenizer_revision, + **tokenizer_kwargs, + ) + + return RendererRegistry.init_renderer( + tokenizer_mode, + renderer_config, + tokenizer_kwargs, + ) diff --git a/vllm/renderers/terratorch.py b/vllm/renderers/terratorch.py new file mode 100644 index 000000000000..f5465df9dbd3 --- /dev/null +++ b/vllm/renderers/terratorch.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.config import RendererConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ConversationMessage, + parse_chat_messages, + parse_chat_messages_futures, +) +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike + +from .protocol import RendererLike + +logger = init_logger(__name__) + + +class TerratorchRenderer(RendererLike): + @classmethod + def from_config(cls, config: "RendererConfig") -> "RendererLike": + return cls(config) + + def __init__(self, config: "RendererConfig") -> None: + super().__init__() + + self.config = config + + if not config.skip_tokenizer_init: + raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`") + + @property + def tokenizer(self) -> TokenizerLike | None: + return None + + def get_tokenizer(self) -> TokenizerLike: + raise ValueError("Tokenizer not available for Terratorch renderer") + + def render_messages( + self, + messages: list[ChatCompletionMessageParam], + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + renderer_config = self.config + + conversation, mm_data, mm_uuids = parse_chat_messages( + messages, + renderer_config, + content_format="string", + ) + + prompt = TokensPrompt(prompt_token_ids=[1]) + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt + + async def render_messages_async( + self, + messages: list[ChatCompletionMessageParam], + **kwargs, + ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]: + renderer_config = self.config + + conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( + messages, + renderer_config, + content_format="string", + ) + + prompt = TokensPrompt(prompt_token_ids=[1]) # Dummy token IDs + if mm_data_future is not None: + prompt["multi_modal_data"] = await mm_data_future + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt diff --git a/vllm/tokenizers/__init__.py b/vllm/tokenizers/__init__.py index 67a6d7c8eb3d..31e74b1a16e2 100644 --- a/vllm/tokenizers/__init__.py +++ b/vllm/tokenizers/__init__.py @@ -1,9 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .deepseekv32 import DeepseekV32Tokenizer -from .hf import HfTokenizer -from .mistral import MistralTokenizer from .protocol import TokenizerLike from .registry import ( TokenizerRegistry, @@ -15,12 +12,9 @@ __all__ = [ "TokenizerLike", - "HfTokenizer", - "MistralTokenizer", "TokenizerRegistry", "cached_get_tokenizer", "get_tokenizer", "cached_tokenizer_from_config", "init_tokenizer_from_config", - "DeepseekV32Tokenizer", ] diff --git a/vllm/tokenizers/deepseekv32.py b/vllm/tokenizers/deepseekv32.py index b0490dacbe2d..43dbb2b5c56c 100644 --- a/vllm/tokenizers/deepseekv32.py +++ b/vllm/tokenizers/deepseekv32.py @@ -2,22 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from pathlib import Path +from typing import Any from transformers import BatchEncoding +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + from .deepseek_v32_encoding import encode_messages -from .hf import HfTokenizer, TokenizerLike -from .registry import TokenizerRegistry +from .hf import CachedHfTokenizer +from .protocol import TokenizerLike -@TokenizerRegistry.register("deepseek_v32") -class DeepseekV32Tokenizer(HfTokenizer): - def __init__(self, tokenizer: TokenizerLike): - self.tokenizer = tokenizer - self.name_or_path = ( - tokenizer.name_or_path if hasattr(tokenizer, "name_or_path") else "" - ) - +class DeepseekV32Tokenizer(CachedHfTokenizer): @classmethod def from_pretrained( cls, @@ -38,7 +34,18 @@ def from_pretrained( ) return DeepseekV32Tokenizer(tokenizer) - def apply_chat_template(self, messages, tools=None, **kwargs): + def __init__(self, tokenizer: TokenizerLike) -> None: + super().__init__() + + self.tokenizer = tokenizer + self.name_or_path = getattr(tokenizer, "name_or_path", "") + + def apply_chat_template( + self, + messages: list["ChatCompletionMessageParam"], + tools: list[dict[str, Any]] | None = None, + **kwargs, + ) -> str: thinking = kwargs.get("thinking", False) thinking_mode = "thinking" if not thinking: @@ -48,7 +55,7 @@ def apply_chat_template(self, messages, tools=None, **kwargs): drop_thinking = True if tools is not None and len(tools) > 0: messages.insert(0, {"role": "system"}) - messages[0]["tools"] = tools + messages[0]["tools"] = tools # type: ignore[typeddict-unknown-key] drop_thinking = False encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking) prompt_str = encode_messages(messages, **encode_config) # type: ignore diff --git a/vllm/tokenizers/hf.py b/vllm/tokenizers/hf.py index 344507312038..c455d80a3fbc 100644 --- a/vllm/tokenizers/hf.py +++ b/vllm/tokenizers/hf.py @@ -3,22 +3,18 @@ import contextlib import copy from pathlib import Path -from typing import TYPE_CHECKING +from typing import TypeAlias -from transformers import AutoTokenizer +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config from .protocol import TokenizerLike -from .registry import TokenizerRegistry -if TYPE_CHECKING: - from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +HfTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast -def get_cached_tokenizer( - tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast", -) -> TokenizerLike: +def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer: """ By default, transformers will recompute multiple tokenizer properties each time they are called, leading to a significant slowdown. @@ -65,11 +61,10 @@ def __reduce__(self): CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" cached_tokenizer.__class__ = CachedTokenizer - return cached_tokenizer # type: ignore + return cached_tokenizer -@TokenizerRegistry.register("hf") -class HfTokenizer(TokenizerLike): +class CachedHfTokenizer(TokenizerLike): @classmethod def from_pretrained( cls, @@ -79,7 +74,9 @@ def from_pretrained( revision: str | None = None, download_dir: str | None = None, **kwargs, - ) -> "TokenizerLike": + ) -> HfTokenizer: + from transformers import AutoTokenizer + try: tokenizer = AutoTokenizer.from_pretrained( path_or_repo_id, diff --git a/vllm/tokenizers/mistral.py b/vllm/tokenizers/mistral.py index 1f44037dd55e..534b0da484a5 100644 --- a/vllm/tokenizers/mistral.py +++ b/vllm/tokenizers/mistral.py @@ -3,10 +3,11 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, cast +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.logger import init_logger from .protocol import TokenizerLike -from .registry import TokenizerRegistry if TYPE_CHECKING: from mistral_common.protocol.instruct.request import ( @@ -15,9 +16,6 @@ from mistral_common.tokens.tokenizers.tekken import Tekkenizer from transformers import BatchEncoding - from vllm.entrypoints.chat_utils import ChatCompletionMessageParam - from vllm.entrypoints.openai.protocol import ChatCompletionRequest - try: # Transformers v5 from transformers.tokenization_mistral_common import MistralCommonBackend @@ -201,7 +199,6 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int: return tokenizer.unk_id -@TokenizerRegistry.register("mistral") class MistralTokenizer(TokenizerLike): @classmethod def from_pretrained( diff --git a/vllm/tokenizers/registry.py b/vllm/tokenizers/registry.py index c9575511af8c..ba157d2749ae 100644 --- a/vllm/tokenizers/registry.py +++ b/vllm/tokenizers/registry.py @@ -1,10 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import importlib.util -from collections.abc import Callable from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, TypeVar, overload +from typing import TYPE_CHECKING, Any, TypeVar import huggingface_hub from typing_extensions import assert_never @@ -18,7 +16,6 @@ is_remote_gguf, split_remote_gguf, ) -from vllm.transformers_utils.repo_utils import list_filtered_repo_files from vllm.utils.import_utils import resolve_obj_by_qualname from .protocol import TokenizerLike @@ -28,41 +25,19 @@ logger = init_logger(__name__) -_T = TypeVar("_T", bound=type[TokenizerLike]) +_T = TypeVar("_T", bound=TokenizerLike) class TokenizerRegistry: - # Tokenizer name -> tokenizer_cls or (tokenizer module, tokenizer class) - REGISTRY: dict[str, type[TokenizerLike] | tuple[str, str]] = {} + # Tokenizer name -> (tokenizer module, tokenizer class) + REGISTRY: dict[str, tuple[str, str]] = { + "deepseekv32": ("vllm.tokenizers.deepseekv32", "DeepseekV32Tokenizer"), + "hf": ("vllm.tokenizers.hf", "CachedHfTokenizer"), + "mistral": ("vllm.tokenizers.mistral", "MistralTokenizer"), + } - # In-tree tokenizers @staticmethod - @overload - def register(tokenizer_mode: str) -> Callable[[_T], _T]: ... - - # OOT tokenizers - @staticmethod - @overload - def register(tokenizer_mode: str, module: str, class_name: str) -> None: ... - - @staticmethod - def register( - tokenizer_mode: str, - module: str | None = None, - class_name: str | None = None, - ) -> Callable[[_T], _T] | None: - # In-tree tokenizers - if module is None or class_name is None: - - def wrapper(tokenizer_cls: _T) -> _T: - assert tokenizer_mode not in TokenizerRegistry.REGISTRY - TokenizerRegistry.REGISTRY[tokenizer_mode] = tokenizer_cls - - return tokenizer_cls - - return wrapper - - # OOT tokenizers + def register(tokenizer_mode: str, module: str, class_name: str) -> None: if tokenizer_mode in TokenizerRegistry.REGISTRY: logger.warning( "%s.%s is already registered for tokenizer_mode=%r. " @@ -77,30 +52,26 @@ def wrapper(tokenizer_cls: _T) -> _T: return None @staticmethod - def get_tokenizer(tokenizer_mode: str, *args, **kwargs) -> "TokenizerLike": + def init_tokenizer(tokenizer_mode: str, *args, **kwargs) -> TokenizerLike: if tokenizer_mode not in TokenizerRegistry.REGISTRY: raise ValueError(f"No tokenizer registered for {tokenizer_mode=!r}.") - item = TokenizerRegistry.REGISTRY[tokenizer_mode] - if isinstance(item, type): - return item.from_pretrained(*args, **kwargs) - - module, class_name = item + module, class_name = TokenizerRegistry.REGISTRY[tokenizer_mode] logger.debug_once(f"Loading {class_name} for {tokenizer_mode=!r}") - class_ = resolve_obj_by_qualname(f"{module}.{class_name}") - return class_.from_pretrained(*args, **kwargs) + cls_: type[TokenizerLike] = resolve_obj_by_qualname(f"{module}.{class_name}") + return cls_.from_pretrained(*args, **kwargs) def get_tokenizer( + tokenizer_cls: type[_T], tokenizer_name: str | Path, *args, - tokenizer_mode: str = "auto", trust_remote_code: bool = False, revision: str | None = None, download_dir: str | None = None, **kwargs, -) -> TokenizerLike: +) -> _T: """Gets a tokenizer for the given model name via HuggingFace or ModelScope.""" if envs.VLLM_USE_MODELSCOPE: # download model from ModelScope hub, @@ -125,16 +96,6 @@ def get_tokenizer( ) tokenizer_name = tokenizer_path - if tokenizer_mode == "slow": - if kwargs.get("use_fast", False): - raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") - - tokenizer_mode = "hf" - kwargs["use_fast"] = False - - if "truncation_side" not in kwargs: - kwargs["truncation_side"] = "left" - # Separate model folder from file path for GGUF models if is_gguf(tokenizer_name): if check_gguf_file(tokenizer_name): @@ -150,56 +111,22 @@ def get_tokenizer( ) kwargs["gguf_file"] = gguf_file - # Try to use official Mistral tokenizer if possible - if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"): - allow_patterns = ["tekken.json", "tokenizer.model.v*"] - files_list = list_filtered_repo_files( - model_name_or_path=str(tokenizer_name), - allow_patterns=allow_patterns, - revision=revision, - ) - if len(files_list) > 0: - tokenizer_mode = "mistral" - - # Fallback to HF tokenizer - if tokenizer_mode == "auto": - tokenizer_mode = "hf" - tokenizer_args = (tokenizer_name, *args) - tokenizer_kwargs = dict( + tokenizer_kwargs = dict[str, Any]( trust_remote_code=trust_remote_code, revision=revision, download_dir=download_dir, **kwargs, ) - if tokenizer_mode == "custom": - logger.warning_once( - "TokenizerRegistry now uses `tokenizer_mode` as the registry key " - "instead of `tokenizer_name`. " - "Please update the definition of `.from_pretrained` in " - "your custom tokenizer to accept `args=%s`, `kwargs=%s`. " - "Then, you can pass `tokenizer_mode=%r` instead of " - "`tokenizer_mode='custom'` when initializing vLLM.", - tokenizer_args, - str(tokenizer_kwargs), - tokenizer_name, - ) - - tokenizer_mode = str(tokenizer_name) - - tokenizer = TokenizerRegistry.get_tokenizer( - tokenizer_mode, - *tokenizer_args, - **tokenizer_kwargs, - ) + tokenizer = tokenizer_cls.from_pretrained(*tokenizer_args, **tokenizer_kwargs) if not tokenizer.is_fast: logger.warning( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead." ) - return tokenizer + return tokenizer # type: ignore cached_get_tokenizer = lru_cache(get_tokenizer) @@ -216,6 +143,9 @@ def cached_tokenizer_from_config(renderer_config: "RendererConfig", **kwargs): def init_tokenizer_from_config(renderer_config: "RendererConfig"): + if renderer_config.skip_tokenizer_init: + return None + runner_type = renderer_config.model_config.runner_type if runner_type == "generate" or runner_type == "draft": truncation_side = "left" @@ -224,10 +154,7 @@ def init_tokenizer_from_config(renderer_config: "RendererConfig"): else: assert_never(runner_type) - return get_tokenizer( - renderer_config.tokenizer, - tokenizer_mode=renderer_config.tokenizer_mode, - trust_remote_code=renderer_config.trust_remote_code, - revision=renderer_config.tokenizer_revision, + return cached_tokenizer_from_config( + renderer_config, truncation_side=truncation_side, ) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index b76f9c0595d6..885774a60638 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -24,9 +24,10 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams +from vllm.renderers import RendererLike, renderer_from_config from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask -from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config +from vllm.tokenizers import TokenizerLike from vllm.tracing import init_tracer from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.usage.usage_lib import UsageContext @@ -109,12 +110,8 @@ def __init__( "enabling logging without default stat loggers." ) - if self.renderer_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = init_tokenizer_from_config(self.renderer_config) - - self.input_processor = InputProcessor(self.vllm_config, tokenizer) + renderer = renderer_from_config(self.renderer_config) + self.input_processor = InputProcessor(self.vllm_config, renderer) self.io_processor = get_io_processor( self.vllm_config, self.renderer_config.io_processor_plugin, @@ -715,12 +712,11 @@ def tokenizer(self, tokenizer: TokenizerLike | None) -> None: self.input_processor.tokenizer = tokenizer async def get_tokenizer(self) -> TokenizerLike: - if self.tokenizer is None: - raise ValueError( - "Unable to get tokenizer because `skip_tokenizer_init=True`" - ) + return self.input_processor.get_tokenizer() - return self.tokenizer + @property + def renderer(self) -> RendererLike: + return self.input_processor.renderer async def is_tracing_enabled(self) -> bool: return self.observability_config.otlp_traces_endpoint is not None # type: ignore diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index a2f6ba5be8c1..7e36857fbe8c 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -18,8 +18,10 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams +from vllm.renderers import RendererLike from vllm.sampling_params import SamplingParams -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest from vllm.v1.metrics.stats import MultiModalCacheStats @@ -39,7 +41,7 @@ class InputProcessor: def __init__( self, vllm_config: VllmConfig, - tokenizer: TokenizerLike | None, + renderer: RendererLike, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ) -> None: self.vllm_config = vllm_config @@ -56,7 +58,7 @@ def __init__( self.input_preprocessor = InputPreprocessor( self.renderer_config, - tokenizer, + renderer, mm_registry, mm_processor_cache=self.mm_processor_cache, ) @@ -69,6 +71,13 @@ def tokenizer(self) -> TokenizerLike | None: def tokenizer(self, tokenizer: TokenizerLike | None) -> None: self.input_preprocessor.tokenizer = tokenizer + def get_tokenizer(self) -> TokenizerLike: + return self.input_preprocessor.get_tokenizer() + + @property + def renderer(self) -> RendererLike: + return self.input_preprocessor.renderer + def _validate_logprobs( self, params: SamplingParams, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index ba0e1cf25cb0..065d72df0464 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -21,9 +21,10 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams +from vllm.renderers import RendererLike, renderer_from_config from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask -from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config +from vllm.tokenizers import TokenizerLike from vllm.tracing import init_tracer from vllm.usage.usage_lib import UsageContext from vllm.v1.engine import EngineCoreRequest @@ -84,12 +85,8 @@ def __init__( self.dp_group = None self.should_execute_dummy_batch = False - if self.renderer_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = init_tokenizer_from_config(self.renderer_config) - - self.input_processor = InputProcessor(self.vllm_config, tokenizer) + renderer = renderer_from_config(self.renderer_config) + self.input_processor = InputProcessor(self.vllm_config, renderer) self.io_processor = get_io_processor( self.vllm_config, self.renderer_config.io_processor_plugin, @@ -364,12 +361,11 @@ def tokenizer(self, tokenizer: TokenizerLike | None) -> None: self.input_processor.tokenizer = tokenizer def get_tokenizer(self) -> TokenizerLike: - if self.tokenizer is None: - raise ValueError( - "Unable to get tokenizer because `skip_tokenizer_init=True`" - ) + return self.input_processor.get_tokenizer() - return self.tokenizer + @property + def renderer(self) -> RendererLike: + return self.input_processor.renderer def do_log_stats(self) -> None: """Log stats if logging is enabled.""" diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index f8a2df43dd90..fa852b2fb79f 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -10,7 +10,7 @@ import vllm.envs from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.import_utils import LazyLoader from vllm.v1.structured_output.backend_types import ( StructuredOutputBackend,