Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tests/ut/worker/test_worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,8 +1020,6 @@ def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb,
# Verify _dummy_run call count and order (by size descending)
expected_calls = [
unittest.mock.call(16),
unittest.mock.call(8),
unittest.mock.call(4),
unittest.mock.call(1),
]
worker.model_runner._dummy_run.assert_has_calls(expected_calls)
Expand All @@ -1030,7 +1028,7 @@ def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb,
worker.model_runner.capture_model.assert_not_called()

# Verify log output
self.assertEqual(mock_logger.info.call_count, 4)
self.assertEqual(mock_logger.info.call_count, 2)

# Verify seed setting
mock_seed_everything.assert_called_once_with(12345)
Expand Down
13 changes: 7 additions & 6 deletions vllm_ascend/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface
from vllm.config.utils import Range

from vllm_ascend.ascend_config import get_ascend_config

Expand All @@ -46,13 +47,13 @@ def fusion_pass_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:

def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config["graph_fusion_manager"]
graph = current_pass_manager(graph, runtime_shape)
graph = current_pass_manager(graph, compile_range)
return graph

decompositions = select_decomp_table()
Expand All @@ -71,7 +72,7 @@ def npugraph_ex_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
# When currently using the FULL_DECODE_ONLY mode,
Expand Down Expand Up @@ -124,14 +125,14 @@ def compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:

ascend_config = get_ascend_config()
if ascend_config.enable_npugraph_ex:
return npugraph_ex_compile(graph, example_inputs, compiler_config,
runtime_shape, key)
compile_range, key)
Comment thread
MengqingCao marked this conversation as resolved.
else:
return fusion_pass_compile(graph, example_inputs, compiler_config,
runtime_shape, key)
compile_range, key)
8 changes: 6 additions & 2 deletions vllm_ascend/compilation/graph_fusion_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#

from torch import fx as fx
from vllm.compilation.inductor_pass import get_pass_context
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig

Expand All @@ -32,10 +33,13 @@ class GraphFusionPassManager:
def __init__(self):
self.passes: list[VllmInductorPass] = []

def __call__(self, graph: fx.Graph, runtime_shape) -> fx.Graph:
def __call__(self, graph: fx.Graph, compile_range) -> fx.Graph:
Comment thread
MengqingCao marked this conversation as resolved.
compile_range = get_pass_context().compile_range

for pass_ in self.passes:
if pass_.is_applicable(runtime_shape):
if pass_.is_applicable_for_range(compile_range):
pass_(graph)
graph.recompile()
return graph

def add(self, pass_: VllmInductorPass):
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/compilation/passes/norm_quant_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.logger import logger


Expand Down Expand Up @@ -308,7 +309,7 @@ def __call__(self, graph: torch.fx.Graph):
logger.debug("Replaced %s patterns", self.matched_count)
self.end_and_log()

def is_applicable(self, runtime_shape: int | None = None) -> bool:
def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from vllm.attention.layer import Attention
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.compilation import Range
from vllm.logger import logger


Expand Down Expand Up @@ -283,7 +284,7 @@ def __call__(self, graph: torch.fx.Graph):
pattern_idx += 1
self.end_and_log()

def is_applicable(self, runtime_shape):
def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""
Expand Down
14 changes: 0 additions & 14 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,6 @@
# Future Plan:
# Remove this patch when vLLM merge the PR.
#
# ** 7. File: platform/patch_compile_backend.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.compilation.backends.PiecewiseCompileInterpreter`
# `vllm.compilation.piecewise_backend.PiecewiseBackend`
# Why:
# vllm removed the compile graph for general shape, which caused operator fusion to fail.
# This issue affects the performance of model inference on Ascend.
# How:
# recover the compiled graph for dynamic_shape in PiecewiseBackend.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/24252
# Future Plan:
# Remove this patch when fix the problem.
#
# * Worker Patch:
# ===============
#
Expand Down
1 change: 0 additions & 1 deletion vllm_ascend/patch/platform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import os

import vllm_ascend.patch.platform.patch_compile_backend # noqa
import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_ec_connector # noqa
import vllm_ascend.patch.platform.patch_mamba_config # noqa
Expand Down
235 changes: 0 additions & 235 deletions vllm_ascend/patch/platform/patch_compile_backend.py

This file was deleted.

Loading
Loading