Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 272 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,3 +1487,275 @@ def test_ir_op_priority_ctx():
# context restored even after exception
assert ir.ops.rms_norm.get_priority() == ["vllm_c", "native"]
assert ir.ops.fused_add_rms_norm.get_priority() == ["native"]


# Config validation tests for field validators added to prevent invalid values


@pytest.mark.parametrize("bad", [-1, -5, -100, 0])
def test_attention_flash_attn_max_num_splits_rejects_non_positive(bad):
"""flash_attn_max_num_splits_for_cuda_graph must be positive (> 0).
Negative or zero values would cause issues in CUDA graph setup."""
from vllm.config.attention import AttentionConfig

with pytest.raises(ValidationError, match="flash_attn_max_num_splits"):
AttentionConfig(flash_attn_max_num_splits_for_cuda_graph=bad)


@pytest.mark.parametrize("good", [1, 16, 32, 64])
def test_attention_flash_attn_max_num_splits_accepts_positive(good):
from vllm.config.attention import AttentionConfig

cfg = AttentionConfig(flash_attn_max_num_splits_for_cuda_graph=good)
assert cfg.flash_attn_max_num_splits_for_cuda_graph == good


@pytest.mark.parametrize("bad", [-1, -5, 0])
def test_attention_tq_max_kv_splits_rejects_non_positive(bad):
"""tq_max_kv_splits_for_cuda_graph must be positive (> 0)."""
from vllm.config.attention import AttentionConfig

with pytest.raises(ValidationError, match="tq_max_kv_splits"):
AttentionConfig(tq_max_kv_splits_for_cuda_graph=bad)


@pytest.mark.parametrize("good", [1, 16, 32, 64])
def test_attention_tq_max_kv_splits_accepts_positive(good):
from vllm.config.attention import AttentionConfig

cfg = AttentionConfig(tq_max_kv_splits_for_cuda_graph=good)
assert cfg.tq_max_kv_splits_for_cuda_graph == good


@pytest.mark.parametrize("bad", [0, 8, 15, 17, 31, 33])
def test_attention_flex_attn_block_size_rejects_invalid(bad):
"""flex_attn_block_m/n must be >= 16 and power of 2."""
from vllm.config.attention import AttentionConfig

with pytest.raises(ValidationError, match="flex_attn"):
AttentionConfig(flex_attn_block_m=bad)
with pytest.raises(ValidationError, match="flex_attn"):
AttentionConfig(flex_attn_block_n=bad)


@pytest.mark.parametrize("good", [16, 32, 64, 128, 256])
def test_attention_flex_attn_block_size_accepts_valid(good):
from vllm.config.attention import AttentionConfig

cfg = AttentionConfig(flex_attn_block_m=good, flex_attn_block_n=good)
assert cfg.flex_attn_block_m == good
assert cfg.flex_attn_block_n == good


@pytest.mark.parametrize("bad", [0, 3, 7, 15, 17, 31, 33])
def test_attention_flex_attn_logical_block_size_rejects_non_power_of_2(bad):
"""flex_attn_q_block_size/kv_block_size must be power of 2."""
from vllm.config.attention import AttentionConfig

with pytest.raises(ValidationError, match="power of 2"):
AttentionConfig(flex_attn_q_block_size=bad)
with pytest.raises(ValidationError, match="power of 2"):
AttentionConfig(flex_attn_kv_block_size=bad)


@pytest.mark.parametrize("good", [1, 2, 4, 8, 16, 32, 64, 128])
def test_attention_flex_attn_logical_block_size_accepts_power_of_2(good):
from vllm.config.attention import AttentionConfig

cfg = AttentionConfig(flex_attn_q_block_size=good, flex_attn_kv_block_size=good)
assert cfg.flex_attn_q_block_size == good
assert cfg.flex_attn_kv_block_size == good


@pytest.mark.parametrize("bad", [-1, -100, 0])
def test_kv_events_buffer_steps_rejects_non_positive(bad):
"""buffer_steps must be positive (> 0)."""
from vllm.config.kv_events import KVEventsConfig

with pytest.raises(ValidationError, match="buffer_steps"):
KVEventsConfig(buffer_steps=bad)


@pytest.mark.parametrize("good", [1, 100, 10_000, 100_000])
def test_kv_events_buffer_steps_accepts_positive(good):
from vllm.config.kv_events import KVEventsConfig

cfg = KVEventsConfig(buffer_steps=good)
assert cfg.buffer_steps == good


@pytest.mark.parametrize("bad", [-1, -100, 0])
def test_kv_events_hwm_rejects_non_positive(bad):
"""hwm (high water mark) must be positive (> 0)."""
from vllm.config.kv_events import KVEventsConfig

with pytest.raises(ValidationError, match="hwm"):
KVEventsConfig(hwm=bad)


