Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import json
from argparse import ArgumentError
from contextlib import nullcontext
from dataclasses import dataclass, field
from contextlib import AbstractContextManager, nullcontext
from typing import Annotated, Literal

import pytest
from pydantic import Field

from vllm.config import AttentionConfig, CompilationConfig, config
from vllm.engine.arg_utils import (
Expand Down Expand Up @@ -96,46 +96,44 @@ def test_get_type(type_hints, type, expected):
],
)
def test_literal_to_kwargs(type_hints, expected):
context = nullcontext()
context: AbstractContextManager[object] = nullcontext()
if expected is Exception:
context = pytest.raises(expected)
with context:
assert literal_to_kwargs(type_hints) == expected


@config
@dataclass
class NestedConfig:
field: int = 1
"""field"""


@config
@dataclass
class DummyConfig:
regular_bool: bool = True
"""Regular bool with default True"""
optional_bool: bool | None = None
"""Optional bool with default None"""
optional_literal: Literal["x", "y"] | None = None
"""Optional literal with default None"""
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
tuple_n: tuple[int, ...] = Field(default_factory=lambda: (1, 2, 3))
"""Tuple with variable length"""
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
tuple_2: tuple[int, int] = Field(default_factory=lambda: (1, 2))
"""Tuple with fixed length"""
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
list_n: list[int] = Field(default_factory=lambda: [1, 2, 3])
"""List with variable length"""
list_literal: list[Literal[1, 2]] = field(default_factory=list)
list_literal: list[Literal[1, 2]] = Field(default_factory=list)
"""List with literal choices"""
list_union: list[str | type[object]] = field(default_factory=list)
list_union: list[str | type[object]] = Field(default_factory=list)
"""List with union type"""
set_n: set[int] = field(default_factory=lambda: {1, 2, 3})
set_n: set[int] = Field(default_factory=lambda: {1, 2, 3})
"""Set with variable length"""
literal_literal: Literal[Literal[1], Literal[2]] = 1
"""Literal of literals with default 1"""
json_tip: dict = field(default_factory=dict)
json_tip: dict = Field(default_factory=dict)
"""Dict which will be JSON in CLI"""
nested_config: NestedConfig = field(default_factory=NestedConfig)
nested_config: NestedConfig = Field(default_factory=NestedConfig)
"""Nested config"""


Expand Down Expand Up @@ -195,7 +193,7 @@ def test_get_kwargs():
json_tip = "Should either be a valid JSON string or JSON keys"
assert json_tip in kwargs["json_tip"]["help"]
# nested config should construct the nested config
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) # type: ignore[call-arg]


@pytest.mark.parametrize(
Expand Down
5 changes: 1 addition & 4 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ class _TestConfigFields:


def test_get_field():
with pytest.raises(ValueError):
get_field(_TestConfigFields, "a")

b = get_field(_TestConfigFields, "b")
assert isinstance(b, Field)
assert b.default is MISSING
Expand Down Expand Up @@ -188,7 +185,7 @@ def test_get_pooling_config():
)
def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True)
pooler_config = PoolerConfig(seq_pooling_type="CLS", use_activation=False)
model_config = ModelConfig(model_id, pooler_config=pooler_config)

assert asdict(model_config.pooler_config) == asdict(pooler_config)
Expand Down
26 changes: 8 additions & 18 deletions tests/tools/test_config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,22 @@

from tools.pre_commit.validate_config import validate_ast

_TestConfig1 = """
_TestConfig1 = '''
@config
class _TestConfig1:
pass
"""

_TestConfig2 = '''
@config
@dataclass
class _TestConfig2:
a: int
"""docstring"""
'''

