Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/ut/worker/test_worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,8 @@ def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb,
# Verify _dummy_run call count and order (by size descending)
expected_calls = [
unittest.mock.call(16),
unittest.mock.call(8),
unittest.mock.call(4),
unittest.mock.call(1),
]
worker.model_runner._dummy_run.assert_has_calls(expected_calls)
Expand All @@ -1028,7 +1030,7 @@ def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb,
worker.model_runner.capture_model.assert_not_called()

# Verify log output
self.assertEqual(mock_logger.info.call_count, 2)
self.assertEqual(mock_logger.info.call_count, 4)

# Verify seed setting
mock_seed_everything.assert_called_once_with(12345)
Expand Down
13 changes: 6 additions & 7 deletions vllm_ascend/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface
from vllm.config.utils import Range

from vllm_ascend.ascend_config import get_ascend_config

Expand All @@ -47,13 +46,13 @@ def fusion_pass_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:

def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config["graph_fusion_manager"]
graph = current_pass_manager(graph, compile_range)
graph = current_pass_manager(graph, runtime_shape)
return graph

decompositions = select_decomp_table()
Expand All @@ -72,7 +71,7 @@ def npugraph_ex_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
# When currently using the FULL_DECODE_ONLY mode,
Expand Down Expand Up @@ -125,14 +124,14 @@ def compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:

ascend_config = get_ascend_config()
if ascend_config.enable_npugraph_ex:
return npugraph_ex_compile(graph, example_inputs, compiler_config,
compile_range, key)
runtime_shape, key)
else:
return fusion_pass_compile(graph, example_inputs, compiler_config,
compile_range, key)
runtime_shape, key)
8 changes: 2 additions & 6 deletions vllm_ascend/compilation/graph_fusion_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#

from torch import fx as fx
from vllm.compilation.inductor_pass import get_pass_context
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig

Expand All @@ -33,13 +32,10 @@ class GraphFusionPassManager:
def __init__(self):
self.passes: list[VllmInductorPass] = []

def __call__(self, graph: fx.Graph, compile_range) -> fx.Graph:
compile_range = get_pass_context().compile_range

def __call__(self, graph: fx.Graph, runtime_shape) -> fx.Graph:
for pass_ in self.passes:
if pass_.is_applicable_for_range(compile_range):
if pass_.is_applicable(runtime_shape):
pass_(graph)
graph.recompile()
return graph

def add(self, pass_: VllmInductorPass):
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/compilation/passes/norm_quant_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.logger import logger


Expand Down Expand Up @@ -309,7 +308,7 @@ def __call__(self, graph: torch.fx.Graph):
logger.debug("Replaced %s patterns", self.matched_count)
self.end_and_log()

def is_applicable_for_range(self, compile_range: Range) -> bool:
def is_applicable(self, runtime_shape: int | None = None) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from vllm.attention.layer import Attention
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.compilation import Range
from vllm.logger import logger


Expand Down Expand Up @@ -284,7 +283,7 @@ def __call__(self, graph: torch.fx.Graph):
pattern_idx += 1
self.end_and_log()

def is_applicable_for_range(self, compile_range: Range) -> bool:
def is_applicable(self, runtime_shape):
"""
Check if the pass is applicable for the current configuration.
"""
Expand Down
14 changes: 14 additions & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,20 @@
# Future Plan:
# Remove this patch when vLLM merge the PR.
#
# ** 7. File: platform/patch_compile_backend.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.compilation.backends.PiecewiseCompileInterpreter`
# `vllm.compilation.piecewise_backend.PiecewiseBackend`
# Why:
# vllm removed the compile graph for general shape, which caused operator fusion to fail.
# This issue affects the performance of model inference on Ascend.
# How:
# recover the compiled graph for dynamic_shape in PiecewiseBackend.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/24252
# Future Plan:
# Remove this patch when fix the problem.
#
# * Worker Patch:
# ===============
#
Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/patch/platform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import os

import vllm_ascend.patch.platform.patch_compile_backend # noqa
import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_ec_connector # noqa
import vllm_ascend.patch.platform.patch_mamba_config # noqa
Expand Down
235 changes: 235 additions & 0 deletions vllm_ascend/patch/platform/patch_compile_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
from collections.abc import Callable
from typing import Any

import torch
import torch.fx as fx
import vllm.compilation.backends
import vllm.compilation.piecewise_backend
from torch._dispatch.python import enable_python_dispatcher
from vllm.compilation.backends import VllmBackend
from vllm.compilation.counter import compilation_counter
from vllm.compilation.piecewise_backend import RangeEntry
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.utils import Range
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname

logger = init_logger(__name__)


class AscendPiecewiseCompileInterpreter(torch.fx.Interpreter):
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
compilation configs.

