diff --git a/tests/conftest.py b/tests/conftest.py index 5fc09b241e23..f02b5a8c0520 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -738,7 +738,7 @@ class VllmRunner: - `block_size`: Set to `16` instead of `None` to reduce memory usage. - `enable_chunked_prefill`: Set to `False` instead of `None` for test reproducibility. - - `enforce_eager`: Set to `False` instead of `None` to test CUDA graph. + - `enforce_eager`: Set to `False` to test CUDA graph. """ def __init__( diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 2c86658022c0..16721ee9ce74 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -8,7 +8,7 @@ import pytest -from vllm.config import PoolerConfig, config +from vllm.config import config from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, get_type, is_not_builtin, is_type, literal_to_kwargs, nullable_kvs, @@ -222,17 +222,6 @@ def test_prefix_cache_default(): assert not engine_args.enable_prefix_caching -def test_valid_pooling_config(): - parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) - args = parser.parse_args([ - '--override-pooler-config', - '{"pooling_type": "MEAN"}', - ]) - engine_args = EngineArgs.from_cli_args(args=args) - assert engine_args.override_pooler_config == PoolerConfig( - pooling_type="MEAN", ) - - @pytest.mark.parametrize( ("arg"), [ diff --git a/tests/quantization/test_register_quantization_config.py b/tests/quantization/test_register_quantization_config.py index abc1c05de3c0..0ea71aaf828b 100644 --- a/tests/quantization/test_register_quantization_config.py +++ b/tests/quantization/test_register_quantization_config.py @@ -14,7 +14,7 @@ from vllm.model_executor.layers.linear import LinearBase # noqa: E501 from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization import ( - get_quantization_config, register_quantization_config) + QuantizationMethods, get_quantization_config, register_quantization_config) from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) @@ -54,7 +54,7 @@ def __init__(self, num_bits: int = 8) -> None: """Initialize the quantization config.""" self.num_bits = num_bits - def get_name(self) -> str: + def get_name(self) -> QuantizationMethods: """Name of the quantization method.""" return "custom_quant" diff --git a/tests/test_config.py b/tests/test_config.py index 2e5da8128d99..f2155d954db0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -185,7 +185,7 @@ def test_get_pooling_config(): revision=None, ) - pooling_config = model_config._init_pooler_config(None) + pooling_config = model_config._init_pooler_config() assert pooling_config is not None assert pooling_config.normalize @@ -205,11 +205,12 @@ def test_get_pooling_config_from_args(): dtype="float16", revision=None) - override_config = PoolerConfig(pooling_type='CLS', normalize=True) + override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True) + model_config.override_pooler_config = override_pooler_config - pooling_config = model_config._init_pooler_config(override_config) + pooling_config = model_config._init_pooler_config() assert pooling_config is not None - assert asdict(pooling_config) == asdict(override_config) + assert asdict(pooling_config) == asdict(override_pooler_config) @pytest.mark.skipif(current_platform.is_rocm(), diff --git a/vllm/config.py b/vllm/config.py index abe59734e2d6..f9c5e25a47d4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -16,9 +16,8 @@ replace) from importlib.util import find_spec from pathlib import Path -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, - Optional, Protocol, TypeVar, Union, cast, get_args, - get_origin) +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, + Protocol, TypeVar, Union, cast, get_args, get_origin) import torch from pydantic import BaseModel, Field, PrivateAttr @@ -211,103 +210,190 @@ def get_field(cls: ConfigType, name: str) -> Field: f"{cls.__name__}.{name} must have a default value or default factory.") -class ModelConfig: - """Configuration for the model. +TokenizerMode = Literal["auto", "slow", "mistral", "custom"] +ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] - Args: - model: Name or path of the huggingface model to use. - It is also used as the content for `model_name` tag in metrics - output when `served_model_name` is not specified. - task: The task to use the model for. Each vLLM instance only supports - one task, even if the same model can be used for multiple tasks. - When the model only supports one task, "auto" can be used to select - it; otherwise, you must specify explicitly which task to use. - tokenizer: Name or path of the huggingface tokenizer to use. - tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if - available, "slow" will always use the slow tokenizer, - "mistral" will always use the tokenizer from `mistral_common`, and - "custom" will use --tokenizer to select the preregistered tokenizer. - trust_remote_code: Trust remote code (e.g., from HuggingFace) when - downloading the model and tokenizer. - allowed_local_media_path: Allowing API requests to read local images or - videos from directories specified by the server file system. - This is a security risk. Should only be enabled in trusted - environments. - dtype: Data type for model weights and activations. The "auto" option - will use FP16 precision for FP32 and FP16 models, and BF16 precision - for BF16 models. - seed: Random seed for reproducibility. - revision: The specific model version to use. It can be a branch name, - a tag name, or a commit id. If unspecified, will use the default - version. - code_revision: The specific revision to use for the model code on - Hugging Face Hub. It can be a branch name, a tag name, or a - commit id. If unspecified, will use the default version. - tokenizer_revision: The specific tokenizer version to use. It can be a - branch name, a tag name, or a commit id. If unspecified, will use - the default version. - max_model_len: Maximum length of a sequence (including prompt and - output). If None, will be derived from the model. - spec_target_max_model_len: Specify the the maximum length for spec - decoding draft models. - quantization: Quantization method that was used to quantize the model - weights. If None, we assume the model weights are not quantized. - enforce_eager: Whether to enforce eager execution. If True, we will - disable CUDA graph and always execute the model in eager mode. - If False, we will use CUDA graph and eager execution in hybrid. - If None, the user did not specify, so default to False. - max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. - When a sequence has context length larger than this, we fall back - to eager mode. Additionally for encoder-decoder models, if the - sequence length of the encoder input is larger than this, we fall - back to the eager mode. - max_logprobs: Maximum number of log probabilities. Defaults to 20. - disable_sliding_window: Whether to disable sliding window. If True, - we will disable the sliding window functionality of the model. - If the model does not support sliding window, this argument is - ignored. - skip_tokenizer_init: If true, skip initialization of tokenizer and - detokenizer. - served_model_name: The model name used in metrics tag `model_name`, - matches the model name exposed via the APIs. If multiple model - names provided, the first name will be used. If not specified, - the model name will be the same as `model`. - limit_mm_per_prompt: Maximum number of data items per modality - per prompt. Only applicable for multimodal models. - mm_processor_kwargs: Overrides for the multi-modal processor obtained - from `AutoProcessor.from_pretrained`. - disable_mm_preprocessor_cache: If True, disable caching of the - processed multi-modal inputs. - use_async_output_proc: Whether to use async output processor. - Defaults to True. - config_format: The config format which shall be loaded. - Defaults to 'auto' which defaults to 'hf'. - hf_token: The token to use as HTTP bearer authorization for remote files - . If `True`, will use the token generated when running - `huggingface-cli login` (stored in `~/.huggingface`). - hf_overrides: If a dictionary, contains arguments to be forwarded to the - HuggingFace config. If a callable, it is called to update the - HuggingFace config. - override_neuron_config: Initialize non default neuron config or - override default neuron config that are specific to Neuron devices, - this argument will be used to configure the neuron config that - can not be gathered from the vllm arguments. - override_pooler_config: Initialize non default pooling config or - override default pooling config for the pooling model. - logits_processor_pattern: Optional regex pattern specifying valid - logits processor qualified names that can be passed with the - `logits_processors` extra completion argument. Defaults to None, - which allows no processors. - generation_config: Configuration parameter file for generation. - model_impl: Which implementation of the model to use: - "auto" will try to use the vLLM implementation if it exists and - fall back to the Transformers implementation if no vLLM - implementation is available. - "vllm" will use the vLLM model implementation. - "transformers" will use the Transformers model implementation. - override_generation_config: Override the generation config with the - given config. - """ + +@config +@dataclass +class ModelConfig: + """Configuration for the model.""" + + model: str = "facebook/opt-125m" + """Name or path of the Hugging Face model to use. It is also used as the + content for `model_name` tag in metrics output when `served_model_name` is + not specified.""" + task: Literal[TaskOption, Literal["draft"]] = "auto" + """The task to use the model for. Each vLLM instance only supports one + task, even if the same model can be used for multiple tasks. When the model + only supports one task, "auto" can be used to select it; otherwise, you + must specify explicitly which task to use.""" + tokenizer: str = None # type: ignore + """Name or path of the Hugging Face tokenizer to use. If unspecified, model + name or path will be used.""" + tokenizer_mode: TokenizerMode = "auto" + """Tokenizer mode:\n + - "auto" will use the fast tokenizer if available.\n + - "slow" will always use the slow tokenizer.\n + - "mistral" will always use the tokenizer from `mistral_common`.\n + - "custom" will use --tokenizer to select the preregistered tokenizer.""" + trust_remote_code: bool = False + """Trust remote code (e.g., from HuggingFace) when downloading the model + and tokenizer.""" + dtype: Union[ModelDType, torch.dtype] = "auto" + """Data type for model weights and activations:\n + - "auto" will use FP16 precision for FP32 and FP16 models, and BF16 + precision for BF16 models.\n + - "half" for FP16. Recommended for AWQ quantization.\n + - "float16" is the same as "half".\n + - "bfloat16" for a balance between precision and range.\n + - "float" is shorthand for FP32 precision.\n + - "float32" for FP32 precision.""" + seed: Optional[int] = None + """Random seed for reproducibility.""" + hf_config_path: Optional[str] = None + """Name or path of the Hugging Face config to use. If unspecified, model + name or path will be used.""" + allowed_local_media_path: str = "" + """Allowing API requests to read local images or videos from directories + specified by the server file system. This is a security risk. Should only + be enabled in trusted environments.""" + revision: Optional[str] = None + """The specific model version to use. It can be a branch name, a tag name, + or a commit id. If unspecified, will use the default version.""" + code_revision: Optional[str] = None + """The specific revision to use for the model code on the Hugging Face Hub. + It can be a branch name, a tag name, or a commit id. If unspecified, will + use the default version.""" + rope_scaling: dict[str, Any] = field(default_factory=dict) + """RoPE scaling configuration in JSON format. For example, + `{"rope_type":"dynamic","factor":2.0}`.""" + rope_theta: Optional[float] = None + """RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE + theta improves the performance of the scaled model.""" + tokenizer_revision: Optional[str] = None + """The specific revision to use for the tokenizer on the Hugging Face Hub. + It can be a branch name, a tag name, or a commit id. If unspecified, will + use the default version.""" + max_model_len: int = None # type: ignore + """Model context length (prompt and output). If unspecified, will be + automatically derived from the model config. + + When passing via `--max-model-len`, supports k/m/g/K/M/G in human-readable + format. Examples:\n + - 1k -> 1000\n + - 1K -> 1024\n + - 25.6k -> 25,600""" + spec_target_max_model_len: Optional[int] = None + """Specify the the maximum length for spec decoding draft models.""" + quantization: Optional[QuantizationMethods] = None + """Method used to quantize the weights. If `None`, we first check the + `quantization_config` attribute in the model config file. If that is + `None`, we assume the model weights are not quantized and use `dtype` to + determine the data type of the weights.""" + enforce_eager: bool = False + """Whether to always use eager-mode PyTorch. If True, we will disable CUDA + graph and always execute the model in eager mode. If False, we will use + CUDA graph and eager execution in hybrid for maximal performance and + flexibility.""" + max_seq_len_to_capture: int = 8192 + """Maximum sequence len covered by CUDA graphs. When a sequence has context + length larger than this, we fall back to eager mode. Additionally for + encoder-decoder models, if the sequence length of the encoder input is + larger than this, we fall back to the eager mode.""" + max_logprobs: int = 20 + """Maximum number of log probabilities to return when `logprobs` is + specified in `SamplingParams`. The default value comes the default for the + OpenAI Chat Completions API.""" + disable_sliding_window: bool = False + """Whether to disable sliding window. If True, we will disable the sliding + window functionality of the model, capping to sliding window size. If the + model does not support sliding window, this argument is ignored.""" + disable_cascade_attn: bool = False + """Disable cascade attention for V1. While cascade attention does not + change the mathematical correctness, disabling it could be useful for + preventing potential numerical issues. Note that even if this is set to + False, cascade attention will be only used when the heuristic tells that + it's beneficial.""" + skip_tokenizer_init: bool = False + """Skip initialization of tokenizer and detokenizer. Expects valid + `prompt_token_ids` and `None` for prompt from the input. The generated + output will contain token ids.""" + served_model_name: Optional[Union[str, list[str]]] = None + """The model name(s) used in the API. If multiple names are provided, the + server will respond to any of the provided names. The model name in the + model field of a response will be the first name in this list. If not + specified, the model name will be the same as the `--model` argument. Noted + that this name(s) will also be used in `model_name` tag content of + prometheus metrics, if multiple names provided, metrics tag will take the + first one.""" + limit_mm_per_prompt: dict[str, int] = field(default_factory=dict) + """Maximum number of data items per modality per prompt. Only applicable + for multimodal models.""" + use_async_output_proc: bool = True + """Whether to use async output processor.""" + config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value + """The format of the model config to load:\n + - "auto" will try to load the config in hf format if available else it + will try to load in mistral format.\n + - "hf" will load the config in hf format.\n + - "mistral" will load the config in mistral format.""" + hf_token: Optional[Union[bool, str]] = None + """The token to use as HTTP bearer authorization for remote files . If + `True`, will use the token generated when running `huggingface-cli login` + (stored in `~/.huggingface`).""" + hf_overrides: HfOverrides = field(default_factory=dict) + """If a dictionary, contains arguments to be forwarded to the Hugging Face + config. If a callable, it is called to update the HuggingFace config. When + specified via CLI, the argument must be a valid JSON string.""" + mm_processor_kwargs: Optional[dict[str, Any]] = None + """Arguments to be forwarded to the model's processor for multi-modal data, + e.g., image processor. Overrides for the multi-modal processor obtained + from `AutoProcessor.from_pretrained`. The available overrides depend on the + model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. + When specified via CLI, the argument must be a valid JSON string.""" + disable_mm_preprocessor_cache: bool = False + """If `True`, disable caching of the multi-modal preprocessor/mapper (not + recommended).""" + override_neuron_config: dict[str, Any] = field(default_factory=dict) + """Initialize non-default neuron config or override default neuron config + that are specific to Neuron devices, this argument will be used to + configure the neuron config that can not be gathered from the vllm + arguments. e.g. `{"cast_logits_dtype": "bloat16"}`. When specified via CLI, + the argument must be a valid JSON string.""" + pooler_config: Optional["PoolerConfig"] = field(init=False) + """Pooler config which controls the behaviour of output pooling in pooling + models.""" + override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None + """Initialize non-default pooling config or override default pooling config + for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`. + When specified via CLI, the argument must be a valid JSON string.""" + logits_processor_pattern: Optional[str] = None + """Optional regex pattern specifying valid logits processor qualified names + that can be passed with the `logits_processors` extra completion argument. + Defaults to `None`, which allows no processors.""" + generation_config: str = "auto" + """The folder path to the generation config. Defaults to `"auto"`, the + generation config will be loaded from model path. If set to `"vllm"`, no + generation config is loaded, vLLM defaults will be used. If set to a folder + path, the generation config will be loaded from the specified folder path. + If `max_new_tokens` is specified in generation config, then it sets a + server-wide limit on the number of output tokens for all requests.""" + override_generation_config: dict[str, Any] = field(default_factory=dict) + """Overrides or sets generation config. e.g. `{"temperature": 0.5}`. If + used with `--generation-config auto`, the override parameters will be + merged with the default config from the model. If used with + `--generation-config vllm`, only the override parameters are used. + When specified via CLI, the argument must be a valid JSON string.""" + enable_sleep_mode: bool = False + """Enable sleep mode for the engine (only cuda platform is supported).""" + model_impl: Union[str, ModelImpl] = ModelImpl.AUTO.value + """Which implementation of the model to use:\n + - "auto" will try to use the vLLM implementation, if it exists, and fall + back to the Transformers implementation if no vLLM implementation is + available.\n + - "vllm" will use the vLLM model implementation.\n + - "transformers" will use the Transformers model implementation.""" def compute_hash(self) -> str: """ @@ -342,92 +428,43 @@ def compute_hash(self) -> str: assert_hashable(str_factors) return hashlib.sha256(str(factors).encode()).hexdigest() - def __init__( - self, - model: str, - task: Literal[TaskOption, Literal["draft"]], - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - hf_config_path: Optional[str] = None, - allowed_local_media_path: str = "", - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[dict[str, Any]] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - disable_cascade_attn: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, list[str]]] = None, - limit_mm_per_prompt: Optional[dict[str, int]] = None, - mm_processor_kwargs: Optional[dict[str, Any]] = None, - disable_mm_preprocessor_cache: bool = False, - use_async_output_proc: bool = True, - config_format: ConfigFormat = ConfigFormat.AUTO, - hf_token: Optional[Union[bool, str]] = None, - hf_overrides: Optional[HfOverrides] = None, - override_neuron_config: Optional[dict[str, Any]] = None, - override_pooler_config: Optional["PoolerConfig"] = None, - logits_processor_pattern: Optional[str] = None, - generation_config: str = "auto", - enable_sleep_mode: bool = False, - override_generation_config: Optional[dict[str, Any]] = None, - model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, - ) -> None: - self.model = maybe_model_redirect(model) - self.tokenizer = maybe_model_redirect(tokenizer) - - self.hf_config_path = hf_config_path - if isinstance(hf_config_path, str): - self.hf_config_path = maybe_model_redirect(hf_config_path) - - self.tokenizer_mode = tokenizer_mode - self.trust_remote_code = trust_remote_code - self.allowed_local_media_path = allowed_local_media_path - self.seed = seed - self.revision = revision - self.code_revision = code_revision - self.rope_scaling = rope_scaling - self.rope_theta = rope_theta - self.model_impl = model_impl - - if hf_overrides is None: - hf_overrides = {} - - if callable(hf_overrides): + def __post_init__(self) -> None: + self.model = maybe_model_redirect(self.model) + # The tokenizer is consistent with the model by default. + if self.tokenizer is None: + self.tokenizer = self.model + if self.tokenizer_revision is None: + self.tokenizer_revision = self.revision + self.tokenizer = maybe_model_redirect(self.tokenizer) + + if isinstance(self.hf_config_path, str): + self.hf_config_path = maybe_model_redirect(self.hf_config_path) + + if callable(self.hf_overrides): hf_overrides_kw = {} - hf_overrides_fn = hf_overrides + hf_overrides_fn = self.hf_overrides else: - hf_overrides_kw = hf_overrides + hf_overrides_kw = self.hf_overrides hf_overrides_fn = None - if rope_scaling is not None: - hf_override: dict[str, Any] = {"rope_scaling": rope_scaling} + if self.rope_scaling: + hf_override: dict[str, Any] = {"rope_scaling": self.rope_scaling} hf_overrides_kw.update(hf_override) - hf_overrides_str = json.dumps(hf_overrides) + hf_overrides_str = json.dumps(hf_overrides_kw) msg = ( "`--rope-scaling` will be removed in a future release. " f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") warnings.warn(DeprecationWarning(msg), stacklevel=2) - if rope_theta is not None: - hf_override = {"rope_theta": rope_theta} + if self.rope_theta is not None: + hf_override = {"rope_theta": self.rope_theta} hf_overrides_kw.update(hf_override) - hf_overrides_str = json.dumps(hf_overrides) + hf_overrides_str = json.dumps(hf_overrides_kw) msg = ( "`--rope-theta` will be removed in a future release. " f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") warnings.warn(DeprecationWarning(msg), stacklevel=2) - self.maybe_pull_model_tokenizer_for_s3(model, tokenizer) + self.maybe_pull_model_tokenizer_for_s3(self.model, self.tokenizer) if (backend := envs.VLLM_ATTENTION_BACKEND ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: @@ -437,20 +474,6 @@ def __init__( "https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501 "for instructions on how to install it.") - # The tokenizer version is consistent with the model version by default. - if tokenizer_revision is None: - self.tokenizer_revision = revision - else: - self.tokenizer_revision = tokenizer_revision - self.quantization = quantization - self.enforce_eager = enforce_eager - self.max_seq_len_to_capture = max_seq_len_to_capture - self.max_logprobs = max_logprobs - self.disable_sliding_window = disable_sliding_window - self.disable_cascade_attn = disable_cascade_attn - self.skip_tokenizer_init = skip_tokenizer_init - self.enable_sleep_mode = enable_sleep_mode - from vllm.platforms import current_platform if (self.enable_sleep_mode @@ -458,9 +481,12 @@ def __init__( raise ValueError( "Sleep mode is not supported on current platform.") + if isinstance(self.config_format, str): + self.config_format = ConfigFormat(self.config_format) + hf_config = get_config(self.hf_config_path or self.model, - trust_remote_code, revision, code_revision, - config_format) + self.trust_remote_code, self.revision, + self.code_revision, self.config_format) if hf_overrides_kw: logger.info("Overriding HF config with %s", hf_overrides_kw) @@ -476,13 +502,8 @@ def __init__( "attention_chunk_size", None) self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( - self.model, hf_token=hf_token, revision=revision) - self.dtype = _get_and_verify_dtype(self.hf_config, dtype) - self.use_async_output_proc = use_async_output_proc - - # Set enforce_eager to False if the value is unset. - if self.enforce_eager is None: - self.enforce_eager = False + self.model, hf_token=self.hf_token, revision=self.revision) + self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype) interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"] sliding_window = getattr(self.hf_text_config, "sliding_window", None) @@ -515,18 +536,14 @@ def __init__( self.max_model_len = _get_and_verify_max_len( hf_config=self.hf_text_config, - max_model_len=max_model_len, + max_model_len=self.max_model_len, disable_sliding_window=self.disable_sliding_window, sliding_window_len=self.get_hf_config_sliding_window(), - spec_target_max_model_len=spec_target_max_model_len, + spec_target_max_model_len=self.spec_target_max_model_len, encoder_config=self.encoder_config) - self.served_model_name = get_served_model_name(model, - served_model_name) - self.multimodal_config = self._init_multimodal_config( - limit_mm_per_prompt=limit_mm_per_prompt, - mm_processor_kwargs=mm_processor_kwargs, - disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, - ) + self.served_model_name = get_served_model_name(self.model, + self.served_model_name) + self.multimodal_config = self._init_multimodal_config() if not self.skip_tokenizer_init: self._verify_tokenizer_mode() @@ -535,24 +552,19 @@ def __init__( self.has_noops = self._init_has_noops() self.has_inner_state = self._init_has_inner_state() - if current_platform.is_neuron(): - self.override_neuron_config = override_neuron_config - else: - self.override_neuron_config = None + if (not current_platform.is_neuron() and self.override_neuron_config): + raise ValueError( + "`override_neuron_config` is only supported on Neuron.") - supported_tasks, task = self._resolve_task(task) + supported_tasks, task = self._resolve_task(self.task) self.supported_tasks = supported_tasks - self.task: Final = task + self.task = task if self.task in ("draft", "generate"): self.truncation_side = "left" else: self.truncation_side = "right" - self.pooler_config = self._init_pooler_config(override_pooler_config) - self.logits_processor_pattern = logits_processor_pattern - - self.generation_config = generation_config - self.override_generation_config = override_generation_config or {} + self.pooler_config = self._init_pooler_config() self._verify_quantization() self._verify_cuda_graph() @@ -591,26 +603,21 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str, model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) self.tokenizer = s3_tokenizer.dir - def _init_multimodal_config( - self, - limit_mm_per_prompt: Optional[dict[str, int]], - mm_processor_kwargs: Optional[dict[str, Any]], - disable_mm_preprocessor_cache: bool, - ) -> Optional["MultiModalConfig"]: + def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: if self.registry.is_multimodal_model(self.architectures): return MultiModalConfig( - limit_per_prompt=limit_mm_per_prompt or {}, - mm_processor_kwargs=mm_processor_kwargs or {}, - disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, - ) + limit_per_prompt=self.limit_mm_per_prompt, + mm_processor_kwargs=self.mm_processor_kwargs, + disable_mm_preprocessor_cache=self. + disable_mm_preprocessor_cache) - if limit_mm_per_prompt: + if self.limit_mm_per_prompt: raise ValueError("`limit_mm_per_prompt` is only supported for " "multimodal models.") - if mm_processor_kwargs: + if self.mm_processor_kwargs: raise ValueError("`mm_processor_kwargs` is only supported for " "multimodal models.") - if disable_mm_preprocessor_cache: + if self.disable_mm_preprocessor_cache: raise ValueError("`disable_mm_preprocessor_cache` is only " "supported for multimodal models.") @@ -620,31 +627,32 @@ def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( self.model, self.revision) - def _init_pooler_config( - self, - override_pooler_config: Optional["PoolerConfig"], - ) -> Optional["PoolerConfig"]: + def _init_pooler_config(self) -> Optional["PoolerConfig"]: if self.runner_type == "pooling": - user_config = override_pooler_config or PoolerConfig() + if isinstance(self.override_pooler_config, dict): + self.override_pooler_config = PoolerConfig( + **self.override_pooler_config) + + pooler_config = self.override_pooler_config or PoolerConfig() base_config = get_pooling_config(self.model, self.revision) if base_config is not None: # Only set values that are not overridden by the user for k, v in base_config.items(): - if getattr(user_config, k) is None: - setattr(user_config, k, v) + if getattr(pooler_config, k) is None: + setattr(pooler_config, k, v) if self.is_matryoshka: - if user_config.normalize is None: - user_config.normalize = True - elif not user_config.normalize: + if pooler_config.normalize is None: + pooler_config.normalize = True + elif not pooler_config.normalize: raise ValueError( "`normalize` must be enabled (set to True) " "for models that are compatible with " "Matryoshka Representation.") - return user_config + return pooler_config return None @@ -662,11 +670,11 @@ def _init_has_inner_state(self) -> bool: return self.registry.model_has_inner_state(self.architectures) def _verify_tokenizer_mode(self) -> None: - tokenizer_mode = self.tokenizer_mode.lower() - if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]: + tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower()) + if tokenizer_mode not in get_args(TokenizerMode): raise ValueError( f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " - "either 'auto', 'slow', 'mistral' or 'custom'.") + f"one of {get_args(TokenizerMode)}.") self.tokenizer_mode = tokenizer_mode def _get_preferred_task( @@ -781,7 +789,8 @@ def _verify_quantization(self) -> None: "quark", "nvfp4", "bitblas", "gptq_bitblas" ] if self.quantization is not None: - self.quantization = self.quantization.lower() + self.quantization = cast(QuantizationMethods, + self.quantization.lower()) # Parse quantization method from the HF model config, if available. quant_cfg = self._parse_quant_hf_config() @@ -857,8 +866,6 @@ def _verify_quantization(self) -> None: "non-quantized models.", self.quantization) def _verify_cuda_graph(self) -> None: - if self.max_seq_len_to_capture is None: - self.max_seq_len_to_capture = self.max_model_len self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_model_len) ROCM_UNSUPPORTED_MODELS = ['mllama'] @@ -1294,7 +1301,7 @@ def supported_runner_types(self) -> set[RunnerType]: @property def runner_type(self) -> RunnerType: - return _TASK_RUNNER[self.task] + return _TASK_RUNNER[cast(_ResolvedTask, self.task)] @property def is_v1_compatible(self) -> bool: @@ -2201,7 +2208,7 @@ class SpeculativeConfig: according to the log probability settings in SamplingParams.""" # Draft model configuration - quantization: Optional[str] = None + quantization: Optional[QuantizationMethods] = None """Quantization method that was used to quantize the draft model weights. If `None`, we assume the model weights are not quantized. Note that it only takes effect when using the draft model-based speculative method.""" @@ -2386,7 +2393,6 @@ def __post_init__(self): code_revision=self.code_revision, tokenizer_revision=self.target_model_config. tokenizer_revision, - max_model_len=None, spec_target_max_model_len=self.target_model_config. max_model_len, quantization=self.quantization, @@ -2793,30 +2799,31 @@ def verify_with_model_config(self, model_config: ModelConfig): class MultiModalConfig: """Controls the behavior of multimodal models.""" - limit_per_prompt: dict[str, int] = field(default_factory=dict) + limit_per_prompt: dict[str, int] = get_field(ModelConfig, + "limit_mm_per_prompt") """ The maximum number of input items allowed per prompt for each modality. This should be a JSON string that will be parsed into a dictionary. Defaults to 1 (V0) or 999 (V1) for each modality. For example, to allow up to 16 images and 2 videos per prompt: - :code:`{"images": 16, "videos": 2}` + `{"images": 16, "videos": 2}` """ mm_processor_kwargs: Optional[dict[str, object]] = None """ Overrides for the multi-modal processor obtained from - :meth:`transformers.AutoProcessor.from_pretrained`. + `transformers.AutoProcessor.from_pretrained`. The available overrides depend on the model that is being run. For example, for Phi-3-Vision: - :code:`{"num_crops": 4}`. + `{"num_crops": 4}`. """ disable_mm_preprocessor_cache: bool = False """ - If :code:`True`, disable caching of the processed multi-modal inputs. + If `True`, disable caching of the processed multi-modal inputs. """ def compute_hash(self) -> str: @@ -2907,10 +2914,6 @@ def compute_hash(self) -> str: usedforsecurity=False).hexdigest() return hash_str - @staticmethod - def from_json(json_str: str) -> "PoolerConfig": - return PoolerConfig(**json.loads(json_str)) - _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index be0cd4d3a20d..4f074fcd1b8e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -20,15 +20,16 @@ DeviceConfig, DistributedExecutorBackend, GuidedDecodingBackend, GuidedDecodingBackendV1, HfOverrides, KVTransferConfig, LoadConfig, LoadFormat, - LoRAConfig, ModelConfig, ModelImpl, MultiModalConfig, - ObservabilityConfig, ParallelConfig, PoolerConfig, - PrefixCachingHashAlgo, PromptAdapterConfig, - SchedulerConfig, SchedulerPolicy, SpeculativeConfig, - TaskOption, TokenizerPoolConfig, VllmConfig, - get_attr_docs, get_field) + LoRAConfig, ModelConfig, ModelDType, ModelImpl, + MultiModalConfig, ObservabilityConfig, ParallelConfig, + PoolerConfig, PrefixCachingHashAlgo, + PromptAdapterConfig, SchedulerConfig, SchedulerPolicy, + SpeculativeConfig, TaskOption, TokenizerMode, + TokenizerPoolConfig, VllmConfig, get_attr_docs, + get_field) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.plugins import load_general_plugins from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 @@ -183,6 +184,9 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name]["nargs"] = "+" elif contains_type(type_hints, int): kwargs[name]["type"] = int + # Special case for large integers + if name in {"max_model_len"}: + kwargs[name]["type"] = human_readable_int elif contains_type(type_hints, float): kwargs[name]["type"] = float elif contains_type(type_hints, dict): @@ -212,22 +216,23 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: @dataclass class EngineArgs: """Arguments for vLLM engine.""" - model: str = 'facebook/opt-125m' - served_model_name: Optional[Union[str, List[str]]] = None - tokenizer: Optional[str] = None - hf_config_path: Optional[str] = None - task: TaskOption = "auto" - skip_tokenizer_init: bool = False - tokenizer_mode: str = 'auto' - trust_remote_code: bool = False - allowed_local_media_path: str = "" + model: str = ModelConfig.model + served_model_name: Optional[Union[ + str, List[str]]] = ModelConfig.served_model_name + tokenizer: Optional[str] = ModelConfig.tokenizer + hf_config_path: Optional[str] = ModelConfig.hf_config_path + task: TaskOption = ModelConfig.task + skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init + tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode + trust_remote_code: bool = ModelConfig.trust_remote_code + allowed_local_media_path: str = ModelConfig.allowed_local_media_path download_dir: Optional[str] = LoadConfig.download_dir load_format: str = LoadConfig.load_format - config_format: ConfigFormat = ConfigFormat.AUTO - dtype: str = 'auto' + config_format: str = ModelConfig.config_format + dtype: ModelDType = ModelConfig.dtype kv_cache_dtype: CacheDType = CacheConfig.cache_dtype - seed: Optional[int] = None - max_model_len: Optional[int] = None + seed: Optional[int] = ModelConfig.seed + max_model_len: Optional[int] = ModelConfig.max_model_len # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. @@ -245,8 +250,8 @@ class EngineArgs: enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching prefix_caching_hash_algo: PrefixCachingHashAlgo = \ CacheConfig.prefix_caching_hash_algo - disable_sliding_window: bool = False - disable_cascade_attn: bool = False + disable_sliding_window: bool = ModelConfig.disable_sliding_window + disable_cascade_attn: bool = ModelConfig.disable_cascade_attn use_v2_block_manager: bool = True swap_space: float = CacheConfig.swap_space cpu_offload_gb: float = CacheConfig.cpu_offload_gb @@ -258,18 +263,19 @@ class EngineArgs: long_prefill_token_threshold: int = \ SchedulerConfig.long_prefill_token_threshold max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs - max_logprobs: int = 20 # Default value for OpenAI Chat Completions API + max_logprobs: int = ModelConfig.max_logprobs disable_log_stats: bool = False - revision: Optional[str] = None - code_revision: Optional[str] = None - rope_scaling: Optional[Dict[str, Any]] = None - rope_theta: Optional[float] = None - hf_token: Optional[Union[bool, str]] = None - hf_overrides: Optional[HfOverrides] = None - tokenizer_revision: Optional[str] = None - quantization: Optional[str] = None - enforce_eager: Optional[bool] = None - max_seq_len_to_capture: int = 8192 + revision: Optional[str] = ModelConfig.revision + code_revision: Optional[str] = ModelConfig.code_revision + rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling") + rope_theta: Optional[float] = ModelConfig.rope_theta + hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token + hf_overrides: Optional[HfOverrides] = \ + get_field(ModelConfig, "hf_overrides") + tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision + quantization: Optional[QuantizationMethods] = ModelConfig.quantization + enforce_eager: bool = ModelConfig.enforce_eager + max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce # The following three fields are deprecated and will be removed in a future # release. Setting them will have no effect. Please remove them from your @@ -280,8 +286,10 @@ class EngineArgs: get_field(TokenizerPoolConfig, "extra_config") limit_mm_per_prompt: dict[str, int] = \ get_field(MultiModalConfig, "limit_per_prompt") - mm_processor_kwargs: Optional[Dict[str, Any]] = None - disable_mm_preprocessor_cache: bool = False + mm_processor_kwargs: Optional[Dict[str, Any]] = \ + MultiModalConfig.mm_processor_kwargs + disable_mm_preprocessor_cache: bool = \ + MultiModalConfig.disable_mm_preprocessor_cache # LoRA fields enable_lora: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled @@ -323,7 +331,8 @@ class EngineArgs: DecodingConfig.disable_any_whitespace guided_decoding_disable_additional_properties: bool = \ DecodingConfig.disable_additional_properties - logits_processor_pattern: Optional[str] = None + logits_processor_pattern: Optional[ + str] = ModelConfig.logits_processor_pattern speculative_config: Optional[Dict[str, Any]] = None @@ -331,22 +340,25 @@ class EngineArgs: show_hidden_metrics_for_version: Optional[str] = None otlp_traces_endpoint: Optional[str] = None collect_detailed_traces: Optional[str] = None - disable_async_output_proc: bool = False + disable_async_output_proc: bool = not ModelConfig.use_async_output_proc scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls - override_neuron_config: Optional[Dict[str, Any]] = None - override_pooler_config: Optional[PoolerConfig] = None + override_neuron_config: dict[str, Any] = \ + get_field(ModelConfig, "override_neuron_config") + override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ + ModelConfig.override_pooler_config compilation_config: Optional[CompilationConfig] = None worker_cls: str = ParallelConfig.worker_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls kv_transfer_config: Optional[KVTransferConfig] = None - generation_config: Optional[str] = "auto" - override_generation_config: Optional[Dict[str, Any]] = None - enable_sleep_mode: bool = False - model_impl: str = "auto" + generation_config: str = ModelConfig.generation_config + enable_sleep_mode: bool = ModelConfig.enable_sleep_mode + override_generation_config: dict[str, Any] = \ + get_field(ModelConfig, "override_generation_config") + model_impl: str = ModelConfig.model_impl calculate_kv_scales: bool = CacheConfig.calculate_kv_scales @@ -356,9 +368,6 @@ class EngineArgs: use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load def __post_init__(self): - if not self.tokenizer: - self.tokenizer = self.model - # support `EngineArgs(compilation_config={...})` # without having to manually construct a # CompilationConfig object @@ -375,80 +384,87 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: """Shared CLI arguments for vLLM engine.""" # Model arguments - parser.add_argument( - '--model', - type=str, - default=EngineArgs.model, - help='Name or path of the huggingface model to use.') - parser.add_argument( - '--task', - default=EngineArgs.task, - choices=get_args(TaskOption), - help='The task to use the model for. Each vLLM instance only ' - 'supports one task, even if the same model can be used for ' - 'multiple tasks. When the model only supports one task, ``"auto"`` ' - 'can be used to select it; otherwise, you must specify explicitly ' - 'which task to use.') - parser.add_argument( - '--tokenizer', - type=optional_type(str), - default=EngineArgs.tokenizer, - help='Name or path of the huggingface tokenizer to use. ' - 'If unspecified, model name or path will be used.') - parser.add_argument( - "--hf-config-path", - type=optional_type(str), - default=EngineArgs.hf_config_path, - help='Name or path of the huggingface config to use. ' - 'If unspecified, model name or path will be used.') - parser.add_argument( - '--skip-tokenizer-init', - action='store_true', - help='Skip initialization of tokenizer and detokenizer. ' - 'Expects valid prompt_token_ids and None for prompt from ' - 'the input. The generated output will contain token ids.') - parser.add_argument( - '--revision', - type=optional_type(str), - default=None, - help='The specific model version to use. It can be a branch ' - 'name, a tag name, or a commit id. If unspecified, will use ' - 'the default version.') - parser.add_argument( - '--code-revision', - type=optional_type(str), - default=None, - help='The specific revision to use for the model code on ' - 'Hugging Face Hub. It can be a branch name, a tag name, or a ' - 'commit id. If unspecified, will use the default version.') - parser.add_argument( - '--tokenizer-revision', - type=optional_type(str), - default=None, - help='Revision of the huggingface tokenizer to use. ' - 'It can be a branch name, a tag name, or a commit id. ' - 'If unspecified, will use the default version.') - parser.add_argument( - '--tokenizer-mode', - type=str, - default=EngineArgs.tokenizer_mode, - choices=['auto', 'slow', 'mistral', 'custom'], - help='The tokenizer mode.\n\n* "auto" will use the ' - 'fast tokenizer if available.\n* "slow" will ' - 'always use the slow tokenizer. \n* ' - '"mistral" will always use the `mistral_common` tokenizer. \n* ' - '"custom" will use --tokenizer to select the ' - 'preregistered tokenizer.') - parser.add_argument('--trust-remote-code', - action='store_true', - help='Trust remote code from huggingface.') - parser.add_argument( - '--allowed-local-media-path', - type=str, - help="Allowing API requests to read local images or videos " - "from directories specified by the server file system. " - "This is a security risk. " - "Should only be enabled in trusted environments.") + model_kwargs = get_kwargs(ModelConfig) + model_group = parser.add_argument_group( + title="ModelConfig", + description=ModelConfig.__doc__, + ) + model_group.add_argument("--model", **model_kwargs["model"]) + model_group.add_argument("--task", **model_kwargs["task"]) + model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"]) + model_group.add_argument("--tokenizer-mode", + **model_kwargs["tokenizer_mode"]) + model_group.add_argument("--trust-remote-code", + **model_kwargs["trust_remote_code"]) + model_group.add_argument("--dtype", **model_kwargs["dtype"]) + model_group.add_argument("--seed", **model_kwargs["seed"]) + model_group.add_argument("--hf-config-path", + **model_kwargs["hf_config_path"]) + model_group.add_argument("--allowed-local-media-path", + **model_kwargs["allowed_local_media_path"]) + model_group.add_argument("--revision", **model_kwargs["revision"]) + model_group.add_argument("--code-revision", + **model_kwargs["code_revision"]) + model_group.add_argument("--rope-scaling", + **model_kwargs["rope_scaling"]) + model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"]) + model_group.add_argument("--tokenizer-revision", + **model_kwargs["tokenizer_revision"]) + model_group.add_argument("--max-model-len", + **model_kwargs["max_model_len"]) + model_group.add_argument("--quantization", "-q", + **model_kwargs["quantization"]) + model_group.add_argument("--enforce-eager", + **model_kwargs["enforce_eager"]) + model_group.add_argument("--max-seq-len-to-capture", + **model_kwargs["max_seq_len_to_capture"]) + model_group.add_argument("--max-logprobs", + **model_kwargs["max_logprobs"]) + model_group.add_argument("--disable-sliding-window", + **model_kwargs["disable_sliding_window"]) + model_group.add_argument("--disable-cascade-attn", + **model_kwargs["disable_cascade_attn"]) + model_group.add_argument("--skip-tokenizer-init", + **model_kwargs["skip_tokenizer_init"]) + model_group.add_argument("--served-model-name", + **model_kwargs["served_model_name"]) + # This one is a special case because it is the + # opposite of ModelConfig.use_async_output_proc + model_group.add_argument( + "--disable-async-output-proc", + action="store_true", + default=EngineArgs.disable_async_output_proc, + help="Disable async output processing. This may result in " + "lower performance.") + model_group.add_argument("--config-format", + choices=[f.value for f in ConfigFormat], + **model_kwargs["config_format"]) + # This one is a special case because it can bool + # or str. TODO: Handle this in get_kwargs + model_group.add_argument("--hf-token", + type=str, + nargs="?", + const=True, + default=model_kwargs["hf_token"]["default"], + help=model_kwargs["hf_token"]["help"]) + model_group.add_argument("--hf-overrides", + **model_kwargs["hf_overrides"]) + model_group.add_argument("--override-neuron-config", + **model_kwargs["override_neuron_config"]) + model_group.add_argument("--override-pooler-config", + **model_kwargs["override_pooler_config"]) + model_group.add_argument("--logits-processor-pattern", + **model_kwargs["logits_processor_pattern"]) + model_group.add_argument("--generation-config", + **model_kwargs["generation_config"]) + model_group.add_argument("--override-generation-config", + **model_kwargs["override_generation_config"]) + model_group.add_argument("--enable-sleep-mode", + **model_kwargs["enable_sleep_mode"]) + model_group.add_argument("--model-impl", + choices=[f.value for f in ModelImpl], + **model_kwargs["model_impl"]) + # Model loading arguments load_kwargs = get_kwargs(LoadConfig) load_group = parser.add_argument_group( @@ -465,38 +481,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: load_group.add_argument('--use-tqdm-on-load', **load_kwargs["use_tqdm_on_load"]) - parser.add_argument( - '--config-format', - default=EngineArgs.config_format, - choices=[f.value for f in ConfigFormat], - help='The format of the model config to load.\n\n' - '* "auto" will try to load the config in hf format ' - 'if available else it will try to load in mistral format ') - parser.add_argument( - '--dtype', - type=str, - default=EngineArgs.dtype, - choices=[ - 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' - ], - help='Data type for model weights and activations.\n\n' - '* "auto" will use FP16 precision for FP32 and FP16 models, and ' - 'BF16 precision for BF16 models.\n' - '* "half" for FP16. Recommended for AWQ quantization.\n' - '* "float16" is the same as "half".\n' - '* "bfloat16" for a balance between precision and range.\n' - '* "float" is shorthand for FP32 precision.\n' - '* "float32" for FP32 precision.') - parser.add_argument('--max-model-len', - type=human_readable_int, - default=EngineArgs.max_model_len, - help='Model context length. If unspecified, will ' - 'be automatically derived from the model config. ' - 'Supports k/m/g/K/M/G in human-readable format.\n' - 'Examples:\n' - '- 1k → 1000\n' - '- 1K → 1024\n') - # Guided decoding arguments guided_decoding_kwargs = get_kwargs(DecodingConfig) guided_decoding_group = parser.add_argument_group( @@ -520,26 +504,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: choices=list(ReasoningParserManager.reasoning_parsers), **guided_decoding_kwargs["reasoning_backend"]) - parser.add_argument( - '--logits-processor-pattern', - type=optional_type(str), - default=None, - help='Optional regex pattern specifying valid logits processor ' - 'qualified names that can be passed with the `logits_processors` ' - 'extra completion argument. Defaults to None, which allows no ' - 'processors.') - parser.add_argument( - '--model-impl', - type=str, - default=EngineArgs.model_impl, - choices=[f.value for f in ModelImpl], - help='Which implementation of the model to use.\n\n' - '* "auto" will try to use the vLLM implementation if it exists ' - 'and fall back to the Transformers implementation if no vLLM ' - 'implementation is available.\n' - '* "vllm" will use the vLLM model implementation.\n' - '* "transformers" will use the Transformers model ' - 'implementation.\n') # Parallel arguments parallel_kwargs = get_kwargs(ParallelConfig) parallel_group = parser.add_argument_group( @@ -592,10 +556,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: cache_group.add_argument('--calculate-kv-scales', **cache_kwargs["calculate_kv_scales"]) - parser.add_argument('--disable-sliding-window', - action='store_true', - help='Disables sliding window, ' - 'capping to sliding window size.') parser.add_argument('--use-v2-block-manager', action='store_true', default=True, @@ -605,73 +565,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'Setting this flag to True or False' ' has no effect on vLLM behavior.') - parser.add_argument('--seed', - type=int, - default=EngineArgs.seed, - help='Random seed for operations.') - parser.add_argument( - '--max-logprobs', - type=int, - default=EngineArgs.max_logprobs, - help=('Max number of log probs to return logprobs is specified in' - ' SamplingParams.')) parser.add_argument('--disable-log-stats', action='store_true', help='Disable logging statistics.') - # Quantization settings. - parser.add_argument('--quantization', - '-q', - type=optional_type(str), - choices=[*QUANTIZATION_METHODS, None], - default=EngineArgs.quantization, - help='Method used to quantize the weights. If ' - 'None, we first check the `quantization_config` ' - 'attribute in the model config file. If that is ' - 'None, we assume the model weights are not ' - 'quantized and use `dtype` to determine the data ' - 'type of the weights.') - parser.add_argument( - '--rope-scaling', - default=None, - type=json.loads, - help='RoPE scaling configuration in JSON format. ' - 'For example, ``{"rope_type":"dynamic","factor":2.0}``') - parser.add_argument('--rope-theta', - default=None, - type=float, - help='RoPE theta. Use with `rope_scaling`. In ' - 'some cases, changing the RoPE theta improves the ' - 'performance of the scaled model.') - parser.add_argument( - '--hf-token', - type=str, - nargs='?', - const=True, - default=None, - help='The token to use as HTTP bearer authorization' - ' for remote files. If `True`, will use the token ' - 'generated when running `huggingface-cli login` ' - '(stored in `~/.huggingface`).') - parser.add_argument('--hf-overrides', - type=json.loads, - default=EngineArgs.hf_overrides, - help='Extra arguments for the HuggingFace config. ' - 'This should be a JSON string that will be ' - 'parsed into a dictionary.') - parser.add_argument('--enforce-eager', - action='store_true', - help='Always use eager-mode PyTorch. If False, ' - 'will use eager mode and CUDA graph in hybrid ' - 'for maximal performance and flexibility.') - parser.add_argument('--max-seq-len-to-capture', - type=int, - default=EngineArgs.max_seq_len_to_capture, - help='Maximum sequence length covered by CUDA ' - 'graphs. When a sequence has context length ' - 'larger than this, we fall back to eager mode. ' - 'Additionally for encoder-decoder models, if the ' - 'sequence length of the encoder input is larger ' - 'than this, we fall back to the eager mode.') # Tokenizer arguments tokenizer_kwargs = get_kwargs(TokenizerPoolConfig) @@ -775,20 +671,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "Default to `original/**/*` to avoid repeated loading of llama's " "checkpoints.") - parser.add_argument( - "--served-model-name", - nargs="+", - type=str, - default=None, - help="The model name(s) used in the API. If multiple " - "names are provided, the server will respond to any " - "of the provided names. The model name in the model " - "field of a response will be the first name in this " - "list. If not specified, the model name will be the " - "same as the ``--model`` argument. Noted that this name(s) " - "will also be used in `model_name` tag content of " - "prometheus metrics, if multiple names provided, metrics " - "tag will take the first one.") parser.add_argument('--qlora-adapter-name-or-path', type=str, default=None, @@ -822,13 +704,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "modules. This involves use of possibly costly and or blocking " "operations and hence might have a performance impact.") - parser.add_argument( - '--disable-async-output-proc', - action='store_true', - default=EngineArgs.disable_async_output_proc, - help="Disable async output processing. This may result in " - "lower performance.") - # Scheduler arguments scheduler_kwargs = get_kwargs(SchedulerConfig) scheduler_group = parser.add_argument_group( @@ -871,19 +746,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument('--scheduler-cls', **scheduler_kwargs["scheduler_cls"]) - parser.add_argument( - '--override-neuron-config', - type=json.loads, - default=None, - help="Override or set neuron device configuration. " - "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.") - parser.add_argument( - '--override-pooler-config', - type=PoolerConfig.from_json, - default=None, - help="Override or set the pooling method for pooling models. " - "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.") - parser.add_argument('--compilation-config', '-O', type=CompilationConfig.from_cli, @@ -920,34 +782,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help='The worker extension class on top of the worker cls, ' 'it is useful if you just want to add new functions to the worker ' 'class without changing the existing functions.') - parser.add_argument( - "--generation-config", - type=optional_type(str), - default="auto", - help="The folder path to the generation config. " - "Defaults to 'auto', the generation config will be loaded from " - "model path. If set to 'vllm', no generation config is loaded, " - "vLLM defaults will be used. If set to a folder path, the " - "generation config will be loaded from the specified folder path. " - "If `max_new_tokens` is specified in generation config, then " - "it sets a server-wide limit on the number of output tokens " - "for all requests.") - - parser.add_argument( - "--override-generation-config", - type=json.loads, - default=None, - help="Overrides or sets generation config in JSON format. " - "e.g. ``{\"temperature\": 0.5}``. If used with " - "--generation-config=auto, the override parameters will be merged " - "with the default config from the model. If generation-config is " - "None, only the override parameters are used.") - - parser.add_argument("--enable-sleep-mode", - action="store_true", - default=False, - help="Enable sleep mode for the engine. " - "(only cuda platform is supported)") parser.add_argument( "--additional-config", @@ -966,16 +800,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "If enabled, the model will be able to generate reasoning content." ) - parser.add_argument( - "--disable-cascade-attn", - action="store_true", - default=False, - help="Disable cascade attention for V1. While cascade attention " - "does not change the mathematical correctness, disabling it " - "could be useful for preventing potential numerical issues. " - "Note that even if this is set to False, cascade attention will be " - "only used when the heuristic tells that it's beneficial.") - return parser @classmethod @@ -1002,8 +826,7 @@ def create_model_config(self) -> ModelConfig: model=self.model, hf_config_path=self.hf_config_path, task=self.task, - # We know this is not None because we set it in __post_init__ - tokenizer=cast(str, self.tokenizer), + tokenizer=self.tokenizer, tokenizer_mode=self.tokenizer_mode, trust_remote_code=self.trust_remote_code, allowed_local_media_path=self.allowed_local_media_path, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 948e8f36e0e6..de03882dadf0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -13,7 +13,7 @@ from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) -from vllm.config import CompilationConfig +from vllm.config import CompilationConfig, ModelDType, TokenizerMode from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, TaskOption) from vllm.engine.llm_engine import LLMEngine @@ -31,6 +31,7 @@ from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest, LLMGuidedOptions) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, PoolingRequestOutput, RequestOutput, ScoringRequestOutput) @@ -162,20 +163,20 @@ def __init__( self, model: str, tokenizer: Optional[str] = None, - tokenizer_mode: str = "auto", + tokenizer_mode: TokenizerMode = "auto", skip_tokenizer_init: bool = False, trust_remote_code: bool = False, allowed_local_media_path: str = "", tensor_parallel_size: int = 1, - dtype: str = "auto", - quantization: Optional[str] = None, + dtype: ModelDType = "auto", + quantization: Optional[QuantizationMethods] = None, revision: Optional[str] = None, tokenizer_revision: Optional[str] = None, seed: Optional[int] = None, gpu_memory_utilization: float = 0.9, swap_space: float = 4, cpu_offload_gb: float = 0, - enforce_eager: Optional[bool] = None, + enforce_eager: bool = False, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, @@ -188,12 +189,7 @@ def __init__( compilation_config: Optional[Union[int, dict[str, Any]]] = None, **kwargs, ) -> None: - ''' - LLM constructor. - - Note: if enforce_eager is unset (enforce_eager is None) - it defaults to False. - ''' + """LLM constructor.""" if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 10f5241f9a71..0b74e8faff9d 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -12,6 +12,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs @@ -186,7 +187,7 @@ def __repr__(self) -> str: f"out_group_size={self.out_group_size})") @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "aqlm" @classmethod diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 227be1497d0e..cfc31ae20549 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -7,6 +7,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.parameter import (GroupQuantScaleParameter, @@ -44,7 +45,7 @@ def __repr__(self) -> str: f"zero_point={self.zero_point}, " f"modules_to_not_convert={self.modules_to_not_convert})") - def get_name(self) -> str: + def get_name(self) -> QuantizationMethods: return "awq" def get_supported_act_dtypes(self) -> List[torch.dtype]: diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index ef4a7765d61e..193e90b85812 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, set_weight_attrs) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.awq import (AWQConfig, is_layer_skipped_awq) from vllm.model_executor.layers.quantization.base_config import ( @@ -73,7 +74,7 @@ def __repr__(self) -> str: f"modules_to_not_convert={self.modules_to_not_convert})") @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "awq_marlin" @classmethod @@ -101,8 +102,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig": modules_to_not_convert, config) @classmethod - def override_quantization_method(cls, hf_quant_cfg, - user_quant) -> Optional[str]: + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) is_valid_user_quant = (user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin") diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 5ef11546fd41..8cf058b406fb 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -2,11 +2,16 @@ import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type import torch from torch import nn +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationMethods +else: + QuantizationMethods = str + class QuantizeMethodBase(ABC): """Base class for different quantized methods.""" @@ -66,7 +71,7 @@ def __init__(self): self.packed_modules_mapping: Dict[str, List[str]] = dict() @abstractmethod - def get_name(self) -> str: + def get_name(self) -> QuantizationMethods: """Name of the quantization method.""" raise NotImplementedError @@ -99,8 +104,8 @@ def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": raise NotImplementedError @classmethod - def override_quantization_method(cls, hf_quant_cfg, - user_quant) -> Optional[str]: + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: """ Detects if this quantization method can support a given checkpoint format by overriding the user specified quantization method -- diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 3eaaa6c252ce..ab858d72034a 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -5,6 +5,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( @@ -100,7 +101,7 @@ def __repr__(self) -> str: f"quant_method={self.quant_method})") @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "bitblas" @classmethod @@ -139,8 +140,8 @@ def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig": lm_head_quantized) @classmethod - def override_quantization_method(cls, hf_quant_cfg, - user_quant) -> Optional[str]: + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: # compat: autogptq >=0.8.0 use checkpoint_format: str # compat: autogptq <=0.7.1 is_bitblas_format: bool is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas" diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index f5d32efe8368..a472779d930b 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, set_weight_attrs) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.utils import direct_register_custom_op @@ -56,7 +57,7 @@ def __repr__(self) -> str: f"llm_int8_skip_modules={self.llm_int8_skip_modules})") @classmethod - def get_name(self) -> str: + def get_name(self) -> QuantizationMethods: return "bitsandbytes" @classmethod diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 5be6b22c7b23..0585c09bd84b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 @@ -71,7 +72,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_min_capability(cls) -> int: return 70 - def get_name(self) -> str: + def get_name(self) -> QuantizationMethods: return "compressed-tensors" def get_quant_method( diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index 67934d37284e..df7ec3376b55 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs @@ -41,8 +42,8 @@ def __repr__(self) -> str: f"group_size={self.group_size}") @classmethod - def get_name(cls) -> str: - return "DeepSpeedFP" + def get_name(cls) -> QuantizationMethods: + return "deepspeedfp" @classmethod def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig": diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index be19b80975ec..cce95941b714 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -8,6 +8,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs @@ -20,7 +21,7 @@ def __init__(self) -> None: super().__init__() @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "experts_int8" @classmethod diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 7dddc40f3446..1fa2b3a8eeea 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -9,6 +9,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -38,7 +39,7 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float): self.fp8_linear = Fp8LinearOp() @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "fbgemm_fp8" @classmethod diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 01056c37b86c..5515ba27ea19 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -16,6 +16,7 @@ FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod @@ -83,7 +84,7 @@ def __init__( self.weight_block_size = weight_block_size @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "fp8" @classmethod diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 6b499f81c55f..05058dfaa733 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, FusedMoEMethodBase) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -31,7 +32,7 @@ def __init__(self, ) -> None: def __repr__(self) -> str: return ("GGUFConfig()") - def get_name(self) -> str: + def get_name(self) -> QuantizationMethods: return "gguf" def get_supported_act_dtypes(self) -> List[torch.dtype]: diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 1c8d6cb1ea79..5059e0cdfd4a 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -10,6 +10,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.gptq_utils import ( @@ -79,7 +80,7 @@ def __repr__(self) -> str: f"dynamic={self.dynamic}") @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "gptq" @classmethod diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 88cada4c61b8..891d8cdf36af 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -7,6 +7,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( @@ -123,7 +124,7 @@ def __repr__(self) -> str: f"quant_method={self.quant_method})") @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "gptq_bitblas" @classmethod @@ -151,8 +152,8 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig": lm_head_quantized) @classmethod - def override_quantization_method(cls, hf_quant_cfg, - user_quant) -> Optional[str]: + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg) is_valid_user_quant = (user_quant is None or user_quant == "bitblas" diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 52cd0a5b6975..c7f9d95f4c2d 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -11,6 +11,7 @@ FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( @@ -100,7 +101,7 @@ def __repr__(self) -> str: f"dynamic={self.dynamic}") @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "gptq_marlin" @classmethod @@ -130,8 +131,8 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": lm_head_quantized, dynamic, config) @classmethod - def override_quantization_method(cls, hf_quant_cfg, - user_quant) -> Optional[str]: + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) is_valid_user_quant = (user_quant is None or user_quant == "marlin" diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index dd747e182e28..1fe08e4b34fe 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -8,6 +8,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.parameter import (BasevLLMParameter, @@ -85,7 +86,7 @@ def __repr__(self) -> str: self.quant_type, self.group_size) @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "gptq_marlin_24" @classmethod @@ -108,8 +109,8 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config": return cls(weight_bits, group_size) @classmethod - def override_quantization_method(cls, hf_quant_cfg, - user_quant) -> Optional[str]: + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: is_marlin_24_format = ( hf_quant_cfg.get("checkpoint_format") == "marlin_24") diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 4edc9aa848a1..7bd398137e02 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( @@ -50,7 +51,7 @@ def __repr__(self) -> str: f"group_size={self.group_size})") @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "hqq" @classmethod diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index c09cc13cb276..212af278ff81 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -6,6 +6,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod, is_layer_skipped_awq) from vllm.model_executor.layers.quantization.base_config import ( @@ -58,7 +59,7 @@ def __repr__(self) -> str: f"group_size={self.group_size})") @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "ipex" @classmethod @@ -97,8 +98,8 @@ def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig": lm_head_quantized) @classmethod - def override_quantization_method(cls, hf_quant_cfg, - user_quant) -> Optional[str]: + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: if not current_platform.is_cpu() and not current_platform.is_xpu(): return None diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 4cf0c677c079..9ef71a7894d7 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -8,6 +8,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -63,7 +64,7 @@ def __repr__(self) -> str: f"lm_head_quantized={self.lm_head_quantized})") @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "marlin" @classmethod @@ -87,8 +88,8 @@ def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": return cls(group_size, lm_head_quantized) @classmethod - def override_quantization_method(cls, hf_quant_cfg, - user_quant) -> Optional[str]: + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: # compat: autogptq >=0.8.0 use checkpoint_format: str # compat: autogptq <=0.7.1 is_marlin_format: bool is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin" diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 3de153699155..828447dd1019 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -11,6 +11,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod @@ -42,7 +43,7 @@ def __init__( " the format is experimental and could change.") @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "modelopt" @classmethod @@ -184,8 +185,8 @@ def __init__( self.exclude_modules = exclude_modules @classmethod - def get_name(cls) -> str: - return "modelopt_nvfp4" + def get_name(cls) -> QuantizationMethods: + return "nvfp4" @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 00c4b661ef2c..b8e3a4364379 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -9,6 +9,7 @@ FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( @@ -64,7 +65,7 @@ def __init__(self, linear_quant_method: str, weight_bits: int, self.modules_to_not_convert = modules_to_not_convert @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "moe_wna16" @classmethod @@ -100,8 +101,8 @@ def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config": lm_head_quantized, modules_to_not_convert, config) @classmethod - def override_quantization_method(cls, hf_quant_cfg, - user_quant) -> Optional[str]: + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) if can_convert and user_quant == "moe_wna16": return cls.get_name() diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py index f6f66803f816..7933eab2a530 100644 --- a/vllm/model_executor/layers/quantization/neuron_quant.py +++ b/vllm/model_executor/layers/quantization/neuron_quant.py @@ -6,6 +6,7 @@ from torch.nn import Module +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -30,7 +31,7 @@ def __init__( self.dequant_dtype = dequant_dtype self.quantize_method = quantize_method - def get_name(self) -> str: + def get_name(self) -> QuantizationMethods: return "neuron_quant" def get_supported_act_dtypes(self) -> List[str]: diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 592ffc5dad13..004d74e68b9a 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -9,6 +9,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase) from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, @@ -50,7 +51,7 @@ def __init__( ignored_layers=ignored_layers) @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "ptpc_fp8" @classmethod diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py index 1e05917a5187..06ff6c71b913 100644 --- a/vllm/model_executor/layers/quantization/qqq.py +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -8,6 +8,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.parameter import (BasevLLMParameter, @@ -84,7 +85,7 @@ def __repr__(self) -> str: self.weight_bits, self.group_size) @classmethod - def get_name(cls) -> str: + def get_name(cls) -> QuantizationMethods: return "qqq" @classmethod diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index cf9108ea72c3..da2312190084 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -8,6 +8,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod @@ -47,7 +48,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_min_capability(cls) -> int: return 70 - def get_name(self) -> str: + def get_name(self) -> QuantizationMethods: return "quark" def get_quant_method(self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index 5c2babcf4ab6..751002fa0945 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -6,6 +6,7 @@ from torch.nn.parameter import Parameter from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs @@ -20,7 +21,7 @@ def __init__(self, torchao_config) -> None: def __repr__(self) -> str: return f"TorchAOConfig({self.torchao_config})" - def get_name(self) -> str: + def get_name(self) -> QuantizationMethods: return "torchao" def get_supported_act_dtypes(self) -> List[torch.dtype]: diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py index 14e5bcf6e5bb..8333c16ce6a1 100644 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.parameter import ModelWeightParameter @@ -27,7 +28,7 @@ def __init__( f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme - def get_name(self) -> str: + def get_name(self) -> QuantizationMethods: return "tpu_int8" def get_supported_act_dtypes(self) -> List[torch.dtype]: diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index c5970c71c539..00f4e66bd13e 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -1496,7 +1496,7 @@ def get_rope( if key in _ROPE_DICT: return _ROPE_DICT[key] - if rope_scaling is None: + if not rope_scaling: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype) else: diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 67aaad10fcfe..a7b313f4e502 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -180,7 +180,6 @@ def _get_neuron_config_after_override(default_neuron_config, NeuronConfig, QuantizationConfig, SparseAttnConfig) - overridden_neuron_config = overridden_neuron_config or {} sparse_attn = overridden_neuron_config.pop("sparse_attn", {}) if sparse_attn: overridden_neuron_config["sparse_attn"] = SparseAttnConfig(