From 54589f8580eaba9601119484ff765daa484e6425 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 22 Dec 2025 13:15:37 +0000 Subject: [PATCH 01/10] [BugFix][Graph] Change runtime_shape to compile_range Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/compilation/compiler_interface.py | 13 +++++++------ .../compilation/graph_fusion_pass_manager.py | 8 +++++--- .../compilation/passes/norm_quant_fusion_pass.py | 4 ++-- .../compilation/passes/qknorm_rope_fusion_pass.py | 8 ++++++-- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 24b85e87986..3c3b8b904cd 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -26,7 +26,7 @@ from torch._inductor.decomposition import select_decomp_table from torch.fx import GraphModule from vllm.compilation.compiler_interface import CompilerInterface - +from vllm.config.utils import Range from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import COMPILATION_PASS_KEY @@ -47,13 +47,13 @@ def fusion_pass_compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: Optional[int] = None, + compile_range: Range, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: def compile_inner(graph, example_inputs): current_pass_manager = compiler_config[COMPILATION_PASS_KEY] - graph = current_pass_manager(graph, runtime_shape) + graph = current_pass_manager(graph, compile_range) return graph decompositions = select_decomp_table() @@ -125,14 +125,15 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: Optional[int] = None, + # runtime_shape: Optional[int] = None, + compile_range: Range, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: ascend_config = get_ascend_config() if ascend_config.enable_npugraph_ex: return npugraph_ex_compile(graph, example_inputs, compiler_config, - runtime_shape, key) + compile_range, key) else: return fusion_pass_compile(graph, example_inputs, compiler_config, - runtime_shape, key) + compile_range, key) diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index e311b2602a7..d4ce77bd8b2 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -19,7 +19,7 @@ from torch import fx as fx from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig - +from vllm.compilation.inductor_pass import get_pass_context class GraphFusionPassManager: """ @@ -32,9 +32,11 @@ class GraphFusionPassManager: def __init__(self): self.passes: list[VllmInductorPass] = [] - def __call__(self, graph: fx.Graph, runtime_shape) -> fx.Graph: + def __call__(self, graph: fx.Graph, compile_range) -> fx.Graph: + compile_range = get_pass_context().compile_range + for pass_ in self.passes: - if pass_.is_applicable(runtime_shape): + if pass_.is_applicable_for_range(compile_range): pass_(graph) return graph diff --git a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py index f929c1a47f9..e51b630da4c 100644 --- a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py @@ -21,7 +21,7 @@ from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig from vllm.logger import logger - +from vllm.config.utils import Range class AddRMSNormQuantPattern: @@ -308,7 +308,7 @@ def __call__(self, graph: torch.fx.Graph): logger.debug("Replaced %s patterns", self.matched_count) self.end_and_log() - def is_applicable(self, runtime_shape: int | None = None) -> bool: + def is_applicable(self, compile_range: Range) -> bool: """ Check if the pass is applicable for the current configuration. """ diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index f8355a15166..b6b2ada6cb2 100644 --- a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -23,7 +23,7 @@ from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import logger - +from vllm.config.utils import Range class QKNormRopeFusionPattern: @@ -272,7 +272,11 @@ def __init__(self, vllm_config: VllmConfig): def __call__(self, graph: torch.fx.Graph): self.begin() + logger.info("before graph fusion") + logger.info(graph.graph) self.matched_count = self.pattern_match_passes.apply(graph) + logger.info("after graph fusion") + logger.info(graph.graph) logger.debug("Fused %s QKNorm and Rope patterns", self.matched_count) logger.debug("Patterns registered for replacement:") pattern_idx = 0 @@ -283,7 +287,7 @@ def __call__(self, graph: torch.fx.Graph): pattern_idx += 1 self.end_and_log() - def is_applicable(self, runtime_shape): + def is_applicable(self, compile_range: Range) -> bool: """ Check if the pass is applicable for the current configuration. """ From c9ccbe0d6a7affae9fc0d9066f584488ddcc7bd1 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Tue, 23 Dec 2025 03:23:12 +0000 Subject: [PATCH 02/10] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- .../compilation/passes/norm_quant_fusion_pass.py | 11 +++++------ .../compilation/passes/qknorm_rope_fusion_pass.py | 7 ++----- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py index e51b630da4c..6c5e0d70a20 100644 --- a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py @@ -21,7 +21,7 @@ from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig from vllm.logger import logger -from vllm.config.utils import Range +from vllm.config.compilation import Range class AddRMSNormQuantPattern: @@ -307,9 +307,8 @@ def __call__(self, graph: torch.fx.Graph): self.matched_count = self.pattern_match_passes.apply(graph) logger.debug("Replaced %s patterns", self.matched_count) self.end_and_log() - - def is_applicable(self, compile_range: Range) -> bool: - """ - Check if the pass is applicable for the current configuration. - """ + + def is_applicable_for_range(self, compile_range: Range) -> bool: return True + + diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index b6b2ada6cb2..897477ddd84 100644 --- a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -23,7 +23,7 @@ from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import logger -from vllm.config.utils import Range +from vllm.config.compilation import Range class QKNormRopeFusionPattern: @@ -287,8 +287,5 @@ def __call__(self, graph: torch.fx.Graph): pattern_idx += 1 self.end_and_log() - def is_applicable(self, compile_range: Range) -> bool: - """ - Check if the pass is applicable for the current configuration. - """ + def is_applicable_for_range(self, compile_range: Range) -> bool: return True From f8e1facf4258be1c2ce383d3e5fcf1167742829f Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Tue, 23 Dec 2025 03:32:42 +0000 Subject: [PATCH 03/10] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/worker/worker.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 2529a298deb..78f60fe5e29 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -48,7 +48,7 @@ DraftTokenIds, ModelRunnerOutput) from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.workspace import init_workspace_manager - +from vllm.config import CUDAGraphMode import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.batch_invariant import init_batch_invariance @@ -380,11 +380,22 @@ def compile_or_warm_up_model(self) -> None: self.model_runner.eplb_warmup() warmup_sizes = (self.vllm_config.compilation_config.compile_sizes or []).copy() - if not self.model_config.enforce_eager: - warmup_sizes = [ - x for x in warmup_sizes if x not in - self.vllm_config.compilation_config.cudagraph_capture_sizes - ] + cg_capture_sizes: list[int] = [] + if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + cg_capture_sizes = [] if cg_sizes is None else cg_sizes + warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes] + + compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() + # For each compile_range, if none of the batch sizes + # in warmup_sizes or cudagraph_capture_sizes are in the range, + # add the end of the range to ensure compilation/warmup. + all_sizes = set(cg_capture_sizes) + all_sizes.update([x for x in warmup_sizes if isinstance(x, int)]) + for compile_range in compile_ranges: + if not any(x in compile_range for x in all_sizes): + warmup_sizes.append(compile_range.end) + for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size) From 99bd91c8ae0ca82db0613dbfba7ac1a3db4f3e04 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Sun, 4 Jan 2026 01:10:52 +0000 Subject: [PATCH 04/10] fix npugraph_ex Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/compilation/compiler_interface.py | 4 ++-- .../compilation/graph_fusion_pass_manager.py | 5 +++-- .../passes/norm_quant_fusion_pass.py | 7 +++---- .../passes/qknorm_rope_fusion_pass.py | 3 ++- vllm_ascend/worker/worker.py | 17 ++++++++++------- 5 files changed, 20 insertions(+), 16 deletions(-) diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 3c3b8b904cd..a560f735556 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -27,6 +27,7 @@ from torch.fx import GraphModule from vllm.compilation.compiler_interface import CompilerInterface from vllm.config.utils import Range + from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import COMPILATION_PASS_KEY @@ -72,7 +73,7 @@ def npugraph_ex_compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: Optional[int] = None, + compile_range: Range, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: # When currently using the FULL_DECODE_ONLY mode, @@ -125,7 +126,6 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - # runtime_shape: Optional[int] = None, compile_range: Range, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index d4ce77bd8b2..2764b760ef9 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -17,9 +17,10 @@ # from torch import fx as fx +from vllm.compilation.inductor_pass import get_pass_context from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig -from vllm.compilation.inductor_pass import get_pass_context + class GraphFusionPassManager: """ @@ -34,7 +35,7 @@ def __init__(self): def __call__(self, graph: fx.Graph, compile_range) -> fx.Graph: compile_range = get_pass_context().compile_range - + for pass_ in self.passes: if pass_.is_applicable_for_range(compile_range): pass_(graph) diff --git a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py index 6c5e0d70a20..6c5c8b64c90 100644 --- a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py @@ -20,8 +20,9 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig -from vllm.logger import logger from vllm.config.compilation import Range +from vllm.logger import logger + class AddRMSNormQuantPattern: @@ -307,8 +308,6 @@ def __call__(self, graph: torch.fx.Graph): self.matched_count = self.pattern_match_passes.apply(graph) logger.debug("Replaced %s patterns", self.matched_count) self.end_and_log() - + def is_applicable_for_range(self, compile_range: Range) -> bool: return True - - diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index 897477ddd84..8931a7495e1 100644 --- a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -22,8 +22,9 @@ from vllm.attention.layer import Attention from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.logger import logger from vllm.config.compilation import Range +from vllm.logger import logger + class QKNormRopeFusionPattern: diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 78f60fe5e29..ca034826b6c 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -28,7 +28,7 @@ import vllm.envs as envs_vllm from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions from torch_npu.profiler import dynamic_profile as dp -from vllm.config import VllmConfig, set_current_vllm_config +from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized @@ -48,7 +48,7 @@ DraftTokenIds, ModelRunnerOutput) from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.workspace import init_workspace_manager -from vllm.config import CUDAGraphMode + import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.batch_invariant import init_batch_invariance @@ -382,11 +382,14 @@ def compile_or_warm_up_model(self) -> None: or []).copy() cg_capture_sizes: list[int] = [] if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: - cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes - cg_capture_sizes = [] if cg_sizes is None else cg_sizes - warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes] + cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + cg_capture_sizes = [] if cg_sizes is None else cg_sizes + warmup_sizes = [ + x for x in warmup_sizes if x not in cg_capture_sizes + ] - compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() + compile_ranges = self.vllm_config.compilation_config.get_compile_ranges( + ) # For each compile_range, if none of the batch sizes # in warmup_sizes or cudagraph_capture_sizes are in the range, # add the end of the range to ensure compilation/warmup. @@ -395,7 +398,7 @@ def compile_or_warm_up_model(self) -> None: for compile_range in compile_ranges: if not any(x in compile_range for x in all_sizes): warmup_sizes.append(compile_range.end) - + for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size) From 9c5a49444932d4b1b30845624c622453a2378120 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Sun, 4 Jan 2026 01:16:17 +0000 Subject: [PATCH 05/10] fix can not fusion Signed-off-by: wxsIcey <1790571317@qq.com> --- .../compilation/graph_fusion_pass_manager.py | 1 + vllm_ascend/patch/__init__.py | 14 -- vllm_ascend/patch/platform/__init__.py | 1 - .../patch/platform/patch_compile_backend.py | 235 ------------------ 4 files changed, 1 insertion(+), 250 deletions(-) delete mode 100644 vllm_ascend/patch/platform/patch_compile_backend.py diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index 2764b760ef9..4e458dde4dd 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -39,6 +39,7 @@ def __call__(self, graph: fx.Graph, compile_range) -> fx.Graph: for pass_ in self.passes: if pass_.is_applicable_for_range(compile_range): pass_(graph) + graph.recompile() return graph def add(self, pass_: VllmInductorPass): diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index abdb631e513..a1037855e28 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -106,20 +106,6 @@ # Future Plan: # Remove this patch when vLLM merge the PR. # -# ** 7. File: platform/patch_compile_backend.py** -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# 1. `vllm.compilation.backends.PiecewiseCompileInterpreter` -# `vllm.compilation.piecewise_backend.PiecewiseBackend` -# Why: -# vllm removed the compile graph for general shape, which caused operator fusion to fail. -# This issue affects the performance of model inference on Ascend. -# How: -# recover the compiled graph for dynamic_shape in PiecewiseBackend. -# Related PR (if no, explain why): -# https://github.com/vllm-project/vllm/pull/24252 -# Future Plan: -# Remove this patch when fix the problem. -# # * Worker Patch: # =============== # diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index cc33cde177b..49840db39db 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -16,7 +16,6 @@ import os -import vllm_ascend.patch.platform.patch_compile_backend # noqa import vllm_ascend.patch.platform.patch_distributed # noqa import vllm_ascend.patch.platform.patch_ec_connector # noqa import vllm_ascend.patch.platform.patch_mamba_config # noqa diff --git a/vllm_ascend/patch/platform/patch_compile_backend.py b/vllm_ascend/patch/platform/patch_compile_backend.py deleted file mode 100644 index af8ec53a0de..00000000000 --- a/vllm_ascend/patch/platform/patch_compile_backend.py +++ /dev/null @@ -1,235 +0,0 @@ -from collections.abc import Callable -from typing import Any - -import torch -import torch.fx as fx -import vllm.compilation.backends -import vllm.compilation.piecewise_backend -from torch._dispatch.python import enable_python_dispatcher -from vllm.compilation.backends import VllmBackend -from vllm.compilation.counter import compilation_counter -from vllm.compilation.piecewise_backend import RangeEntry -from vllm.config import CUDAGraphMode, VllmConfig -from vllm.config.utils import Range -from vllm.logger import init_logger -from vllm.platforms import current_platform -from vllm.utils.import_utils import resolve_obj_by_qualname - -logger = init_logger(__name__) - - -class AscendPiecewiseCompileInterpreter(torch.fx.Interpreter): - """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. - It runs the given graph with fake inputs, and compile some - submodules specified by `compile_submod_names` with the given - compilation configs. - - NOTE: the order in `compile_submod_names` matters, because - it will be used to determine the order of the compiled piecewise - graphs. The first graph will handle logging, and the last graph - has some special cudagraph output handling. - """ - - def __init__( - self, - module: torch.fx.GraphModule, - compile_submod_names: list[str], - vllm_config: VllmConfig, - vllm_backend: "VllmBackend", - ): - super().__init__(module) - from torch._guards import detect_fake_mode - - self.fake_mode = detect_fake_mode() - self.compile_submod_names = compile_submod_names - self.compilation_config = vllm_config.compilation_config - self.vllm_config = vllm_config - self.vllm_backend = vllm_backend - # When True, it annoyingly dumps the torch.fx.Graph on errors. - self.extra_traceback = False - - def run(self, *args): - # maybe instead just assert inputs are fake? - fake_args = [ - self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t - for t in args - ] - with self.fake_mode, enable_python_dispatcher(): - return super().run(*fake_args) - - def call_module( - self, - target: torch.fx.node.Target, - args: tuple[torch.fx.node.Argument, ...], - kwargs: dict[str, Any], - ) -> Any: - assert isinstance(target, str) - - output = super().call_module(target, args, kwargs) - - if target in self.compile_submod_names: - index = self.compile_submod_names.index(target) - submod = self.fetch_attr(target) - - sym_shape_indices = [ - i for i, x in enumerate(args) if isinstance(x, torch.SymInt) - ] - max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens - r1 = Range(start=1, end=max_num_batched_tokens) - compiled_graph_for_dynamic_shape = ( - self.vllm_backend.compiler_manager.compile( - submod, - args, - self.vllm_backend.inductor_config, - self.compilation_config, - graph_index=index, - num_graphs=len(self.compile_submod_names), - compile_range=r1, - )) - - # Lazy import here to avoid circular import - from vllm.compilation.piecewise_backend import PiecewiseBackend - - piecewise_backend = PiecewiseBackend( - submod, - self.vllm_config, - index, - len(self.compile_submod_names), - sym_shape_indices, - compiled_graph_for_dynamic_shape, - self.vllm_backend, - ) - - if (self.compilation_config.cudagraph_mode. - has_piecewise_cudagraphs() and - not self.compilation_config.use_inductor_graph_partition): - # We're using Dynamo-based piecewise splitting, so we wrap - # the whole subgraph with a static graph wrapper. - from vllm.compilation.cuda_graph import CUDAGraphOptions - - # resolve the static graph wrapper class (e.g. CUDAGraphWrapper - # class) as platform dependent. - static_graph_wrapper_class = resolve_obj_by_qualname( - current_platform.get_static_graph_wrapper_cls()) - - # Always assign PIECEWISE runtime mode to the - # CUDAGraphWrapper for piecewise_backend, to distinguish - # it from the FULL cudagraph runtime mode, no matter it - # is wrapped on a full or piecewise fx graph. - self.module.__dict__[target] = static_graph_wrapper_class( - runnable=piecewise_backend, - vllm_config=self.vllm_config, - runtime_mode=CUDAGraphMode.PIECEWISE, - cudagraph_options=CUDAGraphOptions( - debug_log_enable=piecewise_backend.is_first_graph, - gc_disable=not piecewise_backend.is_first_graph, - weak_ref_output=piecewise_backend.is_last_graph, - ), - ) - else: - self.module.__dict__[target] = piecewise_backend - - compilation_counter.num_piecewise_capturable_graphs_seen += 1 - - return output - - -class AscendPiecewiseBackend: - - def __init__( - self, - graph: fx.GraphModule, - vllm_config: VllmConfig, - piecewise_compile_index: int, - total_piecewise_compiles: int, - sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, - vllm_backend: VllmBackend, - ): - """ - The backend for piecewise compilation. - It mainly handles the compilation of static shapes and - dispatching based on runtime shape. - - We will compile `self.graph` once for the general shape, - and then compile for different shapes specified in - `compilation_config.compile_sizes`. - """ - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape - self.graph = graph - self.vllm_config = vllm_config - self.compilation_config = vllm_config.compilation_config - self.piecewise_compile_index = piecewise_compile_index - self.total_piecewise_compiles = total_piecewise_compiles - self.vllm_backend = vllm_backend - - self.is_first_graph = piecewise_compile_index == 0 - self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 - - self.is_full_graph = total_piecewise_compiles == 1 - self.is_encoder_compilation = vllm_backend.is_encoder - - self.compile_ranges = self.compilation_config.get_compile_ranges() - if self.is_encoder_compilation: - # For encoder compilation we use the max int32 value - # to set the upper bound of the compile ranges - max_int32 = 2**31 - 1 - last_compile_range = self.compile_ranges[-1] - assert (last_compile_range.end == - vllm_config.scheduler_config.max_num_batched_tokens) - self.compile_ranges[-1] = Range(start=last_compile_range.start, - end=max_int32) - - log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" - logger.debug_once(log_string) - - self.compile_sizes = self.compilation_config.compile_sizes - log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" - logger.debug_once(log_string) - - self.sym_shape_indices = sym_shape_indices - - # the entries for ranges that we need to either - self.range_entries: dict[Range, RangeEntry] = {} - - # to_be_compiled_ranges tracks the remaining ranges to compile, - # and updates during the compilation process, so we need to copy it - self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges) - - # We only keep compilation management inside this class directly. - for size in self.compile_sizes: - range = Range(start=size, end=size) - if range not in self.compile_ranges: - self.range_entries[range] = RangeEntry(compile_range=range, ) - self.to_be_compiled_ranges.add(range) - - for range in self.compile_ranges: - self.range_entries[range] = RangeEntry(compile_range=range, ) - - def _find_range_for_shape(self, runtime_shape: int) -> Range | None: - # First we try to find the range entry for the concrete compile size - # If not found, we search for the range entry - # that contains the runtime shape. - if runtime_shape in self.compile_sizes: - return self.range_entries[Range(start=runtime_shape, - end=runtime_shape)] - else: - for range in self.compile_ranges: - if runtime_shape in range: - return self.range_entries[range] - return None - - def __call__(self, *args) -> Any: - runtime_shape = args[self.sym_shape_indices[0]] - range_entry = self._find_range_for_shape(runtime_shape) - - assert range_entry is not None, ( - f"Shape out of considered range: {runtime_shape} " - "[1, max_num_batched_tokens]") - - return self.compiled_graph_for_general_shape(*args) - - -vllm.compilation.backends.PiecewiseCompileInterpreter = AscendPiecewiseCompileInterpreter -vllm.compilation.piecewise_backend.PiecewiseBackend.__init__ = AscendPiecewiseBackend.__init__ -vllm.compilation.piecewise_backend.PiecewiseBackend.__call__ = AscendPiecewiseBackend.__call__ From 81fa06aa82779993cd8d53ca43c4b4d3919f2428 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Sun, 4 Jan 2026 01:20:52 +0000 Subject: [PATCH 06/10] remove redundant log Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/compilation/passes/norm_quant_fusion_pass.py | 3 +++ vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py | 7 +++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py index 6c5c8b64c90..eeaccd80680 100644 --- a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py @@ -310,4 +310,7 @@ def __call__(self, graph: torch.fx.Graph): self.end_and_log() def is_applicable_for_range(self, compile_range: Range) -> bool: + """ + Check if the pass is applicable for the current configuration. + """ return True diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index 8931a7495e1..ed90c7f8690 100644 --- a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -273,11 +273,7 @@ def __init__(self, vllm_config: VllmConfig): def __call__(self, graph: torch.fx.Graph): self.begin() - logger.info("before graph fusion") - logger.info(graph.graph) self.matched_count = self.pattern_match_passes.apply(graph) - logger.info("after graph fusion") - logger.info(graph.graph) logger.debug("Fused %s QKNorm and Rope patterns", self.matched_count) logger.debug("Patterns registered for replacement:") pattern_idx = 0 @@ -289,4 +285,7 @@ def __call__(self, graph: torch.fx.Graph): self.end_and_log() def is_applicable_for_range(self, compile_range: Range) -> bool: + """ + Check if the pass is applicable for the current configuration. + """ return True From 5b8aed7a420835cc9c0388ef688d590770bceb91 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Sun, 4 Jan 2026 02:33:13 +0000 Subject: [PATCH 07/10] fix ut Signed-off-by: wxsIcey <1790571317@qq.com> --- tests/ut/worker/test_worker_v1.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index 49d6c86eeb8..c3b41c2b8ae 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -1007,8 +1007,6 @@ def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb, # Verify _dummy_run call count and order (by size descending) expected_calls = [ unittest.mock.call(16), - unittest.mock.call(8), - unittest.mock.call(4), unittest.mock.call(1), ] worker.model_runner._dummy_run.assert_has_calls(expected_calls) @@ -1017,7 +1015,7 @@ def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb, worker.model_runner.capture_model.assert_not_called() # Verify log output - self.assertEqual(mock_logger.info.call_count, 4) + self.assertEqual(mock_logger.info.call_count, 2) # Verify atb warm up mock_warm_up_atb.assert_called_once() From af8d96fd1336b916130331e0addee330b25f70b1 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Wed, 7 Jan 2026 02:17:50 +0000 Subject: [PATCH 08/10] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/worker/worker.py | 37 ++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index ca034826b6c..0094a0eb549 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -380,24 +380,25 @@ def compile_or_warm_up_model(self) -> None: self.model_runner.eplb_warmup() warmup_sizes = (self.vllm_config.compilation_config.compile_sizes or []).copy() - cg_capture_sizes: list[int] = [] - if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: - cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes - cg_capture_sizes = [] if cg_sizes is None else cg_sizes - warmup_sizes = [ - x for x in warmup_sizes if x not in cg_capture_sizes - ] - - compile_ranges = self.vllm_config.compilation_config.get_compile_ranges( - ) - # For each compile_range, if none of the batch sizes - # in warmup_sizes or cudagraph_capture_sizes are in the range, - # add the end of the range to ensure compilation/warmup. - all_sizes = set(cg_capture_sizes) - all_sizes.update([x for x in warmup_sizes if isinstance(x, int)]) - for compile_range in compile_ranges: - if not any(x in compile_range for x in all_sizes): - warmup_sizes.append(compile_range.end) + if not self.model_config.enforce_eager: + cg_capture_sizes: list[int] = [] + if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + cg_capture_sizes = [] if cg_sizes is None else cg_sizes + warmup_sizes = [ + x for x in warmup_sizes if x not in cg_capture_sizes + ] + + compile_ranges = self.vllm_config.compilation_config.get_compile_ranges( + ) + # For each compile_range, if none of the batch sizes + # in warmup_sizes or cudagraph_capture_sizes are in the range, + # add the end of the range to ensure compilation/warmup. + all_sizes = set(cg_capture_sizes) + all_sizes.update([x for x in warmup_sizes if isinstance(x, int)]) + for compile_range in compile_ranges: + if not any(x in compile_range for x in all_sizes): + warmup_sizes.append(compile_range.end) for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) From c643874fa1a698f4c716725a379ceeecc04e956e Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Wed, 7 Jan 2026 02:22:12 +0000 Subject: [PATCH 09/10] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- tests/ut/worker/test_worker_v1.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index c3b41c2b8ae..49d6c86eeb8 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -1007,6 +1007,8 @@ def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb, # Verify _dummy_run call count and order (by size descending) expected_calls = [ unittest.mock.call(16), + unittest.mock.call(8), + unittest.mock.call(4), unittest.mock.call(1), ] worker.model_runner._dummy_run.assert_has_calls(expected_calls) @@ -1015,7 +1017,7 @@ def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb, worker.model_runner.capture_model.assert_not_called() # Verify log output - self.assertEqual(mock_logger.info.call_count, 2) + self.assertEqual(mock_logger.info.call_count, 4) # Verify atb warm up mock_warm_up_atb.assert_called_once() From 89f15bd91f2bb8914167b68906d75e17ea205c2d Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Wed, 7 Jan 2026 02:42:23 +0000 Subject: [PATCH 10/10] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/compilation/compiler_interface.py | 2 +- vllm_ascend/compilation/graph_fusion_pass_manager.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index a560f735556..21bdedf7648 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -54,7 +54,7 @@ def fusion_pass_compile( def compile_inner(graph, example_inputs): current_pass_manager = compiler_config[COMPILATION_PASS_KEY] - graph = current_pass_manager(graph, compile_range) + graph = current_pass_manager(graph) return graph decompositions = select_decomp_table() diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index 4e458dde4dd..d4dab40ee54 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -33,7 +33,7 @@ class GraphFusionPassManager: def __init__(self): self.passes: list[VllmInductorPass] = [] - def __call__(self, graph: fx.Graph, compile_range) -> fx.Graph: + def __call__(self, graph: fx.Graph) -> fx.Graph: compile_range = get_pass_context().compile_range for pass_ in self.passes: