diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 24b85e87986..21bdedf7648 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -26,6 +26,7 @@ from torch._inductor.decomposition import select_decomp_table from torch.fx import GraphModule from vllm.compilation.compiler_interface import CompilerInterface +from vllm.config.utils import Range from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import COMPILATION_PASS_KEY @@ -47,13 +48,13 @@ def fusion_pass_compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: Optional[int] = None, + compile_range: Range, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: def compile_inner(graph, example_inputs): current_pass_manager = compiler_config[COMPILATION_PASS_KEY] - graph = current_pass_manager(graph, runtime_shape) + graph = current_pass_manager(graph) return graph decompositions = select_decomp_table() @@ -72,7 +73,7 @@ def npugraph_ex_compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: Optional[int] = None, + compile_range: Range, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: # When currently using the FULL_DECODE_ONLY mode, @@ -125,14 +126,14 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: Optional[int] = None, + compile_range: Range, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: ascend_config = get_ascend_config() if ascend_config.enable_npugraph_ex: return npugraph_ex_compile(graph, example_inputs, compiler_config, - runtime_shape, key) + compile_range, key) else: return fusion_pass_compile(graph, example_inputs, compiler_config, - runtime_shape, key) + compile_range, key) diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index e311b2602a7..d4dab40ee54 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -17,6 +17,7 @@ # from torch import fx as fx +from vllm.compilation.inductor_pass import get_pass_context from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig @@ -32,10 +33,13 @@ class GraphFusionPassManager: def __init__(self): self.passes: list[VllmInductorPass] = [] - def __call__(self, graph: fx.Graph, runtime_shape) -> fx.Graph: + def __call__(self, graph: fx.Graph) -> fx.Graph: + compile_range = get_pass_context().compile_range + for pass_ in self.passes: - if pass_.is_applicable(runtime_shape): + if pass_.is_applicable_for_range(compile_range): pass_(graph) + graph.recompile() return graph def add(self, pass_: VllmInductorPass): diff --git a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py index f929c1a47f9..eeaccd80680 100644 --- a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py @@ -20,6 +20,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig +from vllm.config.compilation import Range from vllm.logger import logger @@ -308,7 +309,7 @@ def __call__(self, graph: torch.fx.Graph): logger.debug("Replaced %s patterns", self.matched_count) self.end_and_log() - def is_applicable(self, runtime_shape: int | None = None) -> bool: + def is_applicable_for_range(self, compile_range: Range) -> bool: """ Check if the pass is applicable for the current configuration. """ diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index f8355a15166..ed90c7f8690 100644 --- a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -22,6 +22,7 @@ from vllm.attention.layer import Attention from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.compilation import Range from vllm.logger import logger @@ -283,7 +284,7 @@ def __call__(self, graph: torch.fx.Graph): pattern_idx += 1 self.end_and_log() - def is_applicable(self, runtime_shape): + def is_applicable_for_range(self, compile_range: Range) -> bool: """ Check if the pass is applicable for the current configuration. """ diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index abdb631e513..a1037855e28 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -106,20 +106,6 @@ # Future Plan: # Remove this patch when vLLM merge the PR. # -# ** 7. File: platform/patch_compile_backend.py** -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# 1. `vllm.compilation.backends.PiecewiseCompileInterpreter` -# `vllm.compilation.piecewise_backend.PiecewiseBackend` -# Why: -# vllm removed the compile graph for general shape, which caused operator fusion to fail. -# This issue affects the performance of model inference on Ascend. -# How: -# recover the compiled graph for dynamic_shape in PiecewiseBackend. -# Related PR (if no, explain why): -# https://github.com/vllm-project/vllm/pull/24252 -# Future Plan: -# Remove this patch when fix the problem. -# # * Worker Patch: # =============== # diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index cc33cde177b..49840db39db 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -16,7 +16,6 @@ import os -import vllm_ascend.patch.platform.patch_compile_backend # noqa import vllm_ascend.patch.platform.patch_distributed # noqa import vllm_ascend.patch.platform.patch_ec_connector # noqa import vllm_ascend.patch.platform.patch_mamba_config # noqa diff --git a/vllm_ascend/patch/platform/patch_compile_backend.py b/vllm_ascend/patch/platform/patch_compile_backend.py deleted file mode 100644 index af8ec53a0de..00000000000 --- a/vllm_ascend/patch/platform/patch_compile_backend.py +++ /dev/null @@ -1,235 +0,0 @@ -from collections.abc import Callable -from typing import Any - -import torch -import torch.fx as fx -import vllm.compilation.backends -import vllm.compilation.piecewise_backend -from torch._dispatch.python import enable_python_dispatcher -from vllm.compilation.backends import VllmBackend -from vllm.compilation.counter import compilation_counter -from vllm.compilation.piecewise_backend import RangeEntry -from vllm.config import CUDAGraphMode, VllmConfig -from vllm.config.utils import Range -from vllm.logger import init_logger -from vllm.platforms import current_platform -from vllm.utils.import_utils import resolve_obj_by_qualname - -logger = init_logger(__name__) - - -class AscendPiecewiseCompileInterpreter(torch.fx.Interpreter): - """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. - It runs the given graph with fake inputs, and compile some - submodules specified by `compile_submod_names` with the given - compilation configs. - - NOTE: the order in `compile_submod_names` matters, because - it will be used to determine the order of the compiled piecewise - graphs. The first graph will handle logging, and the last graph - has some special cudagraph output handling. - """ - - def __init__( - self, - module: torch.fx.GraphModule, - compile_submod_names: list[str], - vllm_config: VllmConfig, - vllm_backend: "VllmBackend", - ): - super().__init__(module) - from torch._guards import detect_fake_mode - - self.fake_mode = detect_fake_mode() - self.compile_submod_names = compile_submod_names - self.compilation_config = vllm_config.compilation_config - self.vllm_config = vllm_config - self.vllm_backend = vllm_backend - # When True, it annoyingly dumps the torch.fx.Graph on errors. - self.extra_traceback = False - - def run(self, *args): - # maybe instead just assert inputs are fake? - fake_args = [ - self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t - for t in args - ] - with self.fake_mode, enable_python_dispatcher(): - return super().run(*fake_args) - - def call_module( - self, - target: torch.fx.node.Target, - args: tuple[torch.fx.node.Argument, ...], - kwargs: dict[str, Any], - ) -> Any: - assert isinstance(target, str) - - output = super().call_module(target, args, kwargs) - - if target in self.compile_submod_names: - index = self.compile_submod_names.index(target) - submod = self.fetch_attr(target) - - sym_shape_indices = [ - i for i, x in enumerate(args) if isinstance(x, torch.SymInt) - ] - max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens - r1 = Range(start=1, end=max_num_batched_tokens) - compiled_graph_for_dynamic_shape = ( - self.vllm_backend.compiler_manager.compile( - submod, - args, - self.vllm_backend.inductor_config, - self.compilation_config, - graph_index=index, - num_graphs=len(self.compile_submod_names), - compile_range=r1, - )) - - # Lazy import here to avoid circular import - from vllm.compilation.piecewise_backend import PiecewiseBackend - - piecewise_backend = PiecewiseBackend( - submod, - self.vllm_config, - index, - len(self.compile_submod_names), - sym_shape_indices, - compiled_graph_for_dynamic_shape, - self.vllm_backend, - ) - - if (self.compilation_config.cudagraph_mode. - has_piecewise_cudagraphs() and - not self.compilation_config.use_inductor_graph_partition): - # We're using Dynamo-based piecewise splitting, so we wrap - # the whole subgraph with a static graph wrapper. - from vllm.compilation.cuda_graph import CUDAGraphOptions - - # resolve the static graph wrapper class (e.g. CUDAGraphWrapper - # class) as platform dependent. - static_graph_wrapper_class = resolve_obj_by_qualname( - current_platform.get_static_graph_wrapper_cls()) - - # Always assign PIECEWISE runtime mode to the - # CUDAGraphWrapper for piecewise_backend, to distinguish - # it from the FULL cudagraph runtime mode, no matter it - # is wrapped on a full or piecewise fx graph. - self.module.__dict__[target] = static_graph_wrapper_class( - runnable=piecewise_backend, - vllm_config=self.vllm_config, - runtime_mode=CUDAGraphMode.PIECEWISE, - cudagraph_options=CUDAGraphOptions( - debug_log_enable=piecewise_backend.is_first_graph, - gc_disable=not piecewise_backend.is_first_graph, - weak_ref_output=piecewise_backend.is_last_graph, - ), - ) - else: - self.module.__dict__[target] = piecewise_backend - - compilation_counter.num_piecewise_capturable_graphs_seen += 1 - - return output - - -class AscendPiecewiseBackend: - - def __init__( - self, - graph: fx.GraphModule, - vllm_config: VllmConfig, - piecewise_compile_index: int, - total_piecewise_compiles: int, - sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, - vllm_backend: VllmBackend, - ): - """ - The backend for piecewise compilation. - It mainly handles the compilation of static shapes and - dispatching based on runtime shape. - - We will compile `self.graph` once for the general shape, - and then compile for different shapes specified in - `compilation_config.compile_sizes`. - """ - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape - self.graph = graph - self.vllm_config = vllm_config - self.compilation_config = vllm_config.compilation_config - self.piecewise_compile_index = piecewise_compile_index - self.total_piecewise_compiles = total_piecewise_compiles - self.vllm_backend = vllm_backend - - self.is_first_graph = piecewise_compile_index == 0 - self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 - - self.is_full_graph = total_piecewise_compiles == 1 - self.is_encoder_compilation = vllm_backend.is_encoder - - self.compile_ranges = self.compilation_config.get_compile_ranges() - if self.is_encoder_compilation: - # For encoder compilation we use the max int32 value - # to set the upper bound of the compile ranges - max_int32 = 2**31 - 1 - last_compile_range = self.compile_ranges[-1] - assert (last_compile_range.end == - vllm_config.scheduler_config.max_num_batched_tokens) - self.compile_ranges[-1] = Range(start=last_compile_range.start, - end=max_int32) - - log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" - logger.debug_once(log_string) - - self.compile_sizes = self.compilation_config.compile_sizes - log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" - logger.debug_once(log_string) - - self.sym_shape_indices = sym_shape_indices - - # the entries for ranges that we need to either - self.range_entries: dict[Range, RangeEntry] = {} - - # to_be_compiled_ranges tracks the remaining ranges to compile, - # and updates during the compilation process, so we need to copy it - self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges) - - # We only keep compilation management inside this class directly. - for size in self.compile_sizes: - range = Range(start=size, end=size) - if range not in self.compile_ranges: - self.range_entries[range] = RangeEntry(compile_range=range, ) - self.to_be_compiled_ranges.add(range) - - for range in self.compile_ranges: - self.range_entries[range] = RangeEntry(compile_range=range, ) - - def _find_range_for_shape(self, runtime_shape: int) -> Range | None: - # First we try to find the range entry for the concrete compile size - # If not found, we search for the range entry - # that contains the runtime shape. - if runtime_shape in self.compile_sizes: - return self.range_entries[Range(start=runtime_shape, - end=runtime_shape)] - else: - for range in self.compile_ranges: - if runtime_shape in range: - return self.range_entries[range] - return None - - def __call__(self, *args) -> Any: - runtime_shape = args[self.sym_shape_indices[0]] - range_entry = self._find_range_for_shape(runtime_shape) - - assert range_entry is not None, ( - f"Shape out of considered range: {runtime_shape} " - "[1, max_num_batched_tokens]") - - return self.compiled_graph_for_general_shape(*args) - - -vllm.compilation.backends.PiecewiseCompileInterpreter = AscendPiecewiseCompileInterpreter -vllm.compilation.piecewise_backend.PiecewiseBackend.__init__ = AscendPiecewiseBackend.__init__ -vllm.compilation.piecewise_backend.PiecewiseBackend.__call__ = AscendPiecewiseBackend.__call__ diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 2529a298deb..0094a0eb549 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -28,7 +28,7 @@ import vllm.envs as envs_vllm from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions from torch_npu.profiler import dynamic_profile as dp -from vllm.config import VllmConfig, set_current_vllm_config +from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized @@ -381,10 +381,25 @@ def compile_or_warm_up_model(self) -> None: warmup_sizes = (self.vllm_config.compilation_config.compile_sizes or []).copy() if not self.model_config.enforce_eager: - warmup_sizes = [ - x for x in warmup_sizes if x not in - self.vllm_config.compilation_config.cudagraph_capture_sizes - ] + cg_capture_sizes: list[int] = [] + if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + cg_capture_sizes = [] if cg_sizes is None else cg_sizes + warmup_sizes = [ + x for x in warmup_sizes if x not in cg_capture_sizes + ] + + compile_ranges = self.vllm_config.compilation_config.get_compile_ranges( + ) + # For each compile_range, if none of the batch sizes + # in warmup_sizes or cudagraph_capture_sizes are in the range, + # add the end of the range to ensure compilation/warmup. + all_sizes = set(cg_capture_sizes) + all_sizes.update([x for x in warmup_sizes if isinstance(x, int)]) + for compile_range in compile_ranges: + if not any(x in compile_range for x in all_sizes): + warmup_sizes.append(compile_range.end) + for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size)