NOTE: the order in `compile_submod_names` matters, because
it will be used to determine the order of the compiled piecewise
graphs. The first graph will handle logging, and the last graph
has some special cudagraph output handling.
"""

def __init__(
self,
module: torch.fx.GraphModule,
compile_submod_names: list[str],
vllm_config: VllmConfig,
vllm_backend: "VllmBackend",
):
super().__init__(module)
from torch._guards import detect_fake_mode

self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config
self.vllm_backend = vllm_backend
# When True, it annoyingly dumps the torch.fx.Graph on errors.
self.extra_traceback = False

def run(self, *args):
# maybe instead just assert inputs are fake?
fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in args
]
with self.fake_mode, enable_python_dispatcher():
return super().run(*fake_args)

def call_module(
self,
target: torch.fx.node.Target,
args: tuple[torch.fx.node.Argument, ...],
kwargs: dict[str, Any],
) -> Any:
assert isinstance(target, str)

output = super().call_module(target, args, kwargs)

if target in self.compile_submod_names:
index = self.compile_submod_names.index(target)
submod = self.fetch_attr(target)

sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
r1 = Range(start=1, end=max_num_batched_tokens)
compiled_graph_for_dynamic_shape = (
self.vllm_backend.compiler_manager.compile(
submod,
args,
self.vllm_backend.inductor_config,
self.compilation_config,
graph_index=index,
num_graphs=len(self.compile_submod_names),
compile_range=r1,
))

# Lazy import here to avoid circular import
from vllm.compilation.piecewise_backend import PiecewiseBackend

piecewise_backend = PiecewiseBackend(
submod,
self.vllm_config,
index,
len(self.compile_submod_names),
sym_shape_indices,
compiled_graph_for_dynamic_shape,
self.vllm_backend,
)

if (self.compilation_config.cudagraph_mode.
has_piecewise_cudagraphs() and
not self.compilation_config.use_inductor_graph_partition):
# We're using Dynamo-based piecewise splitting, so we wrap
# the whole subgraph with a static graph wrapper.
from vllm.compilation.cuda_graph import CUDAGraphOptions

# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls())

# Always assign PIECEWISE runtime mode to the
# CUDAGraphWrapper for piecewise_backend, to distinguish
# it from the FULL cudagraph runtime mode, no matter it
# is wrapped on a full or piecewise fx graph.
self.module.__dict__[target] = static_graph_wrapper_class(
runnable=piecewise_backend,
vllm_config=self.vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=piecewise_backend.is_first_graph,
gc_disable=not piecewise_backend.is_first_graph,
weak_ref_output=piecewise_backend.is_last_graph,
),
)
else:
self.module.__dict__[target] = piecewise_backend

compilation_counter.num_piecewise_capturable_graphs_seen += 1

return output


class AscendPiecewiseBackend:

def __init__(
self,
graph: fx.GraphModule,
vllm_config: VllmConfig,
piecewise_compile_index: int,
total_piecewise_compiles: int,
sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend,
):
"""
The backend for piecewise compilation.
It mainly handles the compilation of static shapes and
dispatching based on runtime shape.

We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.
"""
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend

self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1

self.is_full_graph = total_piecewise_compiles == 1
self.is_encoder_compilation = vllm_backend.is_encoder

self.compile_ranges = self.compilation_config.get_compile_ranges()
if self.is_encoder_compilation:
# For encoder compilation we use the max int32 value
# to set the upper bound of the compile ranges
max_int32 = 2**31 - 1
last_compile_range = self.compile_ranges[-1]
assert (last_compile_range.end ==
vllm_config.scheduler_config.max_num_batched_tokens)
self.compile_ranges[-1] = Range(start=last_compile_range.start,
end=max_int32)

log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
logger.debug_once(log_string)

self.compile_sizes = self.compilation_config.compile_sizes
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
logger.debug_once(log_string)

self.sym_shape_indices = sym_shape_indices

# the entries for ranges that we need to either
self.range_entries: dict[Range, RangeEntry] = {}

# to_be_compiled_ranges tracks the remaining ranges to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)

# We only keep compilation management inside this class directly.
for size in self.compile_sizes:
range = Range(start=size, end=size)
if range not in self.compile_ranges:
self.range_entries[range] = RangeEntry(compile_range=range, )
self.to_be_compiled_ranges.add(range)

for range in self.compile_ranges:
self.range_entries[range] = RangeEntry(compile_range=range, )

def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
# First we try to find the range entry for the concrete compile size
# If not found, we search for the range entry
# that contains the runtime shape.
if runtime_shape in self.compile_sizes:
return self.range_entries[Range(start=runtime_shape,
end=runtime_shape)]
else:
for range in self.compile_ranges:
if runtime_shape in range:
return self.range_entries[range]
return None

def __call__(self, *args) -> Any:
runtime_shape = args[self.sym_shape_indices[0]]
range_entry = self._find_range_for_shape(runtime_shape)

assert range_entry is not None, (
f"Shape out of considered range: {runtime_shape} "
"[1, max_num_batched_tokens]")

return self.compiled_graph_for_general_shape(*args)


vllm.compilation.backends.PiecewiseCompileInterpreter = AscendPiecewiseCompileInterpreter
vllm.compilation.piecewise_backend.PiecewiseBackend.__init__ = AscendPiecewiseBackend.__init__
vllm.compilation.piecewise_backend.PiecewiseBackend.__call__ = AscendPiecewiseBackend.__call__
Loading
Loading