diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index c3a065c56142..13e988307047 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -1,9 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools import hashlib -import multiprocessing import os import pickle import tempfile @@ -15,7 +13,6 @@ import torch import vllm.envs as envs -import vllm.model_executor.layers.activation from vllm.compilation.backends import VllmBackend from vllm.compilation.caching import ( StandaloneCompiledArtifacts, @@ -476,64 +473,57 @@ def test_standalone_compile_correctness(): @create_new_process_for_each_test("spawn") def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch): """ - Test that compiling gpt2 twice results in a cache hit and - capture torch dynamic symbol creations to ensure make_symbol - not called on cache hit. - """ + Test that compiling gpt2 twice results in a cache hit. - import torch.fx.experimental.symbolic_shapes as symbolic_shapes_module - from torch.utils._sympy.symbol import make_symbol + Counter values are read from the EngineCore subprocess via + ``LLM.collective_rpc`` so the test works under default V1 + multiprocessing (no shared memory between test and engine). + """ from vllm import LLM - create_symbol_counter = multiprocessing.Value("i", 0) - original_make_symbol = make_symbol + def _snap(self): + from vllm.compilation.counter import compilation_counter - @functools.wraps(original_make_symbol) - def counting_make_symbol(prefix, idx, **kwargs): - with create_symbol_counter.get_lock(): - create_symbol_counter.value += 1 - return original_make_symbol(prefix, idx, **kwargs) - - symbolic_shapes_module.make_symbol = counting_make_symbol - try: - with monkeypatch.context() as m, tempfile.TemporaryDirectory() as tmpdirname: - m.setenv("VLLM_CACHE_ROOT", tmpdirname) - m.setenv("VLLM_USE_AOT_COMPILE", "1") - # First compilation - initialize model and generate - llm_model = LLM( - model="gpt2", - compilation_config=CompilationConfig( - mode=CompilationMode.VLLM_COMPILE, - ), - max_model_len=256, - ) + return ( + compilation_counter.num_aot_compiles, + compilation_counter.num_aot_artifacts_saved, + compilation_counter.num_aot_artifacts_loaded, + ) - llm_model.generate("Hello, my name is") - assert create_symbol_counter.value == 2 - create_symbol_counter.value = 0 + # collective_rpc(callable) requires pickle-based serialization. + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - # Clean up first model - del llm_model - disable_envs_cache() - vllm.model_executor.layers.activation._ACTIVATION_REGISTRY._dict.clear() + with monkeypatch.context() as m, tempfile.TemporaryDirectory() as tmpdirname: + m.setenv("VLLM_CACHE_ROOT", tmpdirname) + m.setenv("VLLM_USE_AOT_COMPILE", "1") + # First compilation - initialize model and generate + llm_model = LLM( + model="gpt2", + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + ), + max_model_len=256, + ) - # Second compilation - should hit cache - m.setenv("VLLM_FORCE_AOT_LOAD", "1") - llm_model = LLM( - model="gpt2", - compilation_config=CompilationConfig( - mode=CompilationMode.VLLM_COMPILE, - ), - max_model_len=256, - ) - llm_model.generate("Hello, my name is") + llm_model.generate("Hello, my name is") + assert llm_model.collective_rpc(_snap)[0] == (1, 1, 0) - assert create_symbol_counter.value == 0 + # Clean up first model + del llm_model + disable_envs_cache() - finally: - # Restore original method - symbolic_shapes_module.make_symbol = original_make_symbol + # Second compilation - should hit cache + m.setenv("VLLM_FORCE_AOT_LOAD", "1") + llm_model = LLM( + model="gpt2", + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + ), + max_model_len=256, + ) + llm_model.generate("Hello, my name is") + assert llm_model.collective_rpc(_snap)[0] == (0, 0, 1) @pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")