From 71adc02d1bbaf4c0044c9df5bdd53676383b7b9f Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Thu, 9 Oct 2025 11:36:21 -0700 Subject: [PATCH 01/14] [Feature] Pydantic validation for scheduler.py and structured_outputs.py Signed-off-by: Vinay Damodaran --- vllm/config/scheduler.py | 56 ++++++++++--------------------- vllm/config/structured_outputs.py | 7 ++-- 2 files changed, 22 insertions(+), 41 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 396258aac287..968df3ed08bc 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -2,10 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib -from dataclasses import InitVar, field from typing import Any, Literal, Union -from pydantic import SkipValidation, model_validator +from pydantic import Field, SkipValidation, model_validator from pydantic.dataclasses import dataclass from typing_extensions import Self @@ -37,22 +36,22 @@ class SchedulerConfig: This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_num_seqs: SkipValidation[int] = None # type: ignore + max_num_seqs: int = 128 """Maximum number of sequences to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_model_len: SkipValidation[int] = None # type: ignore + max_model_len: int = 8192 """Maximum length of a sequence (including prompt and generated text). This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" - max_num_partial_prefills: int = 1 + max_num_partial_prefills: int = Field(default=1, ge=1) """For chunked prefill, the maximum number of sequences that can be partially prefilled concurrently.""" - max_long_partial_prefills: int = 1 + max_long_partial_prefills: int = Field(default=1, ge=1) """For chunked prefill, the maximum number of prompts longer than long_prefill_token_threshold that will be prefilled concurrently. Setting this less than max_num_partial_prefills will allow shorter prompts to jump @@ -62,7 +61,7 @@ class SchedulerConfig: """For chunked prefill, a request is considered long if the prompt is longer than this number of tokens.""" - num_lookahead_slots: int = 0 + num_lookahead_slots: int = Field(default=0, ge=0) """The number of slots to allocate per sequence per step, beyond the known token ids. This is used in speculative decoding to store KV activations of tokens which may or may not be @@ -71,7 +70,7 @@ class SchedulerConfig: NOTE: This will be replaced by speculative config in the future; it is present to enable correctness tests until then.""" - cuda_graph_sizes: list[int] = field(default_factory=list) + cuda_graph_sizes: list[int] = Field(default_factory=list) """Cuda graph capture sizes 1. if none provided, then default set to [min(max_num_seqs * 2, 512)] 2. if one value is provided, then the capture list would follow the @@ -86,7 +85,7 @@ class SchedulerConfig: is_multimodal_model: bool = False """True if the model is multimodal.""" - is_encoder_decoder: InitVar[bool] = False + is_encoder_decoder: bool = False """True if the model is an encoder-decoder model. Note: This is stored in the ModelConfig, and is used only here to @@ -94,14 +93,14 @@ class SchedulerConfig: """ # TODO (ywang96): Make this configurable. - max_num_encoder_input_tokens: int = field(init=False) + max_num_encoder_input_tokens: int = Field(init=False) """Multimodal encoder compute budget, only used in V1. NOTE: This is not currently configurable. It will be overridden by max_num_batched_tokens in case max multimodal embedding size is larger.""" # TODO (ywang96): Make this configurable. - encoder_cache_size: int = field(init=False) + encoder_cache_size: int = Field(init=False) """Multimodal encoder cache size, only used in V1. NOTE: This is not currently configurable. It will be overridden by @@ -120,7 +119,7 @@ class SchedulerConfig: - "priority" means requests are handled based on given priority (lower value means earlier handling) and time of arrival deciding any ties).""" - chunked_prefill_enabled: bool = field(init=False) + chunked_prefill_enabled: bool = Field(init=False) """True if chunked prefill is enabled.""" disable_chunked_mm_input: bool = False @@ -169,14 +168,9 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - def __post_init__(self, is_encoder_decoder: bool) -> None: - if self.max_model_len is None: - self.max_model_len = 8192 - - if self.max_num_seqs is None: - self.max_num_seqs = 128 - - if is_encoder_decoder: + @model_validator(mode="after") + def _validate_scheduler_config(self) -> Self: + if self.is_encoder_decoder: # Chunked prefill should be disabled for encoder-decoder models. self.disable_chunked_mm_input = True self.chunked_prefill_enabled = False @@ -251,8 +245,6 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: if self.async_scheduling: self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler" - @model_validator(mode="after") - def _verify_args(self) -> Self: if ( self.max_num_batched_tokens < self.max_model_len and not self.chunked_prefill_enabled @@ -281,19 +273,7 @@ def _verify_args(self) -> Self: self.max_num_seqs * self.max_model_len, ) - if self.num_lookahead_slots < 0: - raise ValueError( - "num_lookahead_slots " - f"({self.num_lookahead_slots}) must be greater than or " - "equal to 0." - ) - - if self.max_num_partial_prefills < 1: - raise ValueError( - f"max_num_partial_prefills ({self.max_num_partial_prefills}) " - "must be greater than or equal to 1." - ) - elif self.max_num_partial_prefills > 1: + if self.max_num_partial_prefills > 1: if not self.chunked_prefill_enabled: raise ValueError( "Chunked prefill must be enabled to set " @@ -307,12 +287,10 @@ def _verify_args(self) -> Self: f"than the max_model_len ({self.max_model_len})." ) - if (self.max_long_partial_prefills < 1) or ( - self.max_long_partial_prefills > self.max_num_partial_prefills - ): + if self.max_long_partial_prefills > self.max_num_partial_prefills: raise ValueError( f"max_long_partial_prefills ({self.max_long_partial_prefills}) " - "must be greater than or equal to 1 and less than or equal to " + "must be less than or equal to " f"max_num_partial_prefills ({self.max_num_partial_prefills})." ) diff --git a/vllm/config/structured_outputs.py b/vllm/config/structured_outputs.py index 5111c9c77d90..13713e07a3b6 100644 --- a/vllm/config/structured_outputs.py +++ b/vllm/config/structured_outputs.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib -from typing import Any, Literal +from typing import Any, Literal, Self +from pydantic import model_validator from pydantic.dataclasses import dataclass from vllm.config.utils import config @@ -54,7 +55,8 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - def __post_init__(self): + @model_validator(mode="after") + def _validate_structured_output_config(self) -> Self: if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"): raise ValueError( "disable_any_whitespace is only supported for " @@ -65,3 +67,4 @@ def __post_init__(self): "disable_additional_properties is only supported " "for the guidance backend." ) + return self From 7de7d59a76c2cfacac92ce6e60b9ed25e36d885f Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 10 Oct 2025 20:39:10 +0200 Subject: [PATCH 02/14] Fix `is_encoder_decoder` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/scheduler.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 968df3ed08bc..701ef1b02f86 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -85,7 +85,7 @@ class SchedulerConfig: is_multimodal_model: bool = False """True if the model is multimodal.""" - is_encoder_decoder: bool = False + is_encoder_decoder: bool = Field(default=False, init_var=True) """True if the model is an encoder-decoder model. Note: This is stored in the ModelConfig, and is used only here to @@ -168,9 +168,9 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - @model_validator(mode="after") - def _validate_scheduler_config(self) -> Self: - if self.is_encoder_decoder: + def __post_init__(self, is_encoder_decoder: bool) -> None: + """Post init to handle init vars.""" + if is_encoder_decoder: # Chunked prefill should be disabled for encoder-decoder models. self.disable_chunked_mm_input = True self.chunked_prefill_enabled = False @@ -181,6 +181,8 @@ def _validate_scheduler_config(self) -> Self: " prefix caching; disabling both." ) + @model_validator(mode="after") + def _validate_scheduler_config(self) -> Self: if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS From b6f2a06a7c88b6ad8f6cfa1be7ea15e80cf934bf Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 10 Oct 2025 20:56:23 +0200 Subject: [PATCH 03/14] Fix deferred defaults Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/scheduler.py | 26 +++++++++++++++++++++----- vllm/engine/arg_utils.py | 12 +++++++----- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 701ef1b02f86..79463ce4cd72 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -4,7 +4,7 @@ import hashlib from typing import Any, Literal, Union -from pydantic import Field, SkipValidation, model_validator +from pydantic import Field, field_validator, model_validator from pydantic.dataclasses import dataclass from typing_extensions import Self @@ -30,19 +30,19 @@ class SchedulerConfig: runner_type: RunnerType = "generate" """The runner type to launch for the model.""" - max_num_batched_tokens: SkipValidation[int] = None # type: ignore + max_num_batched_tokens: int = Field(default=None, ge=1) """Maximum number of tokens to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_num_seqs: int = 128 + max_num_seqs: int = Field(default=None, validate_default=True, ge=1) """Maximum number of sequences to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_model_len: int = 8192 + max_model_len: int = Field(default=None, validate_default=True, ge=1) """Maximum length of a sequence (including prompt and generated text). This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" @@ -78,7 +78,7 @@ class SchedulerConfig: 3. more than one value (e.g. 1 2 128) is provided, then the capture list will follow the provided list.""" - enable_chunked_prefill: SkipValidation[bool] = None # type: ignore + enable_chunked_prefill: bool = Field(default=None) """If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.""" @@ -168,6 +168,22 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str + @field_validator("max_num_seqs", mode="before") + @classmethod + def _validate_max_num_seqs(cls, max_num_seqs: Any | None) -> Any: + if max_num_seqs is None: + logger.warning("max_num_seqs is not set, using arbitrary value 128.") + return 128 + return max_num_seqs + + @field_validator("max_model_len", mode="before") + @classmethod + def _validate_max_model_len(cls, max_model_len: Any | None) -> Any: + if max_model_len is None: + logger.warning("max_model_len is not set, using arbitrary value 8192.") + return 8192 + return max_model_len + def __post_init__(self, is_encoder_decoder: bool) -> None: """Post init to handle init vars.""" if is_encoder_decoder: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7e66d8dba8ac..5f5ae288b67d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -347,7 +347,7 @@ class EngineArgs: dtype: ModelDType = ModelConfig.dtype kv_cache_dtype: CacheDType = CacheConfig.cache_dtype seed: Optional[int] = ModelConfig.seed - max_model_len: Optional[int] = ModelConfig.max_model_len + max_model_len: int = get_field(ModelConfig, "max_model_len") cuda_graph_sizes: list[int] = get_field(SchedulerConfig, "cuda_graph_sizes") # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without @@ -399,11 +399,11 @@ class EngineArgs: cpu_offload_gb: float = CacheConfig.cpu_offload_gb gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes - max_num_batched_tokens: Optional[int] = SchedulerConfig.max_num_batched_tokens + max_num_batched_tokens: int = get_field(SchedulerConfig, "max_num_batched_tokens") max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold - max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs + max_num_seqs: int = get_field(SchedulerConfig, "max_num_seqs") max_logprobs: int = ModelConfig.max_logprobs logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode disable_log_stats: bool = False @@ -454,7 +454,9 @@ class EngineArgs: model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config") ignore_patterns: Union[str, list[str]] = get_field(LoadConfig, "ignore_patterns") - enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill + enable_chunked_prefill: Optional[bool] = get_field( + SchedulerConfig, "enable_chunked_prefill" + ) disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input disable_hybrid_kv_cache_manager: bool = ( @@ -1718,7 +1720,7 @@ def _set_default_args( incremental_prefill_supported = ( pooling_type is not None and pooling_type.lower() == "last" - and is_causal + and bool(is_causal) ) action = "Enabling" if incremental_prefill_supported else "Disabling" From 5984a27e7aeb7780942275ab8dc4e140697ff81b Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 10 Oct 2025 21:04:12 +0200 Subject: [PATCH 04/14] Use `InitVar` to fix docs build Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 79463ce4cd72..60646e3338f6 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib +from dataclasses import InitVar from typing import Any, Literal, Union from pydantic import Field, field_validator, model_validator @@ -85,7 +86,7 @@ class SchedulerConfig: is_multimodal_model: bool = False """True if the model is multimodal.""" - is_encoder_decoder: bool = Field(default=False, init_var=True) + is_encoder_decoder: InitVar[bool] = False """True if the model is an encoder-decoder model. Note: This is stored in the ModelConfig, and is used only here to From 4e261986072b7bf20cb1b83aa31914577f2d175c Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Fri, 10 Oct 2025 16:55:11 -0700 Subject: [PATCH 05/14] Fixing tests Signed-off-by: Vinay Damodaran --- vllm/config/scheduler.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 60646e3338f6..1b2bec2b64db 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -37,13 +37,13 @@ class SchedulerConfig: This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_num_seqs: int = Field(default=None, validate_default=True, ge=1) + max_num_seqs: int = Field(default=None, ge=1) """Maximum number of sequences to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_model_len: int = Field(default=None, validate_default=True, ge=1) + max_model_len: int = Field(default=None, ge=1) """Maximum length of a sequence (including prompt and generated text). This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" @@ -171,7 +171,7 @@ def compute_hash(self) -> str: @field_validator("max_num_seqs", mode="before") @classmethod - def _validate_max_num_seqs(cls, max_num_seqs: Any | None) -> Any: + def _validate_max_num_seqs(cls, max_num_seqs: int | None) -> int: if max_num_seqs is None: logger.warning("max_num_seqs is not set, using arbitrary value 128.") return 128 @@ -179,7 +179,7 @@ def _validate_max_num_seqs(cls, max_num_seqs: Any | None) -> Any: @field_validator("max_model_len", mode="before") @classmethod - def _validate_max_model_len(cls, max_model_len: Any | None) -> Any: + def _validate_max_model_len(cls, max_model_len: int | None) -> int: if max_model_len is None: logger.warning("max_model_len is not set, using arbitrary value 8192.") return 8192 @@ -292,12 +292,10 @@ def _validate_scheduler_config(self) -> Self: self.max_num_seqs * self.max_model_len, ) - if self.max_num_partial_prefills > 1: - if not self.chunked_prefill_enabled: - raise ValueError( - "Chunked prefill must be enabled to set " - "max_num_partial_prefills > 1." - ) + if self.max_num_partial_prefills > 1 and not self.chunked_prefill_enabled: + raise ValueError( + "Chunked prefill must be enabled to set max_num_partial_prefills > 1." + ) if self.long_prefill_token_threshold > self.max_model_len: raise ValueError( From cf2ce2ee0aa55152a95ed023e2f5fae68e83f11a Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Fri, 10 Oct 2025 17:19:31 -0700 Subject: [PATCH 06/14] Trying a before model validator Signed-off-by: Vinay Damodaran --- vllm/config/scheduler.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 1b2bec2b64db..805e281e2cc5 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -5,7 +5,7 @@ from dataclasses import InitVar from typing import Any, Literal, Union -from pydantic import Field, field_validator, model_validator +from pydantic import Field, model_validator from pydantic.dataclasses import dataclass from typing_extensions import Self @@ -169,21 +169,17 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - @field_validator("max_num_seqs", mode="before") - @classmethod - def _validate_max_num_seqs(cls, max_num_seqs: int | None) -> int: - if max_num_seqs is None: + @model_validator(mode="before") + def _set_defaults(self) -> Self: + if self.max_num_seqs is None: logger.warning("max_num_seqs is not set, using arbitrary value 128.") - return 128 - return max_num_seqs + self.max_num_seqs = 128 - @field_validator("max_model_len", mode="before") - @classmethod - def _validate_max_model_len(cls, max_model_len: int | None) -> int: - if max_model_len is None: + if self.max_model_len is None: logger.warning("max_model_len is not set, using arbitrary value 8192.") - return 8192 - return max_model_len + self.max_model_len = 8192 + + return self def __post_init__(self, is_encoder_decoder: bool) -> None: """Post init to handle init vars.""" From 5f19b02d85d95e6044bc7b584383290c093d7b53 Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Fri, 10 Oct 2025 17:33:28 -0700 Subject: [PATCH 07/14] Fix before model validator Signed-off-by: Vinay Damodaran --- vllm/config/scheduler.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 805e281e2cc5..fdb6401d9e9a 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -170,16 +170,17 @@ def compute_hash(self) -> str: return hash_str @model_validator(mode="before") - def _set_defaults(self) -> Self: - if self.max_num_seqs is None: + @classmethod + def _set_defaults(cls, data: Any) -> Any: + if data.get("max_num_seqs") is None: logger.warning("max_num_seqs is not set, using arbitrary value 128.") - self.max_num_seqs = 128 + data["max_num_seqs"] = 128 - if self.max_model_len is None: + if data.get("max_model_len") is None: logger.warning("max_model_len is not set, using arbitrary value 8192.") - self.max_model_len = 8192 + data["max_model_len"] = 8192 - return self + return data def __post_init__(self, is_encoder_decoder: bool) -> None: """Post init to handle init vars.""" From 35c0611d12252cdb48f6f22912854ab9f73470bb Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Fri, 10 Oct 2025 17:42:20 -0700 Subject: [PATCH 08/14] hail mary Signed-off-by: Vinay Damodaran --- vllm/config/scheduler.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index fdb6401d9e9a..47e0dde195dd 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -169,19 +169,6 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - @model_validator(mode="before") - @classmethod - def _set_defaults(cls, data: Any) -> Any: - if data.get("max_num_seqs") is None: - logger.warning("max_num_seqs is not set, using arbitrary value 128.") - data["max_num_seqs"] = 128 - - if data.get("max_model_len") is None: - logger.warning("max_model_len is not set, using arbitrary value 8192.") - data["max_model_len"] = 8192 - - return data - def __post_init__(self, is_encoder_decoder: bool) -> None: """Post init to handle init vars.""" if is_encoder_decoder: @@ -197,6 +184,14 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: @model_validator(mode="after") def _validate_scheduler_config(self) -> Self: + if self.max_num_seqs is None: + logger.warning("max_num_seqs is not set, using arbitrary value 128.") + self.max_num_seqs = 128 + + if self.max_model_len is None: + logger.warning("max_model_len is not set, using arbitrary value 8192.") + self.max_model_len = 8192 + if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS From e310ec414ebee5623b3af4de47443e65f7bdfab5 Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Fri, 10 Oct 2025 18:35:07 -0700 Subject: [PATCH 09/14] Adding optionals Signed-off-by: Vinay Damodaran --- vllm/config/scheduler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 47e0dde195dd..78d04e552a05 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -3,7 +3,7 @@ import hashlib from dataclasses import InitVar -from typing import Any, Literal, Union +from typing import Any, Literal, Optional, Union from pydantic import Field, model_validator from pydantic.dataclasses import dataclass @@ -31,19 +31,19 @@ class SchedulerConfig: runner_type: RunnerType = "generate" """The runner type to launch for the model.""" - max_num_batched_tokens: int = Field(default=None, ge=1) + max_num_batched_tokens: Optional[int] = Field(default=None, ge=1) """Maximum number of tokens to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_num_seqs: int = Field(default=None, ge=1) + max_num_seqs: Optional[int] = Field(default=None, ge=1) """Maximum number of sequences to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_model_len: int = Field(default=None, ge=1) + max_model_len: Optional[int] = Field(default=None, ge=1) """Maximum length of a sequence (including prompt and generated text). This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" From 900f6e02f157478fcf5f593ebbe4f34e40eea019 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Sun, 12 Oct 2025 20:37:17 +0100 Subject: [PATCH 10/14] Revert commits after 5984a27e7aeb7780942275ab8dc4e140697ff81b Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/scheduler.py | 44 ++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 78d04e552a05..60646e3338f6 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -3,9 +3,9 @@ import hashlib from dataclasses import InitVar -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Union -from pydantic import Field, model_validator +from pydantic import Field, field_validator, model_validator from pydantic.dataclasses import dataclass from typing_extensions import Self @@ -31,19 +31,19 @@ class SchedulerConfig: runner_type: RunnerType = "generate" """The runner type to launch for the model.""" - max_num_batched_tokens: Optional[int] = Field(default=None, ge=1) + max_num_batched_tokens: int = Field(default=None, ge=1) """Maximum number of tokens to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_num_seqs: Optional[int] = Field(default=None, ge=1) + max_num_seqs: int = Field(default=None, validate_default=True, ge=1) """Maximum number of sequences to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_model_len: Optional[int] = Field(default=None, ge=1) + max_model_len: int = Field(default=None, validate_default=True, ge=1) """Maximum length of a sequence (including prompt and generated text). This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" @@ -169,6 +169,22 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str + @field_validator("max_num_seqs", mode="before") + @classmethod + def _validate_max_num_seqs(cls, max_num_seqs: Any | None) -> Any: + if max_num_seqs is None: + logger.warning("max_num_seqs is not set, using arbitrary value 128.") + return 128 + return max_num_seqs + + @field_validator("max_model_len", mode="before") + @classmethod + def _validate_max_model_len(cls, max_model_len: Any | None) -> Any: + if max_model_len is None: + logger.warning("max_model_len is not set, using arbitrary value 8192.") + return 8192 + return max_model_len + def __post_init__(self, is_encoder_decoder: bool) -> None: """Post init to handle init vars.""" if is_encoder_decoder: @@ -184,14 +200,6 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: @model_validator(mode="after") def _validate_scheduler_config(self) -> Self: - if self.max_num_seqs is None: - logger.warning("max_num_seqs is not set, using arbitrary value 128.") - self.max_num_seqs = 128 - - if self.max_model_len is None: - logger.warning("max_model_len is not set, using arbitrary value 8192.") - self.max_model_len = 8192 - if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS @@ -284,10 +292,12 @@ def _validate_scheduler_config(self) -> Self: self.max_num_seqs * self.max_model_len, ) - if self.max_num_partial_prefills > 1 and not self.chunked_prefill_enabled: - raise ValueError( - "Chunked prefill must be enabled to set max_num_partial_prefills > 1." - ) + if self.max_num_partial_prefills > 1: + if not self.chunked_prefill_enabled: + raise ValueError( + "Chunked prefill must be enabled to set " + "max_num_partial_prefills > 1." + ) if self.long_prefill_token_threshold > self.max_model_len: raise ValueError( From a803c79ac3cb06082baa9b0b39ef59e6831158f1 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Sun, 12 Oct 2025 22:56:32 +0200 Subject: [PATCH 11/14] Skip validation for max_num_batched_tokens because the fallback is dynamic Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index a75a38063406..1c4fe69abd31 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -5,7 +5,7 @@ from dataclasses import InitVar from typing import Any, Literal -from pydantic import Field, field_validator, model_validator +from pydantic import Field, SkipValidation, field_validator, model_validator from pydantic.dataclasses import dataclass from typing_extensions import Self @@ -31,7 +31,7 @@ class SchedulerConfig: runner_type: RunnerType = "generate" """The runner type to launch for the model.""" - max_num_batched_tokens: int = Field(default=None, ge=1) + max_num_batched_tokens: SkipValidation[int] = Field(default=None, ge=1) """Maximum number of tokens to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will From 2f7716ae1f409b05448ebd9c0306395759e1b00b Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Thu, 30 Oct 2025 17:14:23 -0700 Subject: [PATCH 12/14] Set model_max_len in model_validator Signed-off-by: Vinay Damodaran --- vllm/config/scheduler.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 98eaa4db3a70..dc6bce8b6773 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -21,6 +21,7 @@ RunnerType = Literal["generate", "pooling", "draft"] SchedulerPolicy = Literal["fcfs", "priority"] +DEFAULT_MAX_MODEL_LEN = 8192 @config @@ -171,14 +172,6 @@ def _validate_max_num_seqs(cls, max_num_seqs: Any | None) -> Any: return 128 return max_num_seqs - @field_validator("max_model_len", mode="before") - @classmethod - def _validate_max_model_len(cls, max_model_len: Any | None) -> Any: - if max_model_len is None: - logger.warning("max_model_len is not set, using arbitrary value 8192.") - return 8192 - return max_model_len - def __post_init__(self, is_encoder_decoder: bool) -> None: """Post init to handle init vars.""" if is_encoder_decoder: @@ -194,6 +187,12 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: @model_validator(mode="after") def _validate_scheduler_config(self) -> Self: + if self.max_model_len is None: + logger.warning( + "max_model_len is not set, using arbitrary value %s.", + DEFAULT_MAX_MODEL_LEN, + ) + self.max_model_len = DEFAULT_MAX_MODEL_LEN if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS From c4415aa73c996def085869d536f0f39e5b6860cf Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Thu, 30 Oct 2025 17:20:23 -0700 Subject: [PATCH 13/14] Do the same for max_num_seqs Signed-off-by: Vinay Damodaran --- vllm/config/scheduler.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index dc6bce8b6773..71ebaf3e20d0 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -5,7 +5,7 @@ from dataclasses import InitVar from typing import Any, Literal -from pydantic import Field, field_validator, model_validator +from pydantic import Field, model_validator from pydantic.dataclasses import dataclass from typing_extensions import Self @@ -22,6 +22,7 @@ RunnerType = Literal["generate", "pooling", "draft"] SchedulerPolicy = Literal["fcfs", "priority"] DEFAULT_MAX_MODEL_LEN = 8192 +DEFAULT_MAX_NUM_SEQS = 128 @config @@ -164,14 +165,6 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - @field_validator("max_num_seqs", mode="before") - @classmethod - def _validate_max_num_seqs(cls, max_num_seqs: Any | None) -> Any: - if max_num_seqs is None: - logger.warning("max_num_seqs is not set, using arbitrary value 128.") - return 128 - return max_num_seqs - def __post_init__(self, is_encoder_decoder: bool) -> None: """Post init to handle init vars.""" if is_encoder_decoder: @@ -193,6 +186,12 @@ def _validate_scheduler_config(self) -> Self: DEFAULT_MAX_MODEL_LEN, ) self.max_model_len = DEFAULT_MAX_MODEL_LEN + if self.max_num_seqs is None: + logger.warning( + "max_num_seqs is not set, using arbitrary value %s.", + DEFAULT_MAX_NUM_SEQS, + ) + self.max_num_seqs = DEFAULT_MAX_NUM_SEQS if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS From d68a8b7b7a9964a025c9ce56c3e640495aae2695 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 31 Oct 2025 14:32:57 +0100 Subject: [PATCH 14/14] Fix Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/scheduler.py | 61 +++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 71ebaf3e20d0..b837b830e774 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib +from collections.abc import Callable from dataclasses import InitVar from typing import Any, Literal -from pydantic import Field, model_validator +from pydantic import Field, field_validator, model_validator from pydantic.dataclasses import dataclass from typing_extensions import Self @@ -21,8 +22,6 @@ RunnerType = Literal["generate", "pooling", "draft"] SchedulerPolicy = Literal["fcfs", "priority"] -DEFAULT_MAX_MODEL_LEN = 8192 -DEFAULT_MAX_NUM_SEQS = 128 @config @@ -33,19 +32,19 @@ class SchedulerConfig: runner_type: RunnerType = "generate" """The runner type to launch for the model.""" - max_num_batched_tokens: int | None = Field(default=None, ge=1) + max_num_batched_tokens: int = Field(default=None, ge=1) """Maximum number of tokens to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_num_seqs: int | None = Field(default=None, ge=1) + max_num_seqs: int = Field(default=None, ge=1) """Maximum number of sequences to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_model_len: int | None = Field(default=None, ge=1) + max_model_len: int = Field(default=None, ge=1) """Maximum length of a sequence (including prompt and generated text). This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" @@ -73,14 +72,6 @@ class SchedulerConfig: NOTE: This will be replaced by speculative config in the future; it is present to enable correctness tests until then.""" - cuda_graph_sizes: list[int] = Field(default_factory=list) - """Cuda graph capture sizes - 1. if none provided, then default set to [min(max_num_seqs * 2, 512)] - 2. if one value is provided, then the capture list would follow the - pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)] - 3. more than one value (e.g. 1 2 128) is provided, then the capture list - will follow the provided list.""" - enable_chunked_prefill: bool = Field(default=None) """If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.""" @@ -165,8 +156,27 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str + @field_validator( + "max_num_batched_tokens", + "max_num_seqs", + "max_model_len", + "enable_chunked_prefill", + mode="wrap", + ) + @classmethod + def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: + """Skip validation if the value is `None` when initialisation is delayed.""" + if value is None: + return value + return handler(value) + def __post_init__(self, is_encoder_decoder: bool) -> None: - """Post init to handle init vars.""" + if self.max_model_len is None: + self.max_model_len = 8192 + + if self.max_num_seqs is None: + self.max_num_seqs = 128 + if is_encoder_decoder: # Chunked prefill should be disabled for encoder-decoder models. self.disable_chunked_mm_input = True @@ -178,20 +188,6 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: " prefix caching; disabling both." ) - @model_validator(mode="after") - def _validate_scheduler_config(self) -> Self: - if self.max_model_len is None: - logger.warning( - "max_model_len is not set, using arbitrary value %s.", - DEFAULT_MAX_MODEL_LEN, - ) - self.max_model_len = DEFAULT_MAX_MODEL_LEN - if self.max_num_seqs is None: - logger.warning( - "max_num_seqs is not set, using arbitrary value %s.", - DEFAULT_MAX_NUM_SEQS, - ) - self.max_num_seqs = DEFAULT_MAX_NUM_SEQS if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS @@ -249,6 +245,8 @@ def _validate_scheduler_config(self) -> Self: if self.async_scheduling: self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler" + @model_validator(mode="after") + def _verify_args(self) -> Self: if ( self.max_num_batched_tokens < self.max_model_len and not self.chunked_prefill_enabled @@ -293,9 +291,8 @@ def _validate_scheduler_config(self) -> Self: if self.max_long_partial_prefills > self.max_num_partial_prefills: raise ValueError( - f"max_long_partial_prefills ({self.max_long_partial_prefills}) " - "must be less than or equal to " - f"max_num_partial_prefills ({self.max_num_partial_prefills})." + f"{self.max_long_partial_prefills=} must be less than or equal to " + f"{self.max_num_partial_prefills=}." ) return self