diff --git a/tests/test_config.py b/tests/test_config.py index c0bd4b14ff8d..d63812d4fc05 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1487,3 +1487,275 @@ def test_ir_op_priority_ctx(): # context restored even after exception assert ir.ops.rms_norm.get_priority() == ["vllm_c", "native"] assert ir.ops.fused_add_rms_norm.get_priority() == ["native"] + + +# Config validation tests for field validators added to prevent invalid values + + +@pytest.mark.parametrize("bad", [-1, -5, -100, 0]) +def test_attention_flash_attn_max_num_splits_rejects_non_positive(bad): + """flash_attn_max_num_splits_for_cuda_graph must be positive (> 0). + Negative or zero values would cause issues in CUDA graph setup.""" + from vllm.config.attention import AttentionConfig + + with pytest.raises(ValidationError, match="flash_attn_max_num_splits"): + AttentionConfig(flash_attn_max_num_splits_for_cuda_graph=bad) + + +@pytest.mark.parametrize("good", [1, 16, 32, 64]) +def test_attention_flash_attn_max_num_splits_accepts_positive(good): + from vllm.config.attention import AttentionConfig + + cfg = AttentionConfig(flash_attn_max_num_splits_for_cuda_graph=good) + assert cfg.flash_attn_max_num_splits_for_cuda_graph == good + + +@pytest.mark.parametrize("bad", [-1, -5, 0]) +def test_attention_tq_max_kv_splits_rejects_non_positive(bad): + """tq_max_kv_splits_for_cuda_graph must be positive (> 0).""" + from vllm.config.attention import AttentionConfig + + with pytest.raises(ValidationError, match="tq_max_kv_splits"): + AttentionConfig(tq_max_kv_splits_for_cuda_graph=bad) + + +@pytest.mark.parametrize("good", [1, 16, 32, 64]) +def test_attention_tq_max_kv_splits_accepts_positive(good): + from vllm.config.attention import AttentionConfig + + cfg = AttentionConfig(tq_max_kv_splits_for_cuda_graph=good) + assert cfg.tq_max_kv_splits_for_cuda_graph == good + + +@pytest.mark.parametrize("bad", [0, 8, 15, 17, 31, 33]) +def test_attention_flex_attn_block_size_rejects_invalid(bad): + """flex_attn_block_m/n must be >= 16 and power of 2.""" + from vllm.config.attention import AttentionConfig + + with pytest.raises(ValidationError, match="flex_attn"): + AttentionConfig(flex_attn_block_m=bad) + with pytest.raises(ValidationError, match="flex_attn"): + AttentionConfig(flex_attn_block_n=bad) + + +@pytest.mark.parametrize("good", [16, 32, 64, 128, 256]) +def test_attention_flex_attn_block_size_accepts_valid(good): + from vllm.config.attention import AttentionConfig + + cfg = AttentionConfig(flex_attn_block_m=good, flex_attn_block_n=good) + assert cfg.flex_attn_block_m == good + assert cfg.flex_attn_block_n == good + + +@pytest.mark.parametrize("bad", [0, 3, 7, 15, 17, 31, 33]) +def test_attention_flex_attn_logical_block_size_rejects_non_power_of_2(bad): + """flex_attn_q_block_size/kv_block_size must be power of 2.""" + from vllm.config.attention import AttentionConfig + + with pytest.raises(ValidationError, match="power of 2"): + AttentionConfig(flex_attn_q_block_size=bad) + with pytest.raises(ValidationError, match="power of 2"): + AttentionConfig(flex_attn_kv_block_size=bad) + + +@pytest.mark.parametrize("good", [1, 2, 4, 8, 16, 32, 64, 128]) +def test_attention_flex_attn_logical_block_size_accepts_power_of_2(good): + from vllm.config.attention import AttentionConfig + + cfg = AttentionConfig(flex_attn_q_block_size=good, flex_attn_kv_block_size=good) + assert cfg.flex_attn_q_block_size == good + assert cfg.flex_attn_kv_block_size == good + + +@pytest.mark.parametrize("bad", [-1, -100, 0]) +def test_kv_events_buffer_steps_rejects_non_positive(bad): + """buffer_steps must be positive (> 0).""" + from vllm.config.kv_events import KVEventsConfig + + with pytest.raises(ValidationError, match="buffer_steps"): + KVEventsConfig(buffer_steps=bad) + + +@pytest.mark.parametrize("good", [1, 100, 10_000, 100_000]) +def test_kv_events_buffer_steps_accepts_positive(good): + from vllm.config.kv_events import KVEventsConfig + + cfg = KVEventsConfig(buffer_steps=good) + assert cfg.buffer_steps == good + + +@pytest.mark.parametrize("bad", [-1, -100, 0]) +def test_kv_events_hwm_rejects_non_positive(bad): + """hwm (high water mark) must be positive (> 0).""" + from vllm.config.kv_events import KVEventsConfig + + with pytest.raises(ValidationError, match="hwm"): + KVEventsConfig(hwm=bad) + + +@pytest.mark.parametrize("good", [1, 1000, 100_000]) +def test_kv_events_hwm_accepts_positive(good): + from vllm.config.kv_events import KVEventsConfig + + cfg = KVEventsConfig(hwm=good) + assert cfg.hwm == good + + +@pytest.mark.parametrize("bad", [-1, 0]) +def test_kv_events_max_queue_size_rejects_non_positive(bad): + """max_queue_size must be positive (> 0).""" + from vllm.config.kv_events import KVEventsConfig + + with pytest.raises(ValidationError, match="max_queue_size"): + KVEventsConfig(max_queue_size=bad) + + +@pytest.mark.parametrize("good", [1, 1000, 100_000]) +def test_kv_events_max_queue_size_accepts_positive(good): + from vllm.config.kv_events import KVEventsConfig + + cfg = KVEventsConfig(max_queue_size=good) + assert cfg.max_queue_size == good + + +@pytest.mark.parametrize("bad", [-1.0, -100.0, 0.0, -0.1]) +def test_kv_transfer_buffer_size_rejects_non_positive(bad): + """kv_buffer_size must be positive (> 0).""" + from vllm.config.kv_transfer import KVTransferConfig + + with pytest.raises(ValidationError, match="kv_buffer_size"): + KVTransferConfig(kv_buffer_size=bad) + + +@pytest.mark.parametrize("good", [1.0, 1e9, 1e10]) +def test_kv_transfer_buffer_size_accepts_positive(good): + from vllm.config.kv_transfer import KVTransferConfig + + cfg = KVTransferConfig(kv_buffer_size=good) + assert cfg.kv_buffer_size == good + + +@pytest.mark.parametrize("bad", [-1, -5, -100]) +def test_kv_transfer_rank_rejects_negative(bad): + """kv_rank must be non-negative (>= 0) when set.""" + from vllm.config.kv_transfer import KVTransferConfig + + with pytest.raises(ValidationError, match="kv_rank"): + KVTransferConfig(kv_rank=bad) + + +@pytest.mark.parametrize("good", [0, 1, 2, 10]) +def test_kv_transfer_rank_accepts_non_negative(good): + from vllm.config.kv_transfer import KVTransferConfig + + cfg = KVTransferConfig(kv_rank=good) + assert cfg.kv_rank == good + + +def test_kv_transfer_rank_accepts_none(): + from vllm.config.kv_transfer import KVTransferConfig + + cfg = KVTransferConfig(kv_rank=None) + assert cfg.kv_rank is None + + +@pytest.mark.parametrize("bad", [-1, 0]) +def test_kv_transfer_parallel_size_rejects_non_positive(bad): + """kv_parallel_size must be positive (> 0).""" + from vllm.config.kv_transfer import KVTransferConfig + + with pytest.raises(ValidationError, match="kv_parallel_size"): + KVTransferConfig(kv_parallel_size=bad) + + +@pytest.mark.parametrize("good", [1, 2, 4, 8]) +def test_kv_transfer_parallel_size_accepts_positive(good): + from vllm.config.kv_transfer import KVTransferConfig + + cfg = KVTransferConfig(kv_parallel_size=good) + assert cfg.kv_parallel_size == good + + +@pytest.mark.parametrize("bad", [-1, 0, 65536, 100000]) +def test_kv_transfer_port_rejects_invalid_range(bad): + """kv_port must be in valid port range [1, 65535].""" + from vllm.config.kv_transfer import KVTransferConfig + + with pytest.raises(ValidationError, match="port"): + KVTransferConfig(kv_port=bad) + + +@pytest.mark.parametrize("good", [1, 80, 443, 8080, 14579, 65535]) +def test_kv_transfer_port_accepts_valid_range(good): + from vllm.config.kv_transfer import KVTransferConfig + + cfg = KVTransferConfig(kv_port=good) + assert cfg.kv_port == good + + +@pytest.mark.parametrize("bad", [-1.0, 0.0, -100.0]) +def test_ec_transfer_buffer_size_rejects_non_positive(bad): + """ec_buffer_size must be positive (> 0).""" + from vllm.config.ec_transfer import ECTransferConfig + + with pytest.raises(ValidationError, match="ec_buffer_size"): + ECTransferConfig(ec_buffer_size=bad) + + +@pytest.mark.parametrize("good", [1.0, 1e9, 1e10]) +def test_ec_transfer_buffer_size_accepts_positive(good): + from vllm.config.ec_transfer import ECTransferConfig + + cfg = ECTransferConfig(ec_buffer_size=good) + assert cfg.ec_buffer_size == good + + +@pytest.mark.parametrize("bad", [-1, -5]) +def test_ec_transfer_rank_rejects_negative(bad): + """ec_rank must be non-negative (>= 0) when set.""" + from vllm.config.ec_transfer import ECTransferConfig + + with pytest.raises(ValidationError, match="ec_rank"): + ECTransferConfig(ec_rank=bad) + + +@pytest.mark.parametrize("good", [0, 1, 2]) +def test_ec_transfer_rank_accepts_non_negative(good): + from vllm.config.ec_transfer import ECTransferConfig + + cfg = ECTransferConfig(ec_rank=good) + assert cfg.ec_rank == good + + +@pytest.mark.parametrize("bad", [-1, 0]) +def test_ec_transfer_parallel_size_rejects_non_positive(bad): + """ec_parallel_size must be positive (> 0).""" + from vllm.config.ec_transfer import ECTransferConfig + + with pytest.raises(ValidationError, match="ec_parallel_size"): + ECTransferConfig(ec_parallel_size=bad) + + +@pytest.mark.parametrize("good", [1, 2, 4]) +def test_ec_transfer_parallel_size_accepts_positive(good): + from vllm.config.ec_transfer import ECTransferConfig + + cfg = ECTransferConfig(ec_parallel_size=good) + assert cfg.ec_parallel_size == good + + +@pytest.mark.parametrize("bad", [-1, 0, 65536, 100000]) +def test_ec_transfer_port_rejects_invalid_range(bad): + """ec_port must be in valid port range [1, 65535].""" + from vllm.config.ec_transfer import ECTransferConfig + + with pytest.raises(ValidationError, match="port"): + ECTransferConfig(ec_port=bad) + + +@pytest.mark.parametrize("good", [1, 80, 14579, 65535]) +def test_ec_transfer_port_accepts_valid_range(good): + from vllm.config.ec_transfer import ECTransferConfig + + cfg = ECTransferConfig(ec_port=good) + assert cfg.ec_port == good diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 52ce9f102a6c..15bf19bf6caa 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -109,3 +109,42 @@ def validate_mla_prefill_backend_before(cls, value: Any) -> Any: if isinstance(value, str): return MLAPrefillBackendEnum[value.upper()] return value + + @field_validator("flash_attn_max_num_splits_for_cuda_graph", mode="after") + @classmethod + def _check_flash_attn_max_num_splits(cls, v: int) -> int: + if v <= 0: + raise ValueError( + f"flash_attn_max_num_splits_for_cuda_graph must be " + f"positive (> 0), got {v}." + ) + return v + + @field_validator("tq_max_kv_splits_for_cuda_graph", mode="after") + @classmethod + def _check_tq_max_kv_splits(cls, v: int) -> int: + if v <= 0: + raise ValueError( + f"tq_max_kv_splits_for_cuda_graph must be positive (> 0), got {v}." + ) + return v + + @field_validator("flex_attn_block_m", "flex_attn_block_n", mode="after") + @classmethod + def _check_flex_attn_block_size(cls, v: int | None) -> int | None: + if v is not None: + if v < 16: + raise ValueError(f"flex_attn block size must be >= 16, got {v}.") + # Check if power of 2 + if v & (v - 1) != 0: + raise ValueError(f"flex_attn block size must be a power of 2, got {v}.") + return v + + @field_validator("flex_attn_q_block_size", "flex_attn_kv_block_size", mode="after") + @classmethod + def _check_flex_attn_logical_block_size(cls, v: int | None) -> int | None: + if v is not None and v & (v - 1) != 0: + raise ValueError( + f"flex_attn logical block size must be a power of 2, got {v}." + ) + return v diff --git a/vllm/config/ec_transfer.py b/vllm/config/ec_transfer.py index a3a927d51ec4..6aec5bae1f46 100644 --- a/vllm/config/ec_transfer.py +++ b/vllm/config/ec_transfer.py @@ -5,6 +5,8 @@ from dataclasses import field from typing import Any, Literal, get_args +from pydantic import field_validator + from vllm.config.utils import config ECProducer = Literal["ec_producer", "ec_both"] @@ -75,6 +77,36 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str + @field_validator("ec_buffer_size", mode="after") + @classmethod + def _check_ec_buffer_size(cls, v: float) -> float: + if v <= 0: + raise ValueError(f"ec_buffer_size must be positive (> 0), got {v}.") + return v + + @field_validator("ec_rank", mode="after") + @classmethod + def _check_ec_rank(cls, v: int | None) -> int | None: + if v is not None and v < 0: + raise ValueError(f"ec_rank must be non-negative (>= 0) when set, got {v}.") + return v + + @field_validator("ec_parallel_size", mode="after") + @classmethod + def _check_ec_parallel_size(cls, v: int) -> int: + if v <= 0: + raise ValueError(f"ec_parallel_size must be positive (> 0), got {v}.") + return v + + @field_validator("ec_port", mode="after") + @classmethod + def _check_ec_port(cls, v: int) -> int: + if not (1 <= v <= 65535): + raise ValueError( + f"ec_port must be in valid port range [1, 65535], got {v}." + ) + return v + def __post_init__(self) -> None: if self.engine_id is None: self.engine_id = str(uuid.uuid4()) diff --git a/vllm/config/kv_events.py b/vllm/config/kv_events.py index d618bc9a73f3..cefbc271b09d 100644 --- a/vllm/config/kv_events.py +++ b/vllm/config/kv_events.py @@ -4,6 +4,8 @@ from typing import Literal +from pydantic import field_validator + from vllm.config.utils import config @@ -47,6 +49,27 @@ class KVEventsConfig: this topic to receive events. """ + @field_validator("buffer_steps", mode="after") + @classmethod + def _check_buffer_steps(cls, v: int) -> int: + if v <= 0: + raise ValueError(f"buffer_steps must be positive (> 0), got {v}.") + return v + + @field_validator("hwm", mode="after") + @classmethod + def _check_hwm(cls, v: int) -> int: + if v <= 0: + raise ValueError(f"hwm (high water mark) must be positive (> 0), got {v}.") + return v + + @field_validator("max_queue_size", mode="after") + @classmethod + def _check_max_queue_size(cls, v: int) -> int: + if v <= 0: + raise ValueError(f"max_queue_size must be positive (> 0), got {v}.") + return v + def __post_init__(self): if self.publisher is None: self.publisher = "zmq" if self.enable_kv_cache_events else "null" diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py index b22af99f703f..0b93f3da9e7f 100644 --- a/vllm/config/kv_transfer.py +++ b/vllm/config/kv_transfer.py @@ -5,6 +5,8 @@ from dataclasses import field from typing import Any, Literal, get_args +from pydantic import field_validator + from vllm.config.utils import config from vllm.utils.hashing import safe_hash @@ -90,6 +92,36 @@ def compute_hash(self) -> str: hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str + @field_validator("kv_buffer_size", mode="after") + @classmethod + def _check_kv_buffer_size(cls, v: float) -> float: + if v <= 0: + raise ValueError(f"kv_buffer_size must be positive (> 0), got {v}.") + return v + + @field_validator("kv_rank", mode="after") + @classmethod + def _check_kv_rank(cls, v: int | None) -> int | None: + if v is not None and v < 0: + raise ValueError(f"kv_rank must be non-negative (>= 0) when set, got {v}.") + return v + + @field_validator("kv_parallel_size", mode="after") + @classmethod + def _check_kv_parallel_size(cls, v: int) -> int: + if v <= 0: + raise ValueError(f"kv_parallel_size must be positive (> 0), got {v}.") + return v + + @field_validator("kv_port", mode="after") + @classmethod + def _check_kv_port(cls, v: int) -> int: + if not (1 <= v <= 65535): + raise ValueError( + f"kv_port must be in valid port range [1, 65535], got {v}." + ) + return v + def __post_init__(self) -> None: if self.engine_id is None: self.engine_id = str(uuid.uuid4())