Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 41 additions & 51 deletions tests/compile/test_aot_compile.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import functools
import hashlib
import multiprocessing
import os
import pickle
import tempfile
Expand All @@ -15,7 +13,6 @@
import torch

import vllm.envs as envs
import vllm.model_executor.layers.activation
from vllm.compilation.backends import VllmBackend
from vllm.compilation.caching import (
StandaloneCompiledArtifacts,
Expand Down Expand Up @@ -476,64 +473,57 @@ def test_standalone_compile_correctness():
@create_new_process_for_each_test("spawn")
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
"""
Test that compiling gpt2 twice results in a cache hit and
capture torch dynamic symbol creations to ensure make_symbol
not called on cache hit.
"""
Test that compiling gpt2 twice results in a cache hit.

import torch.fx.experimental.symbolic_shapes as symbolic_shapes_module
from torch.utils._sympy.symbol import make_symbol
Counter values are read from the EngineCore subprocess via
``LLM.collective_rpc`` so the test works under default V1
multiprocessing (no shared memory between test and engine).
"""

from vllm import LLM

create_symbol_counter = multiprocessing.Value("i", 0)
original_make_symbol = make_symbol
def _snap(self):
from vllm.compilation.counter import compilation_counter

@functools.wraps(original_make_symbol)
def counting_make_symbol(prefix, idx, **kwargs):
with create_symbol_counter.get_lock():
create_symbol_counter.value += 1
return original_make_symbol(prefix, idx, **kwargs)

symbolic_shapes_module.make_symbol = counting_make_symbol
try:
with monkeypatch.context() as m, tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
# First compilation - initialize model and generate
llm_model = LLM(
model="gpt2",
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
),
max_model_len=256,
)
return (
compilation_counter.num_aot_compiles,
compilation_counter.num_aot_artifacts_saved,
compilation_counter.num_aot_artifacts_loaded,
)

llm_model.generate("Hello, my name is")
assert create_symbol_counter.value == 2
create_symbol_counter.value = 0
# collective_rpc(callable) requires pickle-based serialization.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

# Clean up first model
del llm_model
disable_envs_cache()
vllm.model_executor.layers.activation._ACTIVATION_REGISTRY._dict.clear()
with monkeypatch.context() as m, tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
# First compilation - initialize model and generate
llm_model = LLM(
model="gpt2",
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
),
max_model_len=256,
)

# Second compilation - should hit cache
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
llm_model = LLM(
model="gpt2",
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
),
max_model_len=256,
)
llm_model.generate("Hello, my name is")
llm_model.generate("Hello, my name is")
assert llm_model.collective_rpc(_snap)[0] == (1, 1, 0)

assert create_symbol_counter.value == 0
# Clean up first model
del llm_model
disable_envs_cache()

finally:
# Restore original method
symbolic_shapes_module.make_symbol = original_make_symbol
# Second compilation - should hit cache
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
llm_model = LLM(
model="gpt2",
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
),
max_model_len=256,
)
llm_model.generate("Hello, my name is")
assert llm_model.collective_rpc(_snap)[0] == (0, 0, 1)


@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
Expand Down
Loading