@pytest.mark.parametrize("good", [1, 1000, 100_000])
def test_kv_events_hwm_accepts_positive(good):
from vllm.config.kv_events import KVEventsConfig

cfg = KVEventsConfig(hwm=good)
assert cfg.hwm == good


@pytest.mark.parametrize("bad", [-1, 0])
def test_kv_events_max_queue_size_rejects_non_positive(bad):
"""max_queue_size must be positive (> 0)."""
from vllm.config.kv_events import KVEventsConfig

with pytest.raises(ValidationError, match="max_queue_size"):
KVEventsConfig(max_queue_size=bad)


@pytest.mark.parametrize("good", [1, 1000, 100_000])
def test_kv_events_max_queue_size_accepts_positive(good):
from vllm.config.kv_events import KVEventsConfig

cfg = KVEventsConfig(max_queue_size=good)
assert cfg.max_queue_size == good


@pytest.mark.parametrize("bad", [-1.0, -100.0, 0.0, -0.1])
def test_kv_transfer_buffer_size_rejects_non_positive(bad):
"""kv_buffer_size must be positive (> 0)."""
from vllm.config.kv_transfer import KVTransferConfig

with pytest.raises(ValidationError, match="kv_buffer_size"):
KVTransferConfig(kv_buffer_size=bad)


@pytest.mark.parametrize("good", [1.0, 1e9, 1e10])
def test_kv_transfer_buffer_size_accepts_positive(good):
from vllm.config.kv_transfer import KVTransferConfig

cfg = KVTransferConfig(kv_buffer_size=good)
assert cfg.kv_buffer_size == good


@pytest.mark.parametrize("bad", [-1, -5, -100])
def test_kv_transfer_rank_rejects_negative(bad):
"""kv_rank must be non-negative (>= 0) when set."""
from vllm.config.kv_transfer import KVTransferConfig

with pytest.raises(ValidationError, match="kv_rank"):
KVTransferConfig(kv_rank=bad)


@pytest.mark.parametrize("good", [0, 1, 2, 10])
def test_kv_transfer_rank_accepts_non_negative(good):
from vllm.config.kv_transfer import KVTransferConfig

cfg = KVTransferConfig(kv_rank=good)
assert cfg.kv_rank == good


def test_kv_transfer_rank_accepts_none():
from vllm.config.kv_transfer import KVTransferConfig

cfg = KVTransferConfig(kv_rank=None)
assert cfg.kv_rank is None


@pytest.mark.parametrize("bad", [-1, 0])
def test_kv_transfer_parallel_size_rejects_non_positive(bad):
"""kv_parallel_size must be positive (> 0)."""
from vllm.config.kv_transfer import KVTransferConfig

with pytest.raises(ValidationError, match="kv_parallel_size"):
KVTransferConfig(kv_parallel_size=bad)


@pytest.mark.parametrize("good", [1, 2, 4, 8])
def test_kv_transfer_parallel_size_accepts_positive(good):
from vllm.config.kv_transfer import KVTransferConfig

cfg = KVTransferConfig(kv_parallel_size=good)
assert cfg.kv_parallel_size == good


@pytest.mark.parametrize("bad", [-1, 0, 65536, 100000])
def test_kv_transfer_port_rejects_invalid_range(bad):
"""kv_port must be in valid port range [1, 65535]."""
from vllm.config.kv_transfer import KVTransferConfig

with pytest.raises(ValidationError, match="port"):
KVTransferConfig(kv_port=bad)


@pytest.mark.parametrize("good", [1, 80, 443, 8080, 14579, 65535])
def test_kv_transfer_port_accepts_valid_range(good):
from vllm.config.kv_transfer import KVTransferConfig

cfg = KVTransferConfig(kv_port=good)
assert cfg.kv_port == good


@pytest.mark.parametrize("bad", [-1.0, 0.0, -100.0])
def test_ec_transfer_buffer_size_rejects_non_positive(bad):
"""ec_buffer_size must be positive (> 0)."""
from vllm.config.ec_transfer import ECTransferConfig

with pytest.raises(ValidationError, match="ec_buffer_size"):
ECTransferConfig(ec_buffer_size=bad)


@pytest.mark.parametrize("good", [1.0, 1e9, 1e10])
def test_ec_transfer_buffer_size_accepts_positive(good):
from vllm.config.ec_transfer import ECTransferConfig

cfg = ECTransferConfig(ec_buffer_size=good)
assert cfg.ec_buffer_size == good


@pytest.mark.parametrize("bad", [-1, -5])
def test_ec_transfer_rank_rejects_negative(bad):
"""ec_rank must be non-negative (>= 0) when set."""
from vllm.config.ec_transfer import ECTransferConfig

