From 51a462874001dbf1cf728b7e40380798b0efcf2b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 31 Jul 2024 03:14:01 +0000 Subject: [PATCH 01/75] Add entrypoints to stricter checks --- .github/workflows/mypy.yaml | 1 - format.sh | 1 - pyproject.toml | 1 + 3 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 721c9c026cf1..0ae67941a860 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -37,7 +37,6 @@ jobs: mypy vllm/core --follow-imports skip mypy vllm/distributed --follow-imports skip mypy vllm/engine --follow-imports skip - mypy vllm/entrypoints --follow-imports skip mypy vllm/executor --follow-imports skip mypy vllm/lora --follow-imports skip mypy vllm/model_executor --follow-imports skip diff --git a/format.sh b/format.sh index 71697cffacfb..b08ee7877c5a 100755 --- a/format.sh +++ b/format.sh @@ -101,7 +101,6 @@ mypy vllm/attention --follow-imports skip mypy vllm/core --follow-imports skip mypy vllm/distributed --follow-imports skip mypy vllm/engine --follow-imports skip -mypy vllm/entrypoints --follow-imports skip mypy vllm/executor --follow-imports skip mypy vllm/lora --follow-imports skip mypy vllm/model_executor --follow-imports skip diff --git a/pyproject.toml b/pyproject.toml index cd5d196a1620..f054fb35a8a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ files = [ "vllm/*.py", "vllm/adapter_commons", "vllm/assets", + "vllm/entrypoints", "vllm/inputs", "vllm/logging", "vllm/multimodal", From 8ab3ba9223127ebda3a379e9213da245f5143179 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 31 Jul 2024 03:15:26 +0000 Subject: [PATCH 02/75] Fix `swap_space` and `cpu_offload_gb` only accepting ints; clean up code in the process --- vllm/config.py | 15 +++++++-------- vllm/engine/arg_utils.py | 6 +++--- vllm/entrypoints/llm.py | 2 +- vllm/executor/cpu_executor.py | 7 +++---- vllm/executor/openvino_executor.py | 9 ++++----- vllm/utils.py | 3 +++ vllm/worker/tpu_worker.py | 4 ++-- 7 files changed, 23 insertions(+), 23 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index fd48cc3a6b37..c11dc0c0cf23 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -11,8 +11,8 @@ from vllm.model_executor.models import ModelRegistry from vllm.tracing import is_otel_installed from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu, - is_hip, is_neuron, is_openvino, is_tpu, is_xpu, +from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, + is_cpu, is_hip, is_neuron, is_openvino, is_tpu, is_xpu, print_warning_once) if TYPE_CHECKING: @@ -25,7 +25,6 @@ logger = init_logger(__name__) -_GB = 1 << 30 _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _PP_SUPPORTED_MODELS = [ @@ -437,7 +436,7 @@ def __init__( self, block_size: int, gpu_memory_utilization: float, - swap_space: int, + swap_space: float, cache_dtype: str, num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, @@ -446,7 +445,7 @@ def __init__( ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization - self.swap_space_bytes = swap_space * _GB + self.swap_space_bytes = swap_space * GiB_bytes self.num_gpu_blocks_override = num_gpu_blocks_override self.cache_dtype = cache_dtype self.sliding_window = sliding_window @@ -506,9 +505,9 @@ def verify_with_parallel_config( num_gpus_per_node = parallel_config.tensor_parallel_size cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node - msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of " - f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is " - "allocated for the swap space.") + msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " + f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " + "is allocated for the swap space.") if cpu_memory_usage > 0.7 * total_cpu_memory: raise ValueError("Too large swap space. " + msg) elif cpu_memory_usage > 0.4 * total_cpu_memory: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2737b50927f6..dd31a3be1840 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -56,8 +56,8 @@ class EngineArgs: enable_prefix_caching: bool = False disable_sliding_window: bool = False use_v2_block_manager: bool = False - swap_space: int = 4 # GiB - cpu_offload_gb: int = 0 # GiB + swap_space: float = 4 # GiB + cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 @@ -318,7 +318,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.seed, help='Random seed for operations.') parser.add_argument('--swap-space', - type=int, + type=float, default=EngineArgs.swap_space, help='CPU swap space size (GiB) per GPU.') parser.add_argument( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 62309ed345b1..4bbf1805243a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -116,7 +116,7 @@ def __init__( tokenizer_revision: Optional[str] = None, seed: int = 0, gpu_memory_utilization: float = 0.9, - swap_space: int = 4, + swap_space: float = 4, cpu_offload_gb: float = 0, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 3229e5ad20af..f58aaf8a55b9 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.utils import (get_distributed_init_method, get_open_port, +from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port, get_vllm_instance_id, make_async) from vllm.worker.worker_base import WorkerWrapperBase @@ -332,7 +332,6 @@ def _verify_and_get_scheduler_config( def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: - _GB = 1 << 30 if config.enable_prefix_caching: logger.warning("Prefix caching is not supported on CPU, disable it.") config.enable_prefix_caching = False @@ -341,11 +340,11 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: if kv_cache_space >= 0: if kv_cache_space == 0: - config.cpu_kvcache_space_bytes = 4 * _GB # type: ignore + config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) " "for CPU backend is not set, using 4 by default.") else: - config.cpu_kvcache_space_bytes = kv_cache_space * _GB # type: ignore + config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore else: raise RuntimeError( "Invalid environment variable VLLM_CPU_KVCACHE_SPACE" diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index c52a1c9839d7..7df515a2a5ce 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -10,8 +10,8 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) +from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip, + get_open_port, make_async) logger = init_logger(__name__) @@ -165,14 +165,13 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE if kv_cache_space >= 0: - _GB = 1 << 30 if kv_cache_space == 0: - config.openvino_kvcache_space_bytes = 4 * _GB # type: ignore + config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore logger.warning( "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) " "for OpenVINO backend is not set, using 4 by default.") else: - config.openvino_kvcache_space_bytes = kv_cache_space * _GB # type: ignore + config.openvino_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore else: raise RuntimeError( "Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE" diff --git a/vllm/utils.py b/vllm/utils.py index 38e1782a51ab..127131146405 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -32,6 +32,9 @@ logger = init_logger(__name__) +GiB_bytes = 1 << 30 +"""The number of bytes in one gibibyte (GiB).""" + STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.half, "bfloat16": torch.bfloat16, diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 17fa5c35457c..bedea4c40a55 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -140,8 +140,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. # Calculate the CPU KV cache size based on the config. - num_cpu_blocks = (self.cache_config.swap_space_bytes // - block_size_bytes) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + block_size_bytes) num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8. return num_tpu_blocks, num_cpu_blocks From 7efaa82f5785ce78a9488b95115ec12f9c0317f1 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 31 Jul 2024 03:39:38 +0000 Subject: [PATCH 03/75] Update mypy version --- .github/workflows/mypy.yaml | 2 +- requirements-lint.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 0ae67941a860..8bc55f7dceee 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install mypy==1.9.0 + pip install mypy==1.11.1 pip install types-setuptools pip install types-PyYAML pip install types-requests diff --git a/requirements-lint.txt b/requirements-lint.txt index bd34227d3e82..d0b2fef6deae 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -8,7 +8,7 @@ isort==5.13.2 clang-format==18.1.5 # type checking -mypy==1.9.0 +mypy==1.11.1 types-PyYAML types-requests types-setuptools From e5b6784766313ce791660c047d266d635b6b9dc7 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 31 Jul 2024 05:23:13 +0000 Subject: [PATCH 04/75] Improve typing of tokenizer and hf config --- vllm/engine/async_llm_engine.py | 5 +-- vllm/engine/llm_engine.py | 33 +++++++++++++------ vllm/entrypoints/llm.py | 21 +++++------- vllm/entrypoints/openai/serving_chat.py | 10 +++--- vllm/entrypoints/openai/serving_completion.py | 8 ++--- vllm/entrypoints/openai/serving_engine.py | 2 +- vllm/inputs/registry.py | 8 ++--- vllm/model_executor/models/internvl.py | 6 ++-- vllm/model_executor/models/minicpmv.py | 8 ++--- vllm/model_executor/models/phi3v.py | 4 +-- vllm/transformers_utils/detokenizer.py | 14 ++++---- vllm/transformers_utils/tokenizer.py | 15 ++++----- 12 files changed, 72 insertions(+), 62 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d3f9a0ab00f1..88b94602c7a5 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -282,8 +282,9 @@ async def process_model_inputs_async( inputs = {"prompt": inputs} if "prompt_token_ids" not in inputs: - tokenizer = self.get_tokenizer_group("prompts must be None if " - "skip_tokenizer_init is True") + tokenizer = self.get_tokenizer_group( + missing_msg= + "prompts must be None if skip_tokenizer_init is True") prompt_token_ids = await tokenizer.encode_async( request_id=request_id, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1efe2206abe8..8960e5ce04d1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -3,7 +3,9 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Mapping, Optional) from typing import Sequence as GenericSequence -from typing import Set, Type, TypeVar, Union +from typing import Set, Type, Union + +from typing_extensions import TypeVar import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, @@ -38,9 +40,9 @@ init_tracer) from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import (AnyTokenizer, - BaseTokenizerGroup, +from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter @@ -63,6 +65,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: return config.to_diff_dict() +_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) @@ -468,12 +471,21 @@ def __del__(self): "skip_tokenizer_init is True") def get_tokenizer_group( - self, - fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup: - if self.tokenizer is None: - raise ValueError(fail_msg) + self, + group_type: Type[_G] = BaseTokenizerGroup, + *, + missing_msg: str = MISSING_TOKENIZER_GROUP_MSG, + ) -> _G: + tokenizer_group = self.tokenizer + + if tokenizer_group is None: + raise ValueError(missing_msg) + if not isinstance(tokenizer_group, group_type): + raise TypeError("Invalid type of tokenizer group. " + f"Expected type: {group_type}, but " + f"found type: {type(tokenizer_group)}") - return self.tokenizer + return tokenizer_group def get_tokenizer( self, @@ -581,8 +593,9 @@ def process_model_inputs( inputs = {"prompt": inputs} if "prompt_token_ids" not in inputs: - tokenizer = self.get_tokenizer_group("prompts must be None if " - "skip_tokenizer_init is True") + tokenizer = self.get_tokenizer_group( + missing_msg= + "prompts must be None if skip_tokenizer_init is True") prompt_token_ids = tokenizer.encode(request_id=request_id, prompt=inputs["prompt"], diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 4bbf1805243a..765b92a3f6be 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -2,7 +2,6 @@ from typing import ClassVar, List, Optional, Sequence, Union, cast, overload from tqdm import tqdm -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine @@ -14,7 +13,8 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer import get_cached_tokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_cached_tokenizer +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, deprecate_kwargs @@ -156,22 +156,19 @@ def __init__( engine_args, usage_context=UsageContext.LLM_CLASS) self.request_counter = Counter() - def get_tokenizer( - self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - return self.llm_engine.tokenizer.tokenizer + def get_tokenizer(self) -> AnyTokenizer: + return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer + + def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: + tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup) - def set_tokenizer( - self, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - ) -> None: # While CachedTokenizer is dynamic, have no choice but # compare class name. Misjudgment will arise from # user-defined tokenizer started with 'Cached' if tokenizer.__class__.__name__.startswith("Cached"): - self.llm_engine.tokenizer.tokenizer = tokenizer + tokenizer_group.tokenizer = tokenizer else: - self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer( - tokenizer) + tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) @overload # LEGACY: single (prompt + optional token ids) def generate( diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 01843930bf11..359fae8e1969 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -5,7 +5,6 @@ from typing import Union from fastapi import Request -from transformers import PreTrainedTokenizer from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -32,6 +31,7 @@ from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid logger = init_logger(__name__) @@ -213,7 +213,7 @@ async def chat_completion_stream_generator( result_generator: AsyncIterator[RequestOutput], request_id: str, conversation: List[ConversationMessage], - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, ) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] created_time = int(time.time()) @@ -438,7 +438,7 @@ async def chat_completion_full_generator( result_generator: AsyncIterator[RequestOutput], request_id: str, conversation: List[ConversationMessage], - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.served_model_names[0] @@ -522,7 +522,7 @@ async def chat_completion_full_generator( def _get_top_logprobs( self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int], - tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]: + tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]: return [ ChatCompletionLogProb(token=(token := self._get_decoded_token( p[1], @@ -540,7 +540,7 @@ def _create_chat_logprobs( self, token_ids: GenericSequence[int], top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, num_output_top_logprobs: Optional[int] = None, ) -> ChatCompletionLogProbs: """Create OpenAI-style logprobs.""" diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 854835279168..77eae5b99c42 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -5,7 +5,6 @@ from typing import Tuple, cast from fastapi import Request -from transformers import PreTrainedTokenizer from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -30,6 +29,7 @@ from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) @@ -227,7 +227,7 @@ async def completion_stream_generator( created_time: int, model_name: str, num_prompts: int, - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, ) -> AsyncGenerator[str, None]: num_choices = 1 if request.n is None else request.n previous_texts = [""] * num_choices * num_prompts @@ -347,7 +347,7 @@ def request_output_to_completion_response( request_id: str, created_time: int, model_name: str, - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, ) -> CompletionResponse: choices: List[CompletionResponseChoice] = [] num_prompt_tokens = 0 @@ -417,7 +417,7 @@ def _create_completion_logprobs( token_ids: GenericSequence[int], top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], num_output_top_logprobs: int, - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, initial_text_offset: int = 0, ) -> CompletionLogProbs: """Create logprobs for OpenAI Completion API.""" diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index b374a7946b11..861928f9223e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -29,7 +29,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob -from vllm.transformers_utils.tokenizer_group import AnyTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 4a7e5c583291..006dc8e146a6 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,10 +1,10 @@ import functools from dataclasses import dataclass -from typing import (TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, - TypeVar) +from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type from torch import nn from transformers import PretrainedConfig +from typing_extensions import TypeVar from vllm.logger import init_logger @@ -17,7 +17,7 @@ logger = init_logger(__name__) -C = TypeVar("C", bound=PretrainedConfig) +C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) @dataclass(frozen=True) @@ -44,7 +44,7 @@ def get_multimodal_config(self) -> "MultiModalConfig": return multimodal_config - def get_hf_config(self, hf_config_type: Type[C]) -> C: + def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C: """ Get the HuggingFace configuration (:class:`transformers.PretrainedConfig`) of the model, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index f64c78c15f8e..95a989add4cc 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -156,7 +156,7 @@ def get_internvl_num_patches(image_size: int, patch_size: int, def get_max_internvl_image_tokens(ctx: InputContext): - hf_config = ctx.get_hf_config(PretrainedConfig) + hf_config = ctx.get_hf_config() vision_config = hf_config.vision_config image_size = vision_config.image_size patch_size = vision_config.patch_size @@ -172,7 +172,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): return llm_inputs model_config = ctx.model_config - hf_config = ctx.get_hf_config(PretrainedConfig) + hf_config = ctx.get_hf_config() vision_config = hf_config.vision_config image_data = multi_modal_data["image"] @@ -227,7 +227,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int): image_feature_size = get_max_internvl_image_tokens(ctx) model_config = ctx.model_config - hf_config = ctx.get_hf_config(PretrainedConfig) + hf_config = ctx.get_hf_config() vision_config = hf_config.vision_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 8563216d9c39..0e4a4de14311 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -32,7 +32,7 @@ from PIL import Image from torch import nn from torch.nn.init import trunc_normal_ -from transformers.configuration_utils import PretrainedConfig +from transformers import PretrainedConfig from transformers.models.idefics2.modeling_idefics2 import ( Idefics2VisionTransformer) @@ -313,7 +313,7 @@ def _repeat(self, query, N: int): def get_max_minicpmv_image_tokens(ctx: InputContext): - hf_config = ctx.get_hf_config(PretrainedConfig) + hf_config = ctx.get_hf_config() return getattr(hf_config, "query_num", 64) @@ -329,7 +329,7 @@ def dummy_image_for_minicpmv(hf_config): def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int): - hf_config = ctx.get_hf_config(PretrainedConfig) + hf_config = ctx.get_hf_config() # image_feature_size = get_max_minicpmv_image_tokens(ctx) @@ -381,7 +381,7 @@ class MiniCPMV(nn.Module, SupportsVision): def __init__( self, - config, + config: PretrainedConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 75e2f5fc95cb..a10d00cd8af3 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -331,7 +331,7 @@ def get_phi3v_image_feature_size( def get_max_phi3v_image_tokens(ctx: InputContext): return get_phi3v_image_feature_size( - ctx.get_hf_config(PretrainedConfig), + ctx.get_hf_config(), input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, ) @@ -381,7 +381,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): return llm_inputs model_config = ctx.model_config - hf_config = ctx.get_hf_config(PretrainedConfig) + hf_config = ctx.get_hf_config() image_data = multi_modal_data["image"] if isinstance(image_data, Image.Image): diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 76f418674532..a686c043fffd 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -1,10 +1,10 @@ -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import PreTrainedTokenizer from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( - BaseTokenizerGroup) + +from .tokenizer_group import AnyTokenizer, BaseTokenizerGroup # Used eg. for marking rejected tokens in spec decoding. INVALID_TOKEN_ID = -1 @@ -174,7 +174,7 @@ def _replace_none_with_empty(tokens: List[Optional[str]]): def _convert_tokens_to_string_with_added_encoders( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + tokenizer: AnyTokenizer, output_tokens: List[str], skip_special_tokens: bool, spaces_between_special_tokens: bool, @@ -213,7 +213,7 @@ def _convert_tokens_to_string_with_added_encoders( def convert_prompt_ids_to_tokens( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + tokenizer: AnyTokenizer, prompt_ids: List[int], skip_special_tokens: bool = False, ) -> Tuple[List[str], int, int]: @@ -240,7 +240,7 @@ def convert_prompt_ids_to_tokens( # https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 # under Apache 2.0 license def detokenize_incrementally( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + tokenizer: AnyTokenizer, all_input_ids: List[int], prev_tokens: Optional[List[str]], prefix_offset: int, diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index c515f46ecc29..06858a5ad7f5 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,9 +1,8 @@ import os -from typing import Optional, Union +from typing import Optional import huggingface_hub -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import AutoTokenizer, PreTrainedTokenizerFast from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger @@ -11,12 +10,12 @@ from vllm.transformers_utils.tokenizers import BaichuanTokenizer from vllm.utils import make_async +from .tokenizer_group import AnyTokenizer + logger = init_logger(__name__) -def get_cached_tokenizer( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] -) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: +def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: """Get tokenizer with cached properties. This will patch the tokenizer object in place. @@ -62,7 +61,7 @@ def get_tokenizer( revision: Optional[str] = None, download_dir: Optional[str] = None, **kwargs, -) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: +) -> AnyTokenizer: """Gets a tokenizer for the given model name via HuggingFace or ModelScope. """ if VLLM_USE_MODELSCOPE: @@ -133,7 +132,7 @@ def get_tokenizer( def get_lora_tokenizer(lora_request: LoRARequest, *args, - **kwargs) -> Optional[PreTrainedTokenizer]: + **kwargs) -> Optional[AnyTokenizer]: if lora_request is None: return None try: From 2e0fa85501ddcc6c2e6f99eddf1be7942139cbcc Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 31 Jul 2024 05:23:49 +0000 Subject: [PATCH 05/75] Fix `encoding_format` --- vllm/entrypoints/openai/protocol.py | 2 +- vllm/entrypoints/openai/serving_embedding.py | 28 +++++++++++++------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 205860aa8e72..2a10474d5142 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -468,7 +468,7 @@ class EmbeddingRequest(OpenAIBaseModel): # https://platform.openai.com/docs/api-reference/embeddings model: str input: Union[List[int], List[List[int]], str, List[str]] - encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$') + encoding_format: Literal["float", "base64"] = "float" dimensions: Optional[int] = None user: Optional[str] = None diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index bccc90894e79..7e49a60fcff1 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,9 +1,10 @@ import base64 import time -from typing import AsyncIterator, List, Optional, Tuple, cast +from typing import AsyncIterator, List, Literal, Optional, Tuple, Union, cast import numpy as np from fastapi import Request +from typing_extensions import assert_never from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -13,7 +14,7 @@ EmbeddingResponseData, UsageInfo) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.logger import init_logger -from vllm.outputs import EmbeddingRequestOutput +from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) @@ -21,18 +22,28 @@ TypeTokenIDs = List[int] +def _get_embedding( + output: EmbeddingOutput, + encoding_format: Literal["float", "base64"], +) -> Union[List[float], str]: + if encoding_format == "float": + return output.embedding + elif encoding_format == "base64": + embedding_bytes = np.array(output.embedding).tobytes() + return base64.b64encode(embedding_bytes).decode("utf-8") + + assert_never(encoding_format) + + def request_output_to_embedding_response( final_res_batch: List[EmbeddingRequestOutput], request_id: str, created_time: int, model_name: str, - encoding_format: str) -> EmbeddingResponse: + encoding_format: Literal["float", "base64"]) -> EmbeddingResponse: data: List[EmbeddingResponseData] = [] num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): prompt_token_ids = final_res.prompt_token_ids - embedding = final_res.outputs.embedding - if encoding_format == "base64": - embedding_bytes = np.array(embedding).tobytes() - embedding = base64.b64encode(embedding_bytes).decode("utf-8") + embedding = _get_embedding(final_res.outputs, encoding_format) embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) data.append(embedding_data) @@ -81,8 +92,7 @@ async def create_embedding(self, request: EmbeddingRequest, if error_check_ret is not None: return error_check_ret - encoding_format = (request.encoding_format - if request.encoding_format else "float") + encoding_format = request.encoding_format if request.dimensions is not None: return self.create_error_response( "dimensions is currently not supported") From e1f6d4f96bc0ecaafa7a5d5882fc8045a7fe72da Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 31 Jul 2024 05:24:04 +0000 Subject: [PATCH 06/75] Fix misc. --- vllm/entrypoints/llm.py | 16 +++++++++------- vllm/entrypoints/openai/logits_processors.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 8 ++++---- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 765b92a3f6be..7f7ded3d9523 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -487,6 +487,8 @@ def _convert_v1_inputs( inputs: List[PromptInputs] = [] for i in range(num_requests): + item: PromptInputs + if prompts is not None: item = TextPrompt(prompt=prompts[i]) elif prompt_token_ids is not None: @@ -530,12 +532,11 @@ def _validate_and_add_requests( prompt_adapter_request=prompt_adapter_request) def _add_request( - self, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - lora_request: Optional[Union[List[LoRARequest], - LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + self, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request( @@ -543,7 +544,8 @@ def _add_request( inputs, params, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + ) def _run_engine( self, *, use_tqdm: bool diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 31eb5aa628c5..2bfe7f51a1d9 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -44,7 +44,7 @@ def get_logits_processors( logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], allowed_token_ids: Optional[List[int]], tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]: - logits_processors = [] + logits_processors: List[LogitsProcessor] = [] if logit_bias: try: # Convert token_id to integer diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 359fae8e1969..0889379e87db 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -22,7 +22,7 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, PromptAdapterPath) -from vllm.inputs import PromptInputs +from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) @@ -162,9 +162,9 @@ async def create_chat_completion( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - engine_inputs: PromptInputs = { - "prompt_token_ids": prompt_inputs["prompt_token_ids"], - } + engine_inputs = TokensPrompt( + prompt_token_ids=prompt_inputs["prompt_token_ids"], + ) if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data From 625e11f8bc046d277aaf782e62a50ce8233fa8ef Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 31 Jul 2024 00:19:28 -0700 Subject: [PATCH 07/75] [Bugfix][TPU] Set readonly=True for non-root devices (#6980) --- vllm/worker/tpu_worker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index bedea4c40a55..df1a65f6efad 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -104,7 +104,10 @@ def init_device(self) -> None: # Use persistent cache to avoid XLA recompilation. # NOTE(woosuk): This does not completely eliminate the recompilation # overhead because dynamo does not cache the compiled results. - xr.initialize_cache(envs.VLLM_XLA_CACHE_PATH, readonly=False) + # NOTE(woosuk): Set readonly=False only for the rank 0 process to avoid + # race conditions. + xr.initialize_cache(envs.VLLM_XLA_CACHE_PATH, + readonly=not self.is_driver_worker) def load_model(self): self.model_runner.load_model() From fb19d3ebbab0e26b9b960b8934c8be2ec7c797a7 Mon Sep 17 00:00:00 2001 From: Fei Date: Wed, 31 Jul 2024 01:16:01 -0700 Subject: [PATCH 08/75] [Bugfix] fix logit processor excceed vocab size issue (#6927) --- vllm/entrypoints/openai/logits_processors.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 2bfe7f51a1d9..b4bc959d41f9 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -58,6 +58,12 @@ def get_logits_processors( "Found token_id in logit_bias that is not " "an integer or string representing an integer") from exc + # Check if token_id is within the vocab size + for token_id, bias in clamped_logit_bias.items(): + if token_id < 0 or token_id >= tokenizer.vocab_size: + raise ValueError("token_id in logit_bias contains " + "out-of-vocab token id") + def logit_bias_logits_processor(token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: for token_id, bias in clamped_logit_bias.items(): From ad9358c9f5bc777ddb4f5ff2249eb773547843da Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 31 Jul 2024 10:52:08 +0000 Subject: [PATCH 09/75] Fix errors when construct sampling params --- vllm/entrypoints/openai/protocol.py | 4 +- vllm/sampling_params.py | 61 +++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 2a10474d5142..088d7dae1bfa 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -225,7 +225,7 @@ def to_sampling_params(self, tokenizer=tokenizer, ) - return SamplingParams( + return SamplingParams.from_optional( n=self.n, best_of=self.best_of, presence_penalty=self.presence_penalty, @@ -404,7 +404,7 @@ def to_sampling_params(self, tokenizer: PreTrainedTokenizer): tokenizer=tokenizer, ) - return SamplingParams( + return SamplingParams.from_optional( n=self.n, best_of=self.best_of, presence_penalty=self.presence_penalty, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 2598325439eb..a47c28f83c45 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -111,6 +111,67 @@ class SamplingParams: (i.e., no truncation). """ + @staticmethod + def from_optional( + n: Optional[int] = 1, + best_of: Optional[int] = None, + presence_penalty: Optional[float] = 0.0, + frequency_penalty: Optional[float] = 0.0, + repetition_penalty: Optional[float] = 1.0, + temperature: Optional[float] = 1.0, + top_p: Optional[float] = 1.0, + top_k: int = -1, + min_p: float = 0.0, + seed: Optional[int] = None, + use_beam_search: bool = False, + length_penalty: float = 1.0, + early_stopping: Union[bool, str] = False, + stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, + include_stop_str_in_output: bool = False, + ignore_eos: bool = False, + max_tokens: Optional[int] = 16, + min_tokens: int = 0, + logprobs: Optional[int] = None, + prompt_logprobs: Optional[int] = None, + detokenize: bool = True, + skip_special_tokens: bool = True, + spaces_between_special_tokens: bool = True, + logits_processors: Optional[List[LogitsProcessor]] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + ) -> "SamplingParams": + return SamplingParams( + n=1 if n is None else n, + best_of=best_of, + presence_penalty=0.0 + if presence_penalty is None else presence_penalty, + frequency_penalty=0.0 + if frequency_penalty is None else frequency_penalty, + repetition_penalty=1.0 + if repetition_penalty is None else repetition_penalty, + temperature=1.0 if temperature is None else temperature, + top_p=1.0 if top_p is None else top_p, + top_k=top_k, + min_p=min_p, + seed=seed, + use_beam_search=use_beam_search, + length_penalty=length_penalty, + early_stopping=early_stopping, + stop=stop, + stop_token_ids=stop_token_ids, + include_stop_str_in_output=include_stop_str_in_output, + ignore_eos=ignore_eos, + max_tokens=max_tokens, + min_tokens=min_tokens, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs, + detokenize=detokenize, + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + logits_processors=logits_processors, + truncate_prompt_tokens=truncate_prompt_tokens, + ) + def __init__( self, n: int = 1, From ba499d073e1721d465309a15c0485e5f69cbeaec Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 31 Jul 2024 10:54:34 +0000 Subject: [PATCH 10/75] Improve types + format --- vllm/entrypoints/chat_utils.py | 8 ++++-- vllm/entrypoints/openai/cli_args.py | 27 +++++++++++++++++--- vllm/entrypoints/openai/logits_processors.py | 9 ++++--- vllm/entrypoints/openai/serving_chat.py | 9 +++---- 4 files changed, 38 insertions(+), 15 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index fbb7f70b55e1..16b3438b433c 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -188,5 +188,9 @@ def parse_chat_message_content( messages = [ConversationMessage(role=role, content=content)] return ChatMessageParseResult(messages=messages, mm_futures=[]) - return _parse_chat_message_content_parts(role, content, model_config, - tokenizer) + return _parse_chat_message_content_parts( + role, + content, # type: ignore + model_config, + tokenizer, + ) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index a4192937980f..47b147c6e905 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -7,6 +7,7 @@ import argparse import json import ssl +from typing import List, Optional, Sequence, Union from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, @@ -16,8 +17,17 @@ class LoRAParserAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): - lora_list = [] + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None or isinstance(values, str): + raise TypeError("Expected values to be a list") + + lora_list: List[LoRAModulePath] = [] for item in values: name, path = item.split('=') lora_list.append(LoRAModulePath(name, path)) @@ -26,8 +36,17 @@ def __call__(self, parser, namespace, values, option_string=None): class PromptAdapterParserAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): - adapter_list = [] + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None or isinstance(values, str): + raise TypeError("Expected values to be a list") + + adapter_list: List[PromptAdapterPath] = [] for item in values: name, path = item.split('=') adapter_list.append(PromptAdapterPath(name, path)) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index b4bc959d41f9..87968eef4c40 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -2,9 +2,9 @@ from typing import Dict, FrozenSet, Iterable, List, Optional, Union import torch -from transformers import PreTrainedTokenizer from vllm.sampling_params import LogitsProcessor +from vllm.transformers_utils.tokenizer import AnyTokenizer class AllowedTokenIdsLogitsProcessor: @@ -41,9 +41,10 @@ def _get_allowed_token_ids_logits_processor( def get_logits_processors( - logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], - allowed_token_ids: Optional[List[int]], - tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]: + logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], + allowed_token_ids: Optional[List[int]], + tokenizer: AnyTokenizer, +) -> List[LogitsProcessor]: logits_processors: List[LogitsProcessor] = [] if logit_bias: try: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0889379e87db..54722f46199e 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,6 +1,6 @@ import time -from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List, - Optional) +from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Final, + List, Optional) from typing import Sequence as GenericSequence from typing import Union @@ -163,8 +163,7 @@ async def create_chat_completion( prompt_adapter_request=prompt_adapter_request) engine_inputs = TokensPrompt( - prompt_token_ids=prompt_inputs["prompt_token_ids"], - ) + prompt_token_ids=prompt_inputs["prompt_token_ids"]) if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data @@ -217,7 +216,7 @@ async def chat_completion_stream_generator( ) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] created_time = int(time.time()) - chunk_object_type = "chat.completion.chunk" + chunk_object_type: Final = "chat.completion.chunk" first_iteration = True # Send response for each token for each request.n (index) From c596ac9558decd48ca5c37748075b7b41f0d5f10 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 31 Jul 2024 11:05:40 +0000 Subject: [PATCH 11/75] Handle `decoded_token=None` + format --- vllm/entrypoints/openai/serving_chat.py | 32 ++++++++++++------- vllm/entrypoints/openai/serving_completion.py | 12 ++++--- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 54722f46199e..289278531be3 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -543,8 +543,7 @@ def _create_chat_logprobs( num_output_top_logprobs: Optional[int] = None, ) -> ChatCompletionLogProbs: """Create OpenAI-style logprobs.""" - - logprobs_content = [] + logprobs_content: List[ChatCompletionLogProbsContent] = [] for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] @@ -552,23 +551,32 @@ def _create_chat_logprobs( token = tokenizer.decode(token_id) if self.return_tokens_as_token_ids: token = f"token_id:{token_id}" + logprobs_content.append( ChatCompletionLogProbsContent( token=token, - bytes=list(token.encode("utf-8", errors="replace")))) + bytes=list(token.encode("utf-8", errors="replace")), + )) else: + step_token = step_top_logprobs[token_id] + step_decoded = step_token.decoded_token + logprobs_content.append( ChatCompletionLogProbsContent( token=self._get_decoded_token( - step_top_logprobs[token_id], token_id, tokenizer, - self.return_tokens_as_token_ids), - logprob=max(step_top_logprobs[token_id].logprob, - -9999.0), - bytes=list( - step_top_logprobs[token_id].decoded_token.encode( - "utf-8", errors="replace")), + step_token, + token_id, + tokenizer, + self.return_tokens_as_token_ids, + ), + logprob=max(step_token.logprob, -9999.0), + bytes=None if step_decoded is None else list( + step_decoded.encode("utf-8", errors="replace")), top_logprobs=self._get_top_logprobs( - step_top_logprobs, num_output_top_logprobs, - tokenizer))) + step_top_logprobs, + num_output_top_logprobs, + tokenizer, + ), + )) return ChatCompletionLogProbs(content=logprobs_content) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 77eae5b99c42..4e92a4348a01 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -434,17 +434,21 @@ def _create_completion_logprobs( token = tokenizer.decode(token_id) if self.return_tokens_as_token_ids: token = f"token_id:{token_id}" + out_tokens.append(token) out_token_logprobs.append(None) out_top_logprobs.append(None) else: + step_token = step_top_logprobs[token_id] + token = self._get_decoded_token( - step_top_logprobs[token_id], + step_token, token_id, tokenizer, - return_as_token_id=self.return_tokens_as_token_ids) - token_logprob = max(step_top_logprobs[token_id].logprob, - -9999.0) + return_as_token_id=self.return_tokens_as_token_ids, + ) + token_logprob = max(step_token.logprob, -9999.0) + out_tokens.append(token) out_token_logprobs.append(token_logprob) From bebad8cae0d05ab573b0ca22c40b9d81c709dda2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 3 Aug 2024 01:57:56 +0000 Subject: [PATCH 12/75] Fix type errors --- vllm/engine/llm_engine.py | 4 ++-- vllm/entrypoints/openai/logits_processors.py | 8 +++++--- vllm/entrypoints/openai/rpc/client.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index eb9cf69d2029..0989d9abf59f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -40,8 +40,8 @@ init_tracer) from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, - init_tokenizer_from_configs) +from vllm.transformers_utils.tokenizer_group import ( + BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 785005880ed8..090f6c91cc75 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -40,9 +40,11 @@ def _get_allowed_token_ids_logits_processor( return AllowedTokenIdsLogitsProcessor(allowed_token_ids) -def logit_bias_logits_processor(logit_bias: Dict[str, - float], token_ids: List[int], - logits: torch.Tensor) -> torch.Tensor: +def logit_bias_logits_processor( + logit_bias: Dict[int, float], + token_ids: List[int], + logits: torch.Tensor, +) -> torch.Tensor: for token_id, bias in logit_bias.items(): logits[token_id] += bias return logits diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 45bf88b5bf57..6c6608ae7235 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -159,7 +159,7 @@ async def _get_lora_config_rpc(self): expected_type=LoRAConfig, error_message="Could not get LoRAConfig from RPC Server") - async def _is_tracing_enabled_rpc(self) -> ParallelConfig: + async def _is_tracing_enabled_rpc(self) -> bool: """Get is_tracing_enabled flag from the RPCServer""" return await self._send_get_data_rpc_request( From 50a1136e29d0b8bf93211595002b4138d9c61b9e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 3 Aug 2024 18:34:11 +0000 Subject: [PATCH 13/75] Make decorators typed --- tests/tensorizer_loader/conftest.py | 16 ++++++++++++---- tests/utils.py | 11 ++++++++--- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/tensorizer_loader/conftest.py b/tests/tensorizer_loader/conftest.py index b46116391db2..70c4f2d6aa3d 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/tensorizer_loader/conftest.py @@ -1,10 +1,12 @@ import contextlib import functools import gc +from typing import Callable, TypeVar import pytest import ray import torch +from typing_extensions import ParamSpec from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) @@ -22,12 +24,16 @@ def cleanup(): torch.cuda.empty_cache() -def retry_until_skip(n): +_P = ParamSpec("_P") +_R = TypeVar("_R") - def decorator_retry(func): + +def retry_until_skip(n: int): + + def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]: @functools.wraps(func) - def wrapper_retry(*args, **kwargs): + def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R: for i in range(n): try: return func(*args, **kwargs) @@ -35,7 +41,9 @@ def wrapper_retry(*args, **kwargs): gc.collect() torch.cuda.empty_cache() if i == n - 1: - pytest.skip("Skipping test after attempts..") + pytest.skip(f"Skipping test after {n} attempts..") + + raise AssertionError("Code should not be reached") return wrapper_retry diff --git a/tests/utils.py b/tests/utils.py index dd8af8e3afe7..575f4c7df823 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,12 +7,13 @@ import warnings from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import openai import ray import requests from transformers import AutoTokenizer +from typing_extensions import ParamSpec from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) @@ -359,10 +360,14 @@ def wait_for_gpu_memory_to_clear(devices: List[int], time.sleep(5) -def fork_new_process_for_each_test(f): +_P = ParamSpec("_P") + + +def fork_new_process_for_each_test( + f: Callable[_P, None]) -> Callable[_P, None]: @functools.wraps(f) - def wrapper(*args, **kwargs): + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: # Make the process the leader of its own process group # to avoid sending SIGTERM to the parent process os.setpgrp() From 3b0ac79c95f58f259350ba7f499e2e9ac2288dfd Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 3 Aug 2024 18:34:21 +0000 Subject: [PATCH 14/75] Format --- vllm/engine/llm_engine.py | 2 +- vllm/entrypoints/llm.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0989d9abf59f..be971021516d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -40,9 +40,9 @@ init_tracer) from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import ( BaseTokenizerGroup, init_tokenizer_from_configs) -from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 7f7ded3d9523..8b447cb3cc8a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -13,7 +13,8 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer import AnyTokenizer, get_cached_tokenizer +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, deprecate_kwargs From eba3863bdd7c2b222855739280696d51c013ef37 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 3 Aug 2024 18:34:34 +0000 Subject: [PATCH 15/75] Fix type errors --- vllm/entrypoints/openai/serving_completion.py | 52 ++++++++++++++----- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 343ce5020611..fb82dc1f563e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -236,6 +236,19 @@ async def completion_stream_generator( f"{request_id}-{prompt_idx}") raise StopAsyncIteration() + prompt_text = res.prompt + assert prompt_text is not None + + prompt_token_ids = res.prompt_token_ids + assert prompt_token_ids is not None + + prompt_logprobs = res.prompt_logprobs + assert prompt_logprobs is not None + + delta_token_ids: GenericSequence[int] + out_logprobs: Optional[GenericSequence[Optional[Dict[ + int, Logprob]]]] + for output in res.outputs: i = output.index + prompt_idx * num_choices # TODO(simon): optimize the performance by avoiding full @@ -244,18 +257,20 @@ async def completion_stream_generator( assert request.max_tokens is not None if request.echo and request.max_tokens == 0: # only return the prompt - delta_text = res.prompt - delta_token_ids = res.prompt_token_ids - out_logprobs = res.prompt_logprobs + delta_text = prompt_text + delta_token_ids = prompt_token_ids + out_logprobs = prompt_logprobs has_echoed[i] = True elif (request.echo and request.max_tokens > 0 and not has_echoed[i]): # echo the prompt and first token - delta_text = res.prompt + output.text - delta_token_ids = (res.prompt_token_ids + - output.token_ids) - out_logprobs = res.prompt_logprobs + (output.logprobs - or []) + delta_text = prompt_text + output.text + delta_token_ids = [ + *prompt_token_ids, *output.token_ids + ] + out_logprobs = [ + *prompt_logprobs, *(output.logprobs or []) + ] has_echoed[i] = True else: # return just the delta @@ -300,7 +315,7 @@ async def completion_stream_generator( and request.stream_options.include_usage): if (request.stream_options.continuous_usage_stats or output.finish_reason is not None): - prompt_tokens = len(res.prompt_token_ids) + prompt_tokens = len(prompt_token_ids) completion_tokens = len(output.token_ids) usage = UsageInfo( prompt_tokens=prompt_tokens, @@ -349,8 +364,16 @@ def request_output_to_completion_response( for final_res in final_res_batch: prompt_token_ids = final_res.prompt_token_ids + prompt_logprobs = final_res.prompt_logprobs + assert prompt_logprobs is not None + prompt_text = final_res.prompt + assert prompt_text is not None + + token_ids: GenericSequence[int] + out_logprobs: Optional[GenericSequence[Optional[Dict[int, + Logprob]]]] for output in final_res.outputs: assert request.max_tokens is not None @@ -359,9 +382,14 @@ def request_output_to_completion_response( out_logprobs = prompt_logprobs output_text = prompt_text elif request.echo and request.max_tokens > 0: - token_ids = prompt_token_ids + list(output.token_ids) - out_logprobs = (prompt_logprobs + output.logprobs - if request.logprobs is not None else None) + token_ids = [*prompt_token_ids, *output.token_ids] + + if request.logprobs is None: + out_logprobs = None + else: + assert output.logprobs is not None + out_logprobs = [*prompt_logprobs, *output.logprobs] + output_text = prompt_text + output.text else: token_ids = output.token_ids From ff49909d306fc326406a8b2cabc63893a6e5efe8 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 4 Aug 2024 05:41:15 +0000 Subject: [PATCH 16/75] Fix type errors from merged commits --- vllm/engine/protocol.py | 10 ++++++++-- vllm/entrypoints/openai/protocol.py | 1 + 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index fc94ef6662e0..0f679f7688c6 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -40,6 +40,7 @@ async def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncIterator[RequestOutput]: """Generates outputs for a request""" + ... async def encode( self, @@ -50,6 +51,7 @@ async def encode( trace_headers: Optional[Mapping[str, str]] = None, ) -> AsyncIterator[EmbeddingRequestOutput]: """Generate outputs for a request from an embedding model.""" + ... async def abort(self, request_id: str) -> None: """Abort a request. @@ -60,8 +62,10 @@ async def abort(self, request_id: str) -> None: async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" + ... async def get_decoding_config(self) -> DecodingConfig: + ... """Get the decoding configuration of the vLLM engine.""" async def get_tokenizer( @@ -69,16 +73,18 @@ async def get_tokenizer( lora_request: Optional[LoRARequest] = None, ) -> PreTrainedTokenizer: """Get the appropriate Tokenizer for the request""" + ... async def is_tracing_enabled(self) -> bool: - pass + ... async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, model_output: Optional[List[SamplerOutput]] = None, ) -> None: - pass + ... async def check_health(self) -> None: """Raise if unhealthy""" + ... diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 8352fe963da5..751870fb0c04 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -18,6 +18,7 @@ # torch is mocked during docs generation, # so we have to provide the values as literals _MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) +_LONG_INFO: Union["torch.iinfo", Namespace] try: from sphinx.ext.autodoc.mock import _MockModule From 4f80738d8e0e66b00e47f6c5bfe085a19871997e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 4 Aug 2024 08:25:54 +0000 Subject: [PATCH 17/75] Use more flexible tokenizer type --- vllm/engine/async_llm_engine.py | 5 ++--- vllm/engine/output_processor/interfaces.py | 5 ++--- vllm/engine/output_processor/multi_step.py | 5 ++--- vllm/engine/output_processor/stop_checker.py | 6 ++---- vllm/engine/protocol.py | 7 +++---- vllm/entrypoints/chat_utils.py | 10 +++++----- vllm/entrypoints/openai/protocol.py | 6 +++--- vllm/multimodal/image.py | 5 ++--- vllm/transformers_utils/detokenizer.py | 5 +---- 9 files changed, 22 insertions(+), 32 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 704718b2ddfa..12b5241a3d53 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -4,8 +4,6 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) -from transformers import PreTrainedTokenizer - import vllm.envs as envs from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -24,6 +22,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -508,7 +507,7 @@ def _error_callback(self, exc: Exception) -> None: async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, - ) -> "PreTrainedTokenizer": + ) -> "AnyTokenizer": if self.engine_use_ray: return await self.engine.get_tokenizer.remote( # type: ignore lora_request) diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 92aecebe6ec3..a385f37d807a 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -1,13 +1,12 @@ from abc import ABC, abstractmethod from typing import Callable, List -from transformers import PreTrainedTokenizer - from vllm.config import SchedulerConfig from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.stop_checker import StopChecker from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Counter @@ -29,7 +28,7 @@ def create_output_processor( detokenizer: Detokenizer, scheduler: List[Scheduler], seq_counter: Counter, - get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], stop_checker: "StopChecker", ): """Create an output processor. diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 25d15df9f915..6c472528a7a9 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -1,8 +1,6 @@ import functools from typing import Callable, List -from transformers import PreTrainedTokenizer - from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.interfaces import ( SequenceGroupOutputProcessor) @@ -12,6 +10,7 @@ from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Counter logger = init_logger(__name__) @@ -36,7 +35,7 @@ def __init__( detokenizer: Detokenizer, scheduler: List[Scheduler], seq_counter: Counter, - get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], stop_checker: StopChecker, ): self.detokenizer = detokenizer diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 96f0d1142611..0c5f8fb7f5be 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -1,10 +1,9 @@ from typing import Callable, Optional -from transformers import PreTrainedTokenizer - from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import Sequence, SequenceStatus +from vllm.transformers_utils.tokenizer import AnyTokenizer class StopChecker: @@ -15,8 +14,7 @@ class StopChecker: """ def __init__(self, max_model_len: int, - get_tokenizer_for_seq: Callable[[Sequence], - PreTrainedTokenizer]): + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]): # Do not use it directly, but use `self._get_max_model_len`. self._max_model_len = max_model_len self.get_tokenizer_for_seq = get_tokenizer_for_seq diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 0f679f7688c6..c580ca8ff469 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -1,8 +1,6 @@ from typing import (AsyncIterator, List, Mapping, Optional, Protocol, runtime_checkable) -from transformers import PreTrainedTokenizer - from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs from vllm.inputs.data import PromptInputs @@ -12,6 +10,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput +from vllm.transformers_utils.tokenizer import AnyTokenizer @runtime_checkable @@ -71,8 +70,8 @@ async def get_decoding_config(self) -> DecodingConfig: async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, - ) -> PreTrainedTokenizer: - """Get the appropriate Tokenizer for the request""" + ) -> AnyTokenizer: + """Get the appropriate tokenizer for the request""" ... async def is_tracing_enabled(self) -> bool: diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 48447f8ce14f..984e2d78c4ca 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -15,13 +15,13 @@ # yapf: enable # pydantic needs the TypedDict from typing_extensions from pydantic import ConfigDict -from transformers import PreTrainedTokenizer from typing_extensions import Required, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import async_get_and_parse_image +from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) @@ -93,7 +93,7 @@ def load_chat_template(chat_template: Optional[str]) -> Optional[str]: @lru_cache(maxsize=None) def _image_token_str(model_config: ModelConfig, - tokenizer: PreTrainedTokenizer) -> Optional[str]: + tokenizer: AnyTokenizer) -> Optional[str]: # TODO: Let user specify how to insert image tokens into prompt # (similar to chat template) model_type = model_config.hf_config.model_type @@ -126,7 +126,7 @@ def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], model_config: ModelConfig, - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, ) -> ChatMessageParseResult: texts: List[str] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] @@ -177,7 +177,7 @@ def _parse_chat_message_content_parts( def _parse_chat_message_content( message: ChatCompletionMessageParam, model_config: ModelConfig, - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, ) -> ChatMessageParseResult: role = message["role"] content = message.get("content") @@ -199,7 +199,7 @@ def _parse_chat_message_content( def parse_chat_messages( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, ) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]: conversation: List[ConversationMessage] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 751870fb0c04..9ca505ff6e99 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -6,13 +6,13 @@ import torch from pydantic import BaseModel, ConfigDict, Field, model_validator -from transformers import PreTrainedTokenizer from typing_extensions import Annotated from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams from vllm.sampling_params import LogitsProcessor, SamplingParams +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid # torch is mocked during docs generation, @@ -233,7 +233,7 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params def to_sampling_params( - self, tokenizer: PreTrainedTokenizer, + self, tokenizer: AnyTokenizer, guided_decode_logits_processor: Optional[LogitsProcessor], default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens @@ -418,7 +418,7 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params def to_sampling_params( - self, tokenizer: PreTrainedTokenizer, + self, tokenizer: AnyTokenizer, guided_decode_logits_processor: Optional[LogitsProcessor], default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 3b37ce9149fb..44faea27af35 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -3,13 +3,12 @@ import torch from PIL import Image -from transformers import PreTrainedTokenizerBase from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.image_processor import get_image_processor -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from .base import MultiModalInputs, MultiModalPlugin @@ -39,7 +38,7 @@ def repeat_and_pad_token( def repeat_and_pad_image_tokens( - tokenizer: PreTrainedTokenizerBase, + tokenizer: AnyTokenizer, prompt: Optional[str], prompt_token_ids: List[int], *, diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 3541c8b869d4..06dfd59e3ff1 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -1,7 +1,5 @@ from typing import Dict, List, Optional, Tuple -from transformers import PreTrainedTokenizer - from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup from .tokenizer_group import AnyTokenizer, BaseTokenizerGroup @@ -16,8 +14,7 @@ class Detokenizer: def __init__(self, tokenizer_group: BaseTokenizerGroup): self.tokenizer_group = tokenizer_group - def get_tokenizer_for_seq(self, - sequence: Sequence) -> "PreTrainedTokenizer": + def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: """Returns the HF tokenizer to use for a given sequence.""" return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request) From ffe97d6408f7d2b1f582ec2969306fff27775193 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 4 Aug 2024 08:27:39 +0000 Subject: [PATCH 18/75] Fix arg --- vllm/entrypoints/openai/rpc/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 7a72a6f732c9..195e98ccf9c8 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -22,8 +22,8 @@ class AsyncEngineRPCServer: def __init__(self, async_engine_args: AsyncEngineArgs, usage_context: UsageContext, port: int): # Initialize engine first. - self.engine = AsyncLLMEngine.from_engine_args(async_engine_args, - usage_context) + self.engine = AsyncLLMEngine.from_engine_args( + async_engine_args, usage_context=usage_context) # Initialize context. self.context = zmq.asyncio.Context() From 0e4c97d9757dd5f6c3100afc45ff179bb7f1ed8d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 01:31:08 +0000 Subject: [PATCH 19/75] Fix merge --- vllm/config.py | 2 +- vllm/engine/llm_engine.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ffe2739c796b..2dd922b58279 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -12,7 +12,7 @@ from vllm.model_executor.models import ModelRegistry from vllm.tracing import is_otel_installed from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import (GiB_bytes, STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, +from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, cuda_device_count_stateless, get_cpu_memory, is_cpu, is_hip, is_neuron, is_openvino, is_tpu, is_xpu, print_warning_once) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8d9f2732e79a..8745bc4aec6d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -688,8 +688,8 @@ def _tokenize_prompt( * prompt token ids ''' - tokenizer = self.get_tokenizer_group("prompts must be None if " - "skip_tokenizer_init is True") + tokenizer = self.get_tokenizer_group( + missing_msg="prompts must be None if skip_tokenizer_init is True") prompt_token_ids = tokenizer.encode(request_id=request_id, prompt=prompt, From 8291a9de1b85e84c75620ab28fcefdd6f4709188 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 05:45:40 +0000 Subject: [PATCH 20/75] Remove unnecessary type annotations --- vllm/entrypoints/openai/serving_completion.py | 5 ++--- vllm/entrypoints/openai/serving_embedding.py | 8 +++----- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index d3fb0b4f6f41..9fda41f37ebe 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -144,9 +144,8 @@ async def create_completion(self, request: CompletionRequest, # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator: AsyncIterator[Tuple[ - int, RequestOutput]] = merge_async_iterators( - *generators, is_cancelled=raw_request.is_disconnected) + result_generator = merge_async_iterators( + *generators, is_cancelled=raw_request.is_disconnected) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index bd43a722f9a8..f1b3be54f1e5 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,8 +1,7 @@ import asyncio import base64 import time -from typing import (AsyncGenerator, AsyncIterator, List, Literal, Optional, - Tuple, Union, cast) +from typing import AsyncGenerator, List, Literal, Optional, Union, cast import numpy as np from fastapi import Request @@ -149,9 +148,8 @@ async def create_embedding(self, request: EmbeddingRequest, # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator: AsyncIterator[Tuple[ - int, EmbeddingRequestOutput]] = merge_async_iterators( - *generators, is_cancelled=raw_request.is_disconnected) + result_generator = merge_async_iterators( + *generators, is_cancelled=raw_request.is_disconnected) # Non-streaming response final_res_batch: List[Optional[EmbeddingRequestOutput]] From 46732b8923fb1169623a1fa3e76f721aa7564454 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 05:49:08 +0000 Subject: [PATCH 21/75] Simplify code --- tests/test_utils.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8d22c20bb197..3d2c2a061291 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,8 @@ import asyncio import os import socket -import sys from functools import partial -from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol, - Tuple, TypeVar) +from typing import AsyncIterator, Tuple import pytest @@ -13,26 +11,11 @@ from .utils import error_on_warning -if sys.version_info < (3, 10): - if TYPE_CHECKING: - _AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any]) - _AwaitableT_co = TypeVar("_AwaitableT_co", - bound=Awaitable[Any], - covariant=True) - - class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]): - - def __anext__(self) -> _AwaitableT_co: - ... - - def anext(i: "_SupportsSynchronousAnext[_AwaitableT]", /) -> "_AwaitableT": - return i.__anext__() - @pytest.mark.asyncio async def test_merge_async_iterators(): - async def mock_async_iterator(idx: int) -> AsyncIterator[str]: + async def mock_async_iterator(idx: int): try: while True: yield f"item from iterator {idx}" @@ -41,7 +24,7 @@ async def mock_async_iterator(idx: int) -> AsyncIterator[str]: print(f"iterator {idx} cancelled") iterators = [mock_async_iterator(i) for i in range(3)] - merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators( + merged_iterator = merge_async_iterators( *iterators, is_cancelled=partial(asyncio.sleep, 0, result=False)) async def stream_output(generator: AsyncIterator[Tuple[int, str]]): @@ -56,7 +39,8 @@ async def stream_output(generator: AsyncIterator[Tuple[int, str]]): for iterator in iterators: try: - await asyncio.wait_for(anext(iterator), 1) + # Can use anext() in python >= 3.10 + await asyncio.wait_for(iterator.__anext__(), 1) except StopAsyncIteration: # All iterators should be cancelled and print this message. print("Iterator was cancelled normally") From 37ab8346e5349a87bb77b4b7d3cb8a0a0cffb8cc Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 05:54:43 +0000 Subject: [PATCH 22/75] Cleanup --- tests/test_utils.py | 6 ++++-- vllm/entrypoints/openai/serving_chat.py | 6 +++--- vllm/entrypoints/openai/serving_completion.py | 11 +++++++---- vllm/entrypoints/openai/serving_embedding.py | 10 +++++++--- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 3d2c2a061291..c157be1c08f8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -24,8 +24,10 @@ async def mock_async_iterator(idx: int): print(f"iterator {idx} cancelled") iterators = [mock_async_iterator(i) for i in range(3)] - merged_iterator = merge_async_iterators( - *iterators, is_cancelled=partial(asyncio.sleep, 0, result=False)) + merged_iterator = merge_async_iterators(*iterators, + is_cancelled=partial(asyncio.sleep, + 0, + result=False)) async def stream_output(generator: AsyncIterator[Tuple[int, str]]): async for idx, output in generator: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a60946f75d59..26ac31283629 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -66,9 +66,9 @@ def __init__( async def create_chat_completion( self, request: ChatCompletionRequest, - raw_request: Optional[Request] = None - ) -> Union[ErrorResponse, AsyncGenerator[str, None], - ChatCompletionResponse]: + raw_request: Optional[Request] = None, + ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, + ErrorResponse]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 9fda41f37ebe..cd9c5ac2240f 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -3,7 +3,7 @@ from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional) from typing import Sequence as GenericSequence -from typing import Tuple, cast +from typing import Tuple, Union, cast from fastapi import Request @@ -18,7 +18,7 @@ CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - UsageInfo) + ErrorResponse, UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, @@ -60,8 +60,11 @@ def __init__( request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids) - async def create_completion(self, request: CompletionRequest, - raw_request: Request): + async def create_completion( + self, + request: CompletionRequest, + raw_request: Request, + ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/completions/create diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index f1b3be54f1e5..f01ab4c4d8f7 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -12,7 +12,8 @@ from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, - EmbeddingResponseData, UsageInfo) + EmbeddingResponseData, + ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.logger import init_logger from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput @@ -82,8 +83,11 @@ def __init__( request_logger=request_logger) self._check_embedding_mode(model_config.embedding_mode) - async def create_embedding(self, request: EmbeddingRequest, - raw_request: Request): + async def create_embedding( + self, + request: EmbeddingRequest, + raw_request: Request, + ) -> Union[EmbeddingResponse, ErrorResponse]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/embeddings/create From 2da334c205584d0a8990673299975dacfd17b8de Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 05:59:01 +0000 Subject: [PATCH 23/75] Fix type errors --- vllm/entrypoints/openai/api_server.py | 36 +++++++++++++++------------ 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 88f0bd4ee4db..419b68d14c0c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -14,6 +14,7 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import make_asgi_app from starlette.routing import Mount +from typing_extensions import assert_never import vllm.envs as envs from vllm.config import ModelConfig @@ -28,14 +29,16 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, + CompletionResponse, DetokenizeRequest, DetokenizeResponse, - EmbeddingRequest, ErrorResponse, + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, TokenizeRequest, TokenizeResponse) +# yapf: enable from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient from vllm.entrypoints.openai.rpc.server import run_rpc_server -# yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -155,10 +158,11 @@ async def tokenize(request: TokenizeRequest): if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - else: - assert isinstance(generator, TokenizeResponse) + elif isinstance(generator, TokenizeResponse): return JSONResponse(content=generator.model_dump()) + assert_never(generator) + @router.post("/detokenize") async def detokenize(request: DetokenizeRequest): @@ -166,10 +170,11 @@ async def detokenize(request: DetokenizeRequest): if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - else: - assert isinstance(generator, DetokenizeResponse) + elif isinstance(generator, DetokenizeResponse): return JSONResponse(content=generator.model_dump()) + assert_never(generator) + @router.get("/v1/models") async def show_available_models(): @@ -191,13 +196,11 @@ async def create_chat_completion(request: ChatCompletionRequest, if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - if request.stream: - return StreamingResponse(content=generator, - media_type="text/event-stream") - else: - assert isinstance(generator, ChatCompletionResponse) + elif isinstance(generator, ChatCompletionResponse): return JSONResponse(content=generator.model_dump()) + return StreamingResponse(content=generator, media_type="text/event-stream") + @router.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): @@ -206,12 +209,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request): if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - if request.stream: - return StreamingResponse(content=generator, - media_type="text/event-stream") - else: + elif isinstance(generator, CompletionResponse): return JSONResponse(content=generator.model_dump()) + return StreamingResponse(content=generator, media_type="text/event-stream") + @router.post("/v1/embeddings") async def create_embedding(request: EmbeddingRequest, raw_request: Request): @@ -220,9 +222,11 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - else: + elif isinstance(generator, EmbeddingResponse): return JSONResponse(content=generator.model_dump()) + assert_never(generator) + def build_app(args: Namespace) -> FastAPI: app = FastAPI(lifespan=lifespan) From 475d84aa9124dc53ebe481e4238cdca810da0bd6 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 06:08:52 +0000 Subject: [PATCH 24/75] Fix type error --- vllm/entrypoints/openai/api_server.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 419b68d14c0c..7f96a9f3d8e0 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -89,7 +89,8 @@ async def _force_log(): @asynccontextmanager -async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: +async def build_async_engine_client( + args: Namespace) -> AsyncIterator[AsyncEngineClient]: # Context manager to handle async_engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit global engine_args @@ -118,8 +119,11 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: rpc_server_process.start() # Build RPCClient, which conforms to AsyncEngineClient Protocol. - async_engine_client = AsyncEngineRPCClient(port) - await async_engine_client.setup() + # NOTE: Actually, this is not true yet. We still need to support + # embedding models via RPC (see TODO above) + rpc_client = AsyncEngineRPCClient(port) + async_engine_client = rpc_client # type: ignore + await rpc_client.setup() try: yield async_engine_client @@ -128,7 +132,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: rpc_server_process.terminate() # Close all open connections to the backend - async_engine_client.close() + rpc_client.close() # Wait for server process to join rpc_server_process.join() From 937a8caaf98f9715304114f840f434be97d13425 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 06:21:26 +0000 Subject: [PATCH 25/75] Clean --- tests/tensorizer_loader/conftest.py | 2 +- vllm/utils.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/tensorizer_loader/conftest.py b/tests/tensorizer_loader/conftest.py index 70c4f2d6aa3d..07b9c6b3c6be 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/tensorizer_loader/conftest.py @@ -41,7 +41,7 @@ def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R: gc.collect() torch.cuda.empty_cache() if i == n - 1: - pytest.skip(f"Skipping test after {n} attempts..") + pytest.skip(f"Skipping test after {n} attempts.") raise AssertionError("Code should not be reached") diff --git a/vllm/utils.py b/vllm/utils.py index 4c4385ae5ab7..c004e412ac34 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1049,10 +1049,10 @@ def cuda_device_count_stateless() -> int: pynvml = None -def with_nvml_context(fn): +def with_nvml_context(fn: Callable[P, T]) -> Callable[P, T]: @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: if pynvml is not None: pynvml.nvmlInit() try: @@ -1088,9 +1088,9 @@ def is_full_nvlink(device_ids: List[int]) -> bool: #From: https://stackoverflow.com/a/4104188/2749989 -def run_once(f): +def run_once(f: Callable[P, None]) -> Callable[P, None]: - def wrapper(*args, **kwargs) -> Any: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: if not wrapper.has_run: # type: ignore[attr-defined] wrapper.has_run = True # type: ignore[attr-defined] return f(*args, **kwargs) From 33c9e25c61e140a803a4acbcaddb520a1eeba84e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 07:53:30 +0000 Subject: [PATCH 26/75] Introduce `is_list_of` --- vllm/inputs/data.py | 20 ++++++++++---------- vllm/multimodal/image.py | 6 ++++-- vllm/utils.py | 25 ++++++++++++++++++++++--- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 86c2901dc4c8..9df0bd2041d9 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,8 +1,10 @@ from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, - TypedDict, Union, cast, overload) + TypedDict, Union, overload) from typing_extensions import NotRequired +from vllm.utils import is_list_of + if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict @@ -41,25 +43,23 @@ def parse_and_batch_prompt( if len(prompt) == 0: raise ValueError("please provide at least one prompt") - if isinstance(prompt[0], str): + if is_list_of(prompt, str): # case 2: array of strings return [ - ParsedText(content=elem, is_tokens=False) - for elem in cast(List[str], prompt) + ParsedText(content=elem, is_tokens=False) for elem in prompt ] - if isinstance(prompt[0], int): + if is_list_of(prompt, int): # case 3: array of tokens - elem = cast(List[int], prompt) - return [ParsedTokens(content=elem, is_tokens=True)] - if isinstance(prompt[0], list): + return [ParsedTokens(content=prompt, is_tokens=True)] + if is_list_of(prompt, list): if len(prompt[0]) == 0: raise ValueError("please provide at least one prompt") - if isinstance(prompt[0][0], int): + if is_list_of(prompt[0], int): # case 4: array of token arrays return [ ParsedTokens(content=elem, is_tokens=True) - for elem in cast(List[List[int]], prompt) + for elem in prompt ] raise ValueError("prompt must be a string, array of strings, " diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index b6a3909e9563..db50229bda31 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -10,6 +10,7 @@ from vllm.logger import init_logger from vllm.transformers_utils.image_processor import get_image_processor from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils import is_list_of from .base import MultiModalInputs, MultiModalPlugin @@ -113,7 +114,8 @@ def _get_hf_image_processor(self, model_config: ModelConfig): def _default_input_mapper(self, ctx: InputContext, data: object) -> MultiModalInputs: model_config = ctx.model_config - if isinstance(data, (Image.Image, list)): + + if isinstance(data, Image.Image) or is_list_of(data, Image.Image): image_processor = self._get_hf_image_processor(model_config) if image_processor is None: raise RuntimeError("No HuggingFace processor is available " @@ -127,7 +129,7 @@ def _default_input_mapper(self, ctx: InputContext, raise return MultiModalInputs(batch_data) - elif isinstance(data, torch.Tensor): + elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor): raise NotImplementedError("Embeddings input is not supported yet") raise TypeError(f"Invalid image type: {type(data)}") diff --git a/vllm/utils.py b/vllm/utils.py index 61e3bb0bfc33..413f6ce62276 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -17,15 +17,15 @@ from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, - Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar, - Union, overload) + Hashable, List, Literal, Optional, OrderedDict, Set, Tuple, + Type, TypeVar, Union, overload) import numpy as np import numpy.typing as npt import psutil import torch import torch.types -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, TypeGuard, assert_never import vllm.envs as envs from vllm import _custom_ops as ops @@ -807,6 +807,24 @@ def get_dtype_size(dtype: torch.dtype) -> int: return torch.tensor([], dtype=dtype).element_size() +# `collections` helpers +def is_list_of( + value: object, + typ: Type[T], + *, + check: Literal["first", "all"] = "first", +) -> TypeGuard[List[T]]: + if not isinstance(value, list): + return False + + if check == "first": + return len(value) == 0 or isinstance(value[0], typ) + elif check == "all": + return all(isinstance(v, typ) for v in value) + + assert_never(check) + + def merge_dicts(dict1: Dict[K, List[T]], dict2: Dict[K, List[T]]) -> Dict[K, List[T]]: """Merge 2 dicts that have key -> List of items. @@ -954,6 +972,7 @@ def enable_trace_function_call_for_thread() -> None: enable_trace_function_call(log_path) +# `functools` helpers def identity(value: T) -> T: return value From e6dd6f5b6ea5f2eb7febf631d86983a9f439120d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 07:56:30 +0000 Subject: [PATCH 27/75] Avoid circular imports --- examples/offline_inference_encoder_decoder.py | 4 +- tests/conftest.py | 6 +- tests/test_inputs.py | 2 +- vllm/engine/llm_engine.py | 4 +- vllm/entrypoints/llm.py | 4 +- vllm/entrypoints/openai/serving_engine.py | 2 +- vllm/inputs/__init__.py | 20 ++- vllm/inputs/data.py | 145 +++--------------- vllm/inputs/parse.py | 125 +++++++++++++++ vllm/sequence.py | 2 +- vllm/utils.py | 29 ---- 11 files changed, 171 insertions(+), 172 deletions(-) create mode 100644 vllm/inputs/parse.py diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 79b284554f17..c05e8e8bb6f1 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -4,8 +4,8 @@ ''' from vllm import LLM, SamplingParams -from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt -from vllm.utils import zip_enc_dec_prompt_lists +from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, + TokensPrompt, zip_enc_dec_prompt_lists) dtype = "float" diff --git a/tests/conftest.py b/tests/conftest.py index c0bf9897c97f..b0adfc58bcda 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,13 +21,13 @@ from vllm.connections import global_http_connection from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) -from vllm.inputs import TextPrompt +from vllm.inputs import (TextPrompt, to_enc_dec_tuple_list, + zip_enc_dec_prompt_lists) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sequence import SampleLogprobs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, - is_cpu, to_enc_dec_tuple_list, - zip_enc_dec_prompt_lists) + is_cpu) logger = init_logger(__name__) diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 887c7101decd..3725d8687f25 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -2,7 +2,7 @@ import pytest -from vllm.inputs import parse_and_batch_prompt +from vllm.inputs.parse import parse_and_batch_prompt STRING_INPUTS = [ '', diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 75c6d7e6c9b2..10913efbd889 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -22,8 +22,8 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs, - get_prompt_type) +from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs +from vllm.inputs.parse import get_prompt_type from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index eaa157209493..175f418a1294 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,8 +6,8 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt, - parse_and_batch_prompt) +from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt +from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index df4932d8fe18..8d8b5ea4bdf5 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -22,7 +22,7 @@ TokenizeCompletionRequest, TokenizeRequest) # yapf: enable -from vllm.inputs import parse_and_batch_prompt +from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index e22b88f2fc38..1dcd1ad343b3 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,7 @@ -from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText, - ParsedTokens, PromptInputs, SingletonPromptInputs, - TextPrompt, TokensPrompt, get_prompt_type, - is_valid_encoder_decoder_llm_inputs, parse_and_batch_prompt) +from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, PromptInputs, + SingletonPromptInputs, TextPrompt, TokensPrompt, + build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, + zip_enc_dec_prompt_lists) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -14,18 +14,16 @@ """ __all__ = [ - "ParsedText", - "ParsedTokens", - "parse_and_batch_prompt", "TextPrompt", "TokensPrompt", "PromptInputs", + "ExplicitEncoderDecoderPrompt", + "SingletonPromptInputs", "LLMInputs", + "build_explicit_enc_dec_prompt", + "to_enc_dec_tuple_list", + "zip_enc_dec_prompt_lists", "INPUT_REGISTRY", "InputContext", "InputRegistry", - "get_prompt_type", - "is_valid_encoder_decoder_llm_inputs", - "ExplicitEncoderDecoderPrompt", - "SingletonPromptInputs", ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 9df0bd2041d9..4cee911b4398 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,71 +1,11 @@ -from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, - TypedDict, Union, overload) +from typing import TYPE_CHECKING, List, Optional, Tuple, TypedDict, Union from typing_extensions import NotRequired -from vllm.utils import is_list_of - if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict -class ParsedText(TypedDict): - content: str - is_tokens: Literal[False] - - -class ParsedTokens(TypedDict): - content: List[int] - is_tokens: Literal[True] - - -# https://github.com/vllm-project/vllm/pull/4028 -@overload -def parse_and_batch_prompt( - prompt: Union[str, List[str]]) -> Sequence[ParsedText]: - ... - - -@overload -def parse_and_batch_prompt( - prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]: - ... - - -def parse_and_batch_prompt( - prompt: Union[str, List[str], List[int], List[List[int]]], -) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: - if isinstance(prompt, str): - # case 1: a string - return [ParsedText(content=prompt, is_tokens=False)] - - if isinstance(prompt, list): - if len(prompt) == 0: - raise ValueError("please provide at least one prompt") - - if is_list_of(prompt, str): - # case 2: array of strings - return [ - ParsedText(content=elem, is_tokens=False) for elem in prompt - ] - if is_list_of(prompt, int): - # case 3: array of tokens - return [ParsedTokens(content=prompt, is_tokens=True)] - if is_list_of(prompt, list): - if len(prompt[0]) == 0: - raise ValueError("please provide at least one prompt") - - if is_list_of(prompt[0], int): - # case 4: array of token arrays - return [ - ParsedTokens(content=elem, is_tokens=True) - for elem in prompt - ] - - raise ValueError("prompt must be a string, array of strings, " - "array of tokens, or array of token arrays") - - class TextPrompt(TypedDict): """Schema for a text prompt.""" @@ -150,56 +90,6 @@ class ExplicitEncoderDecoderPrompt(TypedDict): """ -def _has_required_keys( - d: dict, - required_keys: set, -) -> bool: - return required_keys.issubset(d.keys()) - - -def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]: - """ - Get the type-name of the prompt argument instance, given that - isinstance() cannot apply to TypedDict subclasses directly. - If the prompt is None, return 'None' as the type name. - - Arguments: - - * prompt: LLM input prompt or None - - Returns: - - * String representation of prompt type - """ - - if prompt is None: - return 'None' - - required_keys_dict = { - 'TextPrompt': {'prompt'}, - 'TokensPrompt': {'prompt_token_ids'}, - 'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'}, - } - - if isinstance(prompt, dict): - for (ptype, required_keys) in required_keys_dict.items(): - # Ignore type checking in the conditional below because type - # checker does not understand that is_dict(prompt) narrows - # down the possible types - if _has_required_keys( - prompt, # type: ignore - required_keys): - return ptype - - raise ValueError(f"Invalid prompt {prompt}, valid types are " - "required_keys_dict={required_keys_dict}") - - if isinstance(prompt, str): - return "str" - - raise ValueError(f"Invalid prompt {prompt}") - - class LLMInputs(TypedDict): """ The inputs in :class:`~vllm.LLMEngine` before they are @@ -229,13 +119,28 @@ class LLMInputs(TypedDict): """ -def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool: - """ - Return True if the LLMInputs instance has the correct configuration - for encoder/decoder. - """ +def build_explicit_enc_dec_prompt( + encoder_prompt: SingletonPromptInputs, + decoder_prompt: SingletonPromptInputs, +) -> ExplicitEncoderDecoderPrompt: + return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt, + decoder_prompt=decoder_prompt) + + +def zip_enc_dec_prompt_lists( + enc_prompt_list: List[SingletonPromptInputs], + dec_prompt_list: List[SingletonPromptInputs], +) -> List[ExplicitEncoderDecoderPrompt]: + return [ + build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt) + for (encoder_prompt, + decoder_prompt) in zip(enc_prompt_list, dec_prompt_list) + ] + - # True if encoder prompt token ids field exists & - # is not None - return ('encoder_prompt_token_ids' in inputs - and inputs['encoder_prompt_token_ids'] is not None) +def to_enc_dec_tuple_list( + enc_dec_prompts: List[ExplicitEncoderDecoderPrompt], +) -> List[Tuple[PromptInputs, PromptInputs]]: + return [(enc_dec_prompt['encoder_prompt'], + enc_dec_prompt['decoder_prompt']) + for enc_dec_prompt in enc_dec_prompts] diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py new file mode 100644 index 000000000000..42bd9858bcbe --- /dev/null +++ b/vllm/inputs/parse.py @@ -0,0 +1,125 @@ +from typing import (List, Literal, Optional, Sequence, TypedDict, Union, + overload) + +from vllm.utils import is_list_of + +from .data import LLMInputs, PromptInputs + + +class ParsedText(TypedDict): + content: str + is_tokens: Literal[False] + + +class ParsedTokens(TypedDict): + content: List[int] + is_tokens: Literal[True] + + +# https://github.com/vllm-project/vllm/pull/4028 +@overload +def parse_and_batch_prompt( + prompt: Union[str, List[str]]) -> Sequence[ParsedText]: + ... + + +@overload +def parse_and_batch_prompt( + prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]: + ... + + +def parse_and_batch_prompt( + prompt: Union[str, List[str], List[int], List[List[int]]], +) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: + if isinstance(prompt, str): + # case 1: a string + return [ParsedText(content=prompt, is_tokens=False)] + + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + + if is_list_of(prompt, str): + # case 2: array of strings + return [ + ParsedText(content=elem, is_tokens=False) for elem in prompt + ] + if is_list_of(prompt, int): + # case 3: array of tokens + return [ParsedTokens(content=prompt, is_tokens=True)] + if is_list_of(prompt, list): + if len(prompt[0]) == 0: + raise ValueError("please provide at least one prompt") + + if is_list_of(prompt[0], int): + # case 4: array of token arrays + return [ + ParsedTokens(content=elem, is_tokens=True) + for elem in prompt + ] + + raise ValueError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + +def _has_required_keys( + d: dict, + required_keys: set, +) -> bool: + return required_keys.issubset(d.keys()) + + +def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]: + """ + Get the type-name of the prompt argument instance, given that + isinstance() cannot apply to TypedDict subclasses directly. + If the prompt is None, return 'None' as the type name. + + Arguments: + + * prompt: LLM input prompt or None + + Returns: + + * String representation of prompt type + """ + + if prompt is None: + return 'None' + + required_keys_dict = { + 'TextPrompt': {'prompt'}, + 'TokensPrompt': {'prompt_token_ids'}, + 'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'}, + } + + if isinstance(prompt, dict): + for (ptype, required_keys) in required_keys_dict.items(): + # Ignore type checking in the conditional below because type + # checker does not understand that is_dict(prompt) narrows + # down the possible types + if _has_required_keys( + prompt, # type: ignore + required_keys): + return ptype + + raise ValueError(f"Invalid prompt {prompt}, valid types are " + f"required_keys_dict={required_keys_dict}") + + if isinstance(prompt, str): + return "str" + + raise ValueError(f"Invalid prompt {prompt}") + + +def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool: + """ + Return True if the LLMInputs instance has the correct configuration + for encoder/decoder. + """ + + # True if encoder prompt token ids field exists & + # is not None + return ('encoder_prompt_token_ids' in inputs + and inputs['encoder_prompt_token_ids'] is not None) diff --git a/vllm/sequence.py b/vllm/sequence.py index 634785533382..fbd148001cc7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -11,7 +11,7 @@ import torch -from vllm.inputs import is_valid_encoder_decoder_llm_inputs +from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest diff --git a/vllm/utils.py b/vllm/utils.py index 413f6ce62276..eb88fce4af0c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -29,8 +29,6 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs, - SingletonPromptInputs) from vllm.logger import enable_trace_function_call, init_logger logger = init_logger(__name__) @@ -1164,30 +1162,3 @@ def is_embedding_model_config(model_config) -> bool: ''' return model_config is not None and \ model_config.embedding_mode - - -def build_explicit_enc_dec_prompt( - encoder_prompt: SingletonPromptInputs, - decoder_prompt: SingletonPromptInputs, -) -> ExplicitEncoderDecoderPrompt: - return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt, - decoder_prompt=decoder_prompt) - - -def zip_enc_dec_prompt_lists( - enc_prompt_list: List[SingletonPromptInputs], - dec_prompt_list: List[SingletonPromptInputs], -) -> List[ExplicitEncoderDecoderPrompt]: - return [ - build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt) - for (encoder_prompt, - decoder_prompt) in zip(enc_prompt_list, dec_prompt_list) - ] - - -def to_enc_dec_tuple_list( - enc_dec_prompts: List[ExplicitEncoderDecoderPrompt], -) -> List[Tuple[PromptInputs, PromptInputs]]: - return [(enc_dec_prompt['encoder_prompt'], - enc_dec_prompt['decoder_prompt']) - for enc_dec_prompt in enc_dec_prompts] From f938c8690274bc61f20720ff8b227beb9e980b11 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 09:09:20 +0000 Subject: [PATCH 28/75] Refactor prompt parsing and extend this to async engine --- .github/workflows/mypy.yaml | 2 +- requirements-common.txt | 2 +- requirements-lint.txt | 2 +- requirements-openvino.txt | 2 +- vllm/engine/async_llm_engine.py | 155 ++++++++++++--- vllm/engine/llm_engine.py | 196 ++++++++----------- vllm/entrypoints/openai/logits_processors.py | 8 +- vllm/inputs/__init__.py | 9 +- vllm/inputs/data.py | 16 +- vllm/inputs/parse.py | 68 ++----- 10 files changed, 252 insertions(+), 208 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 8d423657630c..f7b84eebc8b6 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install mypy==1.9.0 + pip install mypy==1.11.1 pip install types-setuptools pip install types-PyYAML pip install types-requests diff --git a/requirements-common.txt b/requirements-common.txt index d8c95bf77240..ebd0fca51919 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.10.3 outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 -typing_extensions +typing_extensions >= 4.10 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 pyzmq gguf == 0.9.1 diff --git a/requirements-lint.txt b/requirements-lint.txt index bd34227d3e82..d0b2fef6deae 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -8,7 +8,7 @@ isort==5.13.2 clang-format==18.1.5 # type checking -mypy==1.9.0 +mypy==1.11.1 types-PyYAML types-requests types-setuptools diff --git a/requirements-openvino.txt b/requirements-openvino.txt index 2dd971d6400b..dc0ae55c9253 100644 --- a/requirements-openvino.txt +++ b/requirements-openvino.txt @@ -22,7 +22,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.10.3 outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 -typing_extensions +typing_extensions >= 4.10 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 pyzmq gguf == 0.9.1 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index b4a9520e623e..2200003e4b84 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -5,6 +5,7 @@ Optional, Set, Tuple, Type, Union) from transformers import PreTrainedTokenizer +from typing_extensions import assert_never import vllm.envs as envs from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, @@ -16,9 +17,12 @@ from vllm.engine.metrics import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.ray_utils import initialize_ray_cluster, ray -from vllm.inputs import LLMInputs, PromptInputs +from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, + SingletonPromptInputs) +from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalDataDict from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -291,38 +295,140 @@ async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() - async def process_model_inputs_async( + async def _tokenize_prompt_async( self, + prompt: str, request_id: str, - inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: - if isinstance(inputs, str): - inputs = {"prompt": inputs} + ) -> List[int]: + ''' + Wrapper around application of the model's + tokenizer. + + Arguments: + + * prompt + * request_id + * lora_request + + Returns: + + * prompt token ids + ''' + + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") + + return await tokenizer.encode_async(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + + async def _extract_prompt_components_async( + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + ) -> Tuple[Optional[str], List[int], Optional[MultiModalDataDict]]: + ''' + Extract the components of any single encoder or decoder input prompt. - if "prompt_token_ids" not in inputs: - tokenizer = self.get_tokenizer_group("prompts must be None if " - "skip_tokenizer_init is True") + Arguments: - prompt_token_ids = await tokenizer.encode_async( + * request_id + * inputs: single encoder or decoder input prompt + * lora_request: this is only valid for decoder prompts + + Returns: + + * prompt + * prompt_token_ids + * multi_modal_data + ''' + + if isinstance(inputs, str): + prompt = inputs + prompt_token_ids = await self._tokenize_prompt_async( + prompt, request_id=request_id, - prompt=inputs["prompt"], - lora_request=lora_request) + lora_request=lora_request, + ) + multi_modal_data = None + elif isinstance(inputs, dict): + if "prompt_token_ids" in inputs: + prompt = None + prompt_token_ids = inputs["prompt_token_ids"] + else: + # NOTE: This extra assignment is required to pass mypy + prompt = parsed_prompt = inputs["prompt"] + prompt_token_ids = await self._tokenize_prompt_async( + parsed_prompt, + request_id=request_id, + lora_request=lora_request, + ) + + multi_modal_data = inputs.get("multi_modal_data") else: - prompt_token_ids = inputs["prompt_token_ids"] + assert_never(inputs) + + return prompt, prompt_token_ids, multi_modal_data + + async def _process_decoder_only_prompt_async( + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> LLMInputs: + + ( + prompt, + prompt_token_ids, + multi_modal_data, + ) = await self._extract_prompt_components_async( + inputs, + request_id=request_id, + lora_request=lora_request, + ) if prompt_adapter_request: - prompt_token_ids = [ - 0 - ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \ - prompt_token_ids + prompt_token_ids = ( + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + + prompt_token_ids) + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=prompt, + multi_modal_data=multi_modal_data) + + async def process_model_inputs_async( + self, + inputs: PromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + if self.is_encoder_decoder_model(): + # TODO: Make this async + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder - llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + model_inputs = self._process_encoder_decoder_prompt( + inputs, + request_id=request_id, + ) + else: + if is_explicit_encoder_decoder_prompt(inputs): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + + # Decoder-only operation + model_inputs = await self._process_decoder_only_prompt_async( + inputs, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) - return self.input_processor(llm_inputs) + return self.input_processor(model_inputs) async def add_request_async( self, @@ -341,10 +447,11 @@ async def add_request_async( arrival_time = time.time() processed_inputs = await self.process_model_inputs_async( + inputs, request_id=request_id, - inputs=inputs, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + ) self._add_processed_request( request_id=request_id, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 10913efbd889..5044ea8d620c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -5,6 +5,8 @@ from typing import Sequence as GenericSequence from typing import Set, Tuple, Type, TypeVar, Union +from typing_extensions import assert_never + import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, @@ -22,10 +24,12 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs -from vllm.inputs.parse import get_prompt_type +from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, LLMInputs, + PromptInputs, SingletonPromptInputs) +from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalDataDict from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams @@ -553,7 +557,7 @@ def _get_decoder_start_token_id(self, ) -> Optional[int]: def _add_processed_request( self, request_id: str, - processed_inputs: LLMInputs, + processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs], params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], @@ -613,7 +617,7 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() - _LLMInputComponentsType = Tuple[str, List[int], ] + _LLMInputComponentsType = Tuple[str, List[int]] def _prepare_decoder_input_ids_for_generation( self, @@ -646,7 +650,7 @@ def _prepare_decoder_input_ids_for_generation( if decoder_input_ids is None: # no decoder prompt input -> # use decoder_start_token_id as decoder_input_ids - (decoder_input_ids) = self._get_default_enc_dec_decoder_prompt() + decoder_input_ids = self._get_default_enc_dec_decoder_prompt() if (len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id): @@ -657,8 +661,8 @@ def _prepare_decoder_input_ids_for_generation( def _tokenize_prompt( self, prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[str] = None, + request_id: str, + lora_request: Optional[LoRARequest] = None, ) -> List[int]: ''' Wrapper around application of the model's @@ -678,87 +682,60 @@ def _tokenize_prompt( tokenizer = self.get_tokenizer_group("prompts must be None if " "skip_tokenizer_init is True") - prompt_token_ids = tokenizer.encode(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - - return prompt_token_ids + return tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) - def _extract_single_prompt_for_enc_dec_input( + def _extract_prompt_components( self, - inputs: Optional[PromptInputs], - request_id: Optional[str] = None, - ptype: Optional[str] = None, - is_encoder_prompt: bool = False, - ) -> Tuple[Optional[str], List[int]]: + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + ) -> Tuple[Optional[str], List[int], Optional[MultiModalDataDict]]: ''' - Only for encoder/decoder models: - Extract prompt & prompt_token_ids from any single - encoder or decoder input prompt. For encoder input prompts - in particular, also extract multi-modal data. - - This function handles the following scenarios: - 1. The user supplied a singleton encoder prompt - & the prompt/prompt-token-ids must be extracted. - 2. The user supplied an explicit encoder/decoder - prompt & the prompt/prompt-token-ids must be - extracted from either the encoder and decoder prompts. - - For decoder prompts in particular (scenario 2), special - processing is applied to the returned decoder token ids. + Extract the components of any single encoder or decoder input prompt. Arguments: * request_id - * ptype: str representation of the input prompt type. - If `ptype` is `None`, assume that the prompt - type is unknown and must be inferred. This is the - case for ExplicitEncoderDecoder sub-prompts. * inputs: single encoder or decoder input prompt - * is_encoder_prompt: True if encoder input prompt. - If False, decoder prompt tokens - are preprocessed. + * lora_request: this is only valid for decoder prompts Returns: * prompt * prompt_token_ids + * multi_modal_data ''' - prompt_token_ids = None - ptype = (get_prompt_type(inputs) if ptype is None else ptype) - if inputs is None: - prompt = None - elif ptype == 'str': + if isinstance(inputs, str): prompt = inputs prompt_token_ids = self._tokenize_prompt( prompt, request_id=request_id, + lora_request=lora_request, ) - elif ptype == 'TokensPrompt': - prompt = None - prompt_token_ids = inputs['prompt_token_ids'] + multi_modal_data = None + elif isinstance(inputs, dict): + if "prompt_token_ids" in inputs: + prompt = None + prompt_token_ids = inputs["prompt_token_ids"] + else: + # NOTE: This extra assignment is required to pass mypy + prompt = parsed_prompt = inputs["prompt"] + prompt_token_ids = self._tokenize_prompt( + parsed_prompt, + request_id=request_id, + lora_request=lora_request, + ) + + multi_modal_data = inputs.get("multi_modal_data") else: - prompt = inputs['prompt'] - prompt_token_ids = self._tokenize_prompt( - prompt, - request_id=request_id, - ) - - if not is_encoder_prompt: - # Apply special pre-processing to - # decoder prompts - prompt_token_ids = (self._prepare_decoder_input_ids_for_generation( - prompt_token_ids, )) - - assert prompt_token_ids is not None + assert_never(inputs) - return ( - prompt, - prompt_token_ids, - ) + return prompt, prompt_token_ids, multi_modal_data - def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]: + def _get_default_enc_dec_decoder_prompt(self) -> List[int]: ''' Specifically for encoder/decoder models: generate a default decoder prompt for when @@ -798,8 +775,8 @@ def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]: def _process_encoder_decoder_prompt( self, inputs: PromptInputs, - request_id: Optional[str] = None, - ) -> LLMInputs: + request_id: str, + ) -> EncoderDecoderLLMInputs: ''' For encoder/decoder models only: Process an input prompt @@ -830,20 +807,17 @@ def _process_encoder_decoder_prompt( Returns: - * `LLMInputs` instance + * `EncoderDecoderLLMInputs` instance ''' - ptype = get_prompt_type(inputs) - # Obtain encoder and decoder prompt tokens. Note # that, no matter what, the decoder # prompt type is unknown. - if ptype == "ExplicitEncoderDecoder": + if is_explicit_encoder_decoder_prompt(inputs): # If input is explicit encoder/decoder prompt, # then it remains to be determined what type # of encoder prompt we have extracted_encoder_prompt = inputs.get('encoder_prompt') - encoder_ptype = None # Extract decoder prompt from explicit # encoder/decoder prompt extracted_decoder_prompt = inputs.get('decoder_prompt') @@ -851,7 +825,6 @@ def _process_encoder_decoder_prompt( # If input is singleton encoder prompt, then # we know the encoder prompt type extracted_encoder_prompt = inputs - encoder_ptype = ptype # Decoder prompt is always unknown if # encoder/decoder prompt is not explicit extracted_decoder_prompt = None @@ -865,32 +838,35 @@ def _process_encoder_decoder_prompt( ( encoder_prompt, encoder_prompt_token_ids, - ) = self._extract_single_prompt_for_enc_dec_input( + _, + ) = self._extract_prompt_components( extracted_encoder_prompt, request_id=request_id, - ptype=encoder_ptype, - is_encoder_prompt=True, ) # Invoke helper method to obtain # decoder prompt and prompt token ids. # - # The helper method will detect the decoder - # prompt type. - # # Helper method will also apply special # preprocessing unique to decoder prompts. - ( - decoder_prompt, - decoder_prompt_token_ids, - ) = self._extract_single_prompt_for_enc_dec_input( - extracted_decoder_prompt, - request_id=request_id, - ptype=None, - is_encoder_prompt=False, - ) + if extracted_decoder_prompt is None: + decoder_prompt_token_ids = encoder_prompt_token_ids + decoder_prompt = encoder_prompt + else: + ( + decoder_prompt, + decoder_prompt_token_ids, + _, + ) = self._extract_prompt_components( + extracted_decoder_prompt, + request_id=request_id, + ) - return LLMInputs( + decoder_prompt_token_ids = ( + self._prepare_decoder_input_ids_for_generation( + decoder_prompt_token_ids)) + + return EncoderDecoderLLMInputs( prompt_token_ids=decoder_prompt_token_ids, prompt=decoder_prompt, encoder_prompt_token_ids=encoder_prompt_token_ids, @@ -899,9 +875,9 @@ def _process_encoder_decoder_prompt( def _process_decoder_only_prompt( self, - inputs: PromptInputs, + inputs: SingletonPromptInputs, + request_id: str, lora_request: Optional[LoRARequest] = None, - request_id: Optional[str] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: ''' @@ -912,8 +888,8 @@ def _process_decoder_only_prompt( Arguments: * inputs: input prompt - * lora_request * request_id + * lora_request * prompt_adapter_request Returns: @@ -921,18 +897,15 @@ def _process_decoder_only_prompt( * `LLMInputs` instance ''' - if isinstance(inputs, str): - inputs = {"prompt": inputs} - prompt = inputs.get("prompt") - - if "prompt_token_ids" not in inputs: - prompt_token_ids = self._tokenize_prompt( - prompt, - request_id=request_id, - lora_request=lora_request, - ) - else: - prompt_token_ids = inputs["prompt_token_ids"] + ( + prompt, + prompt_token_ids, + multi_modal_data, + ) = self._extract_prompt_components( + inputs, + request_id=request_id, + lora_request=lora_request, + ) if prompt_adapter_request: prompt_token_ids = ( @@ -941,15 +914,15 @@ def _process_decoder_only_prompt( return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=prompt, - multi_modal_data=inputs.get("multi_modal_data")) + multi_modal_data=multi_modal_data) def process_model_inputs( self, - request_id: str, inputs: PromptInputs, + request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: + ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of @@ -960,6 +933,10 @@ def process_model_inputs( request_id=request_id, ) else: + if is_explicit_encoder_decoder_prompt(inputs): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + # Decoder-only operation model_inputs = self._process_decoder_only_prompt( inputs, @@ -1029,10 +1006,11 @@ def add_request( arrival_time = time.time() processed_inputs = self.process_model_inputs( + inputs, request_id=request_id, - inputs=inputs, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + ) self._add_processed_request( request_id=request_id, diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 84871fc83ef5..c0cd820e30c0 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -40,9 +40,11 @@ def _get_allowed_token_ids_logits_processor( return AllowedTokenIdsLogitsProcessor(allowed_token_ids) -def logit_bias_logits_processor(logit_bias: Dict[str, - float], token_ids: List[int], - logits: torch.Tensor) -> torch.Tensor: +def logit_bias_logits_processor( + logit_bias: Dict[int, float], + token_ids: List[int], + logits: torch.Tensor, +) -> torch.Tensor: for token_id, bias in logit_bias.items(): logits[token_id] += bias return logits diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 1dcd1ad343b3..0e1e7c828a71 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,7 @@ -from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, PromptInputs, - SingletonPromptInputs, TextPrompt, TokensPrompt, - build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, - zip_enc_dec_prompt_lists) +from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, + LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + TokensPrompt, build_explicit_enc_dec_prompt, + to_enc_dec_tuple_list, zip_enc_dec_prompt_lists) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -20,6 +20,7 @@ "ExplicitEncoderDecoderPrompt", "SingletonPromptInputs", "LLMInputs", + "EncoderDecoderLLMInputs", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompt_lists", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 4cee911b4398..8732aea3a557 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -103,7 +103,15 @@ class LLMInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ - encoder_prompt_token_ids: NotRequired[List[int]] + multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +class EncoderDecoderLLMInputs(LLMInputs): + encoder_prompt_token_ids: List[int] """The token IDs of the encoder prompt.""" encoder_prompt: NotRequired[Optional[str]] @@ -112,12 +120,6 @@ class LLMInputs(TypedDict): available. """ - multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] - """ - Optional multi-modal data to pass to the model, - if the model supports it. - """ - def build_explicit_enc_dec_prompt( encoder_prompt: SingletonPromptInputs, diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 42bd9858bcbe..984140f3651c 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -1,9 +1,11 @@ -from typing import (List, Literal, Optional, Sequence, TypedDict, Union, - overload) +from typing import List, Literal, Sequence, TypedDict, Union, overload + +from typing_extensions import TypeIs from vllm.utils import is_list_of -from .data import LLMInputs, PromptInputs +from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, + LLMInputs, PromptInputs) class ParsedText(TypedDict): @@ -63,63 +65,15 @@ def parse_and_batch_prompt( "array of tokens, or array of token arrays") -def _has_required_keys( - d: dict, - required_keys: set, -) -> bool: - return required_keys.issubset(d.keys()) +def is_explicit_encoder_decoder_prompt( + inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(inputs, dict) and "encoder_prompt" in inputs -def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]: - """ - Get the type-name of the prompt argument instance, given that - isinstance() cannot apply to TypedDict subclasses directly. - If the prompt is None, return 'None' as the type name. - - Arguments: - - * prompt: LLM input prompt or None - - Returns: - - * String representation of prompt type - """ - - if prompt is None: - return 'None' - - required_keys_dict = { - 'TextPrompt': {'prompt'}, - 'TokensPrompt': {'prompt_token_ids'}, - 'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'}, - } - - if isinstance(prompt, dict): - for (ptype, required_keys) in required_keys_dict.items(): - # Ignore type checking in the conditional below because type - # checker does not understand that is_dict(prompt) narrows - # down the possible types - if _has_required_keys( - prompt, # type: ignore - required_keys): - return ptype - - raise ValueError(f"Invalid prompt {prompt}, valid types are " - f"required_keys_dict={required_keys_dict}") - - if isinstance(prompt, str): - return "str" - - raise ValueError(f"Invalid prompt {prompt}") - - -def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool: +def is_valid_encoder_decoder_llm_inputs( + inputs: LLMInputs) -> TypeIs[EncoderDecoderLLMInputs]: """ Return True if the LLMInputs instance has the correct configuration for encoder/decoder. """ - - # True if encoder prompt token ids field exists & - # is not None - return ('encoder_prompt_token_ids' in inputs - and inputs['encoder_prompt_token_ids'] is not None) + return "encoder_prompt_token_ids" in inputs From 6332d1eb9ffa14039ac55e084bc1e0b2d291a367 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 09:16:09 +0000 Subject: [PATCH 29/75] Remove unnecessary comments --- vllm/engine/llm_engine.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5044ea8d620c..f380b6e81789 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -829,12 +829,6 @@ def _process_encoder_decoder_prompt( # encoder/decoder prompt is not explicit extracted_decoder_prompt = None - # Invoke helper function to obtain encoder - # prompt and prompt token ids, either from - # singleton encoder prompt or from the - # encoder sub-prompt of an explicit - # encoder/decode scenario 2), special - # processing is applied to the returned decoder token ids ( encoder_prompt, encoder_prompt_token_ids, @@ -844,11 +838,6 @@ def _process_encoder_decoder_prompt( request_id=request_id, ) - # Invoke helper method to obtain - # decoder prompt and prompt token ids. - # - # Helper method will also apply special - # preprocessing unique to decoder prompts. if extracted_decoder_prompt is None: decoder_prompt_token_ids = encoder_prompt_token_ids decoder_prompt = encoder_prompt From 07b4d211ecb080a4a2aa9c455823a37d83e46c1b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 09:31:19 +0000 Subject: [PATCH 30/75] Enable full async --- vllm/engine/async_llm_engine.py | 92 +++++++++++++++++++-------------- vllm/engine/llm_engine.py | 68 +++++++++++++----------- 2 files changed, 91 insertions(+), 69 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 2200003e4b84..10dba270717e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -301,21 +301,7 @@ async def _tokenize_prompt_async( request_id: str, lora_request: Optional[LoRARequest] = None, ) -> List[int]: - ''' - Wrapper around application of the model's - tokenizer. - - Arguments: - - * prompt - * request_id - * lora_request - - Returns: - - * prompt token ids - ''' - + """Async version of :meth:`_tokenize_prompt`.""" tokenizer = self.get_tokenizer_group("prompts must be None if " "skip_tokenizer_init is True") @@ -329,22 +315,7 @@ async def _extract_prompt_components_async( request_id: str, lora_request: Optional[LoRARequest] = None, ) -> Tuple[Optional[str], List[int], Optional[MultiModalDataDict]]: - ''' - Extract the components of any single encoder or decoder input prompt. - - Arguments: - - * request_id - * inputs: single encoder or decoder input prompt - * lora_request: this is only valid for decoder prompts - - Returns: - - * prompt - * prompt_token_ids - * multi_modal_data - ''' - + """Async version of :meth:`_extract_prompt_components`.""" if isinstance(inputs, str): prompt = inputs prompt_token_ids = await self._tokenize_prompt_async( @@ -372,6 +343,51 @@ async def _extract_prompt_components_async( return prompt, prompt_token_ids, multi_modal_data + async def _process_encoder_decoder_prompt_async( + self, + inputs: PromptInputs, + request_id: str, + ) -> EncoderDecoderLLMInputs: + """Async version of :meth:`_process_encoder_decoder_prompt`.""" + explicit_inputs = self._to_explicit_encoder_decoder_prompt(inputs) + extracted_encoder_prompt = explicit_inputs["encoder_prompt"] + extracted_decoder_prompt = explicit_inputs["decoder_prompt"] + + ( + encoder_prompt, + encoder_prompt_token_ids, + _, + ) = await self._extract_prompt_components_async( + extracted_encoder_prompt, + request_id=request_id, + ) + + # Avoid repeated processing if the inputs was originally in singleton + # form, see self._to_explicit_encoder_decoder_prompt + if extracted_decoder_prompt is extracted_encoder_prompt: + decoder_prompt_token_ids = encoder_prompt_token_ids + decoder_prompt = encoder_prompt + else: + ( + decoder_prompt, + decoder_prompt_token_ids, + _, + ) = await self._extract_prompt_components_async( + extracted_decoder_prompt, + request_id=request_id, + ) + + decoder_prompt_token_ids = ( + self._prepare_decoder_input_ids_for_generation( + decoder_prompt_token_ids)) + + return EncoderDecoderLLMInputs( + prompt_token_ids=decoder_prompt_token_ids, + prompt=decoder_prompt, + encoder_prompt_token_ids=encoder_prompt_token_ids, + encoder_prompt=encoder_prompt, + ) + async def _process_decoder_only_prompt_async( self, inputs: SingletonPromptInputs, @@ -379,7 +395,7 @@ async def _process_decoder_only_prompt_async( lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: - + """Async version of :meth:`_process_decoder_only_prompt`.""" ( prompt, prompt_token_ids, @@ -390,10 +406,8 @@ async def _process_decoder_only_prompt_async( lora_request=lora_request, ) - if prompt_adapter_request: - prompt_token_ids = ( - [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens - + prompt_token_ids) + prompt_token_ids = self._apply_prompt_adapter( + prompt_token_ids, prompt_adapter_request=prompt_adapter_request) return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=prompt, @@ -406,12 +420,11 @@ async def process_model_inputs_async( lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + """Async version of :meth:`process_model_inputs`.""" if self.is_encoder_decoder_model(): - # TODO: Make this async # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder - - model_inputs = self._process_encoder_decoder_prompt( + model_inputs = await self._process_encoder_decoder_prompt_async( inputs, request_id=request_id, ) @@ -440,6 +453,7 @@ async def add_request_async( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: + """Async version of :meth:`add_request`.""" if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f380b6e81789..204a5c28867e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -24,8 +24,9 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, LLMInputs, - PromptInputs, SingletonPromptInputs) +from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, + ExplicitEncoderDecoderPrompt, LLMInputs, PromptInputs, + SingletonPromptInputs) from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -735,6 +736,18 @@ def _extract_prompt_components( return prompt, prompt_token_ids, multi_modal_data + def _apply_prompt_adapter( + self, + prompt_token_ids: List[int], + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[int]: + if prompt_adapter_request: + prompt_token_ids = ( + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + + prompt_token_ids) + + return prompt_token_ids + def _get_default_enc_dec_decoder_prompt(self) -> List[int]: ''' Specifically for encoder/decoder models: @@ -769,8 +782,19 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: bos_token_id = self._get_bos_token_id() assert bos_token_id is not None - prompt_token_ids: List[int] = [bos_token_id] - return prompt_token_ids + return [bos_token_id] + + def _to_explicit_encoder_decoder_prompt( + self, + inputs: PromptInputs, + ) -> ExplicitEncoderDecoderPrompt: + if is_explicit_encoder_decoder_prompt(inputs): + return inputs + + return ExplicitEncoderDecoderPrompt( + encoder_prompt=inputs, + decoder_prompt=inputs, + ) def _process_encoder_decoder_prompt( self, @@ -779,8 +803,8 @@ def _process_encoder_decoder_prompt( ) -> EncoderDecoderLLMInputs: ''' For encoder/decoder models only: - Process an input prompt - into an `LLMInputs` instance. + Process an input prompt into an + :class:`EncoderDecoderLLMInputs` instance. There are two types of input prompts: singleton prompts which carry only the @@ -810,24 +834,9 @@ def _process_encoder_decoder_prompt( * `EncoderDecoderLLMInputs` instance ''' - # Obtain encoder and decoder prompt tokens. Note - # that, no matter what, the decoder - # prompt type is unknown. - if is_explicit_encoder_decoder_prompt(inputs): - # If input is explicit encoder/decoder prompt, - # then it remains to be determined what type - # of encoder prompt we have - extracted_encoder_prompt = inputs.get('encoder_prompt') - # Extract decoder prompt from explicit - # encoder/decoder prompt - extracted_decoder_prompt = inputs.get('decoder_prompt') - else: - # If input is singleton encoder prompt, then - # we know the encoder prompt type - extracted_encoder_prompt = inputs - # Decoder prompt is always unknown if - # encoder/decoder prompt is not explicit - extracted_decoder_prompt = None + explicit_inputs = self._to_explicit_encoder_decoder_prompt(inputs) + extracted_encoder_prompt = explicit_inputs["encoder_prompt"] + extracted_decoder_prompt = explicit_inputs["decoder_prompt"] ( encoder_prompt, @@ -838,7 +847,9 @@ def _process_encoder_decoder_prompt( request_id=request_id, ) - if extracted_decoder_prompt is None: + # Avoid repeated processing if the inputs was originally in singleton + # form, see self._to_explicit_encoder_decoder_prompt + if extracted_decoder_prompt is extracted_encoder_prompt: decoder_prompt_token_ids = encoder_prompt_token_ids decoder_prompt = encoder_prompt else: @@ -896,10 +907,8 @@ def _process_decoder_only_prompt( lora_request=lora_request, ) - if prompt_adapter_request: - prompt_token_ids = ( - [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens - + prompt_token_ids) + prompt_token_ids = self._apply_prompt_adapter( + prompt_token_ids, prompt_adapter_request=prompt_adapter_request) return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=prompt, @@ -916,7 +925,6 @@ def process_model_inputs( if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder - model_inputs = self._process_encoder_decoder_prompt( inputs, request_id=request_id, From e29864cdcc9ff7837b9e05a250d0129b89b6f6e2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 09:32:44 +0000 Subject: [PATCH 31/75] grammar --- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 10dba270717e..5626f3a2a3d8 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -362,7 +362,7 @@ async def _process_encoder_decoder_prompt_async( request_id=request_id, ) - # Avoid repeated processing if the inputs was originally in singleton + # Avoid repeated processing if the input was originally in singleton # form, see self._to_explicit_encoder_decoder_prompt if extracted_decoder_prompt is extracted_encoder_prompt: decoder_prompt_token_ids = encoder_prompt_token_ids diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 204a5c28867e..d9e6f6912fcb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -847,7 +847,7 @@ def _process_encoder_decoder_prompt( request_id=request_id, ) - # Avoid repeated processing if the inputs was originally in singleton + # Avoid repeated processing if the input was originally in singleton # form, see self._to_explicit_encoder_decoder_prompt if extracted_decoder_prompt is extracted_encoder_prompt: decoder_prompt_token_ids = encoder_prompt_token_ids From c9dfb401f7963910a1e416fc9c31344b7c31ed32 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 09:40:10 +0000 Subject: [PATCH 32/75] Add description --- vllm/inputs/data.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 8732aea3a557..b65e5d5f0686 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -94,6 +94,8 @@ class LLMInputs(TypedDict): """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. + + This includes the data required for decoder-only models. """ prompt_token_ids: List[int] """The token IDs of the prompt.""" @@ -111,6 +113,12 @@ class LLMInputs(TypedDict): class EncoderDecoderLLMInputs(LLMInputs): + """ + The inputs in :class:`~vllm.LLMEngine` before they are + passed to the model executor. + + This includes the required data for encoder-decoder models. + """ encoder_prompt_token_ids: List[int] """The token IDs of the encoder prompt.""" From 123319227828d1c4e1b82ffee7dc67f1012ea7aa Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 09:50:15 +0000 Subject: [PATCH 33/75] Fix wrong type annotations --- tests/conftest.py | 15 +++++++-------- vllm/inputs/data.py | 2 +- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b0adfc58bcda..5bfb8fc132a8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,8 +21,8 @@ from vllm.connections import global_http_connection from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) -from vllm.inputs import (TextPrompt, to_enc_dec_tuple_list, - zip_enc_dec_prompt_lists) +from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, + to_enc_dec_tuple_list, zip_enc_dec_prompt_lists) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sequence import SampleLogprobs @@ -125,9 +125,8 @@ def example_prompts() -> List[str]: @pytest.fixture -def example_encoder_decoder_prompts() \ - -> Dict[DecoderPromptType, - Tuple[List[str], List[Optional[str]]]]: +def example_encoder_decoder_prompts( +) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]: ''' Returns an encoder prompt list and a decoder prompt list, wherein each pair of same-index entries in both lists corresponds to an (encoder prompt, @@ -444,7 +443,7 @@ def generate_greedy_logprobs_limit( def generate_encoder_decoder_greedy_logprobs_limit( self, - encoder_decoder_prompts: Tuple[List[str], List[str]], + encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt], max_tokens: int, num_logprobs: int, **kwargs: Any, @@ -608,7 +607,7 @@ def generate_w_logprobs( def generate_encoder_decoder_w_logprobs( self, - encoder_decoder_prompts: Tuple[List[str], List[str]], + encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt], sampling_params: SamplingParams, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ''' @@ -653,7 +652,7 @@ def generate_greedy_logprobs( def generate_encoder_decoder_greedy_logprobs( self, - encoder_decoder_prompts: Tuple[List[str], List[str]], + encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt], max_tokens: int, num_logprobs: int, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index b65e5d5f0686..1d5b6b3fcdc0 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -150,7 +150,7 @@ def zip_enc_dec_prompt_lists( def to_enc_dec_tuple_list( enc_dec_prompts: List[ExplicitEncoderDecoderPrompt], -) -> List[Tuple[PromptInputs, PromptInputs]]: +) -> List[Tuple[SingletonPromptInputs, SingletonPromptInputs]]: return [(enc_dec_prompt['encoder_prompt'], enc_dec_prompt['decoder_prompt']) for enc_dec_prompt in enc_dec_prompts] From dcdebee669a85192d73c89ddb65732fc3c9594a3 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 09:58:23 +0000 Subject: [PATCH 34/75] Remove redundant docs --- vllm/inputs/data.py | 4 ++-- vllm/inputs/parse.py | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 1d5b6b3fcdc0..d83297a32cbb 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -95,7 +95,7 @@ class LLMInputs(TypedDict): The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. - This includes the data required for decoder-only models. + This specifies the data required for decoder-only models. """ prompt_token_ids: List[int] """The token IDs of the prompt.""" @@ -117,7 +117,7 @@ class EncoderDecoderLLMInputs(LLMInputs): The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. - This includes the required data for encoder-decoder models. + This specifies the required data for encoder-decoder models. """ encoder_prompt_token_ids: List[int] """The token IDs of the encoder prompt.""" diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 984140f3651c..840bc8a49fb3 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -72,8 +72,4 @@ def is_explicit_encoder_decoder_prompt( def is_valid_encoder_decoder_llm_inputs( inputs: LLMInputs) -> TypeIs[EncoderDecoderLLMInputs]: - """ - Return True if the LLMInputs instance has the correct configuration - for encoder/decoder. - """ return "encoder_prompt_token_ids" in inputs From 65db3f1914f0b39bf4b71eba6835f76461587db1 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 09:59:37 +0000 Subject: [PATCH 35/75] Be more strict --- vllm/inputs/parse.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 840bc8a49fb3..b55f6003d575 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -71,5 +71,6 @@ def is_explicit_encoder_decoder_prompt( def is_valid_encoder_decoder_llm_inputs( - inputs: LLMInputs) -> TypeIs[EncoderDecoderLLMInputs]: + inputs: Union[LLMInputs, EncoderDecoderLLMInputs], +) -> TypeIs[EncoderDecoderLLMInputs]: return "encoder_prompt_token_ids" in inputs From 9ffeb222f84d0ac8139e550b1cbbf5ecefc7e484 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 10:03:28 +0000 Subject: [PATCH 36/75] Fix docs --- vllm/inputs/__init__.py | 2 +- vllm/inputs/data.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 0e1e7c828a71..e8f8a40fbd18 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -17,8 +17,8 @@ "TextPrompt", "TokensPrompt", "PromptInputs", - "ExplicitEncoderDecoderPrompt", "SingletonPromptInputs", + "ExplicitEncoderDecoderPrompt", "LLMInputs", "EncoderDecoderLLMInputs", "build_explicit_enc_dec_prompt", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index d83297a32cbb..57f3af9d5420 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -43,14 +43,14 @@ class TokensPrompt(TypedDict): which encapsulates multiple prompts, i.e. of the sort which may be utilized for encoder/decoder models when the user desires to express both the encoder & decoder -prompts explicitly, i.e. ExplicitEncoderDecoderPrompt +prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` -A prompt of type SingletonPromptInputs may be employed +A prompt of type :class:`SingletonPromptInputs` may be employed as (1) input to a decoder-only model, (2) input to the encoder of an encoder/decoder model, in the scenario where the decoder-prompt is not specified explicitly, or (3) as a member of a larger data structure encapsulating -more than one prompt, i.e. ExplicitEncoderDecoderPrompt +more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt` """ @@ -61,7 +61,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict): The encoder and decoder prompts, respectively, may formatted according to any of the - SingletonPromptInputs schemas, and are not + :class:`SingletonPromptInputs` schemas, and are not required to have the same schema. Only the encoder prompt may have multi-modal data. @@ -69,8 +69,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict): Note that an ExplicitEncoderDecoderPrompt may not be used as an input to a decoder-only model, and that the `encoder_prompt` and `decoder_prompt` - fields of this data structure may not themselves - must be SingletonPromptInputs instances. + fields of this data structure themselves must be + :class:`SingletonPromptInputs` instances. """ encoder_prompt: SingletonPromptInputs From c9e0b081561f6a2299027721302f8b99c1f99174 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 10:05:38 +0000 Subject: [PATCH 37/75] Fix 2 --- vllm/engine/llm_engine.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d9e6f6912fcb..54d81ea5587a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -666,8 +666,7 @@ def _tokenize_prompt( lora_request: Optional[LoRARequest] = None, ) -> List[int]: ''' - Wrapper around application of the model's - tokenizer. + Wrapper around application of the model's tokenizer. Arguments: @@ -831,7 +830,7 @@ def _process_encoder_decoder_prompt( Returns: - * `EncoderDecoderLLMInputs` instance + * :class:`EncoderDecoderLLMInputs` instance ''' explicit_inputs = self._to_explicit_encoder_decoder_prompt(inputs) @@ -882,8 +881,7 @@ def _process_decoder_only_prompt( ) -> LLMInputs: ''' For decoder-only models: - Process an input prompt - into an `LLMInputs` instance. + Process an input prompt into an :class:`LLMInputs` instance. Arguments: @@ -894,7 +892,7 @@ def _process_decoder_only_prompt( Returns: - * `LLMInputs` instance + * :class:`LLMInputs` instance ''' ( From 14bca1ff1fa90f44b00aa65a5d664115fecb5c55 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 10:11:59 +0000 Subject: [PATCH 38/75] Disallow multi-modal data for enc/dec models --- vllm/engine/async_llm_engine.py | 13 +++++++++++-- vllm/engine/llm_engine.py | 13 +++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5626f3a2a3d8..e53ebdd35e26 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -356,27 +356,36 @@ async def _process_encoder_decoder_prompt_async( ( encoder_prompt, encoder_prompt_token_ids, - _, + encoder_multi_modal_data, ) = await self._extract_prompt_components_async( extracted_encoder_prompt, request_id=request_id, ) + if encoder_multi_modal_data is not None: + raise ValueError("Multi-modal data is not supported for " + "(language) encoder-decoder models") + # Avoid repeated processing if the input was originally in singleton # form, see self._to_explicit_encoder_decoder_prompt if extracted_decoder_prompt is extracted_encoder_prompt: decoder_prompt_token_ids = encoder_prompt_token_ids decoder_prompt = encoder_prompt + decoder_multi_modal_data = encoder_multi_modal_data else: ( decoder_prompt, decoder_prompt_token_ids, - _, + decoder_multi_modal_data, ) = await self._extract_prompt_components_async( extracted_decoder_prompt, request_id=request_id, ) + if decoder_multi_modal_data is not None: + raise ValueError("Multi-modal data is not supported for " + "(language) encoder-decoder models") + decoder_prompt_token_ids = ( self._prepare_decoder_input_ids_for_generation( decoder_prompt_token_ids)) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 54d81ea5587a..09685b4586d2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -840,27 +840,36 @@ def _process_encoder_decoder_prompt( ( encoder_prompt, encoder_prompt_token_ids, - _, + encoder_multi_modal_data, ) = self._extract_prompt_components( extracted_encoder_prompt, request_id=request_id, ) + if encoder_multi_modal_data is not None: + raise ValueError("Multi-modal data is not supported for " + "(language) encoder-decoder models") + # Avoid repeated processing if the input was originally in singleton # form, see self._to_explicit_encoder_decoder_prompt if extracted_decoder_prompt is extracted_encoder_prompt: decoder_prompt_token_ids = encoder_prompt_token_ids decoder_prompt = encoder_prompt + decoder_multi_modal_data = encoder_multi_modal_data else: ( decoder_prompt, decoder_prompt_token_ids, - _, + decoder_multi_modal_data, ) = self._extract_prompt_components( extracted_decoder_prompt, request_id=request_id, ) + if decoder_multi_modal_data is not None: + raise ValueError("Multi-modal data is not supported for " + "(language) encoder-decoder models") + decoder_prompt_token_ids = ( self._prepare_decoder_input_ids_for_generation( decoder_prompt_token_ids)) From 8fc7099c48935ce0bf253d1a1367f32077a7e6c5 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 10:22:45 +0000 Subject: [PATCH 39/75] Improve type narrowing behavior using `TypeIs` --- vllm/model_executor/models/interfaces.py | 22 +++++++++++----------- vllm/utils.py | 4 ++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 6fdacd446978..db0d6b429d64 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,7 +1,7 @@ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, Union, overload, runtime_checkable) -from typing_extensions import TypeGuard +from typing_extensions import TypeIs from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.logger import init_logger @@ -37,18 +37,18 @@ def __call__(self, *, multimodal_config: MultiModalConfig) -> None: @overload -def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]: +def supports_vision(model: Type[object]) -> TypeIs[Type[SupportsVision]]: ... @overload -def supports_vision(model: object) -> TypeGuard[SupportsVision]: +def supports_vision(model: object) -> TypeIs[SupportsVision]: ... def supports_vision( model: Union[Type[object], object], -) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]: +) -> Union[TypeIs[Type[SupportsVision]], TypeIs[SupportsVision]]: if isinstance(model, type): return isinstance(model, _SupportsVisionType) @@ -94,18 +94,18 @@ def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: @overload -def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]: +def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]: ... @overload -def supports_lora(model: object) -> TypeGuard[SupportsLoRA]: +def supports_lora(model: object) -> TypeIs[SupportsLoRA]: ... def supports_lora( model: Union[Type[object], object], -) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]: +) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]: result = _supports_lora(model) if not result: @@ -137,7 +137,7 @@ def supports_lora( def _supports_lora( model: Union[Type[object], object], -) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]: +) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]: if isinstance(model, type): return isinstance(model, _SupportsLoRAType) @@ -172,18 +172,18 @@ def __init__(self, @overload -def has_inner_state(model: object) -> TypeGuard[HasInnerState]: +def has_inner_state(model: object) -> TypeIs[HasInnerState]: ... @overload -def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]: +def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]: ... def has_inner_state( model: Union[Type[object], object] -) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]: +) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]: if isinstance(model, type): return isinstance(model, _HasInnerStateType) diff --git a/vllm/utils.py b/vllm/utils.py index eb88fce4af0c..fcfdfe85ed14 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -25,7 +25,7 @@ import psutil import torch import torch.types -from typing_extensions import ParamSpec, TypeGuard, assert_never +from typing_extensions import ParamSpec, TypeIs, assert_never import vllm.envs as envs from vllm import _custom_ops as ops @@ -811,7 +811,7 @@ def is_list_of( typ: Type[T], *, check: Literal["first", "all"] = "first", -) -> TypeGuard[List[T]]: +) -> TypeIs[List[T]]: if not isinstance(value, list): return False From 3a8a072d16a6ec4305498ab54abe707dcdee4483 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 11:02:58 +0000 Subject: [PATCH 40/75] Avoid sequential await --- vllm/engine/async_llm_engine.py | 48 +++++++++++------------ vllm/engine/llm_engine.py | 68 +++++++++++++-------------------- 2 files changed, 49 insertions(+), 67 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index e53ebdd35e26..8c3d591a5639 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -349,40 +349,36 @@ async def _process_encoder_decoder_prompt_async( request_id: str, ) -> EncoderDecoderLLMInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" - explicit_inputs = self._to_explicit_encoder_decoder_prompt(inputs) - extracted_encoder_prompt = explicit_inputs["encoder_prompt"] - extracted_decoder_prompt = explicit_inputs["decoder_prompt"] - - ( - encoder_prompt, - encoder_prompt_token_ids, - encoder_multi_modal_data, - ) = await self._extract_prompt_components_async( - extracted_encoder_prompt, - request_id=request_id, - ) + if is_explicit_encoder_decoder_prompt(inputs): + encoder_task = self._extract_prompt_components_async( + inputs["encoder_prompt"], + request_id=request_id, + ) - if encoder_multi_modal_data is not None: - raise ValueError("Multi-modal data is not supported for " - "(language) encoder-decoder models") + decoder_task = self._extract_prompt_components_async( + inputs["decoder_prompt"], + request_id=request_id, + ) - # Avoid repeated processing if the input was originally in singleton - # form, see self._to_explicit_encoder_decoder_prompt - if extracted_decoder_prompt is extracted_encoder_prompt: - decoder_prompt_token_ids = encoder_prompt_token_ids - decoder_prompt = encoder_prompt - decoder_multi_modal_data = encoder_multi_modal_data + ( + (encoder_prompt, encoder_prompt_token_ids, encoder_mm_data), + (decoder_prompt, decoder_prompt_token_ids, decoder_mm_data), + ) = await asyncio.gather(encoder_task, decoder_task) else: ( - decoder_prompt, - decoder_prompt_token_ids, - decoder_multi_modal_data, + encoder_prompt, + encoder_prompt_token_ids, + encoder_mm_data, ) = await self._extract_prompt_components_async( - extracted_decoder_prompt, + inputs, request_id=request_id, ) - if decoder_multi_modal_data is not None: + decoder_prompt_token_ids = encoder_prompt_token_ids + decoder_prompt = encoder_prompt + decoder_mm_data = encoder_mm_data + + if encoder_mm_data is not None or decoder_mm_data is not None: raise ValueError("Multi-modal data is not supported for " "(language) encoder-decoder models") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 09685b4586d2..7501327ef271 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -24,9 +24,8 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - ExplicitEncoderDecoderPrompt, LLMInputs, PromptInputs, - SingletonPromptInputs) +from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, LLMInputs, + PromptInputs, SingletonPromptInputs) from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -783,18 +782,6 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: assert bos_token_id is not None return [bos_token_id] - def _to_explicit_encoder_decoder_prompt( - self, - inputs: PromptInputs, - ) -> ExplicitEncoderDecoderPrompt: - if is_explicit_encoder_decoder_prompt(inputs): - return inputs - - return ExplicitEncoderDecoderPrompt( - encoder_prompt=inputs, - decoder_prompt=inputs, - ) - def _process_encoder_decoder_prompt( self, inputs: PromptInputs, @@ -833,40 +820,39 @@ def _process_encoder_decoder_prompt( * :class:`EncoderDecoderLLMInputs` instance ''' - explicit_inputs = self._to_explicit_encoder_decoder_prompt(inputs) - extracted_encoder_prompt = explicit_inputs["encoder_prompt"] - extracted_decoder_prompt = explicit_inputs["decoder_prompt"] - - ( - encoder_prompt, - encoder_prompt_token_ids, - encoder_multi_modal_data, - ) = self._extract_prompt_components( - extracted_encoder_prompt, - request_id=request_id, - ) - - if encoder_multi_modal_data is not None: - raise ValueError("Multi-modal data is not supported for " - "(language) encoder-decoder models") + if is_explicit_encoder_decoder_prompt(inputs): + ( + encoder_prompt, + encoder_prompt_token_ids, + encoder_mm_data, + ) = self._extract_prompt_components( + inputs["encoder_prompt"], + request_id=request_id, + ) - # Avoid repeated processing if the input was originally in singleton - # form, see self._to_explicit_encoder_decoder_prompt - if extracted_decoder_prompt is extracted_encoder_prompt: - decoder_prompt_token_ids = encoder_prompt_token_ids - decoder_prompt = encoder_prompt - decoder_multi_modal_data = encoder_multi_modal_data - else: ( decoder_prompt, decoder_prompt_token_ids, - decoder_multi_modal_data, + decoder_mm_data, ) = self._extract_prompt_components( - extracted_decoder_prompt, + inputs["decoder_prompt"], request_id=request_id, ) + else: + ( + encoder_prompt, + encoder_prompt_token_ids, + encoder_mm_data, + ) = self._extract_prompt_components( + inputs, + request_id=request_id, + ) + + decoder_prompt_token_ids = encoder_prompt_token_ids + decoder_prompt = encoder_prompt + decoder_mm_data = encoder_mm_data - if decoder_multi_modal_data is not None: + if encoder_mm_data is not None or decoder_mm_data is not None: raise ValueError("Multi-modal data is not supported for " "(language) encoder-decoder models") From ef5327c24506392d2c635cab445fb7614fdfba8c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 11:36:40 +0000 Subject: [PATCH 41/75] Fix type annotations based on test files --- examples/offline_inference_encoder_decoder.py | 6 +-- tests/conftest.py | 23 ++++++---- ...t_basic_distributed_correctness_enc_dec.py | 2 +- tests/models/test_bart.py | 3 +- tests/models/utils.py | 11 ----- vllm/inputs/__init__.py | 4 +- vllm/inputs/data.py | 42 +++++++++++-------- 7 files changed, 46 insertions(+), 45 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index c05e8e8bb6f1..0f266d791885 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -5,7 +5,7 @@ from vllm import LLM, SamplingParams from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, - TokensPrompt, zip_enc_dec_prompt_lists) + TokensPrompt, zip_enc_dec_prompts) dtype = "float" @@ -61,9 +61,9 @@ ) # - Finally, here's a useful helper function for zipping encoder and -# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt +# decoder prompts together into a list of ExplicitEncoderDecoderPrompt # instances -zipped_prompt_list = zip_enc_dec_prompt_lists( +zipped_prompt_list = zip_enc_dec_prompts( ['An encoder prompt', 'Another encoder prompt'], ['A decoder prompt', 'Another decoder prompt']) diff --git a/tests/conftest.py b/tests/conftest.py index 5bfb8fc132a8..5163b5c186e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import os import sys from collections import UserList +from enum import Enum from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union import pytest @@ -14,7 +15,6 @@ AutoModelForVision2Seq, AutoTokenizer, BatchEncoding, BatchFeature) -from tests.models.utils import DecoderPromptType from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.config import TokenizerPoolConfig @@ -22,7 +22,7 @@ from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, - to_enc_dec_tuple_list, zip_enc_dec_prompt_lists) + to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sequence import SampleLogprobs @@ -124,6 +124,13 @@ def example_prompts() -> List[str]: return prompts +class DecoderPromptType(Enum): + """For encoder/decoder models only.""" + CUSTOM = 1 + NONE = 2 + EMPTY_STR = 3 + + @pytest.fixture def example_encoder_decoder_prompts( ) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]: @@ -149,11 +156,11 @@ def example_encoder_decoder_prompts( # NONE decoder prompt type return { DecoderPromptType.NONE: - zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts), + zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts), DecoderPromptType.EMPTY_STR: - zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts), + zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts), DecoderPromptType.CUSTOM: - zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts), + zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts), } @@ -443,7 +450,7 @@ def generate_greedy_logprobs_limit( def generate_encoder_decoder_greedy_logprobs_limit( self, - encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt], + encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str]], max_tokens: int, num_logprobs: int, **kwargs: Any, @@ -607,7 +614,7 @@ def generate_w_logprobs( def generate_encoder_decoder_w_logprobs( self, - encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt], + encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str]], sampling_params: SamplingParams, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ''' @@ -652,7 +659,7 @@ def generate_greedy_logprobs( def generate_encoder_decoder_greedy_logprobs( self, - encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt], + encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str]], max_tokens: int, num_logprobs: int, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: diff --git a/tests/distributed/test_basic_distributed_correctness_enc_dec.py b/tests/distributed/test_basic_distributed_correctness_enc_dec.py index 69eae62ca732..9850c823ff5d 100644 --- a/tests/distributed/test_basic_distributed_correctness_enc_dec.py +++ b/tests/distributed/test_basic_distributed_correctness_enc_dec.py @@ -11,9 +11,9 @@ import pytest -from tests.models.utils import DecoderPromptType from vllm.utils import cuda_device_count_stateless +from ..conftest import DecoderPromptType from ..models.utils import check_logprobs_close from ..utils import fork_new_process_for_each_test diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index 9c26b7163ff6..becf1b5b5df9 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -11,8 +11,7 @@ import pytest - from tests.models.utils import DecoderPromptType - + from ..conftest import DecoderPromptType from .utils import check_logprobs_close MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] diff --git a/tests/models/utils.py b/tests/models/utils.py index d96301b853c8..ff29a0ae81d6 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,5 +1,4 @@ import warnings -from enum import Enum from typing import Dict, List, Optional, Sequence, Tuple, Union from vllm.sequence import SampleLogprobs @@ -136,13 +135,3 @@ def check_logprobs_close( warnings.simplefilter("always") warnings.warn(fail_msg, stacklevel=2) - - -class DecoderPromptType(Enum): - ''' - For encoder/decoder models only - - - ''' - CUSTOM = 1 - NONE = 2 - EMPTY_STR = 3 diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index e8f8a40fbd18..0b08e9691f91 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,7 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, TokensPrompt, build_explicit_enc_dec_prompt, - to_enc_dec_tuple_list, zip_enc_dec_prompt_lists) + to_enc_dec_tuple_list, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -23,7 +23,7 @@ "EncoderDecoderLLMInputs", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", - "zip_enc_dec_prompt_lists", + "zip_enc_dec_prompts", "INPUT_REGISTRY", "InputContext", "InputRegistry", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 57f3af9d5420..0081d3c0f59b 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,6 +1,7 @@ -from typing import TYPE_CHECKING, List, Optional, Tuple, TypedDict, Union +from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple, + Union) -from typing_extensions import NotRequired +from typing_extensions import NotRequired, TypedDict, TypeVar if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict @@ -53,8 +54,10 @@ class TokensPrompt(TypedDict): more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt` """ +_T = TypeVar("_T", bound=SingletonPromptInputs, default=SingletonPromptInputs) -class ExplicitEncoderDecoderPrompt(TypedDict): + +class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T]): """Represents an encoder/decoder model input prompt, comprising an explicit encoder prompt and a decoder prompt. @@ -73,9 +76,9 @@ class ExplicitEncoderDecoderPrompt(TypedDict): :class:`SingletonPromptInputs` instances. """ - encoder_prompt: SingletonPromptInputs + encoder_prompt: _T - decoder_prompt: SingletonPromptInputs + decoder_prompt: Optional[_T] PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] @@ -130,27 +133,30 @@ class EncoderDecoderLLMInputs(LLMInputs): def build_explicit_enc_dec_prompt( - encoder_prompt: SingletonPromptInputs, - decoder_prompt: SingletonPromptInputs, -) -> ExplicitEncoderDecoderPrompt: + encoder_prompt: _T, + decoder_prompt: Optional[_T], +) -> ExplicitEncoderDecoderPrompt[_T]: return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt, decoder_prompt=decoder_prompt) -def zip_enc_dec_prompt_lists( - enc_prompt_list: List[SingletonPromptInputs], - dec_prompt_list: List[SingletonPromptInputs], -) -> List[ExplicitEncoderDecoderPrompt]: +def zip_enc_dec_prompts( + enc_prompts: Iterable[_T], + dec_prompts: Iterable[Optional[_T]], +) -> List[ExplicitEncoderDecoderPrompt[_T]]: + """ + Zip encoder and decoder prompts together into a list of + :class:`ExplicitEncoderDecoderPrompt` instances. + """ return [ build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt) - for (encoder_prompt, - decoder_prompt) in zip(enc_prompt_list, dec_prompt_list) + for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts) ] def to_enc_dec_tuple_list( - enc_dec_prompts: List[ExplicitEncoderDecoderPrompt], -) -> List[Tuple[SingletonPromptInputs, SingletonPromptInputs]]: - return [(enc_dec_prompt['encoder_prompt'], - enc_dec_prompt['decoder_prompt']) + enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T]], +) -> List[Tuple[_T, Optional[_T]]]: + return [(enc_dec_prompt["encoder_prompt"], + enc_dec_prompt["decoder_prompt"]) for enc_dec_prompt in enc_dec_prompts] From 8a835cc7914235c24fdf226188397d661f8adeb4 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 11:37:33 +0000 Subject: [PATCH 42/75] Properly handle `inputs["decoder_prompt"]=None` --- vllm/engine/async_llm_engine.py | 15 +++++++++++---- vllm/engine/llm_engine.py | 23 +++++++++++++++-------- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8c3d591a5639..de85953d4e21 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -355,10 +355,17 @@ async def _process_encoder_decoder_prompt_async( request_id=request_id, ) - decoder_task = self._extract_prompt_components_async( - inputs["decoder_prompt"], - request_id=request_id, - ) + if (decoder_input := inputs["decoder_prompt"]) is None: + + async def dummy_task(): + return None, None, None + + decoder_task = dummy_task() + else: + decoder_task = self._extract_prompt_components_async( + decoder_input, + request_id=request_id, + ) ( (encoder_prompt, encoder_prompt_token_ids, encoder_mm_data), diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7501327ef271..67f37d6c8a65 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -830,14 +830,21 @@ def _process_encoder_decoder_prompt( request_id=request_id, ) - ( - decoder_prompt, - decoder_prompt_token_ids, - decoder_mm_data, - ) = self._extract_prompt_components( - inputs["decoder_prompt"], - request_id=request_id, - ) + if (decoder_input := inputs["decoder_prompt"]) is None: + ( + decoder_prompt, + decoder_prompt_token_ids, + decoder_mm_data, + ) = None, None, None + else: + ( + decoder_prompt, + decoder_prompt_token_ids, + decoder_mm_data, + ) = self._extract_prompt_components( + decoder_input, + request_id=request_id, + ) else: ( encoder_prompt, From e0024c29f4570480dc62407671ae09d6ab0826ac Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 11:47:55 +0000 Subject: [PATCH 43/75] Clean --- vllm/engine/async_llm_engine.py | 42 ++++++++++++++++++++------------- vllm/engine/llm_engine.py | 22 ++++++++--------- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index de85953d4e21..e6bc9eef41d6 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -355,33 +355,42 @@ async def _process_encoder_decoder_prompt_async( request_id=request_id, ) - if (decoder_input := inputs["decoder_prompt"]) is None: - - async def dummy_task(): - return None, None, None - - decoder_task = dummy_task() + decoder_input = inputs["decoder_prompt"] + if decoder_input is None: + ( + encoder_prompt, + encoder_prompt_ids, + encoder_mm_data, + ) = await encoder_task + + ( + decoder_prompt, + decoder_prompt_ids, + decoder_mm_data, + ) = None, None, None else: decoder_task = self._extract_prompt_components_async( decoder_input, request_id=request_id, ) - ( - (encoder_prompt, encoder_prompt_token_ids, encoder_mm_data), - (decoder_prompt, decoder_prompt_token_ids, decoder_mm_data), - ) = await asyncio.gather(encoder_task, decoder_task) + # NOTE: mypy crashes without the intermediate assignment to + # (a, b) + ( + (encoder_prompt, encoder_prompt_ids, encoder_mm_data), + (decoder_prompt, decoder_prompt_ids, decoder_mm_data), + ) = a, b = await asyncio.gather(encoder_task, decoder_task) else: ( encoder_prompt, - encoder_prompt_token_ids, + encoder_prompt_ids, encoder_mm_data, ) = await self._extract_prompt_components_async( inputs, request_id=request_id, ) - decoder_prompt_token_ids = encoder_prompt_token_ids + decoder_prompt_ids = encoder_prompt_ids decoder_prompt = encoder_prompt decoder_mm_data = encoder_mm_data @@ -389,14 +398,13 @@ async def dummy_task(): raise ValueError("Multi-modal data is not supported for " "(language) encoder-decoder models") - decoder_prompt_token_ids = ( - self._prepare_decoder_input_ids_for_generation( - decoder_prompt_token_ids)) + decoder_prompt_ids = ( + self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) return EncoderDecoderLLMInputs( - prompt_token_ids=decoder_prompt_token_ids, + prompt_token_ids=decoder_prompt_ids, prompt=decoder_prompt, - encoder_prompt_token_ids=encoder_prompt_token_ids, + encoder_prompt_token_ids=encoder_prompt_ids, encoder_prompt=encoder_prompt, ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 67f37d6c8a65..3f3720781d1f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -823,23 +823,24 @@ def _process_encoder_decoder_prompt( if is_explicit_encoder_decoder_prompt(inputs): ( encoder_prompt, - encoder_prompt_token_ids, + encoder_prompt_ids, encoder_mm_data, ) = self._extract_prompt_components( inputs["encoder_prompt"], request_id=request_id, ) - if (decoder_input := inputs["decoder_prompt"]) is None: + decoder_input = inputs["decoder_prompt"] + if decoder_input is None: ( decoder_prompt, - decoder_prompt_token_ids, + decoder_prompt_ids, decoder_mm_data, ) = None, None, None else: ( decoder_prompt, - decoder_prompt_token_ids, + decoder_prompt_ids, decoder_mm_data, ) = self._extract_prompt_components( decoder_input, @@ -848,14 +849,14 @@ def _process_encoder_decoder_prompt( else: ( encoder_prompt, - encoder_prompt_token_ids, + encoder_prompt_ids, encoder_mm_data, ) = self._extract_prompt_components( inputs, request_id=request_id, ) - decoder_prompt_token_ids = encoder_prompt_token_ids + decoder_prompt_ids = encoder_prompt_ids decoder_prompt = encoder_prompt decoder_mm_data = encoder_mm_data @@ -863,14 +864,13 @@ def _process_encoder_decoder_prompt( raise ValueError("Multi-modal data is not supported for " "(language) encoder-decoder models") - decoder_prompt_token_ids = ( - self._prepare_decoder_input_ids_for_generation( - decoder_prompt_token_ids)) + decoder_prompt_ids = ( + self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) return EncoderDecoderLLMInputs( - prompt_token_ids=decoder_prompt_token_ids, + prompt_token_ids=decoder_prompt_ids, prompt=decoder_prompt, - encoder_prompt_token_ids=encoder_prompt_token_ids, + encoder_prompt_token_ids=encoder_prompt_ids, encoder_prompt=encoder_prompt, ) From 76af1724f5f18aa4f3a31fb7c212b9158567163e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 11:55:28 +0000 Subject: [PATCH 44/75] Clean --- vllm/engine/async_llm_engine.py | 46 ++++++++++++--------------------- vllm/engine/llm_engine.py | 44 +++++++++++++------------------ 2 files changed, 34 insertions(+), 56 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index e6bc9eef41d6..5b9d49f513a1 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -13,7 +13,8 @@ from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout -from vllm.engine.llm_engine import LLMEngine +from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine, + PromptComponents) from vllm.engine.metrics import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.ray_utils import initialize_ray_cluster, ray @@ -22,7 +23,6 @@ from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.multimodal import MultiModalDataDict from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -314,7 +314,7 @@ async def _extract_prompt_components_async( inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, - ) -> Tuple[Optional[str], List[int], Optional[MultiModalDataDict]]: + ) -> PromptComponents: """Async version of :meth:`_extract_prompt_components`.""" if isinstance(inputs, str): prompt = inputs @@ -349,50 +349,36 @@ async def _process_encoder_decoder_prompt_async( request_id: str, ) -> EncoderDecoderLLMInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" + encoder_comps: PromptComponents + decoder_comps: DecoderPromptComponents + if is_explicit_encoder_decoder_prompt(inputs): encoder_task = self._extract_prompt_components_async( inputs["encoder_prompt"], request_id=request_id, ) - decoder_input = inputs["decoder_prompt"] - if decoder_input is None: - ( - encoder_prompt, - encoder_prompt_ids, - encoder_mm_data, - ) = await encoder_task - - ( - decoder_prompt, - decoder_prompt_ids, - decoder_mm_data, - ) = None, None, None + if (decoder_input := inputs["decoder_prompt"]) is None: + encoder_comps = await encoder_task + decoder_comps = None, None, None else: decoder_task = self._extract_prompt_components_async( decoder_input, request_id=request_id, ) - # NOTE: mypy crashes without the intermediate assignment to - # (a, b) - ( - (encoder_prompt, encoder_prompt_ids, encoder_mm_data), - (decoder_prompt, decoder_prompt_ids, decoder_mm_data), - ) = a, b = await asyncio.gather(encoder_task, decoder_task) + encoder_comps, decoder_comps = await asyncio.gather( + encoder_task, decoder_task) else: - ( - encoder_prompt, - encoder_prompt_ids, - encoder_mm_data, - ) = await self._extract_prompt_components_async( + encoder_comps = await self._extract_prompt_components_async( inputs, request_id=request_id, ) - decoder_prompt_ids = encoder_prompt_ids - decoder_prompt = encoder_prompt - decoder_mm_data = encoder_mm_data + decoder_comps = encoder_comps + + encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps + decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps if encoder_mm_data is not None or decoder_mm_data is not None: raise ValueError("Multi-modal data is not supported for " diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3f3720781d1f..66a870d99a83 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -70,6 +70,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) +PromptComponents = Tuple[Optional[str], List[int], + Optional[MultiModalDataDict]] +DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], + Optional[MultiModalDataDict]] + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -690,7 +695,7 @@ def _extract_prompt_components( inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, - ) -> Tuple[Optional[str], List[int], Optional[MultiModalDataDict]]: + ) -> PromptComponents: ''' Extract the components of any single encoder or decoder input prompt. @@ -820,45 +825,32 @@ def _process_encoder_decoder_prompt( * :class:`EncoderDecoderLLMInputs` instance ''' + encoder_comps: PromptComponents + decoder_comps: DecoderPromptComponents + if is_explicit_encoder_decoder_prompt(inputs): - ( - encoder_prompt, - encoder_prompt_ids, - encoder_mm_data, - ) = self._extract_prompt_components( + encoder_comps = self._extract_prompt_components( inputs["encoder_prompt"], request_id=request_id, ) - decoder_input = inputs["decoder_prompt"] - if decoder_input is None: - ( - decoder_prompt, - decoder_prompt_ids, - decoder_mm_data, - ) = None, None, None + if (decoder_input := inputs["decoder_prompt"]) is None: + decoder_comps = None, None, None else: - ( - decoder_prompt, - decoder_prompt_ids, - decoder_mm_data, - ) = self._extract_prompt_components( + decoder_comps = self._extract_prompt_components( decoder_input, request_id=request_id, ) else: - ( - encoder_prompt, - encoder_prompt_ids, - encoder_mm_data, - ) = self._extract_prompt_components( + encoder_comps = self._extract_prompt_components( inputs, request_id=request_id, ) - decoder_prompt_ids = encoder_prompt_ids - decoder_prompt = encoder_prompt - decoder_mm_data = encoder_mm_data + decoder_comps = encoder_comps + + encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps + decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps if encoder_mm_data is not None or decoder_mm_data is not None: raise ValueError("Multi-modal data is not supported for " From 5c16f2e90f4c93c782676f4580eca2ec5f7c3c3b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 12:00:41 +0000 Subject: [PATCH 45/75] Fix incorrect decoder inputs in singleton case --- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 2 +- vllm/inputs/data.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5b9d49f513a1..973721c0f928 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -375,7 +375,7 @@ async def _process_encoder_decoder_prompt_async( request_id=request_id, ) - decoder_comps = encoder_comps + decoder_comps = None, None, None encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 66a870d99a83..9b2cd3b5430d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -847,7 +847,7 @@ def _process_encoder_decoder_prompt( request_id=request_id, ) - decoder_comps = encoder_comps + decoder_comps = None, None, None encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 0081d3c0f59b..d7883a7a60fc 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -69,7 +69,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T]): Only the encoder prompt may have multi-modal data. - Note that an ExplicitEncoderDecoderPrompt may not + Note that an :class:`ExplicitEncoderDecoderPrompt` may not be used as an input to a decoder-only model, and that the `encoder_prompt` and `decoder_prompt` fields of this data structure themselves must be From e239ba9deefd32697251eaa0efc51c6e07a67d16 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 12:05:24 +0000 Subject: [PATCH 46/75] Clean --- vllm/engine/llm_engine.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9b2cd3b5430d..dec326210070 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -533,7 +533,7 @@ def _get_eos_token_id(self, return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id - def _get_decoder_start_token_id(self, ) -> Optional[int]: + def _get_decoder_start_token_id(self) -> Optional[int]: ''' Obtain the decoder start token id employed by an encoder/decoder model. Returns None for non-encoder/decoder models or if the @@ -648,8 +648,7 @@ def _prepare_decoder_input_ids_for_generation( * Processed token list """ - decoder_start_token_id: Optional[int] = ( - self._get_decoder_start_token_id()) + decoder_start_token_id = self._get_decoder_start_token_id() assert decoder_start_token_id is not None if decoder_input_ids is None: From 4b0e3dff5ab0d4975fe8facbec51fd3ecd59ed69 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 12:12:43 +0000 Subject: [PATCH 47/75] Move functions to a more appropriate place --- vllm/config.py | 10 ++++++++++ vllm/engine/llm_engine.py | 7 +++---- vllm/utils.py | 20 -------------------- vllm/worker/worker.py | 6 ++---- 4 files changed, 15 insertions(+), 28 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ec6d587e7925..d912f17a0aa3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -457,6 +457,16 @@ def _get_num_seqlen_agnostic_layers( if t != "attention" ]) + @property + def is_encoder_decoder_model(self) -> bool: + """Extract the HF encoder/decoder model flag.""" + return getattr(self.hf_config, "is_encoder_decoder", False) + + @property + def is_embedding_model(self) -> bool: + """Extract the embedding model flag.""" + return self.embedding_mode + class CacheConfig: """Configuration for the KV cache. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dec326210070..6edc002457fe 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -47,8 +47,7 @@ AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import (Counter, is_embedding_model_config, - is_encoder_decoder_model_config) +from vllm.utils import Counter from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -1563,7 +1562,7 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time) def is_encoder_decoder_model(self): - return is_encoder_decoder_model_config(self.model_config) + return self.model_config.is_encoder_decoder_model def is_embedding_model(self): - return is_embedding_model_config(self.model_config) + return self.model_config.is_embedding_model diff --git a/vllm/utils.py b/vllm/utils.py index fcfdfe85ed14..782b13920e91 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1142,23 +1142,3 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, """Utility function to run async task in a lock""" async with lock: return await task(*args, **kwargs) - - -def is_encoder_decoder_model_config(model_config) -> bool: - ''' - Extract the HF encoder/decoder model flag from the ModelConfig instance. - Return False if model_config is None. - ''' - return model_config is not None and \ - getattr(model_config.hf_config, - "is_encoder_decoder", - False) - - -def is_embedding_model_config(model_config) -> bool: - ''' - Extract the embedding model flag from the ModelConfig instance. - Return False if model_config is None. - ''' - return model_config is not None and \ - model_config.embedding_mode diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ad6f6750ff98..45751eceacbc 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -19,8 +19,6 @@ from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest -from vllm.utils import (is_embedding_model_config, - is_encoder_decoder_model_config) from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner @@ -113,10 +111,10 @@ def __init__( self.gpu_cache: Optional[List[List[torch.Tensor]]] = None def _is_encoder_decoder_model(self): - return is_encoder_decoder_model_config(self.model_config) + return self.model_config.is_encoder_decoder_model def _is_embedding_model(self): - return is_embedding_model_config(self.model_config) + return self.model_config.is_embedding_model def init_device(self) -> None: if self.device_config.device.type == "cuda": From 53f7f50d717e2da4783a063fe96951a437a504e6 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 12:19:52 +0000 Subject: [PATCH 48/75] Remove outdated comment --- vllm/inputs/parse.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index b55f6003d575..b5e8ef786059 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -18,7 +18,6 @@ class ParsedTokens(TypedDict): is_tokens: Literal[True] -# https://github.com/vllm-project/vllm/pull/4028 @overload def parse_and_batch_prompt( prompt: Union[str, List[str]]) -> Sequence[ParsedText]: From 3afdbc548cafdaf0bea2ba72a011ecdda693035f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 12:55:18 +0000 Subject: [PATCH 49/75] Fix mismatch between hf and vllm output text --- tests/models/test_bart.py | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index becf1b5b5df9..9bca5a86f124 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -2,6 +2,8 @@ Run `pytest tests/models/test_bart.py`. """ +from typing import List, Optional, Tuple + from vllm.utils import is_cpu if not is_cpu(): @@ -11,21 +13,31 @@ import pytest + from vllm.sequence import SampleLogprobs + from ..conftest import DecoderPromptType from .utils import check_logprobs_close MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] - DECODER_PROMPT_TYPES = ([ - DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR, - DecoderPromptType.NONE - ]) + def vllm_to_hf_output( + vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], + decoder_prompt_type: DecoderPromptType, + ): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "" + if decoder_prompt_type == DecoderPromptType.NONE: + hf_output_str = "" + hf_output_str + + return output_ids, hf_output_str, out_logprobs @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) - @pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES) + @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) def test_models( hf_runner, vllm_runner, @@ -145,8 +157,13 @@ def test_models( hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE else 0) - check_logprobs_close(outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens, + ) From c61b01f0f6f2c501e25a3af2e2a38702892964cc Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 13:10:11 +0000 Subject: [PATCH 50/75] Factor out duplicate code --- vllm/engine/async_llm_engine.py | 17 +------------- vllm/engine/llm_engine.py | 39 +++++++++++++++++++-------------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 973721c0f928..ecf75a27bb11 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -377,22 +377,7 @@ async def _process_encoder_decoder_prompt_async( decoder_comps = None, None, None - encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps - decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps - - if encoder_mm_data is not None or decoder_mm_data is not None: - raise ValueError("Multi-modal data is not supported for " - "(language) encoder-decoder models") - - decoder_prompt_ids = ( - self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) - - return EncoderDecoderLLMInputs( - prompt_token_ids=decoder_prompt_ids, - prompt=decoder_prompt, - encoder_prompt_token_ids=encoder_prompt_ids, - encoder_prompt=encoder_prompt, - ) + return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) async def _process_decoder_only_prompt_async( self, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6edc002457fe..c9261be5a4d4 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -785,6 +785,28 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: assert bos_token_id is not None return [bos_token_id] + def _build_enc_dec_llm_inputs( + self, + encoder_comps: PromptComponents, + decoder_comps: DecoderPromptComponents, + ) -> EncoderDecoderLLMInputs: + encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps + decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps + + if encoder_mm_data is not None or decoder_mm_data is not None: + raise ValueError("Multi-modal data is not supported for " + "(language) encoder-decoder models") + + decoder_prompt_ids = ( + self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) + + return EncoderDecoderLLMInputs( + prompt_token_ids=decoder_prompt_ids, + prompt=decoder_prompt, + encoder_prompt_token_ids=encoder_prompt_ids, + encoder_prompt=encoder_prompt, + ) + def _process_encoder_decoder_prompt( self, inputs: PromptInputs, @@ -847,22 +869,7 @@ def _process_encoder_decoder_prompt( decoder_comps = None, None, None - encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps - decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps - - if encoder_mm_data is not None or decoder_mm_data is not None: - raise ValueError("Multi-modal data is not supported for " - "(language) encoder-decoder models") - - decoder_prompt_ids = ( - self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) - - return EncoderDecoderLLMInputs( - prompt_token_ids=decoder_prompt_ids, - prompt=decoder_prompt, - encoder_prompt_token_ids=encoder_prompt_ids, - encoder_prompt=encoder_prompt, - ) + return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) def _process_decoder_only_prompt( self, From f8ed373f506abe057c070d61a3a2b60910e91c77 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 13:13:46 +0000 Subject: [PATCH 51/75] Factor out more duplicate code --- vllm/engine/async_llm_engine.py | 16 +++++----------- vllm/engine/llm_engine.py | 30 +++++++++++++++++++----------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index ecf75a27bb11..af606292c35b 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -387,22 +387,16 @@ async def _process_decoder_only_prompt_async( prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" - ( - prompt, - prompt_token_ids, - multi_modal_data, - ) = await self._extract_prompt_components_async( + prompt_comps = await self._extract_prompt_components_async( inputs, request_id=request_id, lora_request=lora_request, ) - prompt_token_ids = self._apply_prompt_adapter( - prompt_token_ids, prompt_adapter_request=prompt_adapter_request) - - return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=prompt, - multi_modal_data=multi_modal_data) + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) async def process_model_inputs_async( self, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c9261be5a4d4..70917333efa3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -871,6 +871,20 @@ def _process_encoder_decoder_prompt( return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + def _build_decoder_only_llm_inputs( + self, + prompt_comps: PromptComponents, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> LLMInputs: + prompt, prompt_token_ids, multi_modal_data = prompt_comps + + prompt_token_ids = self._apply_prompt_adapter( + prompt_token_ids, prompt_adapter_request=prompt_adapter_request) + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=prompt, + multi_modal_data=multi_modal_data) + def _process_decoder_only_prompt( self, inputs: SingletonPromptInputs, @@ -894,22 +908,16 @@ def _process_decoder_only_prompt( * :class:`LLMInputs` instance ''' - ( - prompt, - prompt_token_ids, - multi_modal_data, - ) = self._extract_prompt_components( + prompt_comps = self._extract_prompt_components( inputs, request_id=request_id, lora_request=lora_request, ) - prompt_token_ids = self._apply_prompt_adapter( - prompt_token_ids, prompt_adapter_request=prompt_adapter_request) - - return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=prompt, - multi_modal_data=multi_modal_data) + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) def process_model_inputs( self, From a4df70ab9715ca40fed135f606df48cdb29270b2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 13:16:00 +0000 Subject: [PATCH 52/75] Remove default values to avoid accidentally miss those arguments --- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index af606292c35b..21643852b029 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -299,7 +299,7 @@ async def _tokenize_prompt_async( self, prompt: str, request_id: str, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[LoRARequest], ) -> List[int]: """Async version of :meth:`_tokenize_prompt`.""" tokenizer = self.get_tokenizer_group("prompts must be None if " diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 70917333efa3..1bf7e220e713 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -625,7 +625,7 @@ def stop_remote_worker_execution_loop(self) -> None: def _prepare_decoder_input_ids_for_generation( self, - decoder_input_ids: Optional[List[int]] = None, + decoder_input_ids: Optional[List[int]], ) -> List[int]: """ Prepares `decoder_input_ids` for generation with encoder-decoder models. @@ -665,7 +665,7 @@ def _tokenize_prompt( self, prompt: str, request_id: str, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[LoRARequest], ) -> List[int]: ''' Wrapper around application of the model's tokenizer. @@ -740,7 +740,7 @@ def _extract_prompt_components( def _apply_prompt_adapter( self, prompt_token_ids: List[int], - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest], ) -> List[int]: if prompt_adapter_request: prompt_token_ids = ( @@ -874,7 +874,7 @@ def _process_encoder_decoder_prompt( def _build_decoder_only_llm_inputs( self, prompt_comps: PromptComponents, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest], ) -> LLMInputs: prompt, prompt_token_ids, multi_modal_data = prompt_comps From 5240bb335abffc3ce65c1e1b96e2eeebf0544fa6 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 14:55:43 +0000 Subject: [PATCH 53/75] Add test for serving encoder/decoder model with OpenAI server --- .../openai/test_encoder_decoder.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 tests/entrypoints/openai/test_encoder_decoder.py diff --git a/tests/entrypoints/openai/test_encoder_decoder.py b/tests/entrypoints/openai/test_encoder_decoder.py new file mode 100644 index 000000000000..85f1c6f18bf3 --- /dev/null +++ b/tests/entrypoints/openai/test_encoder_decoder.py @@ -0,0 +1,50 @@ +import openai +import pytest + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "facebook/bart-base" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--dtype", + "bfloat16", + "--enforce-eager", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server): + return server.get_async_client() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): + completion = await client.completions.create(model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=2, total_tokens=7) + + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 1 From d321c82ee490048ba40e8be749c0fb42b91ee6a8 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 7 Aug 2024 15:47:37 +0000 Subject: [PATCH 54/75] Use two type variables --- tests/conftest.py | 6 ++--- vllm/entrypoints/chat_utils.py | 5 ++-- vllm/inputs/data.py | 44 +++++++++++++++++++++++----------- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 5163b5c186e7..d565da5a1019 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -450,7 +450,7 @@ def generate_greedy_logprobs_limit( def generate_encoder_decoder_greedy_logprobs_limit( self, - encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str]], + encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, num_logprobs: int, **kwargs: Any, @@ -614,7 +614,7 @@ def generate_w_logprobs( def generate_encoder_decoder_w_logprobs( self, - encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str]], + encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], sampling_params: SamplingParams, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ''' @@ -659,7 +659,7 @@ def generate_greedy_logprobs( def generate_encoder_decoder_greedy_logprobs( self, - encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str]], + encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, num_logprobs: int, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 12634c326185..1197c70d88ae 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -2,8 +2,7 @@ from dataclasses import dataclass from functools import lru_cache from pathlib import Path -from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union, - cast, final) +from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast # yapf conflicts with isort for this block # yapf: disable @@ -59,7 +58,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): CustomChatCompletionMessageParam] -@final # So that it should be compatible with Dict[str, str] +# TODO: Make fields ReadOnly once mypy supports it class ConversationMessage(TypedDict): role: str content: str diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index d7883a7a60fc..75ab0c770155 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -54,10 +54,18 @@ class TokensPrompt(TypedDict): more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt` """ -_T = TypeVar("_T", bound=SingletonPromptInputs, default=SingletonPromptInputs) - - -class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T]): +_T1_co = TypeVar("_T1_co", + bound=SingletonPromptInputs, + default=SingletonPromptInputs, + covariant=True) +_T2_co = TypeVar("_T2_co", + bound=SingletonPromptInputs, + default=SingletonPromptInputs, + covariant=True) + + +# TODO: Make fields ReadOnly once mypy supports it +class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): """Represents an encoder/decoder model input prompt, comprising an explicit encoder prompt and a decoder prompt. @@ -76,9 +84,9 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T]): :class:`SingletonPromptInputs` instances. """ - encoder_prompt: _T + encoder_prompt: _T1_co - decoder_prompt: Optional[_T] + decoder_prompt: Optional[_T2_co] PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] @@ -132,18 +140,26 @@ class EncoderDecoderLLMInputs(LLMInputs): """ +_T1 = TypeVar("_T1", + bound=SingletonPromptInputs, + default=SingletonPromptInputs) +_T2 = TypeVar("_T2", + bound=SingletonPromptInputs, + default=SingletonPromptInputs) + + def build_explicit_enc_dec_prompt( - encoder_prompt: _T, - decoder_prompt: Optional[_T], -) -> ExplicitEncoderDecoderPrompt[_T]: + encoder_prompt: _T1, + decoder_prompt: Optional[_T2], +) -> ExplicitEncoderDecoderPrompt[_T1, _T2]: return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt, decoder_prompt=decoder_prompt) def zip_enc_dec_prompts( - enc_prompts: Iterable[_T], - dec_prompts: Iterable[Optional[_T]], -) -> List[ExplicitEncoderDecoderPrompt[_T]]: + enc_prompts: Iterable[_T1], + dec_prompts: Iterable[Optional[_T2]], +) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]: """ Zip encoder and decoder prompts together into a list of :class:`ExplicitEncoderDecoderPrompt` instances. @@ -155,8 +171,8 @@ def zip_enc_dec_prompts( def to_enc_dec_tuple_list( - enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T]], -) -> List[Tuple[_T, Optional[_T]]]: + enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]], +) -> List[Tuple[_T1, Optional[_T2]]]: return [(enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) for enc_dec_prompt in enc_dec_prompts] From e4c5c21e492c2d66b90afcc54472185d0de2c97c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 8 Aug 2024 02:23:10 +0000 Subject: [PATCH 55/75] Update error message --- vllm/engine/llm_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1bf7e220e713..dcaf375f9b15 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -794,8 +794,8 @@ def _build_enc_dec_llm_inputs( decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps if encoder_mm_data is not None or decoder_mm_data is not None: - raise ValueError("Multi-modal data is not supported for " - "(language) encoder-decoder models") + raise ValueError("Multi-modal encoder-decoder models are " + "not supported yet") decoder_prompt_ids = ( self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) From f912f251fd8f512f6aa76edb7090d471ec8ad7b1 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 8 Aug 2024 02:39:37 +0000 Subject: [PATCH 56/75] Format --- vllm/engine/llm_engine.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2877ead09388..bef7833857c3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -5,9 +5,7 @@ from typing import Sequence as GenericSequence from typing import Set, Tuple, Type, Union -from typing_extensions import TypeVar - -from typing_extensions import assert_never +from typing_extensions import TypeVar, assert_never import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, From 7da52f58498786b40c5f848d400003ea59074196 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 8 Aug 2024 13:04:34 +0000 Subject: [PATCH 57/75] Fix circular import problem --- vllm/transformers_utils/tokenizer_group/__init__.py | 3 +-- .../tokenizer_group/tokenizer_group.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index eeab19899b02..9a4149251d74 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -8,8 +8,7 @@ from .tokenizer_group import TokenizerGroup if ray: - from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( - RayTokenizerGroupPool) + from .ray_tokenizer_group import RayTokenizerGroupPool else: RayTokenizerGroupPool = None # type: ignore diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index a5186e48068e..e2c665871ae2 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -2,9 +2,6 @@ from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import (get_lora_tokenizer, - get_lora_tokenizer_async, - get_tokenizer) from vllm.utils import LRUCache from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup @@ -15,6 +12,9 @@ class TokenizerGroup(BaseTokenizerGroup): def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, max_input_length: Optional[int], **tokenizer_config): + # Avoid circular import + from vllm.transformers_utils.tokenizer import get_tokenizer + self.tokenizer_id = tokenizer_id self.tokenizer_config = tokenizer_config self.enable_lora = enable_lora @@ -73,6 +73,9 @@ def get_lora_tokenizer( self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: + # Avoid circular import + from vllm.transformers_utils.tokenizer import get_lora_tokenizer + if not lora_request or not self.enable_lora: return self.tokenizer if lora_request.lora_int_id not in self.lora_tokenizers: @@ -87,6 +90,9 @@ async def get_lora_tokenizer_async( self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: + # Avoid circular import + from vllm.transformers_utils.tokenizer import get_lora_tokenizer_async + if not lora_request or not self.enable_lora: return self.tokenizer if lora_request.lora_int_id not in self.lora_tokenizers: From f475a58ea14132a72b7b7284826af393450e1a00 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 8 Aug 2024 15:48:22 +0000 Subject: [PATCH 58/75] Fix incorrect assertion --- vllm/entrypoints/openai/serving_completion.py | 38 +++++++++---------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index cd9c5ac2240f..5d6e481669ee 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -232,12 +232,6 @@ async def completion_stream_generator( prompt_text = res.prompt assert prompt_text is not None - prompt_token_ids = res.prompt_token_ids - assert prompt_token_ids is not None - - prompt_logprobs = res.prompt_logprobs - assert prompt_logprobs is not None - delta_token_ids: GenericSequence[int] out_logprobs: Optional[GenericSequence[Optional[Dict[ int, Logprob]]]] @@ -251,18 +245,19 @@ async def completion_stream_generator( if request.echo and request.max_tokens == 0: # only return the prompt delta_text = prompt_text - delta_token_ids = prompt_token_ids - out_logprobs = prompt_logprobs + delta_token_ids = res.prompt_token_ids + out_logprobs = res.prompt_logprobs has_echoed[i] = True elif (request.echo and request.max_tokens > 0 and not has_echoed[i]): + assert res.prompt_logprobs is not None # echo the prompt and first token delta_text = prompt_text + output.text delta_token_ids = [ - *prompt_token_ids, *output.token_ids + *res.prompt_token_ids, *output.token_ids ] out_logprobs = [ - *prompt_logprobs, *(output.logprobs or []) + *res.prompt_logprobs, *(output.logprobs or []), ] has_echoed[i] = True else: @@ -308,7 +303,7 @@ async def completion_stream_generator( and request.stream_options.include_usage): if (request.stream_options.continuous_usage_stats or output.finish_reason is not None): - prompt_tokens = len(prompt_token_ids) + prompt_tokens = len(res.prompt_token_ids) completion_tokens = len(output.token_ids) usage = UsageInfo( prompt_tokens=prompt_tokens, @@ -356,11 +351,6 @@ def request_output_to_completion_response( num_generated_tokens = 0 for final_res in final_res_batch: - prompt_token_ids = final_res.prompt_token_ids - - prompt_logprobs = final_res.prompt_logprobs - assert prompt_logprobs is not None - prompt_text = final_res.prompt assert prompt_text is not None @@ -371,17 +361,23 @@ def request_output_to_completion_response( for output in final_res.outputs: assert request.max_tokens is not None if request.echo and request.max_tokens == 0: - token_ids = prompt_token_ids - out_logprobs = prompt_logprobs + token_ids = final_res.prompt_token_ids + out_logprobs = final_res.prompt_logprobs output_text = prompt_text elif request.echo and request.max_tokens > 0: - token_ids = [*prompt_token_ids, *output.token_ids] + token_ids = [ + *final_res.prompt_token_ids, *output.token_ids + ] if request.logprobs is None: out_logprobs = None else: + assert final_res.prompt_logprobs is not None assert output.logprobs is not None - out_logprobs = [*prompt_logprobs, *output.logprobs] + out_logprobs = [ + *final_res.prompt_logprobs, + *output.logprobs, + ] output_text = prompt_text + output.text else: @@ -409,7 +405,7 @@ def request_output_to_completion_response( ) choices.append(choice_data) - num_prompt_tokens += len(prompt_token_ids) + num_prompt_tokens += len(final_res.prompt_token_ids) num_generated_tokens += sum( len(output.token_ids) for output in final_res.outputs) From f03b9396ffc2dc4a3bc1d4357cd5a2774e189186 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 8 Aug 2024 15:52:23 +0000 Subject: [PATCH 59/75] format --- vllm/entrypoints/openai/serving_completion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 5d6e481669ee..f2fa1bf4fee7 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -257,7 +257,8 @@ async def completion_stream_generator( *res.prompt_token_ids, *output.token_ids ] out_logprobs = [ - *res.prompt_logprobs, *(output.logprobs or []), + *res.prompt_logprobs, + *(output.logprobs or []), ] has_echoed[i] = True else: From 47baabdc6d9b4e9a8ff6af38840cdf15ce8ca9b0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 9 Aug 2024 02:52:08 +0000 Subject: [PATCH 60/75] Fix newly-introduced type errors --- vllm/entrypoints/api_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index f6e8a417b648..83c213f8cefa 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -61,6 +61,7 @@ async def generate(request: Request) -> Response: async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: prompt = request_output.prompt + assert prompt is not None text_outputs = [ prompt + output.text for output in request_output.outputs ] @@ -115,6 +116,7 @@ async def run_server(args: Namespace, logger.info("args: %s", args) app = await init_app(args, llm_engine) + assert engine is not None shutdown_task = await serve_http( app, From b8e69b76468e28a351be898428923e1682993d29 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 9 Aug 2024 02:52:43 +0000 Subject: [PATCH 61/75] fix --- vllm/entrypoints/api_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 83c213f8cefa..6127177b4d88 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -81,6 +81,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: assert final_output is not None prompt = final_output.prompt + assert prompt is not None text_outputs = [prompt + output.text for output in final_output.outputs] ret = {"text": text_outputs} return JSONResponse(ret) From eb7312e138aef072ba61a9e7973f894ca3b02dc2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 9 Aug 2024 10:26:46 +0000 Subject: [PATCH 62/75] Simplify --- vllm/entrypoints/openai/serving_completion.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index f2fa1bf4fee7..aaa026907653 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -229,6 +229,8 @@ async def completion_stream_generator( try: async for prompt_idx, res in result_generator: + prompt_token_ids = res.prompt_token_ids + prompt_logprobs = res.prompt_logprobs prompt_text = res.prompt assert prompt_text is not None @@ -245,19 +247,19 @@ async def completion_stream_generator( if request.echo and request.max_tokens == 0: # only return the prompt delta_text = prompt_text - delta_token_ids = res.prompt_token_ids - out_logprobs = res.prompt_logprobs + delta_token_ids = prompt_token_ids + out_logprobs = prompt_logprobs has_echoed[i] = True elif (request.echo and request.max_tokens > 0 and not has_echoed[i]): - assert res.prompt_logprobs is not None + assert prompt_logprobs is not None # echo the prompt and first token delta_text = prompt_text + output.text delta_token_ids = [ - *res.prompt_token_ids, *output.token_ids + *prompt_token_ids, *output.token_ids ] out_logprobs = [ - *res.prompt_logprobs, + *prompt_logprobs, *(output.logprobs or []), ] has_echoed[i] = True @@ -304,7 +306,7 @@ async def completion_stream_generator( and request.stream_options.include_usage): if (request.stream_options.continuous_usage_stats or output.finish_reason is not None): - prompt_tokens = len(res.prompt_token_ids) + prompt_tokens = len(prompt_token_ids) completion_tokens = len(output.token_ids) usage = UsageInfo( prompt_tokens=prompt_tokens, @@ -352,6 +354,8 @@ def request_output_to_completion_response( num_generated_tokens = 0 for final_res in final_res_batch: + prompt_token_ids = final_res.prompt_token_ids + prompt_logprobs = final_res.prompt_logprobs prompt_text = final_res.prompt assert prompt_text is not None @@ -362,21 +366,19 @@ def request_output_to_completion_response( for output in final_res.outputs: assert request.max_tokens is not None if request.echo and request.max_tokens == 0: - token_ids = final_res.prompt_token_ids - out_logprobs = final_res.prompt_logprobs + token_ids = prompt_token_ids + out_logprobs = prompt_logprobs output_text = prompt_text elif request.echo and request.max_tokens > 0: - token_ids = [ - *final_res.prompt_token_ids, *output.token_ids - ] + token_ids = [*prompt_token_ids, *output.token_ids] if request.logprobs is None: out_logprobs = None else: - assert final_res.prompt_logprobs is not None + assert prompt_logprobs is not None assert output.logprobs is not None out_logprobs = [ - *final_res.prompt_logprobs, + *prompt_logprobs, *output.logprobs, ] @@ -406,7 +408,7 @@ def request_output_to_completion_response( ) choices.append(choice_data) - num_prompt_tokens += len(final_res.prompt_token_ids) + num_prompt_tokens += len(prompt_token_ids) num_generated_tokens += sum( len(output.token_ids) for output in final_res.outputs) From 83fba8a6fd508cf5008018651ed2b79b08f654fa Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 12 Aug 2024 12:38:52 +0000 Subject: [PATCH 63/75] Avoid circular import --- vllm/transformers_utils/detokenizer.py | 3 ++- vllm/transformers_utils/tokenizer.py | 7 ++++--- .../tokenizer_group/base_tokenizer_group.py | 14 ++++++-------- .../tokenizer_group/ray_tokenizer_group.py | 3 ++- .../tokenizer_group/tokenizer_group.py | 15 +++++---------- 5 files changed, 19 insertions(+), 23 deletions(-) diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 06dfd59e3ff1..b7624c471cdb 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -2,7 +2,8 @@ from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from .tokenizer_group import AnyTokenizer, BaseTokenizerGroup +from .tokenizer import AnyTokenizer +from .tokenizer_group import BaseTokenizerGroup # Used eg. for marking rejected tokens in spec decoding. INVALID_TOKEN_ID = -1 diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index ed5ee226c530..0271aa809320 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -3,7 +3,8 @@ from typing import Optional, Union import huggingface_hub -from transformers import AutoTokenizer, PreTrainedTokenizerFast +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger @@ -11,10 +12,10 @@ from vllm.transformers_utils.tokenizers import BaichuanTokenizer from vllm.utils import make_async -from .tokenizer_group import AnyTokenizer - logger = init_logger(__name__) +AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: """Get tokenizer with cached properties. diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py index abbcdf2807f6..8f78ef65bbf1 100644 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -1,12 +1,9 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Union - -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from typing import List, Optional from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest - -AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +from vllm.transformers_utils.tokenizer import AnyTokenizer class BaseTokenizerGroup(ABC): @@ -24,9 +21,10 @@ def ping(self) -> bool: pass @abstractmethod - def get_max_input_len(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: + def get_max_input_len( + self, + lora_request: Optional[LoRARequest] = None, + ) -> Optional[int]: """Get the maximum input length for the LoRA request.""" pass diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index 79081c04ddc1..9a999a0d6067 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -13,8 +13,9 @@ from vllm.executor.ray_utils import ray from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import AnyTokenizer -from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup +from .base_tokenizer_group import BaseTokenizerGroup from .tokenizer_group import TokenizerGroup logger = init_logger(__name__) diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index e2c665871ae2..e516eeabaade 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -2,9 +2,13 @@ from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + get_lora_tokenizer, + get_lora_tokenizer_async, + get_tokenizer) from vllm.utils import LRUCache -from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup +from .base_tokenizer_group import BaseTokenizerGroup class TokenizerGroup(BaseTokenizerGroup): @@ -12,9 +16,6 @@ class TokenizerGroup(BaseTokenizerGroup): def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, max_input_length: Optional[int], **tokenizer_config): - # Avoid circular import - from vllm.transformers_utils.tokenizer import get_tokenizer - self.tokenizer_id = tokenizer_id self.tokenizer_config = tokenizer_config self.enable_lora = enable_lora @@ -73,9 +74,6 @@ def get_lora_tokenizer( self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: - # Avoid circular import - from vllm.transformers_utils.tokenizer import get_lora_tokenizer - if not lora_request or not self.enable_lora: return self.tokenizer if lora_request.lora_int_id not in self.lora_tokenizers: @@ -90,9 +88,6 @@ async def get_lora_tokenizer_async( self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: - # Avoid circular import - from vllm.transformers_utils.tokenizer import get_lora_tokenizer_async - if not lora_request or not self.enable_lora: return self.tokenizer if lora_request.lora_int_id not in self.lora_tokenizers: From 4e3f014e26878506e63b82ac7b735c46f3d39369 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 13 Aug 2024 02:26:17 +0000 Subject: [PATCH 64/75] Fix incorrect assertion --- vllm/entrypoints/openai/serving_completion.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index aaa026907653..dcec223c613e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -232,7 +232,6 @@ async def completion_stream_generator( prompt_token_ids = res.prompt_token_ids prompt_logprobs = res.prompt_logprobs prompt_text = res.prompt - assert prompt_text is not None delta_token_ids: GenericSequence[int] out_logprobs: Optional[GenericSequence[Optional[Dict[ @@ -245,6 +244,7 @@ async def completion_stream_generator( assert request.max_tokens is not None if request.echo and request.max_tokens == 0: + assert prompt_text is not None # only return the prompt delta_text = prompt_text delta_token_ids = prompt_token_ids @@ -252,6 +252,7 @@ async def completion_stream_generator( has_echoed[i] = True elif (request.echo and request.max_tokens > 0 and not has_echoed[i]): + assert prompt_text is not None assert prompt_logprobs is not None # echo the prompt and first token delta_text = prompt_text + output.text @@ -357,7 +358,6 @@ def request_output_to_completion_response( prompt_token_ids = final_res.prompt_token_ids prompt_logprobs = final_res.prompt_logprobs prompt_text = final_res.prompt - assert prompt_text is not None token_ids: GenericSequence[int] out_logprobs: Optional[GenericSequence[Optional[Dict[int, @@ -366,10 +366,12 @@ def request_output_to_completion_response( for output in final_res.outputs: assert request.max_tokens is not None if request.echo and request.max_tokens == 0: + assert prompt_text is not None token_ids = prompt_token_ids out_logprobs = prompt_logprobs output_text = prompt_text elif request.echo and request.max_tokens > 0: + assert prompt_text is not None token_ids = [*prompt_token_ids, *output.token_ids] if request.logprobs is None: From b9e8f00b101475a723a09b80a2d67949a0462f69 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 14 Aug 2024 16:18:40 +0000 Subject: [PATCH 65/75] Add type annotation --- vllm/entrypoints/openai/serving_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index cfd64f4a472c..b0f70ff43e22 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -183,7 +183,7 @@ async def create_embedding( return response - def _check_embedding_mode(self, embedding_mode: bool): + def _check_embedding_mode(self, embedding_mode: bool) -> bool: if not embedding_mode: logger.warning( "embedding_mode is False. Embedding API will not work.") From 516aa3bd282ada358c4c0804be04f454ed9c5057 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 16 Aug 2024 04:42:33 +0000 Subject: [PATCH 66/75] Clean up validation logic --- vllm/entrypoints/openai/protocol.py | 78 ++++++++++++------- vllm/entrypoints/openai/serving_chat.py | 10 --- vllm/entrypoints/openai/serving_completion.py | 9 --- 3 files changed, 51 insertions(+), 46 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 411fbd039172..c46f5cf8ce66 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -243,6 +243,10 @@ def to_sampling_params( if max_tokens is None: max_tokens = default_max_tokens + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.top_logprobs + # We now allow logprobs being true without top_logrobs. logits_processors = get_logits_processors( logit_bias=self.logit_bias, @@ -266,8 +270,7 @@ def to_sampling_params( stop=self.stop, stop_token_ids=self.stop_token_ids, logprobs=self.top_logprobs if self.logprobs else None, - prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else - (self.top_logprobs if self.echo else None), + prompt_logprobs=prompt_logprobs, ignore_eos=self.ignore_eos, max_tokens=max_tokens, min_tokens=self.min_tokens, @@ -281,14 +284,36 @@ def to_sampling_params( truncate_prompt_tokens=self.truncate_prompt_tokens, ) - @model_validator(mode='before') + @model_validator(mode="before") @classmethod - def validate_stream_options(cls, values): - if (values.get('stream_options') is not None - and not values.get('stream')): + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): raise ValueError( - "stream_options can only be set if stream is true") - return values + "Stream options can only be defined when `stream=True`.") + + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and prompt_logprobs > 0: + raise ValueError( + "`prompt_logprobs` are not available when `stream=True`.") + + if prompt_logprobs < 0: + raise ValueError("`prompt_logprobs` must be a positive value.") + + if (top_logprobs := data.get("top_logprobs")) is not None: + if top_logprobs < 0: + raise ValueError("`top_logprobs` must be a positive value.") + + if not data.get("logprobs"): + raise ValueError( + "when using `top_logprobs`, `logprobs` must be set to true." + ) + + return data @model_validator(mode="before") @classmethod @@ -321,19 +346,6 @@ def check_tool_choice(cls, data): "When using `tool_choice`, `tools` must be set.") return data - @model_validator(mode="before") - @classmethod - def check_logprobs(cls, data): - if "top_logprobs" in data and data["top_logprobs"] is not None: - if "logprobs" not in data or data["logprobs"] is False: - raise ValueError( - "when using `top_logprobs`, `logprobs` must be set to true." - ) - elif data["top_logprobs"] < 0: - raise ValueError( - "`top_logprobs` must be a value a positive value.") - return data - class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -430,6 +442,10 @@ def to_sampling_params( if max_tokens is None: max_tokens = default_max_tokens + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.logprobs + echo_without_generation = self.echo and self.max_tokens == 0 logits_processors = get_logits_processors( @@ -459,8 +475,7 @@ def to_sampling_params( min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, - prompt_logprobs=self.prompt_logprobs - if self.prompt_logprobs else self.logprobs if self.echo else None, + prompt_logprobs=prompt_logprobs, skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, @@ -486,9 +501,17 @@ def check_guided_decoding_count(cls, data): @model_validator(mode="before") @classmethod def check_logprobs(cls, data): - if "logprobs" in data and data[ - "logprobs"] is not None and not data["logprobs"] >= 0: - raise ValueError("if passed, `logprobs` must be a positive value.") + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and prompt_logprobs > 0: + raise ValueError( + "`prompt_logprobs` are not available when `stream=True`.") + + if prompt_logprobs < 0: + raise ValueError("`prompt_logprobs` must be a positive value.") + + if (logprobs := data.get("logprobs")) is not None and logprobs < 0: + raise ValueError("`logprobs` must be a positive value.") + return data @model_validator(mode="before") @@ -496,7 +519,8 @@ def check_logprobs(cls, data): def validate_stream_options(cls, data): if data.get("stream_options") and not data.get("stream"): raise ValueError( - "Stream options can only be defined when stream is true.") + "Stream options can only be defined when `stream=True`.") + return data diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 57d54e28c23a..4d8e240a88ee 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -83,16 +83,6 @@ async def create_chat_completion( if error_check_ret is not None: return error_check_ret - if request.prompt_logprobs is not None: - if request.stream and request.prompt_logprobs > 0: - return self.create_error_response( - "Prompt_logprobs are not available when stream is enabled") - - if request.prompt_logprobs < 0: - return self.create_error_response( - f"Prompt_logprobs set to invalid " - f"negative value: {request.prompt_logprobs}") - try: ( lora_request, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 6a128b7379c8..34f1200753f8 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -87,15 +87,6 @@ async def create_completion( request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) - if request.prompt_logprobs is not None: - if request.stream and request.prompt_logprobs > 0: - return self.create_error_response( - "Prompt_logprobs are not available when stream is enabled") - elif request.prompt_logprobs < 0: - return self.create_error_response( - f"Prompt_logprobs set to invalid negative " - f"value: {request.prompt_logprobs}") - # Schedule the request and get the result generator. generators: List[AsyncGenerator[RequestOutput, None]] = [] try: From f4af3040f1f3d1bedcc0b5367a1c4f3a786081e8 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 16 Aug 2024 07:17:20 +0000 Subject: [PATCH 67/75] Update tests --- tests/entrypoints/openai/test_completion.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 4d0c6d73518d..0fe35197128b 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -298,14 +298,9 @@ async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, if prompt_logprobs is not None: params["extra_body"] = {"prompt_logprobs": prompt_logprobs} - if prompt_logprobs and prompt_logprobs < 0: - with pytest.raises(BadRequestError) as err_info: + if prompt_logprobs is not None and prompt_logprobs < 0: + with pytest.raises(BadRequestError): await client.chat.completions.create(**params) - expected_err_string = ( - "Error code: 400 - {'object': 'error', 'message': " - "'Prompt_logprobs set to invalid negative value: -1'," - " 'type': 'BadRequestError', 'param': None, 'code': 400}") - assert str(err_info.value) == expected_err_string else: completion = await client.chat.completions.create(**params) if prompt_logprobs and prompt_logprobs > 0: @@ -369,14 +364,9 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, if prompt_logprobs is not None: params["extra_body"] = {"prompt_logprobs": prompt_logprobs} - if prompt_logprobs and prompt_logprobs < 0: - with pytest.raises(BadRequestError) as err_info: + if prompt_logprobs is not None and prompt_logprobs < 0: + with pytest.raises(BadRequestError): await client.completions.create(**params) - expected_err_string = ( - "Error code: 400 - {'object': 'error', 'message': " - "'Prompt_logprobs set to invalid negative value: -1'," - " 'type': 'BadRequestError', 'param': None, 'code': 400}") - assert str(err_info.value) == expected_err_string else: completion = await client.completions.create(**params) if prompt_logprobs and prompt_logprobs > 0: From 9cd8fb5edacebf8e5b2500c5dcfe68cfb025e928 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 18 Aug 2024 01:57:05 +0000 Subject: [PATCH 68/75] Fix type error --- vllm/entrypoints/openai/rpc/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index de1309b4a46a..770ee77926df 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -39,7 +39,7 @@ def cleanup(self): self.context.destroy() self.engine.shutdown_background_loop() # Clear the engine reference so that it can be GC'ed. - self.engine = None + del self.engine async def get_model_config(self, identity): """Send the ModelConfig""" From 7e09fc8c674860f6e9ce34324dfb6c32e659db2a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 18 Aug 2024 07:01:30 +0000 Subject: [PATCH 69/75] Clean up parsing logic --- vllm/entrypoints/chat_utils.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 4c9730310430..49a913deec49 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -3,7 +3,7 @@ from functools import lru_cache from pathlib import Path from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple, - Union, cast) + Union) # yapf conflicts with isort for this block # yapf: disable @@ -15,8 +15,8 @@ ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) # yapf: enable # pydantic needs the TypedDict from typing_extensions -from pydantic import ConfigDict -from typing_extensions import Required, TypedDict +from pydantic import ConfigDict, TypeAdapter +from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger @@ -49,9 +49,11 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): """The type of the content part.""" -ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam, - ChatCompletionContentPartAudioParam, - CustomChatCompletionContentPartParam] +ChatCompletionContentPartParam: TypeAlias = Union[ + OpenAIChatCompletionContentPartParam, + ChatCompletionContentPartAudioParam, + CustomChatCompletionContentPartParam, +] class CustomChatCompletionMessageParam(TypedDict, total=False): @@ -150,6 +152,11 @@ def _get_full_multimodal_text_prompt(placeholder_token_str: str, return f"{placeholder_token_str}\n{text_prompt}" +_TextParser = TypeAdapter(ChatCompletionContentPartTextParam) +_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam) +_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam) + + def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], @@ -163,7 +170,7 @@ def _parse_chat_message_content_parts( for part in parts: part_type = part["type"] if part_type == "text": - text = cast(ChatCompletionContentPartTextParam, part)["text"] + text = _TextParser.validate_python(part)["text"] texts.append(text) elif part_type == "image_url": modality = "image" @@ -171,8 +178,7 @@ def _parse_chat_message_content_parts( raise NotImplementedError( "Multiple multimodal inputs is currently not supported.") - image_url = cast(ChatCompletionContentPartImageParam, - part)["image_url"] + image_url = _ImageParser.validate_python(part)["image_url"] if image_url.get("detail", "auto") != "auto": logger.warning( @@ -187,8 +193,7 @@ def _parse_chat_message_content_parts( raise NotImplementedError( "Multiple multimodal inputs is currently not supported.") - audio_url = cast(ChatCompletionContentPartAudioParam, - part)["audio_url"] + audio_url = _AudioParser.validate_python(part)["audio_url"] audio_future = async_get_and_parse_audio(audio_url["url"]) mm_futures.append(audio_future) else: From b7dc954d81739647682565a1ead5d39e5c9afe40 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 18 Aug 2024 07:16:28 +0000 Subject: [PATCH 70/75] format --- vllm/entrypoints/chat_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 49a913deec49..48fd1333d8f4 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -50,10 +50,8 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): ChatCompletionContentPartParam: TypeAlias = Union[ - OpenAIChatCompletionContentPartParam, - ChatCompletionContentPartAudioParam, - CustomChatCompletionContentPartParam, -] + OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, + CustomChatCompletionContentPartParam, ] class CustomChatCompletionMessageParam(TypedDict, total=False): From 78161b4d8722c9db5b7f68edc41dfbde4197430e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 18 Aug 2024 23:26:45 +0000 Subject: [PATCH 71/75] Remote quotes --- vllm/engine/async_llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index e86551239401..3db1dab9bb5f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -652,7 +652,7 @@ def _error_callback(self, exc: Exception) -> None: async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, - ) -> "AnyTokenizer": + ) -> AnyTokenizer: if self.engine_use_ray: return await self.engine.get_tokenizer.remote( # type: ignore lora_request) From 0a9274adda0b5c6c35ab0e2b183fe2be884ecbbd Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 18 Aug 2024 23:28:25 +0000 Subject: [PATCH 72/75] Add fallback --- vllm/entrypoints/openai/cli_args.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 5330040e1bb6..94742838b421 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -24,7 +24,9 @@ def __call__( values: Optional[Union[str, Sequence[str]]], option_string: Optional[str] = None, ): - if values is None or isinstance(values, str): + if values is None: + values = [] + if isinstance(values, str): raise TypeError("Expected values to be a list") lora_list: List[LoRAModulePath] = [] @@ -43,7 +45,9 @@ def __call__( values: Optional[Union[str, Sequence[str]]], option_string: Optional[str] = None, ): - if values is None or isinstance(values, str): + if values is None: + values = [] + if isinstance(values, str): raise TypeError("Expected values to be a list") adapter_list: List[PromptAdapterPath] = [] From 1e89169735aad2b329ca3a6b558d0e502a0fd9b6 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 18 Aug 2024 23:38:15 +0000 Subject: [PATCH 73/75] Update tests --- tests/entrypoints/openai/test_completion.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 0fe35197128b..8a88420452ec 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -3,7 +3,7 @@ import re import shutil from tempfile import TemporaryDirectory -from typing import Dict, List +from typing import Dict, List, Optional import jsonschema import openai # use the official client for correctness check @@ -274,7 +274,8 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, [(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)], ) async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, - model_name: str, prompt_logprobs: int): + model_name: str, + prompt_logprobs: Optional[int]): params: Dict = { "messages": [{ "role": "system", @@ -303,7 +304,7 @@ async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, await client.chat.completions.create(**params) else: completion = await client.chat.completions.create(**params) - if prompt_logprobs and prompt_logprobs > 0: + if prompt_logprobs is not None: assert completion.prompt_logprobs is not None assert len(completion.prompt_logprobs) > 0 else: @@ -356,7 +357,7 @@ async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, (MODEL_NAME, None)]) async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, model_name: str, - prompt_logprobs: int): + prompt_logprobs: Optional[int]): params: Dict = { "prompt": ["A robot may not injure another robot", "My name is"], "model": model_name, @@ -369,7 +370,7 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, await client.completions.create(**params) else: completion = await client.completions.create(**params) - if prompt_logprobs and prompt_logprobs > 0: + if prompt_logprobs is not None: assert completion.choices[0].prompt_logprobs is not None assert len(completion.choices[0].prompt_logprobs) > 0 From e2ec43c6c0f9c51b6e3387abf9b18664bead5e0d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 18 Aug 2024 23:39:26 +0000 Subject: [PATCH 74/75] Move chat tests to the correct file --- tests/entrypoints/openai/test_chat.py | 84 ++++++++++++++++++++- tests/entrypoints/openai/test_completion.py | 82 -------------------- 2 files changed, 83 insertions(+), 83 deletions(-) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index c96d602b6343..afcb0f44befc 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -1,7 +1,7 @@ # imports for guided decoding tests import json import re -from typing import List +from typing import Dict, List, Optional import jsonschema import openai # use the official client for correctness check @@ -174,6 +174,88 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name, prompt_logprobs", + [(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)], +) +async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, + model_name: str, + prompt_logprobs: Optional[int]): + params: Dict = { + "messages": [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Who won the world series in 2020?" + }, { + "role": + "assistant", + "content": + "The Los Angeles Dodgers won the World Series in 2020." + }, { + "role": "user", + "content": "Where was it played?" + }], + "model": + model_name + } + + if prompt_logprobs is not None: + params["extra_body"] = {"prompt_logprobs": prompt_logprobs} + + if prompt_logprobs is not None and prompt_logprobs < 0: + with pytest.raises(BadRequestError): + await client.chat.completions.create(**params) + else: + completion = await client.chat.completions.create(**params) + if prompt_logprobs is not None: + assert completion.prompt_logprobs is not None + assert len(completion.prompt_logprobs) > 0 + else: + assert completion.prompt_logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, + model_name: str): + params: Dict = { + "messages": [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Who won the world series in 2020?" + }, { + "role": + "assistant", + "content": + "The Los Angeles Dodgers won the World Series in 2020." + }, { + "role": "user", + "content": "Where was it played?" + }], + "model": + model_name, + "extra_body": { + "prompt_logprobs": 1 + } + } + + completion_1 = await client.chat.completions.create(**params) + + params["extra_body"] = {"prompt_logprobs": 2} + completion_2 = await client.chat.completions.create(**params) + + assert len(completion_1.prompt_logprobs[3]) == 1 + assert len(completion_2.prompt_logprobs[3]) == 2 + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 8a88420452ec..18f41f5fc671 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -268,88 +268,6 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, assert len(completion.choices[0].text) >= 0 -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name, prompt_logprobs", - [(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)], -) -async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, - model_name: str, - prompt_logprobs: Optional[int]): - params: Dict = { - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Who won the world series in 2020?" - }, { - "role": - "assistant", - "content": - "The Los Angeles Dodgers won the World Series in 2020." - }, { - "role": "user", - "content": "Where was it played?" - }], - "model": - model_name - } - - if prompt_logprobs is not None: - params["extra_body"] = {"prompt_logprobs": prompt_logprobs} - - if prompt_logprobs is not None and prompt_logprobs < 0: - with pytest.raises(BadRequestError): - await client.chat.completions.create(**params) - else: - completion = await client.chat.completions.create(**params) - if prompt_logprobs is not None: - assert completion.prompt_logprobs is not None - assert len(completion.prompt_logprobs) > 0 - else: - assert completion.prompt_logprobs is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME], -) -async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, - model_name: str): - params: Dict = { - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Who won the world series in 2020?" - }, { - "role": - "assistant", - "content": - "The Los Angeles Dodgers won the World Series in 2020." - }, { - "role": "user", - "content": "Where was it played?" - }], - "model": - model_name, - "extra_body": { - "prompt_logprobs": 1 - } - } - - completion_1 = await client.chat.completions.create(**params) - - params["extra_body"] = {"prompt_logprobs": 2} - completion_2 = await client.chat.completions.create(**params) - - assert len(completion_1.prompt_logprobs[3]) == 1 - assert len(completion_2.prompt_logprobs[3]) == 2 - - @pytest.mark.asyncio @pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1), (MODEL_NAME, 0), From 1f9ea921490bde95a9edc986d84752bc424f4010 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 19 Aug 2024 04:38:37 +0000 Subject: [PATCH 75/75] Update pydantic version --- docs/requirements-docs.txt | 2 +- requirements-common.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 9a5964ec65b9..59f08aac99c1 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -5,7 +5,7 @@ myst-parser==2.0.0 sphinx-argparse==0.4.0 # packages to install to build the documentation -pydantic +pydantic >= 2.8 -f https://download.pytorch.org/whl/cpu torch py-cpuinfo diff --git a/requirements-common.txt b/requirements-common.txt index b6bed8a73d8c..534d63feec2b 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -11,7 +11,7 @@ fastapi aiohttp openai >= 1.0 # Ensure modern openai package (ensure types module present) uvicorn[standard] -pydantic >= 2.0 # Required for OpenAI server. +pydantic >= 2.8 # Required for OpenAI server. pillow # Required for image processing prometheus_client >= 0.18.0 prometheus-fastapi-instrumentator >= 7.0.0