Skip to content

Commit 4c2b38c

Browse files
authored
Enable Pydantic mypy checks and convert configs to Pydantic dataclasses (#17599)
Signed-off-by: Harry Mellor <[email protected]>
1 parent d781930 commit 4c2b38c

File tree

11 files changed

+115
-102
lines changed

11 files changed

+115
-102
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ repos:
5858
entry: tools/mypy.sh 0 "local"
5959
language: python
6060
types: [python]
61-
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests]
61+
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic]
6262
stages: [pre-commit] # Don't run in CI
6363
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
6464
name: Run mypy for Python 3.9

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ ignore = [
110110
]
111111

112112
[tool.mypy]
113+
plugins = ['pydantic.mypy']
113114
ignore_missing_imports = true
114115
check_untyped_defs = true
115116
follow_imports = "silent"

tests/lora/test_quant_model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,16 @@ class ModelWithQuantization:
2424
MODELS = [
2525
ModelWithQuantization(
2626
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
27-
quantization="GPTQ"),
27+
quantization="gptq"),
2828
]
2929
else:
3030
MODELS = [
3131
ModelWithQuantization(
3232
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
33-
quantization="AWQ"),
33+
quantization="awq"),
3434
ModelWithQuantization(
3535
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
36-
quantization="GPTQ"),
36+
quantization="gptq"),
3737
]
3838

3939

@@ -100,7 +100,7 @@ def test_quant_model_lora(tinyllama_lora_files, model):
100100
"#ff8050",
101101
"#ff8080",
102102
]
103-
elif model.quantization == "AWQ":
103+
elif model.quantization == "awq":
104104
expected_no_lora_output = [
105105
"I'm sorry, I don't understand",
106106
"I'm sorry, I don't understand",
@@ -109,7 +109,7 @@ def test_quant_model_lora(tinyllama_lora_files, model):
109109
"#f07700: A v",
110110
"#f00000: A v",
111111
]
112-
elif model.quantization == "GPTQ":
112+
elif model.quantization == "gptq":
113113
expected_no_lora_output = [
114114
"I'm sorry, I don't have",
115115
"I'm sorry, I don't have",
@@ -122,7 +122,7 @@ def test_quant_model_lora(tinyllama_lora_files, model):
122122
def expect_match(output, expected_output):
123123
# HACK: GPTQ lora outputs are just incredibly unstable.
124124
# Assert that the outputs changed.
125-
if (model.quantization == "GPTQ"
125+
if (model.quantization == "gptq"
126126
and expected_output is expected_lora_output):
127127
assert output != expected_no_lora_output
128128
for i, o in enumerate(output):
@@ -172,7 +172,7 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
172172
model):
173173
if num_gpus_available < 2:
174174
pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
175-
if model.quantization == "GPTQ":
175+
if model.quantization == "gptq":
176176
pytest.skip("GPTQ lora outputs are just incredibly unstable")
177177
llm_tp1 = vllm.LLM(
178178
model=model.model_path,

tests/tracing/test_tracing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def test_traces_with_detailed_steps(
173173
llm = LLM(
174174
model=model,
175175
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
176-
collect_detailed_traces="all",
176+
collect_detailed_traces=["all"],
177177
)
178178
prompts = ["This is a short prompt"]
179179
outputs = llm.generate(prompts, sampling_params=sampling_params)

vllm/config.py

Lines changed: 60 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import warnings
1212
from collections import Counter
1313
from contextlib import contextmanager
14-
from dataclasses import (MISSING, Field, asdict, dataclass, field, fields,
15-
is_dataclass, replace)
14+
from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
15+
replace)
1616
from functools import cached_property
1717
from importlib.util import find_spec
1818
from pathlib import Path
@@ -21,9 +21,12 @@
2121

2222
import regex as re
2323
import torch
24+
from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator,
25+
model_validator)
26+
from pydantic.dataclasses import dataclass
2427
from torch.distributed import ProcessGroup, ReduceOp
2528
from transformers import PretrainedConfig
26-
from typing_extensions import deprecated
29+
from typing_extensions import deprecated, runtime_checkable
2730

2831
import vllm.envs as envs
2932
from vllm import version
@@ -57,10 +60,15 @@
5760
from vllm.model_executor.layers.quantization.base_config import (
5861
QuantizationConfig)
5962
from vllm.model_executor.model_loader import BaseModelLoader
63+
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
6064

6165
ConfigType = type[DataclassInstance]
6266
else:
67+
PlacementGroup = Any
68+
ExecutorBase = Any
6369
QuantizationConfig = Any
70+
BaseModelLoader = Any
71+
TensorizerConfig = Any
6472
ConfigType = type
6573

6674
logger = init_logger(__name__)
@@ -92,6 +100,7 @@
92100
PretrainedConfig]]
93101

94102

103+
@runtime_checkable
95104
class SupportsHash(Protocol):
96105

97106
def compute_hash(self) -> str:
@@ -223,7 +232,7 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
223232

224233

225234
@config
226-
@dataclass
235+
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
227236
class ModelConfig:
228237
"""Configuration for the model."""
229238

@@ -236,7 +245,7 @@ class ModelConfig:
236245
task, even if the same model can be used for multiple tasks. When the model
237246
only supports one task, "auto" can be used to select it; otherwise, you
238247
must specify explicitly which task to use."""
239-
tokenizer: str = None # type: ignore
248+
tokenizer: SkipValidation[str] = None # type: ignore
240249
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
241250
name or path will be used."""
242251
tokenizer_mode: TokenizerMode = "auto"
@@ -284,7 +293,7 @@ class ModelConfig:
284293
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
285294
It can be a branch name, a tag name, or a commit id. If unspecified, will
286295
use the default version."""
287-
max_model_len: int = None # type: ignore
296+
max_model_len: SkipValidation[int] = None # type: ignore
288297
"""Model context length (prompt and output). If unspecified, will be
289298
automatically derived from the model config.
290299
@@ -602,6 +611,22 @@ def __post_init__(self) -> None:
602611
self._verify_cuda_graph()
603612
self._verify_bnb_config()
604613

614+
@field_validator("quantization", mode="before")
615+
@classmethod
616+
def validate_quantization_before(cls, value: Any) -> Any:
617+
if isinstance(value, str):
618+
return value.lower()
619+
return value
620+
621+
@model_validator(mode="after")
622+
def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
623+
if not isinstance(self.tokenizer, str):
624+
raise ValueError("tokenizer must be a string after __post_init__.")
625+
if not isinstance(self.max_model_len, int):
626+
raise ValueError(
627+
"max_model_len must be an integer after __post_init__.")
628+
return self
629+
605630
@property
606631
def registry(self):
607632
return ModelRegistry
@@ -823,8 +848,7 @@ def _verify_quantization(self) -> None:
823848
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas"
824849
]
825850
if self.quantization is not None:
826-
self.quantization = cast(QuantizationMethods,
827-
self.quantization.lower())
851+
self.quantization = cast(QuantizationMethods, self.quantization)
828852

829853
# Parse quantization method from the HF model config, if available.
830854
quant_cfg = self._parse_quant_hf_config()
@@ -1397,7 +1421,7 @@ def get_and_verify_max_len(self, max_model_len: int):
13971421
class CacheConfig:
13981422
"""Configuration for the KV cache."""
13991423

1400-
block_size: BlockSize = None # type: ignore
1424+
block_size: SkipValidation[BlockSize] = None # type: ignore
14011425
"""Size of a contiguous cache block in number of tokens. This is ignored on
14021426
neuron devices and set to `--max-model-len`. On CUDA devices, only block
14031427
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
@@ -1619,7 +1643,8 @@ class LoadConfig:
16191643
download_dir: Optional[str] = None
16201644
"""Directory to download and load the weights, default to the default
16211645
cache directory of Hugging Face."""
1622-
model_loader_extra_config: dict = field(default_factory=dict)
1646+
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
1647+
default_factory=dict)
16231648
"""Extra config for model loader. This will be passed to the model loader
16241649
corresponding to the chosen load_format."""
16251650
ignore_patterns: Optional[Union[list[str], str]] = None
@@ -1929,19 +1954,19 @@ class SchedulerConfig:
19291954
runner_type: RunnerType = "generate"
19301955
"""The runner type to launch for the model."""
19311956

1932-
max_num_batched_tokens: int = None # type: ignore
1957+
max_num_batched_tokens: SkipValidation[int] = None # type: ignore
19331958
"""Maximum number of tokens to be processed in a single iteration.
19341959
19351960
This config has no static default. If left unspecified by the user, it will
19361961
be set in `EngineArgs.create_engine_config` based on the usage context."""
19371962

1938-
max_num_seqs: int = None # type: ignore
1963+
max_num_seqs: SkipValidation[int] = None # type: ignore
19391964
"""Maximum number of sequences to be processed in a single iteration.
19401965
19411966
This config has no static default. If left unspecified by the user, it will
19421967
be set in `EngineArgs.create_engine_config` based on the usage context."""
19431968

1944-
max_model_len: int = None # type: ignore
1969+
max_model_len: SkipValidation[int] = None # type: ignore
19451970
"""Maximum length of a sequence (including prompt and generated text). This
19461971
is primarily set in `ModelConfig` and that value should be manually
19471972
duplicated here."""
@@ -1980,7 +2005,7 @@ class SchedulerConfig:
19802005
"""Apply a delay (of delay factor multiplied by previous
19812006
prompt latency) before scheduling next prompt."""
19822007

1983-
enable_chunked_prefill: bool = None # type: ignore
2008+
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
19842009
"""If True, prefill requests can be chunked based
19852010
on the remaining max_num_batched_tokens."""
19862011

@@ -2202,7 +2227,7 @@ def is_multi_step(self) -> bool:
22022227

22032228

22042229
@config
2205-
@dataclass
2230+
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
22062231
class DeviceConfig:
22072232
"""Configuration for the device to use for vLLM execution."""
22082233

@@ -2260,8 +2285,8 @@ def __post_init__(self):
22602285
self.device = torch.device(self.device_type)
22612286

22622287

2263-
SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator",
2264-
"draft_model", "deepseek_mtp"]
2288+
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
2289+
"mlp_speculator", "draft_model", "deepseek_mtp"]
22652290
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
22662291
"typical_acceptance_sampler"]
22672292

@@ -2272,8 +2297,7 @@ class SpeculativeConfig:
22722297
"""Configuration for speculative decoding."""
22732298

22742299
# General speculative decoding control
2275-
num_speculative_tokens: int = field(default=None,
2276-
init=True) # type: ignore
2300+
num_speculative_tokens: SkipValidation[int] = None # type: ignore
22772301
"""The number of speculative tokens, if provided. It will default to the
22782302
number in the draft model config if present, otherwise, it is required."""
22792303
model: Optional[str] = None
@@ -2349,26 +2373,23 @@ class SpeculativeConfig:
23492373
"""Specifies the tree structure for speculative token generation.
23502374
"""
23512375
# required configuration params passed from engine
2352-
target_model_config: ModelConfig = field(default=None,
2353-
init=True) # type: ignore
2376+
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
23542377
"""The configuration of the target model."""
2355-
target_parallel_config: ParallelConfig = field(default=None,
2356-
init=True) # type: ignore
2378+
target_parallel_config: SkipValidation[
2379+
ParallelConfig] = None # type: ignore
23572380
"""The parallel configuration for the target model."""
2358-
enable_chunked_prefill: bool = field(default=None,
2359-
init=True) # type: ignore
2381+
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
23602382
"""Whether vLLM is configured to use chunked prefill or not. Used for
23612383
raising an error since it's not yet compatible with speculative decode."""
2362-
disable_log_stats: bool = field(default=None, init=True) # type: ignore
2384+
disable_log_stats: SkipValidation[bool] = None # type: ignore
23632385
"""Whether to disable the periodic printing of stage times in speculative
23642386
decoding."""
23652387

