Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ The details of each configuration option are as follows:

| Name | Type | Default | Description |
|------------------------| ---- |---------|----------------------------------------------------------------------------------------|
| `enable` | bool | `False` | Whether to enable npugraph_ex backend. |
| `enable` | bool | `True` | Whether to enable npugraph_ex backend. |
| `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. |
| `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. |
| `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. |
Expand Down
5 changes: 5 additions & 0 deletions tests/e2e/singlecard/test_aclgraph_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def test_full_decode_only_res_consistency(cur_case: LLMTestCase, monkeypatch):
"cudagraph_mode": "FULL_DECODE_ONLY"
},
"quantization": cur_case.quantization,
"additional_config": {
"npugraph_ex_config": {
"enable": False
}
},
}
gen_and_valid(runner_kwargs=runner_kwargs,
prompts=cur_case.prompts,
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/test_ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_init_ascend_config_with_additional_config(self, mock_fix_incompatible_c
self.assertTrue(ascend_config.multistream_overlap_shared_expert)

npugraph_ex_config = ascend_config.npugraph_ex_config
self.assertFalse(npugraph_ex_config.enable)
self.assertTrue(npugraph_ex_config.enable)
self.assertFalse(npugraph_ex_config.enable_static_kernel)

ascend_compilation_config = ascend_config.ascend_compilation_config
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ class NpugraphExConfig:

def __init__(
self,
enable: bool = False,
enable: bool = True,
enable_static_kernel: bool = False,
fuse_norm_quant: bool = True,
fuse_qknorm_rope: bool = True,
Expand All @@ -274,7 +274,7 @@ def __init__(
enable (bool): Whether to enable npugraph_ex backend.
When set to True, the Fx graph generated by Dymano will be
optimized and compiled by the npugraph_ex backend.
Default: False
Default: True
enable_static_kernel (bool): Whether to enable static kernel.
Static kernel is suitable for scenarios with purely static shapes
or minimal shape changes, and can improve network performance.
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def npugraph_ex_compile(
# that can trigger the compilation of static kernel. If this configuration is
# not applied, new shapes will trigger the compilation of static kernels,
# affecting program execution.
num_spec_tokens = vllm_config.speculative_config.num_speculative_token if vllm_config.speculative_config else 0
num_spec_tokens = vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config else 0
uniform_decode_query_len = num_spec_tokens + 1
max_num_tokens = vllm_config.scheduler_config.max_num_seqs * uniform_decode_query_len
decode_cudagraph_batch_sizes = [
Expand Down
10 changes: 6 additions & 4 deletions vllm_ascend/compilation/npu_graph_ex_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import fx as fx
from vllm.config import VllmConfig

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import vllm_version_is

if vllm_version_is("0.15.0"):
Expand Down Expand Up @@ -55,18 +56,19 @@ def add(self, pass_: VllmInductorPass):

def configure(self, config: VllmConfig):
# By default, we enable the graph fusion and quantization fusion pass.
self.npugraph_ex_config: dict = config.additional_config.get("npugraph_ex_config", {})
if self.npugraph_ex_config.get("fuse_norm_quant", True):
self.npugraph_ex_config = get_ascend_config().npugraph_ex_config

if self.npugraph_ex_config.fuse_norm_quant:
from .npugraph_ex_passes.graphex_norm_quant_fusion_pass import GraphEXAddRMSNormFusionPass

self.passes.append(GraphEXAddRMSNormFusionPass(config))

if self.npugraph_ex_config.get("fuse_qknorm_rope", True):
if self.npugraph_ex_config.fuse_qknorm_rope:
from .npugraph_ex_passes.graphex_qknorm_rope_fusion_pass import GraphEXQKNormRopeFusionPass

self.passes.append(GraphEXQKNormRopeFusionPass(config))

if self.npugraph_ex_config.get("fuse_allreduce_rms", True):
if self.npugraph_ex_config.fuse_allreduce_rms:
from .npugraph_ex_passes.graphex_allreduce_rmsnorm_fusion_pass import GraphEXMatmulAllReduceAddRMSNormPass

self.passes.append(GraphEXMatmulAllReduceAddRMSNormPass(config))
14 changes: 14 additions & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,17 @@
# make unquantized_gemm as a customop.
# Future Plan:
# Remove this patch when vLLM support the operator as customop.
#
# ** 13. File: worker/patch_npugraph_ex_triton.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `torchair.core._concrete_graph.ValuePack`,
# `torchair.npu_fx_compiler._unpack_meta`,
# `torchair.npu_fx_compiler._NpuGraphConverter._unpack_npu`
# Why:
# In the Triton scenario, npugraph_ex backend needs to process the value pack of the input parameters.
# How:
# Supplement the relevant processing logic through patches.
# Related PR (if no, explain why):
# https://gitcode.com/Ascend/torchair/pull/2575
# Future Plan:
# Remove this patch when the PTA version used by vllm-ascend has been upgraded.
1 change: 1 addition & 0 deletions vllm_ascend/patch/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
import vllm_ascend.patch.worker.patch_v2_egale # noqa
import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa
import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa
116 changes: 116 additions & 0 deletions vllm_ascend/patch/worker/patch_npugraph_ex_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import importlib
import sys

import torch
import torchair
from torch._subclasses.fake_tensor import FakeTensor
from torchair.core._concrete_graph import _is_symlist
from torchair.npu_fx_compiler import _unpack_meta_list


class ValuePack:
def __init__(self, meta, npu_meta=None) -> None:
self._meta = meta
self._npu_meta = meta if npu_meta is None else npu_meta

@property
def meta(self):
return self._meta

@property
def npu(self):
return self._npu_meta

def __getitem__(self, key):
if isinstance(self._meta, dict):
return self._meta.get(key)
raise ValueError(f"Unsupported meta type for ValuePack __getitem__, key:{key}, type: {type(self._meta)}")

def __repr__(self) -> str:
if isinstance(self._meta, FakeTensor):
meta_str = f"FakeTensor(dtype={self._meta.dtype}, size={list(self._meta.size())}"
elif isinstance(self._meta, torch.Tensor):
meta_str = f"torch.Tensor(dtype={self._meta.dtype}, size={list(self._meta.size())}"
elif isinstance(self._meta, torch.SymInt):
meta_str = f"torch.SymInt({self._meta})"
else:
try:
meta_str = f"{type(self._meta)}({self._meta})"
except Exception:
meta_str = f"{type(self._meta)}"
return f"Pack(meta:{meta_str} npu:{self._npu_meta})"


def _unpack_meta(args, kwargs):
unpacked_args = []
unpacked_kwargs = {}

def _get_meta_part(arg):
if isinstance(arg, (list, tuple)) and any(isinstance(v, ValuePack) for v in arg):
return _unpack_meta_list(arg)
elif isinstance(arg, dict):
return {k: v.meta if isinstance(v, ValuePack) else v for k, v in arg.items()}
elif isinstance(arg, ValuePack):
return arg.meta
else:
return arg

for arg in args:
unpacked_args.append(_get_meta_part(arg))

for key, value in kwargs.items():
unpacked_kwargs[key] = _get_meta_part(value)

return list(unpacked_args), unpacked_kwargs


def _unpack_npu(self, args, kwargs):
unpacked = []
unpacked_kwargs = {}

def _get_npu_part(arg):
if isinstance(arg, (list, tuple)) and len(arg):
if _is_symlist(arg):
arg = self._graph.parse_symlist(arg)
else:
arg = [(v.npu if isinstance(v, ValuePack) else v) for v in arg]
return arg
elif isinstance(arg, dict):
return {k: v.npu if isinstance(v, ValuePack) else v for k, v in arg.items()}
elif isinstance(arg, ValuePack):
return arg.npu
else:
return arg

for arg in args:
unpacked.append(_get_npu_part(arg))

for key, value in kwargs.items():
unpacked_kwargs[key] = _get_npu_part(value)

return unpacked, unpacked_kwargs


torchair.core._concrete_graph.ValuePack = ValuePack
# The ValuePack class is referenced in these two modules, and after the patch, these two modules need to be reloaded.
importlib.reload(sys.modules["torchair.fx_summary"])
importlib.reload(sys.modules["torchair.npu_fx_compiler"])
torchair.npu_fx_compiler._unpack_meta = _unpack_meta
torchair.npu_fx_compiler._NpuGraphConverter._unpack_npu = _unpack_npu
12 changes: 12 additions & 0 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
)

npugraph_ex_config = ascend_config.npugraph_ex_config
if npugraph_ex_config and npugraph_ex_config.fuse_allreduce_rms:
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THREHOLD

new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THREHOLD)
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
logger.debug(
"set compile_ranges_split_points to {new_compile_ranges_split_points} for matmul and allreduce fusion"
)

elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")

Expand Down