diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 7902011519d9..052d5793c1b3 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -1,14 +1,118 @@ # SPDX-License-Identifier: Apache-2.0 +import json from argparse import ArgumentError, ArgumentTypeError +from contextlib import nullcontext +from dataclasses import dataclass, field +from typing import Literal, Optional import pytest -from vllm.config import PoolerConfig -from vllm.engine.arg_utils import EngineArgs, nullable_kvs +from vllm.config import PoolerConfig, config +from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, + get_type, is_not_builtin, is_type, + nullable_kvs, optional_type) from vllm.utils import FlexibleArgumentParser +@pytest.mark.parametrize(("type", "value", "expected"), [ + (int, "42", 42), + (int, "None", None), + (float, "3.14", 3.14), + (float, "None", None), + (str, "Hello World!", "Hello World!"), + (str, "None", None), + (json.loads, '{"foo":1,"bar":2}', { + "foo": 1, + "bar": 2 + }), + (json.loads, "foo=1,bar=2", { + "foo": 1, + "bar": 2 + }), + (json.loads, "None", None), +]) +def test_optional_type(type, value, expected): + optional_type_func = optional_type(type) + context = nullcontext() + if value == "foo=1,bar=2": + context = pytest.warns(DeprecationWarning) + with context: + assert optional_type_func(value) == expected + + +@pytest.mark.parametrize(("type_hint", "type", "expected"), [ + (int, int, True), + (int, float, False), + (list[int], list, True), + (list[int], tuple, False), + (Literal[0, 1], Literal, True), +]) +def test_is_type(type_hint, type, expected): + assert is_type(type_hint, type) == expected + + +@pytest.mark.parametrize(("type_hints", "type", "expected"), [ + ({float, int}, int, True), + ({int, tuple[int]}, int, True), + ({int, tuple[int]}, float, False), + ({str, Literal["x", "y"]}, Literal, True), +]) +def test_contains_type(type_hints, type, expected): + assert contains_type(type_hints, type) == expected + + +@pytest.mark.parametrize(("type_hints", "type", "expected"), [ + ({int, float}, int, int), + ({int, float}, str, None), + ({str, Literal["x", "y"]}, Literal, Literal["x", "y"]), +]) +def test_get_type(type_hints, type, expected): + assert get_type(type_hints, type) == expected + + +@config +@dataclass +class DummyConfigClass: + regular_bool: bool = True + """Regular bool with default True""" + optional_bool: Optional[bool] = None + """Optional bool with default None""" + optional_literal: Optional[Literal["x", "y"]] = None + """Optional literal with default None""" + tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3)) + """Tuple with default (1, 2, 3)""" + tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2)) + """Tuple with default (1, 2)""" + list_n: list[int] = field(default_factory=lambda: [1, 2, 3]) + """List with default [1, 2, 3]""" + + +@pytest.mark.parametrize(("type_hint", "expected"), [ + (int, False), + (DummyConfigClass, True), +]) +def test_is_not_builtin(type_hint, expected): + assert is_not_builtin(type_hint) == expected + + +def test_get_kwargs(): + kwargs = get_kwargs(DummyConfigClass) + print(kwargs) + + # bools should not have their type set + assert kwargs["regular_bool"].get("type") is None + assert kwargs["optional_bool"].get("type") is None + # optional literals should have None as a choice + assert kwargs["optional_literal"]["choices"] == ["x", "y", "None"] + # tuples should have the correct nargs + assert kwargs["tuple_n"]["nargs"] == "+" + assert kwargs["tuple_2"]["nargs"] == 2 + # lists should work + assert kwargs["list_n"]["type"] is int + assert kwargs["list_n"]["nargs"] == "+" + + @pytest.mark.parametrize(("arg", "expected"), [ (None, dict()), ("image=16", { diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1c966469db00..5d735103fc03 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -11,7 +11,7 @@ TypeVar, Union, cast, get_args, get_origin) import torch -from typing_extensions import TypeIs +from typing_extensions import TypeIs, deprecated import vllm.envs as envs from vllm import version @@ -48,33 +48,29 @@ TypeHintT = Union[type[T], object] -def optional_arg(val: str, return_type: Callable[[str], T]) -> Optional[T]: - if val == "" or val == "None": - return None - try: - return return_type(val) - except ValueError as e: - raise argparse.ArgumentTypeError( - f"Value {val} cannot be converted to {return_type}.") from e - - -def optional_str(val: str) -> Optional[str]: - return optional_arg(val, str) - - -def optional_int(val: str) -> Optional[int]: - return optional_arg(val, int) +def optional_type( + return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: + def _optional_type(val: str) -> Optional[T]: + if val == "" or val == "None": + return None + try: + if return_type is json.loads and not re.match("^{.*}$", val): + return cast(T, nullable_kvs(val)) + return return_type(val) + except ValueError as e: + raise argparse.ArgumentTypeError( + f"Value {val} cannot be converted to {return_type}.") from e -def optional_float(val: str) -> Optional[float]: - return optional_arg(val, float) + return _optional_type -def nullable_kvs(val: str) -> Optional[dict[str, int]]: - """NOTE: This function is deprecated, args should be passed as JSON - strings instead. - - Parses a string containing comma separate key [str] to value [int] +@deprecated( + "Passing a JSON argument as a string containing comma separated key=value " + "pairs is deprecated. This will be removed in v0.10.0. Please use a JSON " + "string instead.") +def nullable_kvs(val: str) -> dict[str, int]: + """Parses a string containing comma separate key [str] to value [int] pairs into a dictionary. Args: @@ -83,10 +79,7 @@ def nullable_kvs(val: str) -> Optional[dict[str, int]]: Returns: Dictionary with parsed values. """ - if len(val) == 0: - return None - - out_dict: Dict[str, int] = {} + out_dict: dict[str, int] = {} for item in val.split(","): kv_parts = [part.lower().strip() for part in item.split("=")] if len(kv_parts) != 2: @@ -108,15 +101,103 @@ def nullable_kvs(val: str) -> Optional[dict[str, int]]: return out_dict -def optional_dict(val: str) -> Optional[dict[str, int]]: - if re.match("^{.*}$", val): - return optional_arg(val, json.loads) +def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]: + """Check if the type hint is a specific type.""" + return type_hint is type or get_origin(type_hint) is type + + +def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool: + """Check if the type hints contain a specific type.""" + return any(is_type(type_hint, type) for type_hint in type_hints) + - logger.warning( - "Failed to parse JSON string. Attempting to parse as " - "comma-separated key=value pairs. This will be deprecated in a " - "future release.") - return nullable_kvs(val) +def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT: + """Get the specific type from the type hints.""" + return next((th for th in type_hints if is_type(th, type)), None) + + +def is_not_builtin(type_hint: TypeHint) -> bool: + """Check if the class is not a built-in type.""" + return type_hint.__module__ != "builtins" + + +def get_kwargs(cls: ConfigType) -> dict[str, Any]: + cls_docs = get_attr_docs(cls) + kwargs = {} + for field in fields(cls): + # Get the default value of the field + default = field.default + if field.default_factory is not MISSING: + default = field.default_factory() + + # Get the help text for the field + name = field.name + help = cls_docs[name] + # Escape % for argparse + help = help.replace("%", "%%") + + # Initialise the kwargs dictionary for the field + kwargs[name] = {"default": default, "help": help} + + # Get the set of possible types for the field + type_hints: set[TypeHint] = set() + if get_origin(field.type) is Union: + type_hints.update(get_args(field.type)) + else: + type_hints.add(field.type) + + # Set other kwargs based on the type hints + if contains_type(type_hints, bool): + # Creates --no- and -- flags + kwargs[name]["action"] = argparse.BooleanOptionalAction + elif contains_type(type_hints, Literal): + # Creates choices from Literal arguments + type_hint = get_type(type_hints, Literal) + choices = sorted(get_args(type_hint)) + kwargs[name]["choices"] = choices + choice_type = type(choices[0]) + assert all(type(c) is choice_type for c in choices), ( + "All choices must be of the same type. " + f"Got {choices} with types {[type(c) for c in choices]}") + kwargs[name]["type"] = choice_type + elif contains_type(type_hints, tuple): + type_hint = get_type(type_hints, tuple) + types = get_args(type_hint) + tuple_type = types[0] + assert all(t is tuple_type for t in types if t is not Ellipsis), ( + "All non-Ellipsis tuple elements must be of the same " + f"type. Got {types}.") + kwargs[name]["type"] = tuple_type + kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types) + elif contains_type(type_hints, list): + type_hint = get_type(type_hints, list) + types = get_args(type_hint) + assert len(types) == 1, ( + "List type must have exactly one type. Got " + f"{type_hint} with types {types}") + kwargs[name]["type"] = types[0] + kwargs[name]["nargs"] = "+" + elif contains_type(type_hints, int): + kwargs[name]["type"] = int + elif contains_type(type_hints, float): + kwargs[name]["type"] = float + elif contains_type(type_hints, dict): + # Dict arguments will always be optional + kwargs[name]["type"] = optional_type(json.loads) + elif (contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints)): + kwargs[name]["type"] = str + else: + raise ValueError( + f"Unsupported type {type_hints} for argument {name}.") + + # If None is in type_hints, make the argument optional. + # But not if it's a bool, argparse will handle this better. + if type(None) in type_hints and not contains_type(type_hints, bool): + kwargs[name]["type"] = optional_type(kwargs[name]["type"]) + if kwargs[name].get("choices"): + kwargs[name]["choices"].append("None") + return kwargs @dataclass @@ -279,100 +360,6 @@ def __post_init__(self): def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: """Shared CLI arguments for vLLM engine.""" - def is_type_in_union(cls: TypeHint, type: TypeHint) -> bool: - """Check if the class is a type in a union type.""" - is_union = get_origin(cls) is Union - type_in_union = type in [get_origin(a) or a for a in get_args(cls)] - return is_union and type_in_union - - def get_type_from_union(cls: TypeHint, type: TypeHintT) -> TypeHintT: - """Get the type in a union type.""" - for arg in get_args(cls): - if (get_origin(arg) or arg) is type: - return arg - raise ValueError(f"Type {type} not found in union type {cls}.") - - def is_optional(cls: TypeHint) -> TypeIs[Union[Any, None]]: - """Check if the class is an optional type.""" - return is_type_in_union(cls, type(None)) - - def can_be_type(cls: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]: - """Check if the class can be of type.""" - return cls is type or get_origin(cls) is type or is_type_in_union( - cls, type) - - def is_custom_type(cls: TypeHint) -> bool: - """Check if the class is a custom type.""" - return cls.__module__ != "builtins" - - def get_kwargs(cls: ConfigType) -> dict[str, Any]: - cls_docs = get_attr_docs(cls) - kwargs = {} - for field in fields(cls): - # Get the default value of the field - default = field.default - if field.default_factory is not MISSING: - default = field.default_factory() - - # Get the help text for the field - name = field.name - help = cls_docs[name] - # Escape % for argparse - help = help.replace("%", "%%") - - # Initialise the kwargs dictionary for the field - kwargs[name] = {"default": default, "help": help} - - # Make note of if the field is optional and get the actual - # type of the field if it is - optional = is_optional(field.type) - field_type = get_args( - field.type)[0] if optional else field.type - - # Set type, action and choices for the field depending on the - # type of the field - if can_be_type(field_type, bool): - # Creates --no- and -- flags - kwargs[name]["action"] = argparse.BooleanOptionalAction - kwargs[name]["type"] = bool - elif can_be_type(field_type, Literal): - # Creates choices from Literal arguments - if is_type_in_union(field_type, Literal): - field_type = get_type_from_union(field_type, Literal) - choices = get_args(field_type) - kwargs[name]["choices"] = choices - choice_type = type(choices[0]) - assert all(type(c) is choice_type for c in choices), ( - "All choices must be of the same type. " - f"Got {choices} with types {[type(c) for c in choices]}" - ) - kwargs[name]["type"] = choice_type - elif can_be_type(field_type, tuple): - if is_type_in_union(field_type, tuple): - field_type = get_type_from_union(field_type, tuple) - dtypes = get_args(field_type) - dtype = dtypes[0] - assert all( - d is dtype for d in dtypes if d is not Ellipsis - ), ("All non-Ellipsis tuple elements must be of the same " - f"type. Got {dtypes}.") - kwargs[name]["type"] = dtype - kwargs[name]["nargs"] = "+" - elif can_be_type(field_type, int): - kwargs[name]["type"] = optional_int if optional else int - elif can_be_type(field_type, float): - kwargs[name][ - "type"] = optional_float if optional else float - elif can_be_type(field_type, dict): - kwargs[name]["type"] = optional_dict - elif (can_be_type(field_type, str) - or is_custom_type(field_type)): - kwargs[name]["type"] = optional_str if optional else str - else: - raise ValueError( - f"Unsupported type {field.type} for argument {name}. ") - return kwargs - # Model arguments parser.add_argument( '--model', @@ -390,13 +377,13 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: 'which task to use.') parser.add_argument( '--tokenizer', - type=optional_str, + type=optional_type(str), default=EngineArgs.tokenizer, help='Name or path of the huggingface tokenizer to use. ' 'If unspecified, model name or path will be used.') parser.add_argument( "--hf-config-path", - type=optional_str, + type=optional_type(str), default=EngineArgs.hf_config_path, help='Name or path of the huggingface config to use. ' 'If unspecified, model name or path will be used.') @@ -408,21 +395,21 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: 'the input. The generated output will contain token ids.') parser.add_argument( '--revision', - type=optional_str, + type=optional_type(str), default=None, help='The specific model version to use. It can be a branch ' 'name, a tag name, or a commit id. If unspecified, will use ' 'the default version.') parser.add_argument( '--code-revision', - type=optional_str, + type=optional_type(str), default=None, help='The specific revision to use for the model code on ' 'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'commit id. If unspecified, will use the default version.') parser.add_argument( '--tokenizer-revision', - type=optional_str, + type=optional_type(str), default=None, help='Revision of the huggingface tokenizer to use. ' 'It can be a branch name, a tag name, or a commit id. ' @@ -513,7 +500,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: parser.add_argument( '--logits-processor-pattern', - type=optional_str, + type=optional_type(str), default=None, help='Optional regex pattern specifying valid logits processor ' 'qualified names that can be passed with the `logits_processors` ' @@ -612,7 +599,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: # Quantization settings. parser.add_argument('--quantization', '-q', - type=optional_str, + type=optional_type(str), choices=[*QUANTIZATION_METHODS, None], default=EngineArgs.quantization, help='Method used to quantize the weights. If ' @@ -921,7 +908,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: 'class without changing the existing functions.') parser.add_argument( "--generation-config", - type=optional_str, + type=optional_type(str), default="auto", help="The folder path to the generation config. " "Defaults to 'auto', the generation config will be loaded from " diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index af546c3032af..b3824013f055 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -11,7 +11,7 @@ from collections.abc import Sequence from typing import Optional, Union, get_args -from vllm.engine.arg_utils import AsyncEngineArgs, optional_str +from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) from vllm.entrypoints.openai.serving_models import (LoRAModulePath, @@ -79,7 +79,7 @@ def __call__( def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument("--host", - type=optional_str, + type=optional_type(str), default=None, help="Host name.") parser.add_argument("--port", type=int, default=8000, help="Port number.") @@ -108,13 +108,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=["*"], help="Allowed headers.") parser.add_argument("--api-key", - type=optional_str, + type=optional_type(str), default=None, help="If provided, the server will require this key " "to be presented in the header.") parser.add_argument( "--lora-modules", - type=optional_str, + type=optional_type(str), default=None, nargs='+', action=LoRAParserAction, @@ -126,14 +126,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "\"base_model_name\": \"id\"}``") parser.add_argument( "--prompt-adapters", - type=optional_str, + type=optional_type(str), default=None, nargs='+', action=PromptAdapterParserAction, help="Prompt adapter configurations in the format name=path. " "Multiple adapters can be specified.") parser.add_argument("--chat-template", - type=optional_str, + type=optional_type(str), default=None, help="The file path to the chat template, " "or the template in single-line form " @@ -151,20 +151,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'similar to OpenAI schema. ' 'Example: ``[{"type": "text", "text": "Hello world!"}]``') parser.add_argument("--response-role", - type=optional_str, + type=optional_type(str), default="assistant", help="The role name to return if " "``request.add_generation_prompt=true``.") parser.add_argument("--ssl-keyfile", - type=optional_str, + type=optional_type(str), default=None, help="The file path to the SSL key file.") parser.add_argument("--ssl-certfile", - type=optional_str, + type=optional_type(str), default=None, help="The file path to the SSL cert file.") parser.add_argument("--ssl-ca-certs", - type=optional_str, + type=optional_type(str), default=None, help="The CA certificates file.") parser.add_argument( @@ -180,13 +180,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) parser.add_argument( "--root-path", - type=optional_str, + type=optional_type(str), default=None, help="FastAPI root_path when app is behind a path based routing proxy." ) parser.add_argument( "--middleware", - type=optional_str, + type=optional_type(str), action="append", default=[], help="Additional ASGI middleware to apply to the app. " diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 3ffa5a32c173..fccf459f17dc 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -12,7 +12,7 @@ from prometheus_client import start_http_server from tqdm import tqdm -from vllm.engine.arg_utils import AsyncEngineArgs, optional_str +from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.logger import RequestLogger, logger # yapf: disable @@ -61,7 +61,7 @@ def parse_args(): "to the output URL.", ) parser.add_argument("--response-role", - type=optional_str, + type=optional_type(str), default="assistant", help="The role name to return if " "`request.add_generation_prompt=True`.")