Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions vllm_ascend/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface
from vllm.config.utils import Range

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import COMPILATION_PASS_KEY
Expand All @@ -47,13 +48,13 @@ def fusion_pass_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:

def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
graph = current_pass_manager(graph, runtime_shape)
graph = current_pass_manager(graph)
return graph

decompositions = select_decomp_table()
Expand All @@ -72,7 +73,7 @@ def npugraph_ex_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
# When currently using the FULL_DECODE_ONLY mode,
Expand Down Expand Up @@ -125,14 +126,14 @@ def compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:

ascend_config = get_ascend_config()
if ascend_config.enable_npugraph_ex:
return npugraph_ex_compile(graph, example_inputs, compiler_config,
runtime_shape, key)
compile_range, key)
else:
return fusion_pass_compile(graph, example_inputs, compiler_config,
runtime_shape, key)
compile_range, key)
8 changes: 6 additions & 2 deletions vllm_ascend/compilation/graph_fusion_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#

from torch import fx as fx
from vllm.compilation.inductor_pass import get_pass_context
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig

Expand All @@ -32,10 +33,13 @@ class GraphFusionPassManager:
def __init__(self):
self.passes: list[VllmInductorPass] = []

def __call__(self, graph: fx.Graph, runtime_shape) -> fx.Graph:
def __call__(self, graph: fx.Graph) -> fx.Graph:
compile_range = get_pass_context().compile_range

Comment thread
wxsIcey marked this conversation as resolved.
for pass_ in self.passes:
if pass_.is_applicable(runtime_shape):
if pass_.is_applicable_for_range(compile_range):
pass_(graph)
graph.recompile()
return graph

def add(self, pass_: VllmInductorPass):
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/compilation/passes/norm_quant_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.logger import logger


Expand Down Expand Up @@ -308,7 +309,7 @@ def __call__(self, graph: torch.fx.Graph):
logger.debug("Replaced %s patterns", self.matched_count)
self.end_and_log()

def is_applicable(self, runtime_shape: int | None = None) -> bool:
def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from vllm.attention.layer import Attention
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.compilation import Range
from vllm.logger import logger


Expand Down Expand Up @@ -283,7 +284,7 @@ def __call__(self, graph: torch.fx.Graph):
pattern_idx += 1
self.end_and_log()

def is_applicable(self, runtime_shape):
def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""
Expand Down
14 changes: 0 additions & 14 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,6 @@
# Future Plan:
# Remove this patch when vLLM merge the PR.
#
# ** 7. File: platform/patch_compile_backend.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.compilation.backends.PiecewiseCompileInterpreter`
# `vllm.compilation.piecewise_backend.PiecewiseBackend`
# Why:
# vllm removed the compile graph for general shape, which caused operator fusion to fail.
# This issue affects the performance of model inference on Ascend.
# How:
# recover the compiled graph for dynamic_shape in PiecewiseBackend.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/24252
# Future Plan:
# Remove this patch when fix the problem.
#
# * Worker Patch:
# ===============
#
Expand Down
1 change: 0 additions & 1 deletion vllm_ascend/patch/platform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import os

import vllm_ascend.patch.platform.patch_compile_backend # noqa
import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_ec_connector # noqa
import vllm_ascend.patch.platform.patch_mamba_config # noqa
Expand Down
235 changes: 0 additions & 235 deletions vllm_ascend/patch/platform/patch_compile_backend.py

This file was deleted.

Loading
Loading