diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index c90454ed0e95..430db850c303 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -73,6 +73,7 @@ def test_compile_ranges(use_fresh_inductor_cache): Range(start=16, end=16), Range(start=9, end=32), Range(start=64, end=64), + Range(start=128, end=128), Range(start=33, end=8192), ] ) @@ -95,16 +96,16 @@ def test_compile_ranges(use_fresh_inductor_cache): with set_current_vllm_config(vllm_config): model = TestModel(vllm_config=vllm_config, prefix="").eval() - # Number of compilations: 3 for each compile range + 2 compile sizes + # Number of compilations: 3 compile ranges + 3 compile sizes batch_sizes = [1, 4, 16, 24, 48, 64, 8192] with compilation_counter.expect( num_graphs_seen=1, num_piecewise_graphs_seen=1, - num_backend_compilations=5, + num_backend_compilations=6, ): run_model(vllm_config, model, batch_sizes) - assert post_grad_range_checker.num_calls == 5 + assert post_grad_range_checker.num_calls == 6 def test_compile_config_get_compile_ranges(): diff --git a/tests/compile/test_structured_logging.py b/tests/compile/test_structured_logging.py index 059665254f53..7813b7429b1f 100644 --- a/tests/compile/test_structured_logging.py +++ b/tests/compile/test_structured_logging.py @@ -109,9 +109,9 @@ def test_vllm_structured_logging_artifacts(use_fresh_inductor_cache): f"got {len(vllm_piecewise_split_graph)}" ) compile_start_artifacts = capture.get("artifact", "vllm_piecewise_compile_start") - assert len(compile_start_artifacts) == 2, ( - "Expected 2 vllm_piecewise_compile_start " - "(one for dynamic ranges, one for compile size), " + assert len(compile_start_artifacts) == 4, ( + "Expected 4 vllm_piecewise_compile_start " + "(2 subgraphs x 2 ranges each: dynamic + compile size), " f"got {len(compile_start_artifacts)}" ) submod_dumps = capture.get("graph_dump", r"vllm_submod_.*") diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 09fd1f75091e..7b493d9b92a8 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast -import contextvars import dataclasses import hashlib import json @@ -18,7 +17,7 @@ import torch import torch.fx as fx -from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import dynamo_timed from torch._logging._internal import trace_structured import vllm.envs as envs @@ -510,9 +509,9 @@ def wrap_with_cudagraph_if_needed( class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. - It runs the given graph with fake inputs, and compile some - submodules specified by `compile_submod_names` with the given - compilation configs. + It runs the given split graph interpreter, and for each submodule in + `compile_submod_names`, creates a PiecewiseBackend and compiles all + ranges up front. NOTE: the order in `compile_submod_names` matters, because it will be used to determine the order of the compiled piecewise @@ -540,9 +539,6 @@ def __init__( vllm_backend: "VllmBackend", ) -> None: super().__init__(module) - from torch._guards import detect_fake_mode - - self.fake_mode = detect_fake_mode() self.compile_submod_names = compile_submod_names self.compilation_config = vllm_config.compilation_config self.vllm_config = vllm_config @@ -552,13 +548,7 @@ def __init__( @instrument(span_name="Inductor compilation") def run(self, *args: Any) -> Any: - # maybe instead just assert inputs are fake? - fake_args = [ - self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t - for t in args - ] - with self.fake_mode, enable_python_dispatcher(): - return super().run(*fake_args) + return super().run(*args) def call_module( self, @@ -614,21 +604,6 @@ def call_module( model_tag: str = "backbone" model_is_encoder: bool = False -_on_compilation_complete_callback: contextvars.ContextVar[Callable[[], None] | None] = ( - contextvars.ContextVar("on_compilation_complete_callback", default=None) -) - - -@contextmanager -def set_on_compilation_complete( - callback: Callable[[], None], -) -> Generator[None, None, None]: - token = _on_compilation_complete_callback.set(callback) - try: - yield - finally: - _on_compilation_complete_callback.reset(token) - @contextmanager def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]: @@ -846,6 +821,7 @@ def list_to_str(lst: list | None) -> str: ), ) + @dynamo_timed("vllm_backend") def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any: from .caching import ( VllmSerializableFunction, @@ -1036,11 +1012,24 @@ def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any: ] # propagate the split graph to the piecewise backend, - # compile submodules with symbolic shapes + # compile submodules with symbolic shapes, and compile all ranges + # up front so that compilation is complete before the callable + # is returned. PiecewiseCompileInterpreter( self.split_gm, submod_names_to_compile, self.vllm_config, self ).run(*fake_args) + # All compilation is done. Save the cache. + time_before_saving = time.perf_counter() + self.compiler_manager.save_to_file() + elapsed = time.perf_counter() - time_before_saving + if elapsed > 1: + logger.info_once( + "Saved compiler manager cache in %.2f seconds.", + elapsed, + scope="local", + ) + from torch._guards import detect_fake_mode fake_mode = detect_fake_mode() diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index 3917a4f28cf9..7f3a844a5905 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -313,30 +313,26 @@ def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction return fn - # Fall back to standard VllmBackend + # Fall back to standard VllmBackend. + # Use a lazy closure: the backend needs traced_files for cache + # dir computation, but those are only populated after + # _verify_source_unchanged runs in decorators.py (which happens + # after deserialization completes). from vllm.compilation.backends import VllmBackend is_encoder = state.get("is_encoder", False) - vllm_backend: VllmBackend = VllmBackend( - get_current_vllm_config(), state["prefix"], is_encoder - ) + vllm_config = get_current_vllm_config() + compile_inputs = list(state["example_inputs"]) def optimized_call(*example_inputs: Any) -> Any: - """ - On the first run of the optimized call, we rerun the compiler - backend which should result in a cache hit. After the backend - call returns, we just do a one-time replacement of the optimized - call with the compiled function, so that subsequent calls are on - the AOT compiled path. - """ - compile_inputs = [ - inp if inp is not None else example_inputs[i] - for i, inp in enumerate(fn.example_inputs) - ] + vllm_backend: VllmBackend = VllmBackend( + vllm_config, state["prefix"], is_encoder + ) with tracing(TracingContext(fake_mode)): fn.optimized_call = vllm_backend( state["graph_module"], compile_inputs ).optimized_call + fn.vllm_backend = vllm_backend return fn.optimized_call(*example_inputs) fn = cls(**state, optimized_call=optimized_call) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index c6bc5506a589..6645a0681387 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -466,8 +466,12 @@ def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any: "Directly load AOT compilation from path %s", aot_compilation_path ) # Apply partition wrapper context for proper CUDA graph capture + from .monitor import end_monitoring_torch_compile + with maybe_use_cudagraph_partition_wrapper(self.vllm_config): - return self.aot_compiled_fn(self, *args, **kwargs) + output = self.aot_compiled_fn(self, *args, **kwargs) + end_monitoring_torch_compile(self.vllm_config) + return output if self.compiled: assert ( @@ -552,18 +556,19 @@ def patched_inline_call(self_: Any) -> Any: logger.warning("Detected eager backend, disabling AOT compile.") use_aot_compile = False if use_aot_compile: - from vllm.compilation.backends import set_on_compilation_complete - # store the path for saving after warmup self._aot_compilation_path = aot_compilation_path self._aot_cache_dir = cache_dir - # set callback in context so it's available when compilation completes - with set_on_compilation_complete(self.save_aot_compiled_function): - self.aot_compiled_fn = self.aot_compile(*args, **kwargs) - output = self.aot_compiled_fn(self, *args, **kwargs) + self.aot_compiled_fn = self.aot_compile(*args, **kwargs) + # All compilation is done at this point, save the AOT artifact. + self.save_aot_compiled_function() + output = self.aot_compiled_fn(self, *args, **kwargs) else: output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type] + from .monitor import end_monitoring_torch_compile + + end_monitoring_torch_compile(self.vllm_config) self.compiled = True return output diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index 43b9ae508a5c..fb9dfa3ac127 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -33,7 +33,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None: total_compile_time: float = time.perf_counter() - torch_compile_start_time if compilation_config.mode == CompilationMode.VLLM_COMPILE: logger.info_once( - "torch.compile takes %.2f s in total", + "torch.compile and initial profiling run took %.2f s in total", total_compile_time, scope="local", ) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index f9eb245893d3..ef2b895757fe 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -5,7 +5,6 @@ import io import json import pickle -import time from collections.abc import Callable from pickle import Pickler from typing import Any @@ -16,7 +15,6 @@ from torch._logging._internal import trace_structured from vllm.compilation.backends import VllmBackend -from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig from vllm.config.utils import Range from vllm.logger import init_logger @@ -24,6 +22,55 @@ logger = init_logger(__name__) +def get_fake_args_from_graph(graph: fx.GraphModule) -> list[Any]: + """Get fake args directly from graph placeholder nodes.""" + fake_args = [] + for node in graph.graph.nodes: + if node.op == "placeholder": + fake_args.append(node.meta["example_value"]) + else: + break + return fake_args + + +def create_concrete_args(graph: fx.GraphModule, size: int) -> list[Any]: + """Create example inputs with symbolic dims replaced by a concrete size. + + Used for single-size eager compilation where we need concrete-shaped + inputs but don't have real runtime tensors yet. + """ + from torch._prims_common import compute_required_storage_length + from torch.fx.experimental.symbolic_shapes import is_symbolic + + def concretize(sym_val: Any) -> int: + """Replace all symbolic variables in a SymInt expression with size.""" + if not is_symbolic(sym_val): + return int(sym_val) + expr = sym_val.node.expr + return int(expr.subs({s: size for s in expr.free_symbols})) + + args: list[Any] = [] + for node in graph.graph.nodes: + if node.op != "placeholder": + break + val = node.meta["example_value"] + if isinstance(val, torch.SymInt): + args.append(concretize(val)) + elif isinstance(val, torch.Tensor): + new_shape = tuple(concretize(d) for d in val.shape) + new_strides = tuple(concretize(s) for s in val.stride()) + new_storage_offset = concretize(val.storage_offset()) + needed_size = compute_required_storage_length( + new_shape, new_strides, new_storage_offset + ) + t = torch.empty(needed_size, dtype=val.dtype, device=val.device) + t = t.as_strided(new_shape, new_strides, new_storage_offset) + args.append(t) + else: + args.append(val) + return args + + @dataclasses.dataclass class RangeEntry: compile_range: Range @@ -109,10 +156,6 @@ def __init__( # the entries for ranges that we need to either self.range_entries: dict[Range, RangeEntry] = {} - # to_be_compiled_ranges tracks the remaining ranges to compile, - # and updates during the compilation process, so we need to copy it - self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges) - # We only keep compilation management inside this class directly. if self.compile_sizes is not None: for size in self.compile_sizes: @@ -129,7 +172,6 @@ def __init__( self.range_entries[range] = RangeEntry( compile_range=range, ) - self.to_be_compiled_ranges.add(range) for range in self.compile_ranges: self.range_entries[range] = RangeEntry( @@ -139,12 +181,10 @@ def __init__( # Track whether we've logged the graph for this subgraph (only log once) self._graph_logged = False - # get the on_compilation_complete callback from context... - # PiecewiseBackend is created during the first call, - # which is when the context is set (see compilation/decorators.py) - from vllm.compilation.backends import _on_compilation_complete_callback - - self.on_compilation_complete = _on_compilation_complete_callback.get() + if self.graph is not None: + self.compile_all_ranges() + else: + self.load_all_ranges() def get_compiled_graph_wrapper( self, compiled_graph: Callable[..., Any] @@ -161,25 +201,6 @@ def compiled_graph_wrapper(*args: Any) -> Any: return compiled_graph_wrapper - def check_for_ending_compilation(self) -> None: - if self.is_last_graph and not self.to_be_compiled_ranges: - # no specific sizes to compile - # save the hash of the inductor graph for the next run - time_before_saving = time.perf_counter() - self.vllm_backend.compiler_manager.save_to_file() - elapsed = time.perf_counter() - time_before_saving - if elapsed > 1: - logger.info_once( - "Saved compiler manager cache in %.2f seconds.", - elapsed, - scope="local", - ) - - end_monitoring_torch_compile(self.vllm_config) - # Call the completion callback (e.g., to save AOT compiled function) - if self.on_compilation_complete is not None: - self.on_compilation_complete() - def to_bytes(self) -> dict[str, bytes]: class StandaloneCompiledArtifactsPickler(Pickler): def reducer_override(self, obj: object) -> Any: @@ -216,27 +237,54 @@ def serialize(fn: Callable[..., Any]) -> bytes: return out - def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]: - # We need to pass fake example_inputs, otherwise torch.compile - # will fakify the example_inputs potentially causing some non dynamic - # dimension to be be duck shaped to other existing shapes that have hints - # matching their values. - # This is problem because it can lead to unintended specializations! - # if the new wrongly dynamic dim is specialized - # it will force specializing the whole shape - # torch.compile probably should not accept - # non fake tensors as example inputs! - # See issue https://github.com/vllm-project/vllm/issues/27899 - fake_example_inputs = [] - assert self.graph is not None - for node in self.graph.graph.nodes: - # All place holders come first - if node.op == "placeholder": - fake_example_inputs.append(node.meta["example_value"]) + def compile_all_ranges(self) -> None: + """Compile all range entries for this piecewise subgraph up front.""" + assert self.graph is not None, ( + "Cannot compile without a graph. " + "When loading from cache/AOT artifacts, " + "compile_all_ranges should not be called." + ) + + for range_entry in self.range_entries.values(): + if range_entry.compiled: + continue + + self._log_compile_start(range_entry.compile_range) + + if range_entry.compile_range.is_single_size(): + args_list = create_concrete_args( + self.graph, range_entry.compile_range.start + ) else: - break - assert len(fake_example_inputs) == len(args) - return fake_example_inputs + args_list = get_fake_args_from_graph(self.graph) + + # TODO(https://github.com/vllm-project/vllm/issues/35766) + # Can we remove strict_autograd_cache and + # force_non_lazy_backward_lowering overrides? + # I added them explicitly because this is what they are + # set to before the refactor + # (https://github.com/vllm-project/vllm/pull/35472). + # They affect the aotautograd cache key computation + # but they shouldn't have any effect on the actual + # compilation. + config_patches = dict( + bundled_autograd_cache=True, + strict_autograd_cache=False, + ) + if hasattr(torch._functorch.config, "force_non_lazy_backward_lowering"): + config_patches["force_non_lazy_backward_lowering"] = False + with torch._functorch.config.patch(**config_patches): + range_entry.runnable = self.vllm_backend.compiler_manager.compile( + self.graph, + args_list, + self.vllm_backend.inductor_config, + self.compilation_config, + compile_range=range_entry.compile_range, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, + ) + + range_entry.compiled = True def _log_compile_start(self, compile_range: Range): """Log compilation event for TORCH_TRACE/tlparse.""" @@ -277,44 +325,29 @@ def _log_compile_start(self, compile_range: Range): payload_fn=lambda: self.graph.print_readable(print_output=False), ) - def _maybe_compile_for_range_entry( - self, range_entry: RangeEntry, args: tuple[Any, ...] - ) -> Any: - if not range_entry.compiled: - if self.compiled_runnables is not None: - range_entry.runnable = self.get_compiled_graph_wrapper( - self.compiled_runnables[str(range_entry.compile_range)] - ) - else: - self._log_compile_start(range_entry.compile_range) - - # args are real arguments - # fakify for range, real args for concrete size. - # For concrete size, we clear the shape env in - # compiler_manager.compile() so no need to fakify. - args_list = ( - self._fakify_args(args) - if not range_entry.compile_range.is_single_size() - else list(args) - ) - - with ( - torch._functorch.config.patch("bundled_autograd_cache", True), - ): - range_entry.runnable = self.vllm_backend.compiler_manager.compile( - self.graph, - args_list, - self.vllm_backend.inductor_config, - self.compilation_config, - compile_range=range_entry.compile_range, - graph_index=self.piecewise_compile_index, - num_graphs=self.total_piecewise_compiles, - ) + def load_all_ranges(self) -> None: + """Load all pre-compiled runnables for this piecewise subgraph. + Called during warm start to wrap all cached compiled_runnables + into range_entry.runnable up front, analogous to compile_all_ranges() + for the cold start path. + """ + assert self.compiled_runnables is not None, ( + "load_all_ranges should only be called when compiled_runnables " + "is set (warm start / cache loading path)." + ) + for range_entry in self.range_entries.values(): + if range_entry.compiled: + continue + key = str(range_entry.compile_range) + assert key in self.compiled_runnables, ( + f"Missing compiled runnable for range {range_entry.compile_range}. " + f"Available keys: {list(self.compiled_runnables.keys())}" + ) + range_entry.runnable = self.get_compiled_graph_wrapper( + self.compiled_runnables[key] + ) range_entry.compiled = True - self.to_be_compiled_ranges.remove(range_entry.compile_range) - - self.check_for_ending_compilation() def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None: # First we try to find the range entry for the concrete compile size @@ -338,6 +371,9 @@ def __call__(self, *args: Any) -> Any: assert range_entry is not None, ( f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}" ) - - self._maybe_compile_for_range_entry(range_entry, args) + assert range_entry.compiled, ( + "All ranges should be compiled or loaded up front in " + "PiecewiseBackend.__init__. " + f"range_entry={range_entry.compile_range}" + ) return range_entry.runnable(*args)