Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tests/compile/test_compile_ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_compile_ranges(use_fresh_inductor_cache):
Range(start=16, end=16),
Range(start=9, end=32),
Range(start=64, end=64),
Range(start=128, end=128),
Copy link
Collaborator Author

@zou3519 zou3519 Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile_sizes has 128 in it. Previously we would never trigger it because it was not one of batch sizes. Now we compile everything up front without seeing the batch sizes, so 128 is compiled as well.

Range(start=33, end=8192),
]
)
Expand All @@ -95,16 +96,16 @@ def test_compile_ranges(use_fresh_inductor_cache):

with set_current_vllm_config(vllm_config):
model = TestModel(vllm_config=vllm_config, prefix="").eval()
# Number of compilations: 3 for each compile range + 2 compile sizes
# Number of compilations: 3 compile ranges + 3 compile sizes
batch_sizes = [1, 4, 16, 24, 48, 64, 8192]

with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=1,
num_backend_compilations=5,
num_backend_compilations=6,
):
run_model(vllm_config, model, batch_sizes)
assert post_grad_range_checker.num_calls == 5
assert post_grad_range_checker.num_calls == 6


def test_compile_config_get_compile_ranges():
Expand Down
6 changes: 3 additions & 3 deletions tests/compile/test_structured_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def test_vllm_structured_logging_artifacts(use_fresh_inductor_cache):
f"got {len(vllm_piecewise_split_graph)}"
)
compile_start_artifacts = capture.get("artifact", "vllm_piecewise_compile_start")
assert len(compile_start_artifacts) == 2, (
"Expected 2 vllm_piecewise_compile_start "
"(one for dynamic ranges, one for compile size), "
assert len(compile_start_artifacts) == 4, (
"Expected 4 vllm_piecewise_compile_start "
"(2 subgraphs x 2 ranges each: dynamic + compile size), "
f"got {len(compile_start_artifacts)}"
)
submod_dumps = capture.get("graph_dump", r"vllm_submod_.*")
Expand Down
51 changes: 20 additions & 31 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import ast
import contextvars
import dataclasses
import hashlib
import json
Expand All @@ -18,7 +17,7 @@

import torch
import torch.fx as fx
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import dynamo_timed
from torch._logging._internal import trace_structured

import vllm.envs as envs
Expand Down Expand Up @@ -510,9 +509,9 @@ def wrap_with_cudagraph_if_needed(

class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
compilation configs.
It runs the given split graph interpreter, and for each submodule in
`compile_submod_names`, creates a PiecewiseBackend and compiles all
ranges up front.

NOTE: the order in `compile_submod_names` matters, because
it will be used to determine the order of the compiled piecewise
Expand Down Expand Up @@ -540,9 +539,6 @@ def __init__(
vllm_backend: "VllmBackend",
) -> None:
super().__init__(module)
from torch._guards import detect_fake_mode

self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config
Expand All @@ -552,13 +548,7 @@ def __init__(

@instrument(span_name="Inductor compilation")
def run(self, *args: Any) -> Any:
# maybe instead just assert inputs are fake?
fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in args
]
with self.fake_mode, enable_python_dispatcher():
return super().run(*fake_args)
return super().run(*args)

def call_module(
self,
Expand Down Expand Up @@ -614,21 +604,6 @@ def call_module(
model_tag: str = "backbone"
model_is_encoder: bool = False

_on_compilation_complete_callback: contextvars.ContextVar[Callable[[], None] | None] = (
contextvars.ContextVar("on_compilation_complete_callback", default=None)
)


@contextmanager
def set_on_compilation_complete(
callback: Callable[[], None],
) -> Generator[None, None, None]:
token = _on_compilation_complete_callback.set(callback)
try:
yield
finally:
_on_compilation_complete_callback.reset(token)


@contextmanager
def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]:
Expand Down Expand Up @@ -846,6 +821,7 @@ def list_to_str(lst: list | None) -> str:
),
)

@dynamo_timed("vllm_backend")
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
from .caching import (
VllmSerializableFunction,
Expand Down Expand Up @@ -1036,11 +1012,24 @@ def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
]

# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
# compile submodules with symbolic shapes, and compile all ranges
# up front so that compilation is complete before the callable
# is returned.
PiecewiseCompileInterpreter(
self.split_gm, submod_names_to_compile, self.vllm_config, self
).run(*fake_args)

# All compilation is done. Save the cache.
time_before_saving = time.perf_counter()
self.compiler_manager.save_to_file()
elapsed = time.perf_counter() - time_before_saving
if elapsed > 1:
logger.info_once(
"Saved compiler manager cache in %.2f seconds.",
elapsed,
scope="local",
)

from torch._guards import detect_fake_mode

fake_mode = detect_fake_mode()
Expand Down
26 changes: 11 additions & 15 deletions vllm/compilation/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,30 +313,26 @@ def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction

return fn

# Fall back to standard VllmBackend
# Fall back to standard VllmBackend.
# Use a lazy closure: the backend needs traced_files for cache
# dir computation, but those are only populated after
# _verify_source_unchanged runs in decorators.py (which happens
# after deserialization completes).
from vllm.compilation.backends import VllmBackend

is_encoder = state.get("is_encoder", False)
vllm_backend: VllmBackend = VllmBackend(
get_current_vllm_config(), state["prefix"], is_encoder
)
vllm_config = get_current_vllm_config()
compile_inputs = list(state["example_inputs"])

def optimized_call(*example_inputs: Any) -> Any:
"""
On the first run of the optimized call, we rerun the compiler
backend which should result in a cache hit. After the backend
call returns, we just do a one-time replacement of the optimized
call with the compiled function, so that subsequent calls are on
the AOT compiled path.
"""
compile_inputs = [
inp if inp is not None else example_inputs[i]
for i, inp in enumerate(fn.example_inputs)
]
vllm_backend: VllmBackend = VllmBackend(
vllm_config, state["prefix"], is_encoder
)
with tracing(TracingContext(fake_mode)):
fn.optimized_call = vllm_backend(
state["graph_module"], compile_inputs
).optimized_call
fn.vllm_backend = vllm_backend
return fn.optimized_call(*example_inputs)

fn = cls(**state, optimized_call=optimized_call)
Expand Down
19 changes: 12 additions & 7 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,12 @@ def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any:
"Directly load AOT compilation from path %s", aot_compilation_path
)
# Apply partition wrapper context for proper CUDA graph capture
from .monitor import end_monitoring_torch_compile

