diff --git a/tests/conftest.py b/tests/conftest.py index 25e70319e2cc..0aaf637d41d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset -from vllm.config import TaskOption, TokenizerPoolConfig, _get_and_verify_dtype +from vllm.config import TaskOption, _get_and_verify_dtype from vllm.connections import global_http_connection from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, @@ -1010,20 +1010,6 @@ def vllm_runner(): return VllmRunner -def get_tokenizer_pool_config(tokenizer_group_type): - if tokenizer_group_type is None: - return None - if tokenizer_group_type == "ray": - return TokenizerPoolConfig(pool_size=1, - pool_type="ray", - extra_config={}) - if isinstance(tokenizer_group_type, type): - return TokenizerPoolConfig(pool_size=1, - pool_type=tokenizer_group_type, - extra_config={}) - raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") - - @pytest.fixture() def temporary_enable_log_propagate(): import logging diff --git a/tests/lora/test_tokenizer_group.py b/tests/lora/test_tokenizer_group.py index d605ab734688..8845eb33d207 100644 --- a/tests/lora/test_tokenizer_group.py +++ b/tests/lora/test_tokenizer_group.py @@ -5,17 +5,14 @@ from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer import get_lora_tokenizer -from vllm.transformers_utils.tokenizer_group import get_tokenizer_group - -from ..conftest import get_tokenizer_pool_config +from vllm.transformers_utils.tokenizer_group import TokenizerGroup @pytest.mark.asyncio @pytest.mark.parametrize("tokenizer_group_type", [None, "ray"]) async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) - tokenizer_group = get_tokenizer_group( - get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_group = TokenizerGroup( tokenizer_id="gpt2", enable_lora=True, max_num_seqs=1, @@ -60,8 +57,7 @@ def test_get_lora_tokenizer(sql_lora_files, tmp_path): @pytest.mark.parametrize("max_num_seqs", [1, 2]) @pytest.mark.parametrize("max_loras", [1, 2]) def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras): - tokenizer_group = get_tokenizer_group( - get_tokenizer_pool_config(None), + tokenizer_group = TokenizerGroup( tokenizer_id="gpt2", enable_lora=enable_lora, max_num_seqs=max_num_seqs, diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 0f8b98a13581..da6c774f439a 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -10,7 +10,7 @@ from vllm.inputs import token_inputs from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import get_tokenizer_group +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer, @@ -212,7 +212,7 @@ def test_oov_decode(tokenizer, fast): @pytest.fixture def detokenizer(tokenizer_name: str) -> Detokenizer: - init_kwargs = dict( + tokenizer_group = TokenizerGroup( tokenizer_id=tokenizer_name, enable_lora=False, max_num_seqs=100, @@ -222,11 +222,6 @@ def detokenizer(tokenizer_name: str) -> Detokenizer: revision=None, ) - tokenizer_group = get_tokenizer_group( - None, - **init_kwargs, - ) - return Detokenizer(tokenizer_group) diff --git a/tests/tokenization/test_tokenizer_group.py b/tests/tokenization/test_tokenizer_group.py index 5b62f992c1be..bcfa78ed41cf 100644 --- a/tests/tokenization/test_tokenizer_group.py +++ b/tests/tokenization/test_tokenizer_group.py @@ -1,40 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 -import asyncio -import os -import sys -from typing import Optional -from unittest.mock import patch - import pytest from transformers import AutoTokenizer, PreTrainedTokenizerBase -from vllm.transformers_utils.tokenizer_group import (TokenizerGroup, - get_tokenizer_group) -from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( - RayTokenizerGroupPool) - -from ..conftest import get_tokenizer_pool_config - - -class CustomTokenizerGroup(TokenizerGroup): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._i = 0 - - def encode(self, *args, **kwargs): - self._i += 1 - return super().encode(*args, **kwargs) +from vllm.transformers_utils.tokenizer_group import TokenizerGroup @pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", - [None, "ray", CustomTokenizerGroup]) -async def test_tokenizer_group(tokenizer_group_type): +async def test_tokenizer_group(): reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer_group = get_tokenizer_group( - get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_group = TokenizerGroup( tokenizer_id="gpt2", enable_lora=False, max_num_seqs=1, @@ -49,159 +24,3 @@ async def test_tokenizer_group(tokenizer_group_type): PreTrainedTokenizerBase) assert tokenizer_group.get_lora_tokenizer( None) == await tokenizer_group.get_lora_tokenizer_async(None) - if tokenizer_group_type is CustomTokenizerGroup: - assert tokenizer_group._i > 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) -async def test_tokenizer_group_pool(tokenizer_group_type): - reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer_group_pool = get_tokenizer_group( - get_tokenizer_pool_config(tokenizer_group_type), - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None, - ) - # Send multiple requests to the tokenizer group pool - # (more than the pool size) - # and check that all requests are processed correctly. - num_requests = tokenizer_group_pool.pool_size * 5 - requests = [ - tokenizer_group_pool.encode_async(prompt=f"prompt {i}", - lora_request=None) - for i in range(num_requests) - ] - results = await asyncio.gather(*requests) - expected_results = [ - reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests) - ] - assert results == expected_results - - -@pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) -async def test_tokenizer_group_ray_pool_env_var_propagation( - tokenizer_group_type): - """Test that env vars from caller process are propagated to - tokenizer Ray actors.""" - env_var = "MY_ENV_VAR" - - class EnvVarCheckerTokenizerGroup(TokenizerGroup): - - def ping(self): - assert os.environ.get(env_var) == "1" - return super().ping() - - class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool): - _worker_cls = EnvVarCheckerTokenizerGroup - - tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) - tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None) - with pytest.raises(AssertionError): - tokenizer_pool.ping() - - with patch.dict(os.environ, {env_var: "1"}): - tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) - tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None) - tokenizer_pool.ping() - - -@pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) -async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type): - """Test that Ray tokenizer pool group can recover from failures and - if that's not possible, mark itself as unhealthy.""" - - class FailingTokenizerGroup(TokenizerGroup): - - def __init__(self, - *args, - fail_at: Optional[list[int]] = None, - **kwargs): - super().__init__(*args, **kwargs) - self.i = 0 - self.fail_at = fail_at or [] - - def encode(self, *args, **kwargs): - self.i += 1 - if self.i in self.fail_at: - sys.exit(1) - return super().encode(*args, **kwargs) - - class FailingRayTokenizerGroupPool(RayTokenizerGroupPool): - _worker_cls = FailingTokenizerGroup - - # Fail at first iteration - fail_at = [1] - tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) - tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None, - fail_at=fail_at) - tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy() - - # Modify fail at to not fail at all (will be re-read when actor is - # re-initialized). - fail_at[0] = 1000 - - # We should recover successfully. - await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None) - await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None) - - # Check that we have a new actor - assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors) - assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors - - # Fail at first iteration - fail_at = [1] - tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None, - fail_at=fail_at) - - # We should fail after re-initialization. - with pytest.raises(RuntimeError): - await tokenizer_group_pool.encode_async(prompt="prompt", - lora_request=None) - - # check_health should raise the same thing - with pytest.raises(RuntimeError): - tokenizer_group_pool.check_health() - - # Ensure that non-ActorDiedErrors are still propagated correctly and do not - # cause a re-initialization. - fail_at = [] - tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=2, - fail_at=fail_at) - tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy() - - # Prompt too long error - with pytest.raises(ValueError): - await tokenizer_group_pool.encode_async(prompt="prompt" * 100, - lora_request=None) - await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None) - # Actors should stay the same. - assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors diff --git a/tests/v1/engine/conftest.py b/tests/v1/engine/conftest.py index 8872f0388dd2..f8addd920d57 100644 --- a/tests/v1/engine/conftest.py +++ b/tests/v1/engine/conftest.py @@ -47,7 +47,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: tokenizer=tokenizer, tokenizer_group=init_tokenizer_from_configs( vllm_config.model_config, vllm_config.scheduler_config, - vllm_config.parallel_config, vllm_config.lora_config), + vllm_config.lora_config), vllm_config=vllm_config, full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS], prompt_tokens=prompt_tokens, diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index 1ee93c72cd26..4a23e0c1b212 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -8,8 +8,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs -from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( - BaseTokenizerGroup) +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, FinishReason from vllm.v1.outputs import LogprobsLists, LogprobsTensors @@ -296,7 +295,7 @@ def generate_dummy_prompt_logprobs_tensors( class DummyOutputProcessorTestVectors: """Dummy test vectors for output processor tests""" tokenizer: GeneralTokenizerType - tokenizer_group: BaseTokenizerGroup + tokenizer_group: TokenizerGroup vllm_config: EngineArgs full_tokens: list[list[int]] # Prompt + generated tokens prompt_tokens: list[list[int]] diff --git a/vllm/config.py b/vllm/config.py index 741ce04d5dff..12e6747c79ff 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -52,8 +52,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.loader import BaseModelLoader - from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( - BaseTokenizerGroup) ConfigType = type[DataclassInstance] else: @@ -1407,83 +1405,33 @@ def verify_with_parallel_config( logger.warning("Possibly too large swap space. %s", msg) -PoolType = Literal["ray"] - - @config @dataclass class TokenizerPoolConfig: - """Configuration for the tokenizer pool.""" + """This config is deprecated and will be removed in a future release. - pool_size: int = 0 - """Number of tokenizer workers in the pool to use for asynchronous - tokenization. If 0, will use synchronous tokenization.""" - - pool_type: Union[PoolType, type["BaseTokenizerGroup"]] = "ray" - """Type of tokenizer pool to use for asynchronous tokenization. Ignored if - tokenizer_pool_size is 0.""" + Passing these parameters will have no effect. Please remove them from your + configurations. + """ + pool_size: int = 0 + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" + pool_type: str = "ray" + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" extra_config: dict = field(default_factory=dict) - """Additional config for the pool. The way the config will be used depends - on the pool type. This should be a JSON string that will be parsed into a - dictionary. Ignored if tokenizer_pool_size is 0.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - if self.pool_type not in ("ray", ) and not isinstance( - self.pool_type, type): - raise ValueError(f"Unknown pool type: {self.pool_type}") - if not isinstance(self.extra_config, dict): - raise ValueError("extra_config must be a dictionary.") - - @classmethod - def create_config( - cls, tokenizer_pool_size: int, - tokenizer_pool_type: Union[PoolType, type["BaseTokenizerGroup"]], - tokenizer_pool_extra_config: Optional[Union[str, dict]] - ) -> Optional["TokenizerPoolConfig"]: - """Create a TokenizerPoolConfig from the given parameters. - - If tokenizer_pool_size is 0, return None. - - Args: - tokenizer_pool_size: Number of tokenizer workers in the pool. - tokenizer_pool_type: Type of the pool. - tokenizer_pool_extra_config: Additional config for the pool. - The way the config will be used depends on the - pool type. This can be a JSON string (will be parsed). - """ - if tokenizer_pool_size: - if isinstance(tokenizer_pool_extra_config, str): - tokenizer_pool_extra_config_parsed = json.loads( - tokenizer_pool_extra_config) - else: - tokenizer_pool_extra_config_parsed = ( - tokenizer_pool_extra_config or {}) - tokenizer_pool_config = cls(tokenizer_pool_size, - tokenizer_pool_type, - tokenizer_pool_extra_config_parsed) - else: - tokenizer_pool_config = None - return tokenizer_pool_config + def __post_init__(self) -> None: + logger.warning_once( + "TokenizerPoolConfig is deprecated and will be removed in a " + "future release. Passing this parameter will have no effect. " + "Please remove it from your configurations.") class LoadFormat(str, enum.Enum): @@ -1624,8 +1572,8 @@ def data_parallel_rank_local(self, value: int) -> None: """Disable the custom all-reduce kernel and fall back to NCCL.""" tokenizer_pool_config: Optional[TokenizerPoolConfig] = None - """Config for the tokenizer pool. If None, will use synchronous - tokenization.""" + """This parameter is deprecated and will be removed in a future release. + Please remove it from your configs""" ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" @@ -2544,7 +2492,6 @@ def create_draft_parallel_config( max_parallel_loading_workers, disable_custom_all_reduce=target_parallel_config. disable_custom_all_reduce, - tokenizer_pool_config=target_parallel_config.tokenizer_pool_config, ray_workers_use_nsight=target_parallel_config. ray_workers_use_nsight, placement_group=target_parallel_config.placement_group, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 00328f56b713..6d6b5ac02b14 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,9 +7,8 @@ import re import threading from dataclasses import MISSING, dataclass, fields -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, - Optional, Tuple, Type, TypeVar, Union, cast, get_args, - get_origin) +from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Type, + TypeVar, Union, cast, get_args, get_origin) import torch from typing_extensions import TypeIs @@ -23,7 +22,7 @@ KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, ModelImpl, MultiModalConfig, ObservabilityConfig, ParallelConfig, PoolerConfig, - PoolType, PrefixCachingHashAlgo, PromptAdapterConfig, + PrefixCachingHashAlgo, PromptAdapterConfig, SchedulerConfig, SchedulerPolicy, SpeculativeConfig, TaskOption, TokenizerPoolConfig, VllmConfig, get_attr_docs, get_field) @@ -39,9 +38,6 @@ # yapf: enable -if TYPE_CHECKING: - from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup - logger = init_logger(__name__) ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"] @@ -185,13 +181,12 @@ class EngineArgs: enforce_eager: Optional[bool] = None max_seq_len_to_capture: int = 8192 disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce + # The following three fields are deprecated and will be removed in a future + # release. Setting them will have no effect. Please remove them from your + # configurations. tokenizer_pool_size: int = TokenizerPoolConfig.pool_size - # Note: Specifying a tokenizer pool by passing a class - # is intended for expert use only. The API may change without - # notice. - tokenizer_pool_type: Union[PoolType, Type["BaseTokenizerGroup"]] = \ - TokenizerPoolConfig.pool_type - tokenizer_pool_extra_config: dict[str, Any] = \ + tokenizer_pool_type: str = TokenizerPoolConfig.pool_type + tokenizer_pool_extra_config: dict = \ get_field(TokenizerPoolConfig, "extra_config") limit_mm_per_prompt: dict[str, int] = \ get_field(MultiModalConfig, "limit_per_prompt") @@ -1187,11 +1182,6 @@ def create_engine_config( enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, - tokenizer_pool_config=TokenizerPoolConfig.create_config( - self.tokenizer_pool_size, - self.tokenizer_pool_type, - self.tokenizer_pool_extra_config, - ), ray_workers_use_nsight=self.ray_workers_use_nsight, placement_group=placement_group, distributed_executor_backend=self.distributed_executor_backend, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 67c7e109c9f0..ca8fd83314ae 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -526,8 +526,6 @@ async def add_request_async( ) async def check_health_async(self) -> None: - if self.tokenizer: - self.tokenizer.check_health() self.model_executor.check_health() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4644053785f1..276891489836 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -55,7 +55,7 @@ from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import ( - BaseTokenizerGroup, init_tokenizer_from_configs) + TokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import (Counter, Device, deprecate_kwargs, @@ -66,7 +66,6 @@ logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 -_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _O = TypeVar("_O", RequestOutput, PoolingRequestOutput) _R = TypeVar("_R", default=Any) @@ -205,7 +204,7 @@ def validate_outputs( return outputs_ - tokenizer: Optional[BaseTokenizerGroup] + tokenizer: Optional[TokenizerGroup] def __init__( self, @@ -321,11 +320,6 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.parallel_config.disable_custom_all_reduce, }) - if self.tokenizer: - # Ping the tokenizer to ensure liveness if it runs in a - # different process. - self.tokenizer.ping() - self.cached_scheduler_outputs = [ SchedulerOutputState() for _ in range(self.parallel_config.pipeline_parallel_size) @@ -537,21 +531,12 @@ def __del__(self): if model_executor := getattr(self, "model_executor", None): model_executor.shutdown() - def get_tokenizer_group( - self, - group_type: Type[_G] = BaseTokenizerGroup, - ) -> _G: - tokenizer_group = self.tokenizer - - if tokenizer_group is None: + def get_tokenizer_group(self) -> TokenizerGroup: + if self.tokenizer is None: raise ValueError("Unable to get tokenizer because " "skip_tokenizer_init is True") - if not isinstance(tokenizer_group, group_type): - raise TypeError("Invalid type of tokenizer group. " - f"Expected type: {group_type}, but " - f"found type: {type(tokenizer_group)}") - return tokenizer_group + return self.tokenizer def get_tokenizer( self, @@ -559,11 +544,10 @@ def get_tokenizer( ) -> AnyTokenizer: return self.get_tokenizer_group().get_lora_tokenizer(lora_request) - def _init_tokenizer(self) -> BaseTokenizerGroup: + def _init_tokenizer(self) -> TokenizerGroup: return init_tokenizer_from_configs( model_config=self.model_config, scheduler_config=self.scheduler_config, - parallel_config=self.parallel_config, lora_config=self.lora_config) def _verify_args(self) -> None: @@ -1952,8 +1936,6 @@ def is_sleeping(self) -> bool: return self.model_executor.is_sleeping def check_health(self) -> None: - if self.tokenizer: - self.tokenizer.check_health() self.model_executor.check_health() def is_tracing_enabled(self) -> bool: diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 6e56cbdbbf8c..eb3ae89394ec 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -101,7 +101,6 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, self.tokenizer = init_tokenizer_from_configs( model_config=self.model_config, scheduler_config=engine_config.scheduler_config, - parallel_config=engine_config.parallel_config, lora_config=engine_config.lora_config) self.input_preprocessor = InputPreprocessor(self.model_config, self.tokenizer) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 57c7ab73de37..38a541a408fa 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -40,7 +40,6 @@ RequestOutputKind, SamplingParams) from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) -from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs, is_list_of) @@ -253,10 +252,10 @@ def __init__( self.default_sampling_params: Union[dict[str, Any], None] = None def get_tokenizer(self) -> AnyTokenizer: - return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer + return self.llm_engine.get_tokenizer_group().tokenizer def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: - tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup) + tokenizer_group = self.llm_engine.get_tokenizer_group() # While CachedTokenizer is dynamic, have no choice but # compare class name. Misjudgment will arise from diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 669fb96e6653..a4609290a900 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -13,7 +13,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs) from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, token_inputs) @@ -27,7 +27,7 @@ class InputPreprocessor: def __init__( self, model_config: ModelConfig, - tokenizer: Optional[BaseTokenizerGroup], + tokenizer: Optional[TokenizerGroup], mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ) -> None: super().__init__() @@ -36,7 +36,7 @@ def __init__( self.tokenizer = tokenizer self.mm_registry = mm_registry - def get_tokenizer_group(self) -> BaseTokenizerGroup: + def get_tokenizer_group(self) -> TokenizerGroup: if self.tokenizer is None: raise ValueError("You cannot pass text prompts when " "`skip_tokenizer_init` is True") diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 9d1d4bb92e4a..991d5631e64e 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -8,13 +8,13 @@ from .detokenizer_utils import (convert_prompt_ids_to_tokens, detokenize_incrementally) from .tokenizer import AnyTokenizer -from .tokenizer_group import BaseTokenizerGroup +from .tokenizer_group import TokenizerGroup class Detokenizer: """Provides methods to decode the output of a model into text.""" - def __init__(self, tokenizer_group: BaseTokenizerGroup): + def __init__(self, tokenizer_group: TokenizerGroup): self.tokenizer_group = tokenizer_group def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group.py similarity index 84% rename from vllm/transformers_utils/tokenizer_group/tokenizer_group.py rename to vllm/transformers_utils/tokenizer_group.py index b6e9005bcd24..a829985cb459 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group.py @@ -2,7 +2,7 @@ from typing import List, Optional -from vllm.config import TokenizerPoolConfig +from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens, get_lora_tokenizer, @@ -10,10 +10,8 @@ get_tokenizer) from vllm.utils import LRUCache -from .base_tokenizer_group import BaseTokenizerGroup - -class TokenizerGroup(BaseTokenizerGroup): +class TokenizerGroup: """A group of tokenizers that can be used for LoRA adapters.""" def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, @@ -27,15 +25,6 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, self.lora_tokenizers = LRUCache[int, AnyTokenizer]( capacity=max(max_loras, max_num_seqs) if enable_lora else 0) - @classmethod - def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], - **init_kwargs) -> "TokenizerGroup": - return cls(**init_kwargs) - - def ping(self) -> bool: - """Check if the tokenizer group is alive.""" - return True - def get_max_input_len(self, lora_request: Optional[LoRARequest] = None ) -> Optional[int]: @@ -104,3 +93,18 @@ async def get_lora_tokenizer_async( return tokenizer else: return self.lora_tokenizers[lora_request.lora_int_id] + + +def init_tokenizer_from_configs(model_config: ModelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig]): + return TokenizerGroup( + tokenizer_id=model_config.tokenizer, + enable_lora=bool(lora_config), + max_num_seqs=scheduler_config.max_num_seqs, + max_loras=lora_config.max_loras if lora_config else 0, + max_input_length=None, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision, + truncation_side=model_config.truncation_side) diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py deleted file mode 100644 index 9d2209575bd3..000000000000 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ /dev/null @@ -1,56 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional, Type - -from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, TokenizerPoolConfig) -from vllm.executor.ray_utils import ray - -from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup -from .tokenizer_group import TokenizerGroup - -if ray: - from .ray_tokenizer_group import RayTokenizerGroupPool -else: - RayTokenizerGroupPool = None # type: ignore - - -def init_tokenizer_from_configs(model_config: ModelConfig, - scheduler_config: SchedulerConfig, - parallel_config: ParallelConfig, - lora_config: Optional[LoRAConfig]): - init_kwargs = dict(tokenizer_id=model_config.tokenizer, - enable_lora=bool(lora_config), - max_num_seqs=scheduler_config.max_num_seqs, - max_loras=lora_config.max_loras if lora_config else 0, - max_input_length=None, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.tokenizer_revision, - truncation_side=model_config.truncation_side) - - return get_tokenizer_group(parallel_config.tokenizer_pool_config, - **init_kwargs) - - -def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], - **init_kwargs) -> BaseTokenizerGroup: - tokenizer_cls: Type[BaseTokenizerGroup] - if tokenizer_pool_config is None: - tokenizer_cls = TokenizerGroup - elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass( - tokenizer_pool_config.pool_type, BaseTokenizerGroup): - tokenizer_cls = tokenizer_pool_config.pool_type - elif tokenizer_pool_config.pool_type == "ray": - if RayTokenizerGroupPool is None: - raise ImportError( - "RayTokenizerGroupPool is not available. Please install " - "the ray package to use the Ray tokenizer group pool.") - tokenizer_cls = RayTokenizerGroupPool - else: - raise ValueError( - f"Unknown pool type: {tokenizer_pool_config.pool_type}") - return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs) - - -__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"] diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py deleted file mode 100644 index c5108a7fc6eb..000000000000 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from abc import ABC, abstractmethod -from typing import List, Optional - -from vllm.config import TokenizerPoolConfig -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import AnyTokenizer - - -class BaseTokenizerGroup(ABC): - """A group of tokenizers that can be used for LoRA adapters.""" - - @classmethod - @abstractmethod - def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], - **init_kwargs) -> "BaseTokenizerGroup": - pass - - @abstractmethod - def ping(self) -> bool: - """Check if the tokenizer group is alive.""" - pass - - @abstractmethod - def get_max_input_len( - self, - lora_request: Optional[LoRARequest] = None, - ) -> Optional[int]: - """Get the maximum input length for the LoRA request.""" - pass - - @abstractmethod - def encode(self, - prompt: str, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: - """Encode a prompt using the tokenizer group.""" - pass - - @abstractmethod - async def encode_async( - self, - prompt: str, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: - """Encode a prompt using the tokenizer group.""" - pass - - @abstractmethod - def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - """Get a tokenizer for a LoRA request.""" - pass - - @abstractmethod - async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - """Get a tokenizer for a LoRA request.""" - pass - - def check_health(self): - """Raise exception if the tokenizer group is unhealthy.""" - return diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py deleted file mode 100644 index b048b8094174..000000000000 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ /dev/null @@ -1,244 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import asyncio -import os -from typing import List, Optional - -try: - from ray.exceptions import ActorDiedError # type: ignore -except ImportError: - # For older versions of Ray - from ray.exceptions import RayActorError as ActorDiedError # type: ignore -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy - -from vllm.config import TokenizerPoolConfig -from vllm.executor.ray_utils import ray -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import AnyTokenizer - -from .base_tokenizer_group import BaseTokenizerGroup -from .tokenizer_group import TokenizerGroup - -logger = init_logger(__name__) - - -class RayTokenizerGroupPool(BaseTokenizerGroup): - """A Ray-based pool of TokenizerGroups for async tokenization.""" - - # Class to use for workers making up the pool. - _worker_cls = TokenizerGroup - - @classmethod - def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], - **init_kwargs) -> "RayTokenizerGroupPool": - if not tokenizer_pool_config: - raise ValueError("tokenizer_pool_config must not be None.") - ray_actor_options = (tokenizer_pool_config.extra_config or { - "num_cpus": 0 - }) - ray_actor_options.setdefault( - "scheduling_strategy", - NodeAffinitySchedulingStrategy( - node_id=ray.get_runtime_context().get_node_id(), soft=True)) - - # Carry over the env vars to the actors. - # This is necessary for API keys and such. - ray_actor_options.setdefault("runtime_env", {}) - _carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"]) - - init_kwargs["num_actors"] = tokenizer_pool_config.pool_size - init_kwargs["ray_actor_options"] = ray_actor_options - - return cls(**init_kwargs) - - def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int], num_actors: int, - ray_actor_options: dict, **tokenizer_config): - # Store a local copy of the TokenizerGroup for quick access - # to underlying HF tokenizers. - self._tokenizer_config = { - "tokenizer_id": tokenizer_id, - "enable_lora": enable_lora, - "max_num_seqs": max_num_seqs, - "max_input_length": max_input_length, - **tokenizer_config - } - self._local_tokenizer_group = self._worker_cls( - **self._tokenizer_config, ) - - self._ray_tokenizer_group_cls = ray.remote( - self._worker_cls).options(**ray_actor_options) # type: ignore - self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)] - self._idle_actors: Optional[asyncio.Queue] = None - - # If set, actor is unhealthy. Will reraise on the next - # check_health call. - self._exception: Optional[ActorDiedError] = None - - def _init_actor(self) -> ray.ObjectRef: - return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config) - - @property - def pool_size(self) -> int: - return len(self.tokenizer_actors) - - def ping(self): - return ray.get([ - actor.ping.remote() # type: ignore - for actor in self.tokenizer_actors - ]) - - def _ensure_queue_initialized(self): - if self._idle_actors is None: - self._idle_actors = asyncio.Queue() - for actor in self.tokenizer_actors: - self._idle_actors.put_nowait(actor) - - def _finalize_encode(self, actor: ray.ObjectRef, - original_actor: ray.ObjectRef, actor_is_alive: bool): - assert self._idle_actors is not None - # Cleanup the dead actor. - if not actor_is_alive or original_actor is not actor: - self.tokenizer_actors.remove(original_actor) - if actor_is_alive: - # Put the actor back in the queue. - # This is done in a finally block to ensure that the actor is - # always put back in the queue, even if an exception/cancellation - # is raised. - self._idle_actors.put_nowait(actor) - # Add back the new actor. - if original_actor is not actor: - self.tokenizer_actors.append(actor) - - def encode(self, - prompt: str, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: - """Encode a prompt using the tokenizer group. - - We pick an idle actor and use it to encode the prompt. - The actor is then put back in the queue for future use. - This is blocking. - """ - self.check_health() - self._ensure_queue_initialized() - assert self._idle_actors is not None - - if self._idle_actors.empty(): - raise RuntimeError("No idle actors available.") - actor = self._idle_actors.get_nowait() - actor_is_alive = True - original_actor = actor - try: - ret = ray.get( - actor.encode.remote(prompt=prompt, - lora_request=lora_request, - add_special_tokens=add_special_tokens)) - except ActorDiedError as e: - # If the actor is dead, we first try to reinitialize it. - logger.warning("%s died with ActorDiedError, reinitializing.", - actor, - exc_info=e) - actor = self._init_actor() - try: - ret = ray.get( - actor.encode.remote(prompt=prompt, - lora_request=lora_request, - add_special_tokens=add_special_tokens)) - except ActorDiedError as e: - logger.error( - "%s died for second time in a row, marking " - "RayTokenizerGroupPool as unhealthy.", actor) - actor_is_alive = False - if not self._exception: - self._exception = e - self.check_health() - finally: - self._finalize_encode(actor, original_actor, actor_is_alive) - return ret - - async def encode_async( - self, - prompt: str, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: - """Encode a prompt using the tokenizer group. - - We pick an idle actor and use it to encode the prompt. - If there are no idle actors, we wait until one becomes - available. - The actor is then put back in the queue for future use. - This is non-blocking. - """ - self.check_health() - self._ensure_queue_initialized() - assert self._idle_actors is not None - - actor = await self._idle_actors.get() - actor_is_alive = True - original_actor = actor - try: - ret = await actor.encode.remote( - prompt=prompt, - lora_request=lora_request, - add_special_tokens=add_special_tokens) - except ActorDiedError as e: - # If the actor is dead, we first try to reinitialize it. - logger.warning("%s died with ActorDiedError, reinitializing.", - actor, - exc_info=e) - actor = self._init_actor() - try: - ret = await actor.encode.remote( - prompt=prompt, - lora_request=lora_request, - add_special_tokens=add_special_tokens) - except ActorDiedError as e: - logger.error( - "%s died for second time in a row, marking " - "RayTokenizerGroupPool as unhealthy.", actor) - actor_is_alive = False - if not self._exception: - self._exception = e - self.check_health() - finally: - self._finalize_encode(actor, original_actor, actor_is_alive) - return ret - - def get_max_input_len(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: - """Get the maximum input length for the LoRA request.""" - return self._local_tokenizer_group.get_max_input_len(lora_request) - - def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return self._local_tokenizer_group.get_lora_tokenizer(lora_request) - - async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return await self._local_tokenizer_group.get_lora_tokenizer_async( - lora_request) - - def check_health(self): - if self._exception: - raise RuntimeError( - "TokenizerGroupPool is unhealthy.") from self._exception - - -def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None: - """Copy over all current process environment variables to the runtime_env. - - The variables in runtime_env will take precedence over the current process - environment variables. - - runtime_env will be modified in place.""" - env_vars = os.environ.copy() - runtime_env.setdefault("env_vars", {}) - env_vars.update(runtime_env["env_vars"]) - runtime_env["env_vars"] = env_vars diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 1149dfa9ce5c..fdda812c2853 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -81,9 +81,7 @@ def __init__( self.tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) - self.tokenizer.ping() # Processor (converts Inputs --> EngineCoreRequests). self.processor = Processor( diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 6fa90b269825..af67408097ab 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -20,7 +20,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import ( - BaseTokenizerGroup, init_tokenizer_from_configs) + TokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext from vllm.utils import Device from vllm.v1.engine.core_client import EngineCoreClient @@ -32,7 +32,6 @@ logger = init_logger(__name__) -_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _R = TypeVar("_R", default=Any) @@ -74,9 +73,7 @@ def __init__( self.tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) - self.tokenizer.ping() # Processor (convert Inputs --> EngineCoreRequests) self.processor = Processor(vllm_config=vllm_config, @@ -258,21 +255,12 @@ def wake_up(self, tags: Optional[list[str]] = None): def is_sleeping(self) -> bool: return self.engine_core.is_sleeping() - def get_tokenizer_group( - self, - group_type: type[_G] = BaseTokenizerGroup, - ) -> _G: - tokenizer_group = self.tokenizer - - if tokenizer_group is None: + def get_tokenizer_group(self) -> TokenizerGroup: + if self.tokenizer is None: raise ValueError("Unable to get tokenizer because " "skip_tokenizer_init is True") - if not isinstance(tokenizer_group, group_type): - raise TypeError("Invalid type of tokenizer group. " - f"Expected type: {group_type}, but " - f"found type: {type(tokenizer_group)}") - return tokenizer_group + return self.tokenizer def add_lora(self, lora_request: LoRARequest) -> bool: """Load a new LoRA adapter into the engine for future requests.""" diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index d652b17e55b3..9c42d34b8a9a 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -8,7 +8,7 @@ from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor @@ -225,7 +225,7 @@ class OutputProcessor: def __init__( self, - tokenizer: BaseTokenizerGroup, + tokenizer: TokenizerGroup, log_stats: bool, ): self.log_stats = log_stats diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7e6b7ba47035..af09a67635a2 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -17,7 +17,7 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.structured_output.backend_guidance import ( @@ -31,7 +31,7 @@ class Processor: def __init__( self, vllm_config: VllmConfig, - tokenizer: BaseTokenizerGroup, + tokenizer: TokenizerGroup, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 0edb15558dce..a59ec5efc53e 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -61,9 +61,7 @@ def __init__(self, vllm_config: VllmConfig): tokenizer_group = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) # type: ignore[arg-type] - tokenizer_group.ping() self.vllm_config = vllm_config self.vocab_size = vllm_config.model_config.get_vocab_size() diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 1e4470153e30..0b5a1593b3eb 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -35,9 +35,7 @@ def __init__(self, vllm_config: VllmConfig): tokenizer_group = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) # type: ignore[arg-type] - tokenizer_group.ping() self.disable_any_whitespace = False backend_options = GuidedDecodingParams(