_TestConfig3 = """
_TestConfig2 = """
@config
@dataclass
class _TestConfig3:
class _TestConfig2:
a: int = 1
"""

_TestConfig4 = '''
_TestConfig3 = '''
@config
@dataclass
class _TestConfig4:
class _TestConfig3:
a: Union[Literal[1], Literal[2]] = 1
"""docstring"""
'''
Expand All @@ -40,10 +31,9 @@ class _TestConfig4:
@pytest.mark.parametrize(
("test_config", "expected_error"),
[
(_TestConfig1, "must be a dataclass"),
(_TestConfig2, "must have a default"),
(_TestConfig3, "must have a docstring"),
(_TestConfig4, "must use a single Literal"),
(_TestConfig1, "must have a default"),
(_TestConfig2, "must have a docstring"),
(_TestConfig3, "must use a single Literal"),
],
)
def test_config(test_config, expected_error):
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,8 +766,8 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
"max_model_len": args.max_model_len,
"enforce_eager": enforce_eager,
"draft_tensor_parallel_size": args.draft_tensor_parallel_size,
"max_num_seqs": 100, # limit cudagraph capture runtime
},
max_num_seqs=100, # limit cudagraph capture runtime
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size,
Expand Down
7 changes: 2 additions & 5 deletions tests/v1/structured_output/test_backend_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,8 @@ def test_backend_guidance_rollback_terminated():
# guidance backend. In that case we are in a stopped state, but
# it should be reverted in case EOS is not accepted by the target
# model.
vllm_config = VllmConfig(
decoding_config=StructuredOutputsConfig(
backend="guidance",
)
)
structured_outputs_config = StructuredOutputsConfig(backend="guidance")
vllm_config = VllmConfig(structured_outputs_config=structured_outputs_config)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

backend = GuidanceBackend(
Expand Down
28 changes: 11 additions & 17 deletions tools/pre_commit/validate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,18 @@ class ConfigValidator(ast.NodeVisitor):
def __init__(self): ...

def visit_ClassDef(self, node):
# Validate class with both @config and @dataclass decorators
decorators = [
id
for d in node.decorator_list
if (
isinstance(d, ast.Name)
and ((id := d.id) == "config" or id == "dataclass")
)
or (
isinstance(d, ast.Call)
and (isinstance(d.func, ast.Name) and (id := d.func.id) == "dataclass")
)
]

if set(decorators) == {"config", "dataclass"}:
# Validate classes with a @config decorator
decorators = set()
for decorator in node.decorator_list:
if isinstance(decorator, ast.Call):
decorator = decorator.func
if isinstance(decorator, ast.Name) and decorator.id == "config":
decorators.add(decorator.id)

if decorators == {"config"}:
validate_class(node)
elif set(decorators) == {"config"}:
fail(f"Class {node.name} with config decorator must be a dataclass.", node)
elif "config" in decorators:
fail(f"config decorator for {node.name} should be used alone", node)
Comment thread
hmellor marked this conversation as resolved.

self.generic_visit(node)

Expand Down
2 changes: 2 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
config,
get_attr_docs,
is_init_field,
replace,
update_config,
)
from vllm.config.vllm import (
Expand Down Expand Up @@ -101,6 +102,7 @@
"config",
"get_attr_docs",
"is_init_field",
"replace",
"update_config",
# From vllm.config.vllm
"VllmConfig",
Expand Down
2 changes: 0 additions & 2 deletions vllm/config/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
from typing import Any, Literal

from pydantic import field_validator
from pydantic.dataclasses import dataclass

from vllm.config.utils import config
from vllm.v1.attention.backends.registry import AttentionBackendEnum


@config
@dataclass
class AttentionConfig:
"""Configuration for attention mechanisms in vLLM."""

Expand Down
2 changes: 0 additions & 2 deletions vllm/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import TYPE_CHECKING, Any, Literal

from pydantic import Field, SkipValidation, field_validator
from pydantic.dataclasses import dataclass

from vllm.config.utils import config
from vllm.logger import init_logger
Expand Down Expand Up @@ -37,7 +36,6 @@


@config
@dataclass
class CacheConfig:
"""Configuration for the KV cache."""

Expand Down
6 changes: 1 addition & 5 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal

from pydantic import ConfigDict, Field, TypeAdapter, field_validator
from pydantic.dataclasses import dataclass
from pydantic import Field, TypeAdapter, field_validator

import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
Expand Down Expand Up @@ -96,7 +95,6 @@ def __str__(self) -> str:


@config
@dataclass(config=ConfigDict(extra="forbid"))
class PassConfig:
"""Configuration for custom Inductor passes.

Expand Down Expand Up @@ -267,7 +265,6 @@ class DynamicShapesType(str, enum.Enum):


@config
@dataclass(config=ConfigDict(extra="forbid"))
class DynamicShapesConfig:
"""Configuration to control/debug torch compile dynamic shapes."""

Expand Down Expand Up @@ -311,7 +308,6 @@ def compute_hash(self) -> str:


@config
@dataclass(config=ConfigDict(extra="forbid"))
class CompilationConfig:
"""Configuration for compilation.

Expand Down
4 changes: 1 addition & 3 deletions vllm/config/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@

import torch
from pydantic import ConfigDict, SkipValidation
from pydantic.dataclasses import dataclass

from vllm.config.utils import config
from vllm.utils.hashing import safe_hash

Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]


@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@config(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig:
"""Configuration for the device to use for vLLM execution."""

Expand Down
3 changes: 0 additions & 3 deletions vllm/config/ec_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from dataclasses import field
from typing import Any, Literal, get_args

from pydantic.dataclasses import dataclass

from vllm.config.utils import config

ECProducer = Literal["ec_producer"]
Expand All @@ -15,7 +13,6 @@


@config
@dataclass
class ECTransferConfig:
"""Configuration for distributed EC cache transfer."""

Expand Down
2 changes: 0 additions & 2 deletions vllm/config/kv_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
from typing import Literal

from pydantic import Field
from pydantic.dataclasses import dataclass

from vllm.config.utils import config


@config
@dataclass
class KVEventsConfig:
"""Configuration for KV event publishing."""

Expand Down
3 changes: 0 additions & 3 deletions vllm/config/kv_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from dataclasses import field
from typing import Any, Literal, get_args

from pydantic.dataclasses import dataclass

from vllm.config.utils import config
from vllm.utils.hashing import safe_hash

Expand All @@ -16,7 +14,6 @@


@config
@dataclass
class KVTransferConfig:
"""Configuration for distributed KV cache transfer."""

Expand Down
2 changes: 0 additions & 2 deletions vllm/config/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import TYPE_CHECKING, Any

from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass

from vllm.config.utils import config
from vllm.logger import init_logger
Expand All @@ -21,7 +20,6 @@


@config
@dataclass
class LoadConfig:
"""Configuration for loading the model weights."""

Expand Down
4 changes: 1 addition & 3 deletions vllm/config/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch
from pydantic import ConfigDict, Field, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self

from vllm.config.utils import config
Expand All @@ -26,8 +25,7 @@
LoRAExtraVocabSize = Literal[256, 512]


@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@config(config=ConfigDict(arbitrary_types_allowed=True))
class LoRAConfig:
"""Configuration for LoRA."""

Expand Down
Loading