Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/user_guide/configuration/additional_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ The details of each configuration option are as follows:
|------------------------| ---- |---------|----------------------------------------------------------------------------------------|
| `enable` | bool | `False` | Whether to enable npugraph_ex backend. |
| `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. |
| `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. |
| `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. |
| `fuse_allreduce_rms` | bool | `False` | Whether to enable fuse_allreduce_rms pass. It's set to False because of conflict with SP. |

### Example

Expand Down
20 changes: 19 additions & 1 deletion vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,15 @@ class NpugraphExConfig:
These configurations can directly impact the performance and behavior of models deployed on Ascend platforms.
"""

def __init__(self, enable: bool = False, enable_static_kernel: bool = False, **kwargs):
def __init__(
self,
enable: bool = False,
enable_static_kernel: bool = False,
fuse_norm_quant: bool = True,
fuse_qknorm_rope: bool = True,
fuse_allreduce_rms: bool = False,
**kwargs,
):
"""
Initialize the configuration.

Expand All @@ -251,10 +259,20 @@ def __init__(self, enable: bool = False, enable_static_kernel: bool = False, **k
binary files with the corresponding shapes based on the current batch_size,
which usually takes some time.
Default: False
fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
When set to True, the system will optimize norm and quant operations.
Default: True
fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization.
Default: True
fuse_allreduce_rms (bool): Whether to enable allreduce and addrmsnorm fusion optimization.
Default: False
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
"""
self.enable = enable
self.enable_static_kernel = enable_static_kernel
self.fuse_norm_quant = fuse_norm_quant
self.fuse_qknorm_rope = fuse_qknorm_rope
self.fuse_allreduce_rms = fuse_allreduce_rms


class XliteGraphConfig:
Expand Down
16 changes: 15 additions & 1 deletion vllm_ascend/compilation/npu_graph_ex_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,18 @@ def add(self, pass_: VllmInductorPass):

def configure(self, config: VllmConfig):
# By default, we enable the graph fusion and quantization fusion pass.
self.ascend_compilation_config: dict = config.additional_config.get("ascend_compilation_config", {})
self.npugraph_ex_config: dict = config.additional_config.get("npugraph_ex_config", {})
if self.npugraph_ex_config.get("fuse_norm_quant", True):
from .npugraph_ex_passes.graphex_norm_quant_fusion_pass import GraphEXAddRMSNormFusionPass

self.passes.append(GraphEXAddRMSNormFusionPass(config))

if self.npugraph_ex_config.get("fuse_qknorm_rope", True):
from .npugraph_ex_passes.graphex_qknorm_rope_fusion_pass import GraphEXQKNormRopeFusionPass

self.passes.append(GraphEXQKNormRopeFusionPass(config))

if self.npugraph_ex_config.get("fuse_allreduce_rms", True):
from .npugraph_ex_passes.graphex_allreduce_rmsnorm_fusion_pass import GraphEXMatmulAllReduceAddRMSNormPass

self.passes.append(GraphEXMatmulAllReduceAddRMSNormPass(config))
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torchair
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tp_group

from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check

# computation-communication tiling block is 512
ALLREDUCE_NORM_FUSE_THREHOLD = 512


class GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern:
"""
recognizing the Matmul + AllReduce + AddRMSNorm computation pattern
AllReduce is optimized in the fusion operator to a two-stage communication of ReduceScatter+AllGather
"""

def __init__(self, vllm_config, eps=1e-6):
self.vllm_config = vllm_config
self.eps = eps
device_group = get_tp_group().device_group
backend = device_group._get_backend(torch.device("npu"))
self.local_rank = torch.distributed.get_rank(group=device_group)
self.tp_group_name = backend.get_hccl_comm_name(self.local_rank)
self.tp_size = get_tensor_model_parallel_world_size()

def get_inputs(self):
batch_size, seq_len = 2, 4
hidden_size = 4096
x = torch.randn(batch_size, seq_len, hidden_size, device="npu")
weight = torch.randn(hidden_size, hidden_size, device="npu")
residual = torch.randn(batch_size, seq_len, hidden_size, device="npu")
rms_norm_weight = torch.randn(hidden_size, device="npu")
return [x, weight, residual, rms_norm_weight]

def register(self):
def pattern(x, weight, residual, rms_norm_weight):
mm = torch.ops.vllm.unquantized_gemm(x, weight, None)
all_reduce_ = tensor_model_parallel_all_reduce(mm)
output = torch.ops._C_ascend.npu_add_rms_norm_bias(all_reduce_, residual, rms_norm_weight, None)
out0 = output[0]
out1 = output[2]

return out0, out1

def replacement(x, weight, residual, rms_norm_weight):
out0, out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(
x,
weight,
residual,
rms_norm_weight,
self.tp_group_name,
self.tp_size,
self.local_rank,
self.eps,
True,
False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The is_allgather_add_out parameter is hardcoded to False. The pattern being replaced produces a full tensor for the second output (add_out), as it's derived from an all-reduced tensor. If is_allgather_add_out=False causes the fused operator to return a sharded add_out, this would be a correctness issue, as subsequent layers expect a full tensor. The corresponding test for this operator uses is_allgather_add_out=True. It should likely be True here as well to ensure the replacement is correct for middle layers.

Suggested change
False,
True,

)
return out0, out1

torchair.register_replacement(
search_fn=pattern,
replace_fn=replacement,
example_inputs=self.get_inputs(),
extra_check=extra_stream_scope_check,
)


class GraphEXLastLayerMatmulAllReduceAddRMSNormPattern:
def __init__(self, vllm_config, eps=1e-6):
self.vllm_config = vllm_config
self.eps = eps
device_group = get_tp_group().device_group
backend = device_group._get_backend(torch.device("npu"))
self.local_rank = torch.distributed.get_rank(group=device_group)
self.tp_group_name = backend.get_hccl_comm_name(self.local_rank)
self.tp_size = get_tensor_model_parallel_world_size()

def get_inputs(self):
batch_size, seq_len = 2, 4
hidden_size = 4096
x = torch.randn(batch_size, seq_len, hidden_size, device="npu")
weight = torch.randn(hidden_size, hidden_size, device="npu")
residual = torch.randn(batch_size, seq_len, hidden_size, device="npu")
rms_norm_weight = torch.randn(hidden_size, device="npu")
return [x, weight, residual, rms_norm_weight]

def register(self):
def pattern(x, weight, residual, rms_norm_weight):
mm = torch.ops.vllm.unquantized_gemm(x, weight, None)
all_reduce_ = tensor_model_parallel_all_reduce(mm)
output = torch.ops._C_ascend.npu_add_rms_norm_bias(all_reduce_, residual, rms_norm_weight, None)

return output[0]

def replacement(x, weight, residual, rms_norm_weight):
out0, _ = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(
x,
weight,
residual,
rms_norm_weight,
self.tp_group_name,
self.tp_size,
self.local_rank,
self.eps,
True,
False,
)
return out0

torchair.register_replacement(
search_fn=pattern,
replace_fn=replacement,
example_inputs=self.get_inputs(),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The example_inputs argument to torchair.register_replacement expects a callable, but it's being passed the result of self.get_inputs(). This will cause an error during pattern registration. It should be self.get_inputs to pass the method itself.

Suggested change
example_inputs=self.get_inputs(),
example_inputs=self.get_inputs,

extra_check=extra_stream_scope_check,
)


class GraphEXMatmulAllReduceAddRMSNormPass:
def __init__(self, vllm_config: VllmConfig):
GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register()
GraphEXLastLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register()

def __call__(self, graph: torch.fx.Graph):
pass

def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""
applicable = compile_range.start > ALLREDUCE_NORM_FUSE_THREHOLD
return applicable
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def register(self, pm_pass: PatternMatcherPass):
def pattern(x, weight, residual, rms_norm_weight):
mm = torch.ops.vllm.unquantized_gemm(x, weight, None)
all_reduce_ = tensor_model_parallel_all_reduce(mm)
output = torch.ops.npu.npu_add_rms_norm(all_reduce_, residual, rms_norm_weight)
output = torch.ops._C_ascend.npu_add_rms_norm_bias(all_reduce_, residual, rms_norm_weight, None)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is caused by the introduction of custom operators. As a result, AddRmsNorm in the graph is replaced with AddRmsNormBias.

out0 = output[0]
out1 = output[2]

Expand Down Expand Up @@ -103,7 +103,7 @@ def register(self, pm_pass: PatternMatcherPass):
def pattern(x, weight, residual, rms_norm_weight):
mm = torch.ops.vllm.unquantized_gemm(x, weight, None)
all_reduce_ = tensor_model_parallel_all_reduce(mm)
output = torch.ops.npu.npu_add_rms_norm(all_reduce_, residual, rms_norm_weight)
output = torch.ops._C_ascend.npu_add_rms_norm_bias(all_reduce_, residual, rms_norm_weight, None)

return output[0]

Expand Down