Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 9 additions & 21 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@
should_split,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.utils import hash_factors
from vllm.logger import init_logger
from vllm.logging_utils import lazy
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer

from .caching import VllmSerializableFunction
from .caching import (
VllmSerializableFunction,
_compute_code_hash,
compute_env_and_config_hashes,
)
from .compiler_interface import (
CompilerInterface,
EagerAdaptor,
Expand Down Expand Up @@ -587,31 +590,16 @@ def __call__(
vllm_config = self.vllm_config
# Minimal hashing here with existing utilities, reused below.

env_factors = envs.compile_factors()
env_hash = hash_factors(env_factors)
# Compute config/compiler/code hashes once and reuse
config_hash = vllm_config.compute_hash()
env_hash, config_hash, env_factors = compute_env_and_config_hashes(vllm_config)
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
forward_code_files = list(sorted(self.compilation_config.traced_files))
traced_files = set(self.compilation_config.traced_files)
forward_code_files = list(sorted(traced_files))

logger.debug(
"Traced files (to be considered for compilation cache):\n%s",
lazy(lambda: "\n".join(forward_code_files)),
)
hash_content = []
for filepath in forward_code_files:
hash_content.append(filepath)
if filepath == "<string>":
# This means the function was dynamically generated, with
# e.g. exec(). We can't actually check these.
continue
try:
with open(filepath) as f:
hash_content.append(f.read())
except Exception:
logger.warning("Failed to read file %s", filepath)
continue
code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest()
code_hash = _compute_code_hash(traced_files)
# Clear after consumption
self.compilation_config.traced_files.clear()
if not self.compilation_config.cache_dir:
Expand Down
24 changes: 13 additions & 11 deletions vllm/compilation/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,20 @@ def co_name(self):
return "VllmSerializableFunction"


def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]:
factors = []
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
env_hash = hash_factors(envs.compile_factors())
factors.append(env_hash)

# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
def compute_env_and_config_hashes(
vllm_config: VllmConfig,
) -> tuple[str, str, dict[str, object]]:
"""
Return the hashed environment factors, config hash, and raw env factors.

Both AOT and JIT cache paths rely on this helper to ensure their cache keys
stay in sync.
"""

env_factors = envs.compile_factors()
env_hash = hash_factors(env_factors)
config_hash = vllm_config.compute_hash()
factors.append(config_hash)
return factors
return env_hash, config_hash, env_factors


def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
Expand Down
7 changes: 4 additions & 3 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import supports_dynamo

from .caching import compute_env_and_config_hashes
from .monitor import start_monitoring_torch_compile

logger = init_logger(__name__)
Expand Down Expand Up @@ -352,10 +353,10 @@ def __call__(self, *args, **kwargs):
serialized backend artifacts), then we need to generate a new AOT
compile artifact from scratch.
"""
from .caching import compilation_config_hash_factors

factors: list[str] = compilation_config_hash_factors(self.vllm_config)

# Keep AOT cache key in sync with JIT: env factors + config hash + model.
env_hash, config_hash, _ = compute_env_and_config_hashes(self.vllm_config)
factors: list[str] = [env_hash, config_hash]
factors.append(_model_hash_key(self.forward))
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()

Expand Down