diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 1e66f21ff638..27fc3e8a629d 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -25,14 +25,17 @@ should_split, ) from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig -from vllm.config.utils import hash_factors from vllm.logger import init_logger from vllm.logging_utils import lazy from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.torch_utils import is_torch_equal_or_newer -from .caching import VllmSerializableFunction +from .caching import ( + VllmSerializableFunction, + _compute_code_hash, + compute_env_and_config_hashes, +) from .compiler_interface import ( CompilerInterface, EagerAdaptor, @@ -587,31 +590,16 @@ def __call__( vllm_config = self.vllm_config # Minimal hashing here with existing utilities, reused below. - env_factors = envs.compile_factors() - env_hash = hash_factors(env_factors) - # Compute config/compiler/code hashes once and reuse - config_hash = vllm_config.compute_hash() + env_hash, config_hash, env_factors = compute_env_and_config_hashes(vllm_config) compiler_hash = self.compiler_manager.compute_hash(vllm_config) - forward_code_files = list(sorted(self.compilation_config.traced_files)) + traced_files = set(self.compilation_config.traced_files) + forward_code_files = list(sorted(traced_files)) logger.debug( "Traced files (to be considered for compilation cache):\n%s", lazy(lambda: "\n".join(forward_code_files)), ) - hash_content = [] - for filepath in forward_code_files: - hash_content.append(filepath) - if filepath == "": - # This means the function was dynamically generated, with - # e.g. exec(). We can't actually check these. - continue - try: - with open(filepath) as f: - hash_content.append(f.read()) - except Exception: - logger.warning("Failed to read file %s", filepath) - continue - code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest() + code_hash = _compute_code_hash(traced_files) # Clear after consumption self.compilation_config.traced_files.clear() if not self.compilation_config.cache_dir: diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index 63b7ad7279e3..16c3ade2f0f3 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -135,18 +135,20 @@ def co_name(self): return "VllmSerializableFunction" -def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]: - factors = [] - # 0. factors come from the env, for example, The values of - # VLLM_PP_LAYER_PARTITION will affect the computation graph. - env_hash = hash_factors(envs.compile_factors()) - factors.append(env_hash) - - # 1. factors come from the vllm_config (it mainly summarizes how the - # model is created) +def compute_env_and_config_hashes( + vllm_config: VllmConfig, +) -> tuple[str, str, dict[str, object]]: + """ + Return the hashed environment factors, config hash, and raw env factors. + + Both AOT and JIT cache paths rely on this helper to ensure their cache keys + stay in sync. + """ + + env_factors = envs.compile_factors() + env_hash = hash_factors(env_factors) config_hash = vllm_config.compute_hash() - factors.append(config_hash) - return factors + return env_hash, config_hash, env_factors def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str: diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 11a18c0e6bb7..ee3a34ee7aec 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -29,6 +29,7 @@ from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.torch_utils import supports_dynamo +from .caching import compute_env_and_config_hashes from .monitor import start_monitoring_torch_compile logger = init_logger(__name__) @@ -352,10 +353,10 @@ def __call__(self, *args, **kwargs): serialized backend artifacts), then we need to generate a new AOT compile artifact from scratch. """ - from .caching import compilation_config_hash_factors - - factors: list[str] = compilation_config_hash_factors(self.vllm_config) + # Keep AOT cache key in sync with JIT: env factors + config hash + model. + env_hash, config_hash, _ = compute_env_and_config_hashes(self.vllm_config) + factors: list[str] = [env_hash, config_hash] factors.append(_model_hash_key(self.forward)) hash_key = hashlib.sha256(str(factors).encode()).hexdigest()