23662388
# params generated in the post-init stage
2367-
draft_model_config: ModelConfig = field(default=None,
2368-
init=True) # type: ignore
2389+
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
23692390
"""The configuration of the draft model initialized internal."""
2370-
draft_parallel_config: ParallelConfig = field(default=None,
2371-
init=True) # type: ignore
2391+
draft_parallel_config: SkipValidation[
2392+
ParallelConfig] = None # type: ignore
23722393
"""The parallel configuration for the draft model initialized internal."""
23732394

23742395
def compute_hash(self) -> str:
@@ -2766,7 +2787,7 @@ def __repr__(self) -> str:
27662787

27672788

27682789
@config
2769-
@dataclass
2790+
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
27702791
class LoRAConfig:
27712792
"""Configuration for LoRA."""
27722793

@@ -2863,7 +2884,7 @@ def verify_lora_support(self):
28632884

28642885

28652886
@config
2866-
@dataclass
2887+
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
28672888
class PromptAdapterConfig:
28682889
"""Configuration for PromptAdapters."""
28692890

@@ -3892,17 +3913,11 @@ def __repr__(self) -> str:
38923913
"pass_config",
38933914
"traced_files",
38943915
}
3895-
include = dict()
3896-
for k, v in asdict(self).items():
3897-
if k in exclude:
3898-
continue
3899-
f = get_field(CompilationConfig, k)
3900-
if (d := f.default) is not MISSING and d == v:
3901-
continue
3902-
if (df := f.default_factory) is not MISSING and df() == v:
3903-
continue
3904-
include[k] = v
3905-
return json.dumps(include)
3916+
# The cast to string is necessary because Pydantic is mocked in docs
3917+
# builds and sphinx-argparse doesn't know the return type of decode()
3918+
return str(
3919+
TypeAdapter(CompilationConfig).dump_json(
3920+
self, exclude=exclude, exclude_unset=True).decode())
39063921

39073922
__str__ = __repr__
39083923

@@ -3911,7 +3926,7 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig":
39113926
"""Parse the CLI value for the compilation config."""
39123927
if cli_value in ["0", "1", "2", "3"]:
39133928
return cls(level=int(cli_value))
3914-
return cls(**json.loads(cli_value))
3929+
return TypeAdapter(CompilationConfig).validate_json(cli_value)
39153930

39163931
def __post_init__(self) -> None:
39173932
count_none = self.custom_ops.count("none")
@@ -4037,7 +4052,7 @@ def set_splitting_ops_for_v1(self):
40374052

40384053

40394054
@config
4040-
@dataclass
4055+
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
40414056
class VllmConfig:
40424057
"""Dataclass which contains all vllm-related configuration. This
40434058
simplifies passing around the distinct configurations in the codebase.
@@ -4294,9 +4309,6 @@ def __post_init__(self):
42944309
"To workaround this limitation, vLLM will set 'ieee' input "
42954310
"precision for chunked prefill triton kernels.")
42964311

4297-
if self.compilation_config is None:
4298-
self.compilation_config = CompilationConfig()
4299-
43004312
# async tp is built on top of sequence parallelism
43014313
# and requires it to be enabled.
43024314
if self.compilation_config.pass_config.enable_async_tp:

0 commit comments

Comments
 (0)