-
-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[torch.compile] caching of config fields should be opt-out by default #26468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
97a6194
Opt-out hashing for torch.compile cache keys (ModelConfig + envs)
vnadathur 61d580d
fixed sha-256 for backends.py
WorldExplored 69f6880
short refactor and addressing comments
WorldExplored d6dccaf
added lazy logging to logging utils
vnadathur 1359256
Addresed Codex Problems
WorldExplored 3a3af9b
Merge branch 'main' into envhashing
WorldExplored 1485803
solve merge conflict
vnadathur e4db7f4
fixed ignore list in model.py
WorldExplored 0648852
Merge branch 'main' into envhashing
WorldExplored 226a4ae
Merge branch 'main' into envhashing
WorldExplored f357dbf
Merge branch 'main' into envhashing
WorldExplored 35acae2
Merge branch 'main' into envhashing
WorldExplored a8cd228
Update lazy.py
vnadathur 2a77dac
Merge branch 'main' into envhashing
WorldExplored e8e10bf
Merge branch 'main' into envhashing
WorldExplored 7e1cb9f
revised test file
WorldExplored 8537a07
Merge branch 'main' into envhashing
WorldExplored 4d10df3
update test
vnadathur a40c3af
Merge branch 'main' into envhashing
vnadathur fd7be7b
Merge branch 'main' into envhashing
WorldExplored 2b6b27b
addressed reviewer concerns
WorldExplored ad00cb2
Merge branch 'main' into envhashing
vnadathur 345c8cc
Merge branch 'main' into envhashing
WorldExplored 80da26d
fixed precommit
WorldExplored a989f3f
Merge branch 'main' into envhashing
WorldExplored 1caaf89
fixing ignored_factors list
vnadathur cd23a09
fixing logger debug factors
vnadathur 357929a
handle passconfig
vnadathur f5cdc9d
adjust factors
vnadathur 5a1f65e
Merge branch 'main' into envhashing
vnadathur 1947f98
addressed reviewer feedback
WorldExplored b4c6ff9
fixed pre-commit
WorldExplored cd14b82
fixed pre-commit
WorldExplored 40a4c97
addressed concerns
WorldExplored a92a16e
Merge branch 'main' into envhashing
WorldExplored 6dc9a57
add _data_parallel_master_port_list to ignoer factors due to failure
vnadathur 653d993
Merge branch 'main' into envhashing
vnadathur a2d5ccf
fixed buildkite
WorldExplored 030143c
Merge branch 'main' into envhashing
vnadathur e8e3b2f
nits and feedback
vnadathur 4fed317
Merge branch 'main' into envhashing
WorldExplored a14319b
fixing CI
vnadathur File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from dataclasses import dataclass | ||
| from enum import Enum | ||
|
|
||
| import pytest | ||
|
|
||
| from vllm.config.utils import get_hash_factors, hash_factors, normalize_value | ||
|
|
||
| # Helpers | ||
|
|
||
|
|
||
| def endswith_fqname(obj, suffix: str) -> bool: | ||
| # normalize_value(type) returns fully-qualified name | ||
| # Compare suffix to avoid brittle import paths. | ||
| out = normalize_value(obj) | ||
| return isinstance(out, str) and out.endswith(suffix) | ||
|
|
||
|
|
||
| def expected_path(p_str: str = ".") -> str: | ||
| import pathlib | ||
|
|
||
| p = pathlib.Path(p_str) | ||
| return p.expanduser().resolve().as_posix() | ||
|
|
||
|
|
||
| # Minimal dataclass to test get_hash_factors. | ||
| # Avoid importing heavy vLLM configs. | ||
| @dataclass | ||
| class SimpleConfig: | ||
| a: object | ||
| b: object | None = None | ||
|
|
||
|
|
||
| class DummyLogprobsMode(Enum): | ||
| RAW_LOGITS = "raw_logits" | ||
|
|
||
|
|
||
| def test_hash_factors_deterministic(): | ||
| """Test that hash_factors produces consistent SHA-256 hashes""" | ||
| factors = {"a": 1, "b": "test"} | ||
| hash1 = hash_factors(factors) | ||
| hash2 = hash_factors(factors) | ||
|
|
||
| assert hash1 == hash2 | ||
| # Dict key insertion order should not affect the hash. | ||
| factors_reordered = {"b": "test", "a": 1} | ||
| assert hash_factors(factors_reordered) == hash1 | ||
| assert len(hash1) == 64 | ||
| assert all(c in "0123456789abcdef" for c in hash1) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "inp, expected", | ||
| [ | ||
| (None, None), | ||
| (True, True), | ||
| (1, 1), | ||
| (1.0, 1.0), | ||
| ("x", "x"), | ||
| (b"ab", "6162"), | ||
| (bytearray(b"ab"), "6162"), | ||
| ([1, 2], (1, 2)), | ||
| ({"b": 2, "a": 1}, (("a", 1), ("b", 2))), | ||
| ], | ||
| ) | ||
| def test_normalize_value_matrix(inp, expected): | ||
| """Parametric input→expected normalization table.""" | ||
| assert normalize_value(inp) == expected | ||
|
|
||
|
|
||
| def test_normalize_value_enum(): | ||
| # Enums normalize to (module.QualName, value). | ||
| # DummyLogprobsMode uses a string payload. | ||
| out = normalize_value(DummyLogprobsMode.RAW_LOGITS) | ||
| assert isinstance(out, tuple) | ||
| assert out[0].endswith("DummyLogprobsMode") | ||
| # Expect string payload 'raw_logits'. | ||
| assert out[1] == "raw_logits" | ||
|
|
||
|
|
||
| def test_normalize_value_set_order_insensitive(): | ||
| # Sets are unordered; normalize_value sorts elements for determinism. | ||
| assert normalize_value({3, 1, 2}) == normalize_value({1, 2, 3}) | ||
|
|
||
|
|
||
| def test_normalize_value_path_normalization(): | ||
| from pathlib import Path # local import to avoid global dependency | ||
|
|
||
| # Paths expand/resolve to absolute strings. | ||
| # Stabilizes hashing across working dirs. | ||
| assert normalize_value(Path(".")) == expected_path(".") | ||
|
|
||
|
|
||
| def test_normalize_value_uuid_and_to_json(): | ||
| # Objects may normalize via uuid() or to_json_string(). | ||
| class HasUUID: | ||
| def uuid(self): | ||
| return "test-uuid" | ||
|
|
||
| class ToJson: | ||
| def to_json_string(self): | ||
| return '{"x":1}' | ||
|
|
||
| assert normalize_value(HasUUID()) == "test-uuid" | ||
| assert normalize_value(ToJson()) == '{"x":1}' | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "bad", | ||
| [ | ||
| (lambda x: x), | ||
| (type("CallableInstance", (), {"__call__": lambda self: 0}))(), | ||
| (lambda: (lambda: 0))(), # nested function instance | ||
| ], | ||
| ) | ||
| def test_error_cases(bad): | ||
| """Inputs expected to raise TypeError.""" | ||
| # Reject functions/lambdas/callable instances | ||
| # to avoid under-hashing. | ||
| with pytest.raises(TypeError): | ||
| normalize_value(bad) | ||
|
|
||
|
|
||
| def test_enum_vs_int_disambiguation(): | ||
| # int stays primitive | ||
| nf_int = normalize_value(1) | ||
| assert nf_int == 1 | ||
|
|
||
| # enum becomes ("module.QualName", value) | ||
| nf_enum = normalize_value(DummyLogprobsMode.RAW_LOGITS) | ||
| assert isinstance(nf_enum, tuple) and len(nf_enum) == 2 | ||
| enum_type, enum_val = nf_enum | ||
| assert enum_type.endswith(".DummyLogprobsMode") | ||
| assert enum_val == "raw_logits" | ||
|
|
||
| # Build factor dicts from configs with int vs enum | ||
| f_int = get_hash_factors(SimpleConfig(1), set()) | ||
| f_enum = get_hash_factors(SimpleConfig(DummyLogprobsMode.RAW_LOGITS), set()) | ||
| # The int case remains a primitive value | ||
| assert f_int["a"] == 1 | ||
| # The enum case becomes a tagged tuple ("module.QualName", "raw_logits") | ||
| assert isinstance(f_enum["a"], tuple) and f_enum["a"][1] == "raw_logits" | ||
| # Factor dicts must differ so we don't collide primitives with Enums. | ||
| assert f_int != f_enum | ||
| # Hash digests must differ correspondingly | ||
| assert hash_factors(f_int) != hash_factors(f_enum) | ||
|
|
||
| # Hash functions produce stable hex strings | ||
| h_int = hash_factors(f_int) | ||
| h_enum = hash_factors(f_enum) | ||
| assert isinstance(h_int, str) and len(h_int) == 64 | ||
| assert isinstance(h_enum, str) and len(h_enum) == 64 | ||
|
|
||
|
|
||
| def test_classes_are_types(): | ||
| """Types normalize to FQNs; include real vLLM types.""" | ||
| # Only classes allowed; functions/lambdas are rejected. | ||
| # Canonical form is the fully-qualified name. | ||
| assert isinstance(normalize_value(str), str) | ||
|
|
||
| class LocalDummy: | ||
| pass | ||
|
|
||
| assert endswith_fqname(LocalDummy, ".LocalDummy") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,12 +4,14 @@ | |
| import ast | ||
| import dataclasses | ||
| import hashlib | ||
| import json | ||
| import operator | ||
| import os | ||
| import pprint | ||
| import time | ||
| from collections.abc import Callable, Sequence | ||
| from contextlib import contextmanager | ||
| from functools import partial | ||
| from typing import Any | ||
|
|
||
| import torch | ||
|
|
@@ -23,7 +25,9 @@ | |
| should_split, | ||
| ) | ||
| from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig | ||
| from vllm.config.utils import hash_factors | ||
| from vllm.logger import init_logger | ||
| from vllm.logging_utils import lazy | ||
| from vllm.platforms import current_platform | ||
| from vllm.utils.import_utils import resolve_obj_by_qualname | ||
| from vllm.utils.torch_utils import is_torch_equal_or_newer | ||
|
|
@@ -580,35 +584,47 @@ def configure_post_pass(self): | |
| def __call__( | ||
| self, graph: fx.GraphModule, example_inputs | ||
| ) -> VllmSerializableFunction: | ||
| from .caching import _compute_code_hash, compilation_config_hash_factors | ||
|
|
||
ProExpertProg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| vllm_config = self.vllm_config | ||
| # Minimal hashing here with existing utilities, reused below. | ||
|
|
||
| env_factors = envs.compile_factors() | ||
| env_hash = hash_factors(env_factors) | ||
| # Compute config/compiler/code hashes once and reuse | ||
| config_hash = vllm_config.compute_hash() | ||
vnadathur marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| compiler_hash = self.compiler_manager.compute_hash(vllm_config) | ||
| forward_code_files = list(sorted(self.compilation_config.traced_files)) | ||
|
|
||
| logger.debug( | ||
| "Traced files (to be considered for compilation cache):\n%s", | ||
| lazy(lambda: "\n".join(forward_code_files)), | ||
| ) | ||
| hash_content = [] | ||
| for filepath in forward_code_files: | ||
| hash_content.append(filepath) | ||
| if filepath == "<string>": | ||
| # This means the function was dynamically generated, with | ||
| # e.g. exec(). We can't actually check these. | ||
| continue | ||
| try: | ||
| with open(filepath) as f: | ||
| hash_content.append(f.read()) | ||
| except Exception: | ||
| logger.warning("Failed to read file %s", filepath) | ||
| continue | ||
| code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest() | ||
| # Clear after consumption | ||
| self.compilation_config.traced_files.clear() | ||
| if not self.compilation_config.cache_dir: | ||
| # no provided cache dir, generate one based on the known factors | ||
| # that affects the compilation. if none of the factors change, | ||
| # the cache dir will be the same so that we can reuse the compiled | ||
| # graph. | ||
|
|
||
| factors = compilation_config_hash_factors(vllm_config) | ||
| # 2. factors come from the code files that are traced by Dynamo ( | ||
| # it mainly summarizes how the model is used in forward pass) | ||
| code_hash = _compute_code_hash(self.compilation_config.traced_files) | ||
| self.compilation_config.traced_files.clear() | ||
| factors.append(code_hash) | ||
|
|
||
| # 3. compiler hash | ||
| compiler_hash = self.compiler_manager.compute_hash(vllm_config) | ||
| factors.append(compiler_hash) | ||
|
|
||
| # combine all factors to generate the cache dir | ||
| hash_key = hashlib.md5( | ||
| str(factors).encode(), usedforsecurity=False | ||
| ).hexdigest()[:10] | ||
|
|
||
| factors = [env_hash, config_hash, code_hash, compiler_hash] | ||
| # Use SHA-256 for cache key hashing to be consistent across | ||
| # compute_hash functions. Truncate for a short cache dir name. | ||
| hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we have used the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will add to this pr: #29117 |
||
| cache_dir = os.path.join( | ||
| envs.VLLM_CACHE_ROOT, | ||
| "torch_compile_cache", | ||
| hash_key, | ||
| envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key | ||
| ) | ||
| self.compilation_config.cache_dir = cache_dir | ||
|
|
||
|
|
@@ -621,6 +637,7 @@ def __call__( | |
| os.makedirs(local_cache_dir, exist_ok=True) | ||
| self.compilation_config.local_cache_dir = local_cache_dir | ||
|
|
||
| # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE. | ||
| disable_cache = not is_compile_cache_enabled( | ||
| self.compilation_config.inductor_compile_config | ||
| ) | ||
|
|
@@ -638,6 +655,50 @@ def __call__( | |
| local_cache_dir, disable_cache, self.prefix | ||
| ) | ||
|
|
||
| # Reuses existing cache key | ||
|
|
||
| logger.debug( | ||
| "torch.compile cache factors: env=%s cfg=%s comp=%s code=%s dir=%s", | ||
| env_hash, | ||
| config_hash, | ||
| compiler_hash, | ||
| code_hash, | ||
| local_cache_dir, | ||
| ) | ||
|
|
||
| # Persist and log only hash-relevant factors together. | ||
vnadathur marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| try: | ||
| logger.debug( | ||
| "Compile env factors (raw):\n%s\nVllm config hash: %s", | ||
| lazy(partial(pprint.pformat, env_factors, width=120)), | ||
| config_hash, | ||
| ) | ||
| meta_path = os.path.join(local_cache_dir, "cache_key_factors.json") | ||
| if not os.path.exists(meta_path): | ||
| with open(meta_path, "w") as f: | ||
| json.dump( | ||
| { | ||
| "env": env_factors, # raw factors used for env_hash | ||
| "config_hash": config_hash, | ||
| "code_hash": code_hash, | ||
| "compiler_hash": compiler_hash, | ||
| }, | ||
| f, | ||
| indent=2, | ||
| sort_keys=True, | ||
| ) | ||
| except Exception: | ||
| # Best-effort only; metadata write failures are non-fatal. | ||
| logger.warning( | ||
| ( | ||
| "Could not write compile cache metadata at %s; continuing without " | ||
| "metadata. Compiled cache remains valid; diagnostics may be " | ||
| "limited." | ||
| ), | ||
| local_cache_dir, | ||
| exc_info=True, | ||
| ) | ||
|
|
||
| # when dynamo calls the backend, it means the bytecode | ||
| # transform and analysis are done | ||
| compilation_counter.num_graphs_seen += 1 | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.