1111import warnings
1212from collections import Counter
1313from contextlib import contextmanager
14- from dataclasses import (MISSING , Field , asdict , dataclass , field , fields ,
15- is_dataclass , replace )
14+ from dataclasses import (MISSING , Field , asdict , field , fields , is_dataclass ,
15+ replace )
1616from functools import cached_property
1717from importlib .util import find_spec
1818from pathlib import Path
2121
2222import regex as re
2323import torch
24+ from pydantic import (ConfigDict , SkipValidation , TypeAdapter , field_validator ,
25+ model_validator )
26+ from pydantic .dataclasses import dataclass
2427from torch .distributed import ProcessGroup , ReduceOp
2528from transformers import PretrainedConfig
26- from typing_extensions import deprecated
29+ from typing_extensions import deprecated , runtime_checkable
2730
2831import vllm .envs as envs
2932from vllm import version
5760 from vllm .model_executor .layers .quantization .base_config import (
5861 QuantizationConfig )
5962 from vllm .model_executor .model_loader import BaseModelLoader
63+ from vllm .model_executor .model_loader .tensorizer import TensorizerConfig
6064
6165 ConfigType = type [DataclassInstance ]
6266else :
67+ PlacementGroup = Any
68+ ExecutorBase = Any
6369 QuantizationConfig = Any
70+ BaseModelLoader = Any
71+ TensorizerConfig = Any
6472 ConfigType = type
6573
6674logger = init_logger (__name__ )
92100 PretrainedConfig ]]
93101
94102
103+ @runtime_checkable
95104class SupportsHash (Protocol ):
96105
97106 def compute_hash (self ) -> str :
@@ -223,7 +232,7 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
223232
224233
225234@config
226- @dataclass
235+ @dataclass ( config = ConfigDict ( arbitrary_types_allowed = True ))
227236class ModelConfig :
228237 """Configuration for the model."""
229238
@@ -236,7 +245,7 @@ class ModelConfig:
236245 task, even if the same model can be used for multiple tasks. When the model
237246 only supports one task, "auto" can be used to select it; otherwise, you
238247 must specify explicitly which task to use."""
239- tokenizer : str = None # type: ignore
248+ tokenizer : SkipValidation [ str ] = None # type: ignore
240249 """Name or path of the Hugging Face tokenizer to use. If unspecified, model
241250 name or path will be used."""
242251 tokenizer_mode : TokenizerMode = "auto"
@@ -284,7 +293,7 @@ class ModelConfig:
284293 """The specific revision to use for the tokenizer on the Hugging Face Hub.
285294 It can be a branch name, a tag name, or a commit id. If unspecified, will
286295 use the default version."""
287- max_model_len : int = None # type: ignore
296+ max_model_len : SkipValidation [ int ] = None # type: ignore
288297 """Model context length (prompt and output). If unspecified, will be
289298 automatically derived from the model config.
290299
@@ -602,6 +611,22 @@ def __post_init__(self) -> None:
602611 self ._verify_cuda_graph ()
603612 self ._verify_bnb_config ()
604613
614+ @field_validator ("quantization" , mode = "before" )
615+ @classmethod
616+ def validate_quantization_before (cls , value : Any ) -> Any :
617+ if isinstance (value , str ):
618+ return value .lower ()
619+ return value
620+
621+ @model_validator (mode = "after" )
622+ def validate_model_config_after (self : "ModelConfig" ) -> "ModelConfig" :
623+ if not isinstance (self .tokenizer , str ):
624+ raise ValueError ("tokenizer must be a string after __post_init__." )
625+ if not isinstance (self .max_model_len , int ):
626+ raise ValueError (
627+ "max_model_len must be an integer after __post_init__." )
628+ return self
629+
605630 @property
606631 def registry (self ):
607632 return ModelRegistry
@@ -823,8 +848,7 @@ def _verify_quantization(self) -> None:
823848 "quark" , "modelopt_fp4" , "bitblas" , "gptq_bitblas"
824849 ]
825850 if self .quantization is not None :
826- self .quantization = cast (QuantizationMethods ,
827- self .quantization .lower ())
851+ self .quantization = cast (QuantizationMethods , self .quantization )
828852
829853 # Parse quantization method from the HF model config, if available.
830854 quant_cfg = self ._parse_quant_hf_config ()
@@ -1397,7 +1421,7 @@ def get_and_verify_max_len(self, max_model_len: int):
13971421class CacheConfig :
13981422 """Configuration for the KV cache."""
13991423
1400- block_size : BlockSize = None # type: ignore
1424+ block_size : SkipValidation [ BlockSize ] = None # type: ignore
14011425 """Size of a contiguous cache block in number of tokens. This is ignored on
14021426 neuron devices and set to `--max-model-len`. On CUDA devices, only block
14031427 sizes up to 32 are supported. On HPU devices, block size defaults to 128.
@@ -1619,7 +1643,8 @@ class LoadConfig:
16191643 download_dir : Optional [str ] = None
16201644 """Directory to download and load the weights, default to the default
16211645 cache directory of Hugging Face."""
1622- model_loader_extra_config : dict = field (default_factory = dict )
1646+ model_loader_extra_config : Union [dict , TensorizerConfig ] = field (
1647+ default_factory = dict )
16231648 """Extra config for model loader. This will be passed to the model loader
16241649 corresponding to the chosen load_format."""
16251650 ignore_patterns : Optional [Union [list [str ], str ]] = None
@@ -1929,19 +1954,19 @@ class SchedulerConfig:
19291954 runner_type : RunnerType = "generate"
19301955 """The runner type to launch for the model."""
19311956
1932- max_num_batched_tokens : int = None # type: ignore
1957+ max_num_batched_tokens : SkipValidation [ int ] = None # type: ignore
19331958 """Maximum number of tokens to be processed in a single iteration.
19341959
19351960 This config has no static default. If left unspecified by the user, it will
19361961 be set in `EngineArgs.create_engine_config` based on the usage context."""
19371962
1938- max_num_seqs : int = None # type: ignore
1963+ max_num_seqs : SkipValidation [ int ] = None # type: ignore
19391964 """Maximum number of sequences to be processed in a single iteration.
19401965
19411966 This config has no static default. If left unspecified by the user, it will
19421967 be set in `EngineArgs.create_engine_config` based on the usage context."""
19431968
1944- max_model_len : int = None # type: ignore
1969+ max_model_len : SkipValidation [ int ] = None # type: ignore
19451970 """Maximum length of a sequence (including prompt and generated text). This
19461971 is primarily set in `ModelConfig` and that value should be manually
19471972 duplicated here."""
@@ -1980,7 +2005,7 @@ class SchedulerConfig:
19802005 """Apply a delay (of delay factor multiplied by previous
19812006 prompt latency) before scheduling next prompt."""
19822007
1983- enable_chunked_prefill : bool = None # type: ignore
2008+ enable_chunked_prefill : SkipValidation [ bool ] = None # type: ignore
19842009 """If True, prefill requests can be chunked based
19852010 on the remaining max_num_batched_tokens."""
19862011
@@ -2202,7 +2227,7 @@ def is_multi_step(self) -> bool:
22022227
22032228
22042229@config
2205- @dataclass
2230+ @dataclass ( config = ConfigDict ( arbitrary_types_allowed = True ))
22062231class DeviceConfig :
22072232 """Configuration for the device to use for vLLM execution."""
22082233
@@ -2260,8 +2285,8 @@ def __post_init__(self):
22602285 self .device = torch .device (self .device_type )
22612286
22622287
2263- SpeculativeMethod = Literal ["ngram" , "eagle" , "medusa " , "mlp_speculator " ,
2264- "draft_model" , "deepseek_mtp" ]
2288+ SpeculativeMethod = Literal ["ngram" , "eagle" , "eagle3 " , "medusa " ,
2289+ "mlp_speculator" , " draft_model" , "deepseek_mtp" ]
22652290SpeculativeAcceptanceMethod = Literal ["rejection_sampler" ,
22662291 "typical_acceptance_sampler" ]
22672292
@@ -2272,8 +2297,7 @@ class SpeculativeConfig:
22722297 """Configuration for speculative decoding."""
22732298
22742299 # General speculative decoding control
2275- num_speculative_tokens : int = field (default = None ,
2276- init = True ) # type: ignore
2300+ num_speculative_tokens : SkipValidation [int ] = None # type: ignore
22772301 """The number of speculative tokens, if provided. It will default to the
22782302 number in the draft model config if present, otherwise, it is required."""
22792303 model : Optional [str ] = None
@@ -2349,26 +2373,23 @@ class SpeculativeConfig:
23492373 """Specifies the tree structure for speculative token generation.
23502374 """
23512375 # required configuration params passed from engine
2352- target_model_config : ModelConfig = field (default = None ,
2353- init = True ) # type: ignore
2376+ target_model_config : SkipValidation [ModelConfig ] = None # type: ignore
23542377 """The configuration of the target model."""
2355- target_parallel_config : ParallelConfig = field ( default = None ,
2356- init = True ) # type: ignore
2378+ target_parallel_config : SkipValidation [
2379+ ParallelConfig ] = None # type: ignore
23572380 """The parallel configuration for the target model."""
2358- enable_chunked_prefill : bool = field (default = None ,
2359- init = True ) # type: ignore
2381+ enable_chunked_prefill : SkipValidation [bool ] = None # type: ignore
23602382 """Whether vLLM is configured to use chunked prefill or not. Used for
23612383 raising an error since it's not yet compatible with speculative decode."""
2362- disable_log_stats : bool = field ( default = None , init = True ) # type: ignore
2384+ disable_log_stats : SkipValidation [ bool ] = None # type: ignore
23632385 """Whether to disable the periodic printing of stage times in speculative
23642386 decoding."""
23652387
23662388 # params generated in the post-init stage
2367- draft_model_config : ModelConfig = field (default = None ,
2368- init = True ) # type: ignore
2389+ draft_model_config : SkipValidation [ModelConfig ] = None # type: ignore
23692390 """The configuration of the draft model initialized internal."""
2370- draft_parallel_config : ParallelConfig = field ( default = None ,
2371- init = True ) # type: ignore
2391+ draft_parallel_config : SkipValidation [
2392+ ParallelConfig ] = None # type: ignore
23722393 """The parallel configuration for the draft model initialized internal."""
23732394
23742395 def compute_hash (self ) -> str :
@@ -2766,7 +2787,7 @@ def __repr__(self) -> str:
27662787
27672788
27682789@config
2769- @dataclass
2790+ @dataclass ( config = ConfigDict ( arbitrary_types_allowed = True ))
27702791class LoRAConfig :
27712792 """Configuration for LoRA."""
27722793
@@ -2863,7 +2884,7 @@ def verify_lora_support(self):
28632884
28642885
28652886@config
2866- @dataclass
2887+ @dataclass ( config = ConfigDict ( arbitrary_types_allowed = True ))
28672888class PromptAdapterConfig :
28682889 """Configuration for PromptAdapters."""
28692890
@@ -3892,17 +3913,11 @@ def __repr__(self) -> str:
38923913 "pass_config" ,
38933914 "traced_files" ,
38943915 }
3895- include = dict ()
3896- for k , v in asdict (self ).items ():
3897- if k in exclude :
3898- continue
3899- f = get_field (CompilationConfig , k )
3900- if (d := f .default ) is not MISSING and d == v :
3901- continue
3902- if (df := f .default_factory ) is not MISSING and df () == v :
3903- continue
3904- include [k ] = v
3905- return json .dumps (include )
3916+ # The cast to string is necessary because Pydantic is mocked in docs
3917+ # builds and sphinx-argparse doesn't know the return type of decode()
3918+ return str (
3919+ TypeAdapter (CompilationConfig ).dump_json (
3920+ self , exclude = exclude , exclude_unset = True ).decode ())
39063921
39073922 __str__ = __repr__
39083923
@@ -3911,7 +3926,7 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig":
39113926 """Parse the CLI value for the compilation config."""
39123927 if cli_value in ["0" , "1" , "2" , "3" ]:
39133928 return cls (level = int (cli_value ))
3914- return cls ( ** json . loads (cli_value ) )
3929+ return TypeAdapter ( CompilationConfig ). validate_json (cli_value )
39153930
39163931 def __post_init__ (self ) -> None :
39173932 count_none = self .custom_ops .count ("none" )
@@ -4037,7 +4052,7 @@ def set_splitting_ops_for_v1(self):
40374052
40384053
40394054@config
4040- @dataclass
4055+ @dataclass ( config = ConfigDict ( arbitrary_types_allowed = True ))
40414056class VllmConfig :
40424057 """Dataclass which contains all vllm-related configuration. This
40434058 simplifies passing around the distinct configurations in the codebase.
@@ -4294,9 +4309,6 @@ def __post_init__(self):
42944309 "To workaround this limitation, vLLM will set 'ieee' input "
42954310 "precision for chunked prefill triton kernels." )
42964311
4297- if self .compilation_config is None :
4298- self .compilation_config = CompilationConfig ()
4299-
43004312 # async tp is built on top of sequence parallelism
43014313 # and requires it to be enabled.
43024314 if self .compilation_config .pass_config .enable_async_tp :
0 commit comments