with pytest.raises(ValidationError, match="ec_rank"):
ECTransferConfig(ec_rank=bad)


@pytest.mark.parametrize("good", [0, 1, 2])
def test_ec_transfer_rank_accepts_non_negative(good):
from vllm.config.ec_transfer import ECTransferConfig

cfg = ECTransferConfig(ec_rank=good)
assert cfg.ec_rank == good


@pytest.mark.parametrize("bad", [-1, 0])
def test_ec_transfer_parallel_size_rejects_non_positive(bad):
"""ec_parallel_size must be positive (> 0)."""
from vllm.config.ec_transfer import ECTransferConfig

with pytest.raises(ValidationError, match="ec_parallel_size"):
ECTransferConfig(ec_parallel_size=bad)


@pytest.mark.parametrize("good", [1, 2, 4])
def test_ec_transfer_parallel_size_accepts_positive(good):
from vllm.config.ec_transfer import ECTransferConfig

cfg = ECTransferConfig(ec_parallel_size=good)
assert cfg.ec_parallel_size == good


@pytest.mark.parametrize("bad", [-1, 0, 65536, 100000])
def test_ec_transfer_port_rejects_invalid_range(bad):
"""ec_port must be in valid port range [1, 65535]."""
from vllm.config.ec_transfer import ECTransferConfig

with pytest.raises(ValidationError, match="port"):
ECTransferConfig(ec_port=bad)


@pytest.mark.parametrize("good", [1, 80, 14579, 65535])
def test_ec_transfer_port_accepts_valid_range(good):
from vllm.config.ec_transfer import ECTransferConfig

cfg = ECTransferConfig(ec_port=good)
assert cfg.ec_port == good
39 changes: 39 additions & 0 deletions vllm/config/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,42 @@ def validate_mla_prefill_backend_before(cls, value: Any) -> Any:
if isinstance(value, str):
return MLAPrefillBackendEnum[value.upper()]
return value

@field_validator("flash_attn_max_num_splits_for_cuda_graph", mode="after")
@classmethod
def _check_flash_attn_max_num_splits(cls, v: int) -> int:
if v <= 0:
raise ValueError(
f"flash_attn_max_num_splits_for_cuda_graph must be "
f"positive (> 0), got {v}."
)
return v

@field_validator("tq_max_kv_splits_for_cuda_graph", mode="after")
@classmethod
def _check_tq_max_kv_splits(cls, v: int) -> int:
if v <= 0:
raise ValueError(
f"tq_max_kv_splits_for_cuda_graph must be positive (> 0), got {v}."
)
return v

@field_validator("flex_attn_block_m", "flex_attn_block_n", mode="after")
@classmethod
def _check_flex_attn_block_size(cls, v: int | None) -> int | None:
if v is not None:
if v < 16:
raise ValueError(f"flex_attn block size must be >= 16, got {v}.")
# Check if power of 2
if v & (v - 1) != 0:
raise ValueError(f"flex_attn block size must be a power of 2, got {v}.")
return v

@field_validator("flex_attn_q_block_size", "flex_attn_kv_block_size", mode="after")
@classmethod
def _check_flex_attn_logical_block_size(cls, v: int | None) -> int | None:
if v is not None and v & (v - 1) != 0:
raise ValueError(
f"flex_attn logical block size must be a power of 2, got {v}."
)
return v
32 changes: 32 additions & 0 deletions vllm/config/ec_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dataclasses import field
from typing import Any, Literal, get_args

from pydantic import field_validator

from vllm.config.utils import config

ECProducer = Literal["ec_producer", "ec_both"]
Expand Down Expand Up @@ -75,6 +77,36 @@ def compute_hash(self) -> str:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str

@field_validator("ec_buffer_size", mode="after")
@classmethod
def _check_ec_buffer_size(cls, v: float) -> float:
if v <= 0:
raise ValueError(f"ec_buffer_size must be positive (> 0), got {v}.")
return v

@field_validator("ec_rank", mode="after")
@classmethod
def _check_ec_rank(cls, v: int | None) -> int | None:
if v is not None and v < 0:
raise ValueError(f"ec_rank must be non-negative (>= 0) when set, got {v}.")
return v

@field_validator("ec_parallel_size", mode="after")
@classmethod
def _check_ec_parallel_size(cls, v: int) -> int:
if v <= 0:
raise ValueError(f"ec_parallel_size must be positive (> 0), got {v}.")
return v

@field_validator("ec_port", mode="after")
@classmethod
def _check_ec_port(cls, v: int) -> int:
if not (1 <= v <= 65535):
raise ValueError(
f"ec_port must be in valid port range [1, 65535], got {v}."
)
return v

def __post_init__(self) -> None:
if self.engine_id is None:
self.engine_id = str(uuid.uuid4())
Expand Down
Loading
Loading