with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
return self.aot_compiled_fn(self, *args, **kwargs)
output = self.aot_compiled_fn(self, *args, **kwargs)
end_monitoring_torch_compile(self.vllm_config)
return output

if self.compiled:
assert (
Expand Down Expand Up @@ -552,18 +556,19 @@ def patched_inline_call(self_: Any) -> Any:
logger.warning("Detected eager backend, disabling AOT compile.")
use_aot_compile = False
if use_aot_compile:
from vllm.compilation.backends import set_on_compilation_complete

# store the path for saving after warmup
self._aot_compilation_path = aot_compilation_path
self._aot_cache_dir = cache_dir
# set callback in context so it's available when compilation completes
with set_on_compilation_complete(self.save_aot_compiled_function):
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
output = self.aot_compiled_fn(self, *args, **kwargs)
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
# All compilation is done at this point, save the AOT artifact.
self.save_aot_compiled_function()
output = self.aot_compiled_fn(self, *args, **kwargs)
else:
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]

from .monitor import end_monitoring_torch_compile

end_monitoring_torch_compile(self.vllm_config)
self.compiled = True
return output

Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
total_compile_time: float = time.perf_counter() - torch_compile_start_time
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
logger.info_once(
"torch.compile takes %.2f s in total",
"torch.compile and initial profiling run took %.2f s in total",
total_compile_time,
scope="local",
)
Expand Down
Loading