diff --git a/tests/compile/fullgraph/test_toy_llama.py b/tests/compile/fullgraph/test_toy_llama.py index 915fbc6ce7f3..339d6e340d0f 100644 --- a/tests/compile/fullgraph/test_toy_llama.py +++ b/tests/compile/fullgraph/test_toy_llama.py @@ -26,6 +26,7 @@ VllmConfig, set_current_vllm_config, ) +from vllm.config.utils import get_compile_factors from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -45,16 +46,9 @@ class LlamaConfig: tractable_init: bool = False random_seed: int = 0 - def compute_hash(self) -> str: - factors: list[Any] = [] - for k, v in self.__dict__.items(): - if k == "random_seed": - continue - factors.append((k, v)) - factors.sort() - import hashlib - - return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + def compile_factors(self) -> dict[str, Any]: + ignored = {"random_seed"} + return get_compile_factors(self, ignored) def __post_init__(self): assert self.mlp_size >= self.hidden_size diff --git a/tests/config/test_config_utils.py b/tests/config/test_config_utils.py index 23451c475ea9..72dda47574b0 100644 --- a/tests/config/test_config_utils.py +++ b/tests/config/test_config_utils.py @@ -6,7 +6,7 @@ import pytest -from vllm.config.utils import get_hash_factors, hash_factors, normalize_value +from vllm.config.utils import get_compile_factors, hash_factors, normalize_value # Helpers @@ -25,7 +25,7 @@ def expected_path(p_str: str = ".") -> str: return p.expanduser().resolve().as_posix() -# Minimal dataclass to test get_hash_factors. +# Minimal dataclass to test get_compile_factors. # Avoid importing heavy vLLM configs. @dataclass class SimpleConfig: @@ -136,8 +136,8 @@ def test_enum_vs_int_disambiguation(): assert enum_val == "raw_logits" # Build factor dicts from configs with int vs enum - f_int = get_hash_factors(SimpleConfig(1), set()) - f_enum = get_hash_factors(SimpleConfig(DummyLogprobsMode.RAW_LOGITS), set()) + f_int = get_compile_factors(SimpleConfig(1), set()) + f_enum = get_compile_factors(SimpleConfig(DummyLogprobsMode.RAW_LOGITS), set()) # The int case remains a primitive value assert f_int["a"] == 1 # The enum case becomes a tagged tuple ("module.QualName", "raw_logits") diff --git a/tests/config/test_multimodal_config.py b/tests/config/test_multimodal_config.py index e5c30f999a05..136adee1b368 100644 --- a/tests/config/test_multimodal_config.py +++ b/tests/config/test_multimodal_config.py @@ -19,11 +19,11 @@ def test_mm_encoder_attn_backend_invalid(): def test_mm_encoder_attn_backend_hash_updates(): - base_hash = MultiModalConfig().compute_hash() - overridden_hash = MultiModalConfig( + base_compile_signature = MultiModalConfig().compile_factors() + overridden_compile_signature = MultiModalConfig( mm_encoder_attn_backend=AttentionBackendEnum.FLASH_ATTN - ).compute_hash() - assert base_hash != overridden_hash + ).compile_factors() + assert base_compile_signature != overridden_compile_signature def test_language_model_only_does_not_affect_mm_hash(): diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index dee7cdde744d..6bf041e75cfb 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -3,7 +3,6 @@ import ast import dataclasses -import hashlib import json import operator import os @@ -13,7 +12,7 @@ from collections.abc import Callable, Generator, Sequence from contextlib import contextmanager from copy import deepcopy -from functools import partial +from pathlib import Path from typing import Any import torch @@ -25,13 +24,18 @@ import vllm.envs as envs from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config.compilation import DynamicShapesType -from vllm.config.utils import Range, hash_factors +from vllm.config.utils import CompileFactors, Range, hash_factors from vllm.logger import init_logger from vllm.logging_utils import lazy from vllm.platforms import current_platform from vllm.tracing import instrument, instrument_manual from vllm.utils.import_utils import resolve_obj_by_qualname +from .caching import ( + VllmSerializableFunction, + compute_env_and_config_hashes, + get_code_factors, +) from .compiler_interface import ( CompilerInterface, EagerAdaptor, @@ -137,8 +141,8 @@ def __init__(self, compilation_config: CompilationConfig) -> None: self.compiler = make_compiler(compilation_config) self.loaded_artifacts: dict[str, Any] = {} - def compute_hash(self, vllm_config: VllmConfig) -> str: - return self.compiler.compute_hash(vllm_config) + def compile_factors(self, vllm_config: VllmConfig) -> CompileFactors: + return self.compiler.compile_factors(vllm_config) @contextmanager def compile_context(self, compile_range: Range) -> Generator[None, None, None]: @@ -975,41 +979,31 @@ def list_to_str(lst: list | None) -> str: @dynamo_timed("vllm_backend") def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any: - from .caching import ( - VllmSerializableFunction, - ) - vllm_config = self.vllm_config self._log_compilation_config() # Minimal hashing here with existing utilities, reused below. - env_factors = envs.compile_factors() - env_hash = hash_factors(env_factors) - # Compute config/compiler/code hashes once and reuse - config_hash = vllm_config.compute_hash() - compiler_hash = self.compiler_manager.compute_hash(vllm_config) - forward_code_files = list(sorted(self.compilation_config.traced_files)) + ( + env_hash, + config_hash, + env_factors, + config_factors, + ) = compute_env_and_config_hashes(vllm_config) + compiler_factors = self.compiler_manager.compile_factors(vllm_config) + compiler_hash = hash_factors(compiler_factors) + traced_files = set(self.compilation_config.traced_files) + forward_code_files = sorted( + (Path(filepath) for filepath in traced_files), key=str + ) logger.debug( "Traced files (to be considered for compilation cache):\n%s", - lazy(lambda: "\n".join(forward_code_files)), + lazy(lambda: "\n".join(map(str, forward_code_files))), ) - hash_content = [] - for filepath in forward_code_files: - hash_content.append(filepath) - if filepath == "": - # This means the function was dynamically generated, with - # e.g. exec(). We can't actually check these. - continue - try: - with open(filepath) as f: - hash_content.append(f.read()) - except (OSError, UnicodeDecodeError): - logger.warning("Failed to read file %s", filepath) - continue - code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest() + code_factors = get_code_factors(forward_code_files) + code_hash = hash_factors({"files": code_factors}) # Clear after consumption self.compilation_config.traced_files.clear() if not self.compilation_config.cache_dir: @@ -1017,10 +1011,15 @@ def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any: # that affects the compilation. if none of the factors change, # the cache dir will be the same so that we can reuse the compiled # graph. - factors = [env_hash, config_hash, code_hash, compiler_hash] + all_factors = { + "env": env_factors, + "config": config_factors, + "code": {"files": code_factors}, + "compiler": compiler_factors, + } # Use SHA-256 for cache key hashing to be consistent across - # compute_hash functions. Truncate for a short cache dir name. - hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10] + # compile_factors functions. Truncate for a short cache dir name. + hash_key = hash_factors(all_factors)[:10] cache_dir = os.path.join( envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key ) @@ -1071,25 +1070,34 @@ def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any: # Persist and log only hash-relevant factors together. try: - logger.debug( - "Compile env factors (raw):\n%s\nVllm config hash: %s", - lazy(partial(pprint.pformat, env_factors, width=120)), - config_hash, - ) meta_path = os.path.join(local_cache_dir, "cache_key_factors.json") if not os.path.exists(meta_path): with open(meta_path, "w") as f: json.dump( { "env": env_factors, # raw factors used for env_hash + "config": config_factors, "config_hash": config_hash, - "code_hash": code_hash, + "compiler": compiler_factors, "compiler_hash": compiler_hash, + "code_hash": code_hash, + "code": code_factors, }, f, indent=2, sort_keys=True, ) + logger.debug( + ( + "Persisted compile cache factors to %s " + "(env_keys=%d config_keys=%d compiler_keys=%d code_entries=%d)" + ), + meta_path, + len(env_factors), + len(config_factors), + len(compiler_factors), + len(code_factors), + ) except Exception: # Best-effort only; metadata write failures are non-fatal. logger.warning( diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index c089f02a37ff..6a69728022ea 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -7,6 +7,7 @@ import os import pickle from collections.abc import Callable, Sequence +from pathlib import Path from typing import Any, Literal from unittest.mock import patch @@ -30,9 +31,30 @@ assert isinstance(SerializableCallable, type) + logger = init_logger(__name__) +def get_code_factors(forward_code_files: list[Path]) -> list[dict[str, str]]: + """Return per-file factors for compile cache hashing.""" + code_factors: list[dict[str, str]] = [] + for filepath in forward_code_files: + path_str = str(filepath) + entry: dict[str, str] = {"path": path_str} + if path_str == "": + # Dynamically generated code (e.g., exec); nothing to hash. + code_factors.append(entry) + continue + try: + with filepath.open() as f: + content = f.read() + entry["hash"] = hash_factors({"content": content}) + except Exception: + logger.warning("Failed to read file %s", path_str) + code_factors.append(entry) + return code_factors + + class StandaloneCompiledArtifacts: """Storage for standalone compiled artifacts with content-based deduplication. @@ -400,6 +422,22 @@ def co_name(self) -> Literal["VllmSerializableFunction"]: return "VllmSerializableFunction" +def compute_env_and_config_hashes( + vllm_config: VllmConfig, +) -> tuple[str, str, dict[str, object], dict[str, object]]: + """ + Return the hashed environment factors, config hash, and raw factors. + Both AOT and JIT cache paths rely on this helper to ensure their cache keys + stay in sync. + """ + + env_factors = envs.compile_factors() + env_hash = hash_factors(env_factors) + config_factors = vllm_config.compile_factors() + config_hash = hash_factors(config_factors) + return env_hash, config_hash, env_factors, config_factors + + def reconstruct_serializable_fn_from_mega_artifact( state: dict[str, Any], standalone_compile_artifacts: "StandaloneCompiledArtifacts", diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index bddacfbbc295..d129d08f858f 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -15,9 +15,8 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig -from vllm.config.utils import Range +from vllm.config.utils import CompileFactors, Range, normalize_value from vllm.logger import init_logger -from vllm.utils.hashing import safe_hash from vllm.utils.torch_utils import is_torch_equal_or_newer logger = init_logger(__name__) @@ -50,17 +49,17 @@ def initialize_cache( """ pass - def compute_hash(self, vllm_config: VllmConfig) -> str: + def compile_factors(self, vllm_config: VllmConfig) -> CompileFactors: """ - Gather all the relevant information from the vLLM config, - to compute a hash so that we can cache the compiled model. + Gather compiler-specific factors that influence the generated code. - See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash] - to check what information - is already considered by default. This function should only - consider the information that is specific to the compiler. + See [`VllmConfig.compile_factors`][vllm.config.VllmConfig.compile_factors] + for the base configuration factors. This method should return any + additional data that uniquely identifies the compiler's contribution to + the cache key. Subclasses must return a dictionary; use an empty dict + when no compiler-specific data is needed. """ - return "" + return {} def compile( self, @@ -156,13 +155,13 @@ def get_inductor_factors() -> list[Any]: # summarize system state from torch._inductor.codecache import CacheBase - system_factors = CacheBase.get_system() + system_factors = normalize_value(CacheBase.get_system()) factors.append(system_factors) # summarize pytorch state from torch._inductor.codecache import torch_key - torch_factors = torch_key() + torch_factors = normalize_value(torch_key()) factors.append(torch_factors) return factors @@ -284,12 +283,8 @@ def __init__(self, save_format: Literal["binary", "unpacked"]) -> None: _patch_standalone_compile_atomic_save() self.save_format = save_format - def compute_hash(self, vllm_config: VllmConfig) -> str: - factors = get_inductor_factors() - hash_str: str = safe_hash( - str(factors).encode(), usedforsecurity=False - ).hexdigest()[:10] - return hash_str + def compile_factors(self, vllm_config: VllmConfig) -> CompileFactors: + return {"inductor_standalone": get_inductor_factors()} def initialize_cache( self, cache_dir: str, disable_cache: bool = False, prefix: str = "" @@ -399,7 +394,6 @@ def compile( # since we can serialize the bytes using to_bytes # and reload it using the key when reading return compiled_graph, None - # Save the compiled artifact to disk in the specified path assert key is not None path = os.path.join(self.cache_dir, key) @@ -468,12 +462,8 @@ class InductorAdaptor(CompilerInterface): name = "inductor" - def compute_hash(self, vllm_config: VllmConfig) -> str: - factors = get_inductor_factors() - hash_str: str = safe_hash( - str(factors).encode(), usedforsecurity=False - ).hexdigest()[:10] - return hash_str + def compile_factors(self, vllm_config: VllmConfig) -> CompileFactors: + return {"inductor": get_inductor_factors()} def initialize_cache( self, cache_dir: str, disable_cache: bool = False, prefix: str = "" diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index 0571741419f7..5a662315fdd1 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools from collections.abc import Callable -from typing import Any, ParamSpec, TypeVar +from typing import ParamSpec, TypeVar from torch import fx as fx @@ -176,19 +176,13 @@ def uuid(self) -> str: affects compilation caching. Its uuid depends on the UUIDs of all dependent passes and the pass config. See InductorPass for more info. """ - passes = [] - - state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()} - for pass_ in self.passes: - passes.append(pass_.uuid()) - + passes: list[str] = [pass_.uuid() for pass_ in self.passes] passes.append(self.post_cleanup.uuid()) passes.append(self.ir_lowering.uuid()) - passes.append(self.post_cleanup.uuid()) passes.append(self.fix_functionalization.uuid()) + state = {"pass_config": self.pass_config.compile_factors(), "passes": passes} # Include the compile range in the uuid to ensure that inductor # recompiles the graph for the new dynamic compile range. state["compile_range"] = str(get_pass_context().compile_range) - state["passes"] = passes return InductorPass.hash_dict(state) diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 014bb9b22601..b79b5d5a6969 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -5,7 +5,7 @@ from pydantic import field_validator -from vllm.config.utils import config +from vllm.config.utils import CompileFactors, config, get_compile_factors from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -46,19 +46,13 @@ class AttentionConfig: use_prefill_query_quantization: bool = False """If set, quantize query for attention in prefill.""" - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. + Provide the factors that affect the compiled computation graph. + All dataclass fields participate; add fields to an ignore set if + they should not influence compilation cache keys. """ - from vllm.config.utils import get_hash_factors, hash_factors - - ignored_factors: set[str] = set() - factors = get_hash_factors(self, ignored_factors) - return hash_factors(factors) + return get_compile_factors(self, set()) @field_validator("backend", mode="before") @classmethod diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 49c8868e709f..cccac6e16d87 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -6,7 +6,7 @@ from pydantic import Field, SkipValidation, field_validator, model_validator -from vllm.config.utils import config +from vllm.config.utils import CompileFactors, config, get_compile_factors from vllm.logger import init_logger logger = init_logger(__name__) @@ -162,17 +162,20 @@ class CacheConfig: 'native' (vLLM native CPU offloading), 'lmcache'. KV offloading is only activated when kv_offloading_size is set.""" - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. + WARNING: Whenever a new field is added to this config, review + `ignored_factors` to decide whether the field should be excluded. + All other dataclass fields participate in the hash automatically. Provide a hash that uniquely identifies all the configs that affect the structure of the computation graph from input ids/embeddings to the final hidden states, excluding anything before input ids/embeddings and after the final hidden states. + + This config uses an opt-out hash: start from every dataclass field and + then drop the `ignored_factors` below. """ ignored_factors = { # Runtime/derived knobs that don't affect compiled graph shape @@ -193,10 +196,7 @@ def compute_hash(self) -> str: "kv_sharing_fast_prefill", } - from vllm.config.utils import get_hash_factors, hash_factors - - factors = get_hash_factors(self, ignored_factors) - return hash_factors(factors) + return get_compile_factors(self, ignored_factors) def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 916c5a002058..3f5f00771854 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -13,10 +13,10 @@ import vllm.envs as envs from vllm.compilation.passes.inductor_pass import CallableInductorPass, InductorPass from vllm.config.utils import ( + CompileFactors, Range, config, - get_hash_factors, - hash_factors, + get_compile_factors, ) from vllm.logger import init_logger from vllm.platforms import current_platform @@ -203,14 +203,14 @@ def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]: return {} return FI_ALLREDUCE_FUSION_MAX_SIZE_MB.get(capability.to_int(), {}) - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ Produces a hash unique to the pass configuration. Any new fields that affect compilation should be added to the hash. Any future fields that don't affect compilation should be excluded. """ - return hash_factors(get_hash_factors(self, set())) + return get_compile_factors(self, set()) @field_validator( "fuse_norm_quant", @@ -344,15 +344,11 @@ class DynamicShapesConfig: `True` requires PyTorch 2.10+ """ - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ - Provide a hash for DynamicShapesConfig + Provide the factors used for hashing DynamicShapesConfig. """ - - from vllm.config.utils import get_hash_factors, hash_factors - - factors = get_hash_factors(self, set()) - return hash_factors(factors) + return get_compile_factors(self, set()) @config @@ -703,6 +699,8 @@ class CompilationConfig: Map from layer name to layer objects that need to be accessed outside model code, e.g., Attention, FusedMOE when dp_size>1.""" + bs_to_padded_graph_size: list[int] = field(default_factory=list, init=False) + """Runtime map from batch size to cudagraph padded size.""" static_all_moe_layers: list[str] = field(default_factory=list, init=False) """The names of all the MOE layers in the model """ @@ -726,37 +724,29 @@ class CompilationConfig: "vllm::rocm_aiter_sparse_attn_indexer", ] - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ Provide a hash that uniquely identifies all the configs that affect the structure of the computation graph from input ids/embeddings to the final hidden states, excluding anything before input ids/embeddings and after the final hidden states. - """ - # Opt-out: default-include declared fields; keep a tiny exclude set; - # normalize types; keep SHA-256. For nested opaque configs, include a - # stable identifier (e.g., pass_config.compute_hash()) instead of object id. + This config follows the opt-out hashing pattern: start from every + dataclass field and remove the `ignored_factors` list below. + """ ignored_factors = { # Paths/dirs and runtime/metrics that don’t affect compiled graph "debug_dump_path", "cache_dir", "local_cache_dir", + "bs_to_padded_graph_size", "traced_files", "compilation_time", "static_forward_context", - "pass_config", # handled separately below - "dynamic_shapes_config", # handled separately below } - from vllm.config.utils import get_hash_factors, hash_factors - - factors = get_hash_factors(self, ignored_factors) - - factors["pass_config"] = self.pass_config.compute_hash() - factors["dynamic_shapes_config"] = self.dynamic_shapes_config.compute_hash() - return hash_factors(factors) + return get_compile_factors(self, ignored_factors) def __repr__(self) -> str: exclude: dict[str, bool | dict[str, bool]] = { diff --git a/vllm/config/device.py b/vllm/config/device.py index c20e4d0f288b..2ddc3353e355 100644 --- a/vllm/config/device.py +++ b/vllm/config/device.py @@ -2,13 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import field -from typing import Any, Literal +from typing import Literal import torch from pydantic import ConfigDict, SkipValidation -from vllm.config.utils import config -from vllm.utils.hashing import safe_hash +from vllm.config.utils import CompileFactors, config Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"] @@ -27,7 +26,7 @@ class DeviceConfig: """Device type from the current platform. This is set in `__post_init__`.""" - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ WARNING: Whenever a new field is added to this config, ensure that it is included in the factors list if @@ -42,9 +41,7 @@ def compute_hash(self) -> str: # no factors to consider. # the device/platform information will be summarized # by torch/vllm automatically. - factors: list[Any] = [] - hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() - return hash_str + return {} def __post_init__(self): if self.device == "auto": diff --git a/vllm/config/ec_transfer.py b/vllm/config/ec_transfer.py index a3a927d51ec4..351653b21f83 100644 --- a/vllm/config/ec_transfer.py +++ b/vllm/config/ec_transfer.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib import uuid from dataclasses import field from typing import Any, Literal, get_args -from vllm.config.utils import config +from vllm.config.utils import CompileFactors, config ECProducer = Literal["ec_producer", "ec_both"] ECConsumer = Literal["ec_consumer", "ec_both"] @@ -57,7 +56,7 @@ class ECTransferConfig: """The Python module path to dynamically load the EC connector from. Only supported in V1.""" - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ WARNING: Whenever a new field is added to this config, ensure that it is included in the factors list if @@ -69,11 +68,8 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() - return hash_str + # This config does not affect the compiled graph. + return {} def __post_init__(self) -> None: if self.engine_id is None: diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py index b22af99f703f..5d9fac81fa73 100644 --- a/vllm/config/kv_transfer.py +++ b/vllm/config/kv_transfer.py @@ -5,8 +5,7 @@ from dataclasses import field from typing import Any, Literal, get_args -from vllm.config.utils import config -from vllm.utils.hashing import safe_hash +from vllm.config.utils import CompileFactors, config KVProducer = Literal["kv_producer", "kv_both"] KVConsumer = Literal["kv_consumer", "kv_both"] @@ -72,7 +71,7 @@ class KVTransferConfig: 'recompute': reschedule the request to recompute failed blocks 'fail': immediately fail the request with an error finish reason (default)""" - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ WARNING: Whenever a new field is added to this config, ensure that it is included in the factors list if @@ -84,11 +83,8 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() - return hash_str + # This config does not affect the compiled graph. + return {} def __post_init__(self) -> None: if self.engine_id is None: diff --git a/vllm/config/load.py b/vllm/config/load.py index 93240ec5fc0f..15e868af8442 100644 --- a/vllm/config/load.py +++ b/vllm/config/load.py @@ -5,9 +5,8 @@ from pydantic import Field, field_validator -from vllm.config.utils import config +from vllm.config.utils import CompileFactors, config from vllm.logger import init_logger -from vllm.utils.hashing import safe_hash if TYPE_CHECKING: from vllm.model_executor.model_loader import LoadFormats @@ -102,7 +101,7 @@ class LoadConfig: the original doc for `map_location` parameter in [`torch.load`][] parameter. """ - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ WARNING: Whenever a new field is added to this config, ensure that it is included in the factors list if @@ -114,11 +113,8 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() - return hash_str + # This config does not affect the compiled graph. + return {} @field_validator("load_format", mode="after") def _lowercase_load_format(cls, load_format: str) -> str: diff --git a/vllm/config/lora.py b/vllm/config/lora.py index bfef0efa3df0..0971a64d82ea 100644 --- a/vllm/config/lora.py +++ b/vllm/config/lora.py @@ -7,9 +7,8 @@ from pydantic import ConfigDict, Field, model_validator from typing_extensions import Self -from vllm.config.utils import config +from vllm.config.utils import CompileFactors, config, get_compile_factors from vllm.logger import init_logger -from vllm.utils.hashing import safe_hash if TYPE_CHECKING: from vllm.config import ModelConfig @@ -70,7 +69,7 @@ class LoRAConfig: memory usage. Only takes effect when cudagraph_specialize_lora is True. """ - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ WARNING: Whenever a new field is added to this config, ensure that it is included in the factors list if @@ -82,19 +81,12 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - factors: list[Any] = [] - factors.append(self.max_lora_rank) - factors.append(self.max_loras) - factors.append(self.fully_sharded_loras) - factors.append(self.lora_dtype) - factors.append(self.enable_tower_connector_lora) - # target_modules affects which modules get LoRA applied - factors.append( - tuple(sorted(self.target_modules)) if self.target_modules else None - ) - - hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() - return hash_str + ignored_factors = { + # Runtime/placement only; does not affect compiled graph + "max_cpu_loras", + "default_mm_loras", + } + return get_compile_factors(self, ignored_factors) @model_validator(mode="after") def _validate_lora_config(self) -> Self: diff --git a/vllm/config/model.py b/vllm/config/model.py index c4ee654fe8bc..f55eb97d6061 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -22,7 +22,7 @@ ) from vllm.config.pooler import PoolerConfig from vllm.config.scheduler import RunnerType -from vllm.config.utils import config, getattr_iter +from vllm.config.utils import CompileFactors, config, get_compile_factors, getattr_iter from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.tasks import PoolingTask, ScoreType, SupportedTask @@ -325,19 +325,23 @@ class ModelConfig: video_pruning_rate: InitVar[float | None] = None mm_tensor_ipc: InitVar[MMTensorIPC] = None - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. + WARNING: Whenever a new field is added to this config, review + `ignored_factors` to decide whether that field must be excluded. + Every other dataclass field automatically participates in the hash. Provide a hash that uniquely identifies all the configs that affect the structure of the computation graph from input ids/embeddings to the final hidden states, excluding anything before input ids/embeddings and after the final hidden states. + + This config is opt-out hashed: include every dataclass field except for + those explicitly listed in `ignored_factors`. """ ignored_factors = { + "runner", "convert", "tokenizer", "tokenizer_mode", @@ -355,6 +359,7 @@ def compute_hash(self) -> str: "config_format", "hf_token", "hf_overrides", + "logits_processor_pattern", "override_attention_dtype", "logits_processors", "io_processor_plugin", @@ -371,16 +376,12 @@ def compute_hash(self) -> str: "skip_mm_profiling", } - from vllm.config.utils import get_hash_factors, hash_factors - - factors = get_hash_factors(self, ignored_factors) - + factors = get_compile_factors(self, ignored_factors) # NOTE: For some models (e.g, Qwen3-VL), whether the MM code path is enabled - # affects the computation graph of the language model, therefore we add it - # here early. + # affects the computation graph of the language model. if self.multimodal_config: factors["language_model_only"] = self.multimodal_config.language_model_only - return hash_factors(factors) + return factors or {} def _update_nested( self, diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index e66511c92ab2..21a3a2ee85db 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -7,8 +7,7 @@ from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic.dataclasses import dataclass -from vllm.config.utils import config -from vllm.utils.hashing import safe_hash +from vllm.config.utils import CompileFactors, config, normalize_value from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -235,7 +234,7 @@ def _validate_multimodal_config(self): ) return self - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ WARNING: Whenever a new field is added to this config, ensure that it is included in the factors list if @@ -253,8 +252,8 @@ def compute_hash(self) -> str: else None, self.mm_encoder_tp_mode, ] - hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() - return hash_str + normalized = normalize_value(factors) + return {"factors": normalized} if normalized else {} def get_limit_per_prompt(self, modality: str) -> int: """ diff --git a/vllm/config/observability.py b/vllm/config/observability.py index 84e83c6d4ad2..55ee25f7c695 100644 --- a/vllm/config/observability.py +++ b/vllm/config/observability.py @@ -2,14 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from functools import cached_property -from typing import Any, Literal, cast +from typing import Literal, cast from packaging.version import parse from pydantic import Field, field_validator, model_validator from vllm import version -from vllm.config.utils import config -from vllm.utils.hashing import safe_hash +from vllm.config.utils import CompileFactors, config DetailedTraceModules = Literal["model", "worker", "all"] @@ -92,7 +91,7 @@ def collect_model_execute_time(self) -> bool: or "all" in self.collect_detailed_traces ) - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ WARNING: Whenever a new field is added to this config, ensure that it is included in the factors list if @@ -106,9 +105,7 @@ def compute_hash(self) -> str: """ # no factors to consider. # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() - return hash_str + return {} @field_validator("show_hidden_metrics_for_version") @classmethod diff --git a/vllm/config/offload.py b/vllm/config/offload.py index ad65e8acf35a..aacb5ed11c4e 100644 --- a/vllm/config/offload.py +++ b/vllm/config/offload.py @@ -146,8 +146,8 @@ def compute_hash(self) -> str: alter which layers are hooked and how prefetch indices are computed, so the compilation cache must distinguish them. """ - from vllm.config.utils import get_hash_factors, hash_factors + from vllm.config.utils import get_compile_factors, hash_factors - factors = get_hash_factors(self, ignored_factors=set()) + factors = get_compile_factors(self, ignored_factors=set()) hash_str = hash_factors(factors) return hash_str diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 107dcfa273eb..315399097cd3 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -12,7 +12,7 @@ from typing_extensions import Self import vllm.envs as envs -from vllm.config.utils import config +from vllm.config.utils import CompileFactors, config, get_compile_factors from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.network_utils import get_open_ports_list @@ -593,7 +593,7 @@ def sync_kv_cache_memory_size(dp_group: ProcessGroup, kv_cache_memory: int) -> i torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group) return tensor.item() - def compute_hash(self): + def compile_factors(self) -> CompileFactors: """ Provide a hash that uniquely identifies all the configs that affect the structure of the computation @@ -603,6 +603,10 @@ def compute_hash(self): This hash is also used for DP worker configuration validation to prevent hangs from mismatched collective communication patterns. + + When adding new fields to this config, review `ignored_factors` to + decide whether they should be excluded. All other dataclass fields are + included automatically by the opt-out hashing scheme. """ ignored_factors = { # Derived/runtime topology, networking, or launch details @@ -635,10 +639,7 @@ def compute_hash(self): "_api_process_rank", } - from vllm.config.utils import get_hash_factors, hash_factors - - factors = get_hash_factors(self, ignored_factors) - return hash_factors(factors) + return get_compile_factors(self, ignored_factors) def __post_init__(self) -> None: # Continue with the rest of the initialization diff --git a/vllm/config/pooler.py b/vllm/config/pooler.py index 24368c3494e7..76e1c5f6c9b2 100644 --- a/vllm/config/pooler.py +++ b/vllm/config/pooler.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Literal, get_args +from typing import Literal, get_args -from vllm.config.utils import config +from vllm.config.utils import CompileFactors, config from vllm.logger import init_logger from vllm.tasks import PoolingTask -from vllm.utils.hashing import safe_hash logger = init_logger(__name__) @@ -133,7 +132,7 @@ def get_tok_pooling_type(self) -> TokenPoolingType: assert self.tok_pooling_type is not None, "Should be resolved by ModelConfig" return self.tok_pooling_type - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ WARNING: Whenever a new field is added to this config, ensure that it is included in the factors list if @@ -147,6 +146,30 @@ def compute_hash(self) -> str: """ # no factors to consider. # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() - return hash_str + # No compile-time factors. + return {} + + +def get_use_activation(o: object): + if (normalize := getattr(o, "normalize", None)) is not None: + logger.warning_once( + "`normalize` is deprecated and will be removed in v0.15. " + "Please use `use_activation` instead." + ) + return normalize + + if (softmax := getattr(o, "softmax", None)) is not None: + logger.warning_once( + "`softmax` is deprecated and will be removed in v0.15. " + "Please use `use_activation` instead." + ) + return softmax + + if (activation := getattr(o, "activation", None)) is not None: + logger.warning_once( + "`activation` is deprecated and will be removed in v0.15. " + "Please use `use_activation` instead." + ) + return activation + + return getattr(o, "use_activation", None) diff --git a/vllm/config/profiler.py b/vllm/config/profiler.py index 68fa78854b45..d931e244f454 100644 --- a/vllm/config/profiler.py +++ b/vllm/config/profiler.py @@ -7,7 +7,8 @@ from pydantic import Field, model_validator from typing_extensions import Self -from vllm.config.utils import config +import vllm.envs as envs +from vllm.config.utils import CompileFactors, config from vllm.logger import init_logger from vllm.utils.hashing import safe_hash @@ -122,6 +123,44 @@ def compute_hash(self) -> str: hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str + def compile_factors(self) -> CompileFactors: + # Profiling setup does not affect the computation graph, so hash neutral. + return {} + + def _get_from_env_if_set(self, field_name: str, env_var_name: str) -> str | None: + """Get field from env var if set, with deprecation warning.""" + + if envs.is_set(env_var_name): + value = getattr(envs, env_var_name) + logger.warning_once( + "Using %s environment variable is deprecated and will be removed in " + "v0.15.0 or v1.0.0, whichever is soonest. Please use " + "--profiler-config.%s command line argument or " + "ProfilerConfig(%s=...) config field instead.", + env_var_name, + field_name, + field_name, + ) + return value + return None + + def _set_from_env_if_set( + self, + field_name: str, + env_var_name: str, + to_bool: bool = True, + to_int: bool = False, + ) -> None: + """Set field from env var if set, with deprecation warning.""" + raw_value = self._get_from_env_if_set(field_name, env_var_name) + if raw_value is not None: + value: str | bool | int = raw_value + if to_bool: + value = value == "1" + if to_int: + value = int(value) + setattr(self, field_name, value) + @model_validator(mode="after") def _validate_profiler_config(self) -> Self: has_delay_or_limit = self.delay_iterations > 0 or self.max_iterations > 0 diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 3cd99bb082eb..7d866dfea7cf 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -8,9 +8,8 @@ from pydantic import Field, field_validator from typing_extensions import Self -from vllm.config.utils import config +from vllm.config.utils import CompileFactors, config from vllm.logger import init_logger -from vllm.utils.hashing import safe_hash from vllm.utils.import_utils import resolve_obj_by_qualname if TYPE_CHECKING: @@ -186,7 +185,7 @@ def get_scheduler_cls(self) -> type["SchedulerInterface"]: return cast(type["SchedulerInterface"], self.scheduler_cls) return resolve_obj_by_qualname(self.scheduler_cls) - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ WARNING: Whenever a new field is added to this config, ensure that it is included in the factors list if @@ -198,8 +197,6 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - factors: list[Any] = [] - # max_num_batched_tokens need to be included in the hash due # to two reasons: # 1. LoRA creates static buffers based on max_num_batched_tokens. @@ -209,10 +206,8 @@ def compute_hash(self) -> str: # based on the data sizes. `max_num_batched_tokens` has an # impact on that. For more details, please check # https://github.com/vllm-project/vllm/issues/29585 - factors.append(self.max_num_batched_tokens) - hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() - return hash_str + return {"max_num_batched_tokens": self.max_num_batched_tokens} @field_validator("scheduler_cls", "async_scheduling", mode="wrap") @classmethod diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index f1fda9afd318..92740f631d60 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -12,10 +12,9 @@ from vllm.config.kernel import MoEBackend from vllm.config.model import ModelConfig from vllm.config.parallel import ParallelConfig -from vllm.config.utils import config +from vllm.config.utils import CompileFactors, config, normalize_value from vllm.logger import init_logger from vllm.transformers_utils.config import get_hf_text_config -from vllm.utils.hashing import safe_hash from vllm.utils.import_utils import LazyLoader, has_arctic_inference if TYPE_CHECKING: @@ -194,7 +193,7 @@ class SpeculativeConfig: positions equals this value. Only used when rejection_sample_method is 'synthetic'. Must be in [0, 1].""" - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ WARNING: Whenever a new field is added to this config, ensure that it is included in the factors list if @@ -206,17 +205,13 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - factors: list[Any] = [] # Eagle3 and extract_hidden_states affect the computation graph because - # they return intermediate hidden states in addition to the final hidden state. - uses_aux_hidden_states = self.method in ( - "eagle3", - "extract_hidden_states", - "dflash", - ) - factors.append(uses_aux_hidden_states) + # they return intermediate hidden states in addition to the final hidden + # state. + uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states") + factors: list[Any] = [uses_aux_hidden_states] - # The specific layers used also affect the computation graph + # The specific layers used also affect the computation graph. if uses_aux_hidden_states and self.draft_model_config is not None: layer_ids = getattr( self.draft_model_config.hf_config, @@ -224,11 +219,10 @@ def compute_hash(self) -> str: None, ) if layer_ids is not None: - # Convert to tuple to make it hashable factors.append(tuple(layer_ids)) - hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() - return hash_str + normalized = normalize_value(factors) + return {"factors": normalized} if normalized else {} @staticmethod def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: diff --git a/vllm/config/structured_outputs.py b/vllm/config/structured_outputs.py index e7afbb65bc7f..af4079bf1126 100644 --- a/vllm/config/structured_outputs.py +++ b/vllm/config/structured_outputs.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Literal +from typing import Literal from pydantic import model_validator from typing_extensions import Self -from vllm.config.utils import config -from vllm.utils.hashing import safe_hash +from vllm.config.utils import CompileFactors, config StructuredOutputsBackend = Literal[ "auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer" @@ -41,7 +40,7 @@ class StructuredOutputsConfig: enable_in_reasoning: bool = False """Whether to use structured input for reasoning.""" - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ WARNING: Whenever a new field is added to this config, ensure that it is included in the factors list if @@ -53,11 +52,7 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() - return hash_str + return {} @model_validator(mode="after") def _validate_structured_output_config(self) -> Self: diff --git a/vllm/config/utils.py b/vllm/config/utils.py index a953fcb46e42..9902eefc10e6 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -34,6 +34,7 @@ ConfigType = type[DataclassInstance] ConfigT = TypeVar("ConfigT", bound=DataclassInstance) +CompileFactors = dict[str, object] @overload @@ -199,8 +200,8 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]: @runtime_checkable -class SupportsHash(Protocol): - def compute_hash(self) -> str: ... +class SupportsCompileFactors(Protocol): + def compile_factors(self) -> CompileFactors: ... class SupportsMetricsInfo(Protocol): @@ -321,22 +322,37 @@ def normalize_value(x): ) -def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, object]: +def get_compile_factors(config: ConfigT, ignored_factors: set[str]) -> CompileFactors: """Gets the factors used for hashing a config class. - Includes all dataclass fields not in `ignored_factors`. + - Uses .compile_factors() for nested dataclasses that support it - Errors on non-normalizable values. """ + # dataclasses.fields() skips InitVar entries; __dataclass_fields__ keeps + # them. Include both so ignored_factors can safely name InitVars. + dataclass_fields = getattr(config, "__dataclass_fields__", {}) + field_names = {f.name for f in fields(config)} | set(dataclass_fields) + unknown_ignored = ignored_factors - field_names + if unknown_ignored: + raise ValueError( + f"get_compile_factors: ignored_factors contain unknown fields " + f"{sorted(unknown_ignored)} for {type(config).__name__}" + ) factors: dict[str, object] = {} for dc_field in fields(config): factor = dc_field.name if factor in ignored_factors: continue value = getattr(config, factor, None) + # Nested configs expose factors via compile_factors; unwrap first. + if isinstance(value, SupportsCompileFactors): + factors[factor] = value.compile_factors() + continue try: factors[factor] = normalize_value(value) except TypeError as e: raise TypeError( - f"get_hash_factors: unsupported type for key '{factor}' " + f"get_compile_factors: unsupported type for key '{factor}' " f"({type(value).__name__})" ) from e return factors diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index fad3e0ed240f..0e99e1e82cd7 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -3,13 +3,12 @@ import copy import getpass -import json import os import tempfile import threading import time from contextlib import contextmanager -from dataclasses import is_dataclass +from dataclasses import is_dataclass, replace from datetime import datetime from enum import IntEnum from functools import lru_cache @@ -25,7 +24,6 @@ from vllm.logger import enable_trace_function_call, init_logger from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.utils import random_uuid -from vllm.utils.hashing import safe_hash from .attention import AttentionConfig from .cache import CacheConfig @@ -46,7 +44,12 @@ from .scheduler import SchedulerConfig from .speculative import EagleModelTypes, NgramGPUTypes, SpeculativeConfig from .structured_outputs import StructuredOutputsConfig -from .utils import SupportsHash, config, replace +from .utils import ( + CompileFactors, + SupportsCompileFactors, + config, + hash_factors, +) from .weight_transfer import WeightTransferConfig if TYPE_CHECKING: @@ -313,7 +316,7 @@ class VllmConfig: # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. - additional_config: dict | SupportsHash = Field(default_factory=dict) + additional_config: dict | SupportsCompileFactors = Field(default_factory=dict) """Additional config for specified platform. Different platforms may support different configs. Make sure the configs are valid for the platform you are using. Contents must be hashable.""" @@ -341,7 +344,7 @@ class VllmConfig: remaining requests are aborted once the timeout is reached. """ - def compute_hash(self) -> str: + def compile_factors(self) -> CompileFactors: """ WARNING: Whenever a new field is added to this config, ensure that it is included in the factors list if @@ -353,101 +356,48 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - factors: list[Any] = [] - - # summarize vllm config - vllm_factors: list[Any] = [] from vllm import __version__ - vllm_factors.append(__version__) - if self.model_config: - vllm_factors.append(self.model_config.compute_hash()) - if ( - self.compilation_config - and getattr(self.compilation_config, "compile_mm_encoder", False) - and self.model_config.multimodal_config - ): - vllm_factors.append(self.model_config.multimodal_config.compute_hash()) - else: - vllm_factors.append("None") - if self.cache_config: - vllm_factors.append(self.cache_config.compute_hash()) - else: - vllm_factors.append("None") - if self.parallel_config: - vllm_factors.append(self.parallel_config.compute_hash()) - else: - vllm_factors.append("None") - if self.scheduler_config: - vllm_factors.append(self.scheduler_config.compute_hash()) - else: - vllm_factors.append("None") - if self.device_config: - vllm_factors.append(self.device_config.compute_hash()) - else: - vllm_factors.append("None") - if self.load_config: - vllm_factors.append(self.load_config.compute_hash()) - else: - vllm_factors.append("None") - if self.offload_config: - vllm_factors.append(self.offload_config.compute_hash()) - else: - vllm_factors.append("None") - if self.attention_config: - vllm_factors.append(self.attention_config.compute_hash()) - else: - vllm_factors.append("None") - if self.lora_config: - vllm_factors.append(self.lora_config.compute_hash()) - else: - vllm_factors.append("None") - if self.speculative_config: - vllm_factors.append(self.speculative_config.compute_hash()) - else: - vllm_factors.append("None") - if self.structured_outputs_config: - vllm_factors.append(self.structured_outputs_config.compute_hash()) - if self.profiler_config: - vllm_factors.append(self.profiler_config.compute_hash()) - else: - vllm_factors.append("None") - vllm_factors.append(self.observability_config.compute_hash()) - if self.quant_config: - pass # should be captured by model_config.quantization - if self.compilation_config: - vllm_factors.append(self.compilation_config.compute_hash()) - else: - vllm_factors.append("None") - if self.kernel_config: - vllm_factors.append(self.kernel_config.compute_hash()) - else: - vllm_factors.append(None) - if self.kv_transfer_config: - vllm_factors.append(self.kv_transfer_config.compute_hash()) - else: - vllm_factors.append("None") - if self.ec_transfer_config: - vllm_factors.append(self.ec_transfer_config.compute_hash()) - else: - vllm_factors.append("None") + def get_factors(config_obj: SupportsCompileFactors | None) -> CompileFactors: + return {} if config_obj is None else config_obj.compile_factors() + + factors: dict[str, Any] = { + "version": __version__, + "model": get_factors(self.model_config), + "cache": get_factors(self.cache_config), + "parallel": get_factors(self.parallel_config), + "scheduler": get_factors(self.scheduler_config), + "device": get_factors(self.device_config), + "load": get_factors(self.load_config), + "attention": get_factors(self.attention_config), + "speculative": get_factors(self.speculative_config), + "structured_outputs": get_factors(self.structured_outputs_config), + "observability": get_factors(self.observability_config), + "profiler": get_factors(self.profiler_config), + "compilation": get_factors(self.compilation_config), + "kv_transfer": get_factors(self.kv_transfer_config), + "ec_transfer": get_factors(self.ec_transfer_config), + "lora": get_factors(self.lora_config), + } + if self.additional_config: - if isinstance(additional_config := self.additional_config, dict): - additional_config_hash = safe_hash( - json.dumps(additional_config, sort_keys=True).encode(), - usedforsecurity=False, - ).hexdigest() + additional_config = self.additional_config + if isinstance(additional_config, dict): + factors["additional"] = additional_config + elif isinstance(additional_config, SupportsCompileFactors): + factors["additional"] = additional_config.compile_factors() else: - additional_config_hash = additional_config.compute_hash() - vllm_factors.append(additional_config_hash) + raise TypeError( + "additional_config must be a dict or SupportsCompileFactors" + ) else: - vllm_factors.append("None") - factors.append(vllm_factors) + factors["additional"] = {} - hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[ - :10 - ] - return hash_str + return factors + + def compute_hash(self) -> str: + """Return a stable hash of the compilation-relevant factors.""" + return hash_factors(self.compile_factors()) @property def num_speculative_tokens(self) -> int: diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index c56f8b0364aa..d2c55c81642e 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -35,6 +35,7 @@ from torch.distributed import ProcessGroup, all_reduce from vllm.config import ModelConfig, ParallelConfig +from vllm.config.utils import hash_factors from vllm.distributed.parallel_state import ( get_ep_group, get_eplb_group, @@ -513,7 +514,9 @@ def add_model( communicator=communicator, new_physical_to_logical_map=None, ) - self.model_states[model_config.compute_hash()] = model_state + model_factors = model_config.compile_factors() + model_hash = hash_factors(model_factors) + self.model_states[model_hash] = model_state self.num_valid_physical_experts = model.num_physical_experts def step( @@ -1044,7 +1047,9 @@ def from_mapping( model_config=model_config, ) eplb_state.num_valid_physical_experts = num_valid_physical_experts - eplb_model_state = eplb_state.model_states[model_config.compute_hash()] + model_factors = model_config.compile_factors() + model_hash = hash_factors(model_factors) + eplb_model_state = eplb_state.model_states[model_hash] eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical) (logical_to_physical_map_cpu, logical_replica_count_cpu) = compute_logical_maps( diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 282a92f90196..26a41752b6ed 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -28,7 +28,7 @@ import pickle import weakref from collections import namedtuple -from collections.abc import Callable +from collections.abc import Callable, Sequence from contextlib import contextmanager, nullcontext from dataclasses import dataclass from datetime import timedelta @@ -184,7 +184,7 @@ def patched_fused_scaled_matmul_reduce_scatter_fake( orig_scatter_dim: int, scatter_dim_after_maybe_reshape: int, group_name: str, - output_shape: list[int], + output_shape: Sequence[int], bias: torch.Tensor | None = None, result_scale: torch.Tensor | None = None, out_dtype: torch.dtype | None = None, @@ -236,7 +236,7 @@ def patched_fused_scaled_matmul_reduce_scatter( orig_scatter_dim: int, scatter_dim_after_maybe_reshape: int, group_name: str, - output_shape: list[int], + output_shape: Sequence[int], bias: torch.Tensor | None = None, result_scale: torch.Tensor | None = None, out_dtype: torch.dtype | None = None, diff --git a/vllm/envs.py b/vllm/envs.py index 9e29b53f565e..3e120d4e8857 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1757,10 +1757,8 @@ def validate_environ(hard_fail: bool) -> None: def compile_factors() -> dict[str, object]: - """Return env vars used for torch.compile cache keys. - - Start with every known vLLM env var; drop entries in `ignored_factors`; - hash everything else. This keeps the cache key aligned across workers.""" + """Collect env vars used for torch.compile cache keys.""" + from vllm.config.utils import normalize_value ignored_factors: set[str] = { "MAX_JOBS", @@ -1826,8 +1824,6 @@ def compile_factors() -> dict[str, object]: "NO_COLOR", } - from vllm.config.utils import normalize_value - factors: dict[str, object] = {} for factor, getter in environment_variables.items(): if factor in ignored_factors: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 97976b832097..cf7d1d9204a5 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -10,7 +10,8 @@ from typing import Any import msgspec -from pydantic.dataclasses import dataclass +from pydantic.dataclasses import dataclass as pydantic_dataclass +from typing_extensions import dataclass_transform import vllm.envs as envs from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig @@ -20,6 +21,13 @@ from vllm.utils.mistral import is_mistral_tokenizer from vllm.v1.serial_utils import PydanticMsgspecMixin + +# Keep pydantic runtime behavior while giving mypy dataclass semantics. +@dataclass_transform(field_specifiers=(field,)) +def dataclass(*args, **kwargs): + return pydantic_dataclass(*args, **kwargs) + + logger = init_logger(__name__) _SAMPLING_EPS = 1e-5 diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0fa59579ee76..8500955837c9 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -21,6 +21,7 @@ import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig +from vllm.config.utils import hash_factors from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.envs import enable_envs_cache from vllm.logger import init_logger @@ -990,9 +991,8 @@ def _perform_handshake( "dp_stats_address": dp_stats_address, } if vllm_config.parallel_config.data_parallel_size > 1: - ready_msg["parallel_config_hash"] = ( - vllm_config.parallel_config.compute_hash() - ) + parallel_factors = vllm_config.parallel_config.compile_factors() + ready_msg["parallel_config_hash"] = hash_factors(parallel_factors) handshake_socket.send(msgspec.msgpack.encode(ready_msg)) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 0ce0ed88e414..46fa203284d3 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -19,6 +19,7 @@ from vllm import envs from vllm.config import CacheConfig, ParallelConfig, VllmConfig +from vllm.config.utils import hash_factors from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy @@ -1168,6 +1169,8 @@ def wait_for_engine_startup( f"dp lb mode" ) + parallel_factors = parallel_config.compile_factors() + parallel_hash = hash_factors(parallel_factors) if status == "HELLO" and engine.state == CoreEngineState.NEW: # Send init message with DP config info. init_message = msgspec.msgpack.encode( @@ -1207,7 +1210,7 @@ def wait_for_engine_startup( # Validate config hash consistency across DP workers for MoE models. if coordinated_dp: worker_config_hash = msg.get("parallel_config_hash") - expected_hash = parallel_config.compute_hash() + expected_hash = parallel_hash if worker_config_hash != expected_hash: raise RuntimeError( f"Configuration mismatch detected for engine "