diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a1ff5fb1196b..c2e8c726c943 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -17,7 +17,8 @@ from vllm.logger import init_logger from vllm.utils import weak_ref_tensors -from .compiler_interface import EagerAdaptor, InductorAdaptor +from .compiler_interface import (CompilerInterface, EagerAdaptor, + InductorAdaptor, InductorStandaloneAdaptor) from .counter import compilation_counter from .inductor_pass import InductorPass from .monitor import end_monitoring_torch_compile @@ -26,6 +27,19 @@ logger = init_logger(__name__) +def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: + if compilation_config.use_inductor: + if envs.VLLM_TEST_STANDALONE_COMPILE: + logger.info("Using InductorStandaloneAdaptor") + return InductorStandaloneAdaptor() + else: + logger.info("Using InductorAdaptor") + return InductorAdaptor() + else: + logger.info("Using EagerAdaptor") + return EagerAdaptor() + + class CompilerManager: """ A manager to manage the compilation process, including @@ -41,11 +55,11 @@ class CompilerManager: support int as key. """ - def __init__(self, use_inductor: bool): + def __init__(self, compilation_config: CompilationConfig): self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() - cls = InductorAdaptor if use_inductor else EagerAdaptor - self.compiler = cls() self.is_cache_updated = False + self.compilation_config = compilation_config + self.compiler = make_compiler(compilation_config) def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @@ -123,8 +137,15 @@ def compile(self, # no compiler cached the graph, or the cache is disabled, # we need to compile it + if isinstance(self.compiler, InductorAdaptor): + # Let compile_fx generate a key for us + maybe_key = None + else: + maybe_key = \ + f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" compiled_graph, handle = self.compiler.compile( - graph, example_inputs, additional_inductor_config, runtime_shape) + graph, example_inputs, additional_inductor_config, runtime_shape, + maybe_key) assert compiled_graph is not None, "Failed to compile the graph" @@ -336,7 +357,7 @@ def __init__( self.compilation_config = vllm_config.compilation_config self.compiler_manager: CompilerManager = CompilerManager( - self.compilation_config.use_inductor) + self.compilation_config) # `torch.compile` is JIT compiled, so we don't need to # do anything here diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index b7e7a79bef0b..423581784f7a 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -50,7 +50,8 @@ def compile( graph: fx.GraphModule, example_inputs: List[Any], compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None + runtime_shape: Optional[int] = None, + key: Optional[str] = None, ) -> Tuple[Optional[Callable], Optional[Any]]: """ Compile the graph with the given example inputs and compiler config, @@ -71,6 +72,10 @@ def compile( If the compiler doesn't support caching, it should return None for the handle. If the compiler fails to compile the graph, it should return None for the compiled function as well. + + `key` is required for StandaloneInductorAdapter, it specifies where to + save the compiled artifact. The compiled artifact gets saved to + `cache_dir/key`. """ return None, None @@ -127,23 +132,108 @@ def produce_guards_expression(self, *args, **kwargs): return "" +def get_inductor_factors() -> List[Any]: + factors: List[Any] = [] + # summarize system state + from torch._inductor.codecache import CacheBase + system_factors = CacheBase.get_system() + factors.append(system_factors) + + # summarize pytorch state + from torch._inductor.codecache import torch_key + torch_factors = torch_key() + factors.append(torch_factors) + return factors + + +class InductorStandaloneAdaptor(CompilerInterface): + """ + The adaptor for the Inductor compiler. + Requires PyTorch 2.8+. + This is not on by default yet, but we plan to turn it on by default for + PyTorch 2.8. + + Use VLLM_TEST_STANDALONE_COMPILE to toggle this on or off. + """ + name = "inductor_standalone" + + def compute_hash(self, vllm_config: VllmConfig) -> str: + factors = get_inductor_factors() + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + return hash_str + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + self.cache_dir = cache_dir + + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> Tuple[Optional[Callable], Optional[Any]]: + current_config = {} + if compiler_config is not None: + current_config.update(compiler_config) + set_inductor_config(current_config, runtime_shape) + + if isinstance(runtime_shape, int): + dynamic_shapes = "from_example_inputs" + else: + dynamic_shapes = "from_tracing_context" + + from torch._inductor import standalone_compile + with pass_context(runtime_shape): + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}) + + # Save the compiled artifact to disk in the specified path + assert key is not None + path = os.path.join(self.cache_dir, key) + compiled_graph.save(path=path, format="unpacked") + return compiled_graph, (key, path) + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + assert isinstance(handle, tuple) + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + path = handle[1] + inductor_compiled_graph = torch._inductor.CompiledArtifact.load( + path=path, format="unpacked") + from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) + + def compiled_graph_wrapper(*args): + graph_output = inductor_compiled_graph(*args) + # unpack the tuple if needed + # TODO(rzou): the implication is that we're not + # reading the python bytecode correctly in vLLM? + if returns_tuple: + return graph_output + else: + return graph_output[0] + + return compiled_graph_wrapper + + class InductorAdaptor(CompilerInterface): """ - The adaptor for the Inductor compiler, version 2.5 and 2.6. + The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7. """ name = "inductor" def compute_hash(self, vllm_config: VllmConfig) -> str: - factors: List[Any] = [] - # summarize system state - from torch._inductor.codecache import CacheBase - system_factors = CacheBase.get_system() - factors.append(system_factors) - - # summarize pytorch state - from torch._inductor.codecache import torch_key - torch_factors = torch_key() - factors.append(torch_factors) + factors = get_inductor_factors() hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()[:10] return hash_str @@ -168,23 +258,19 @@ def compile( graph: fx.GraphModule, example_inputs: List[Any], compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None + runtime_shape: Optional[int] = None, + key: Optional[str] = None, ) -> Tuple[Optional[Callable], Optional[Any]]: - current_config = {} from torch._inductor.compile_fx import compile_fx + current_config = {} + if compiler_config is not None: + current_config.update(compiler_config) # disable remote cache current_config["fx_graph_cache"] = True current_config["fx_graph_remote_cache"] = False - if compiler_config is not None: - current_config.update(compiler_config) - - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters - # can be beneficial - current_config["max_autotune"] = True - current_config["coordinate_descent_tuning"] = True + set_inductor_config(current_config, runtime_shape) # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 @@ -422,6 +508,14 @@ def metrics_context(self) -> contextlib.AbstractContextManager: return contextlib.nullcontext() +def set_inductor_config(config, runtime_shape): + if isinstance(runtime_shape, int): + # for a specific batchsize, tuning triton kernel parameters + # can be beneficial + config["max_autotune"] = True + config["coordinate_descent_tuning"] = True + + class EagerAdaptor(CompilerInterface): name = "eager" @@ -430,7 +524,8 @@ def compile( graph: fx.GraphModule, example_inputs: List[Any], compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None + runtime_shape: Optional[int] = None, + key: Optional[str] = None, ) -> Tuple[Optional[Callable], Optional[Any]]: # we don't need to compile the graph, just return the graph itself. # It does not support caching, return None for the handle. diff --git a/vllm/envs.py b/vllm/envs.py index ea40bfff11b5..cdd6debd444c 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -261,6 +261,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: bool( os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), + # Internal flag to enable/disable Inductor standalone compile + "VLLM_TEST_STANDALONE_COMPILE": + lambda: os.environ.get("VLLM_TEST_STANDALONE_COMPILE", "0") != "0", + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": @@ -789,6 +793,7 @@ def factorize(name: str): "VLLM_USE_TRITON_AWQ", "VLLM_DP_RANK", "VLLM_DP_SIZE", + "VLLM_TEST_STANDALONE_COMPILE", ] for key in environment_variables_to_hash: if key in environment_variables: