Skip to content
7 changes: 6 additions & 1 deletion vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ class AscendCompilationConfig:
deployed on Ascend platforms.
"""

def __init__(self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = False, **kwargs):
def __init__(
self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = False, fuse_allreduce_rms: bool = False, **kwargs
):
"""
Initialize the configuration.

Expand All @@ -179,10 +181,13 @@ def __init__(self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = False,
Default: True
fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization.
Default: False
fuse_allreduce_rms (bool): Whether to enable allreduce and addrmsnorm fusion optimization.
Default: False
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
"""
self.fuse_norm_quant = fuse_norm_quant
self.fuse_qknorm_rope = HAS_TRITON or fuse_qknorm_rope
self.fuse_allreduce_rms = fuse_allreduce_rms


class XliteGraphConfig:
Expand Down
5 changes: 5 additions & 0 deletions vllm_ascend/compilation/graph_fusion_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,8 @@ def configure(self, config: VllmConfig):
from .passes.qknorm_rope_fusion_pass import QKNormRopeFusionPass

self.passes.append(QKNormRopeFusionPass(config))

if self.ascend_compilation_config.get("fuse_allreduce_rms", True):
from .passes.allreduce_rmsnorm_fusion_pass import MatmulAllReduceAddRMSNormPass

self.passes.append(MatmulAllReduceAddRMSNormPass(config))
153 changes: 153 additions & 0 deletions vllm_ascend/compilation/passes/allreduce_rmsnorm_fusion_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tp_group
from vllm.logger import logger

# computation-communication tiling block is 512
ALLREDUCE_NORM_FUSE_THREHOLD = 512


class MiddleLayerMatmulAllReduceAddRMSNormPattern:
"""
recognizing the Matmul+AllReduce+AddRMSNorm computation pattern
AllReduce is optimized in the fusion operator to a two-stage communication of ReduceScatter+AllGather
"""

def __init__(self, vllm_config, eps=1e-6):
self.vllm_config = vllm_config
self.eps = eps
device_group = get_tp_group().device_group
backend = device_group._get_backend(torch.device("npu"))
self.local_rank = torch.distributed.get_rank(group=device_group)
self.tp_group_name = backend.get_hccl_comm_name(self.local_rank)
self.tp_size = get_tensor_model_parallel_world_size()

def get_inputs(self):
batch_size, seq_len = 2, 4
hidden_size = 4096
x = torch.randn(batch_size, seq_len, hidden_size, device="npu")
weight = torch.randn(hidden_size, hidden_size, device="npu")
residual = torch.randn(batch_size, seq_len, hidden_size, device="npu")
rms_norm_weight = torch.randn(hidden_size, device="npu")
return [x, weight, residual, rms_norm_weight]

def register(self, pm_pass: PatternMatcherPass):
def pattern(x, weight, residual, rms_norm_weight):
mm = torch.ops.vllm.unquantized_gemm(x, weight, None)
all_reduce_ = tensor_model_parallel_all_reduce(mm)
output = torch.ops.npu.npu_add_rms_norm(all_reduce_, residual, rms_norm_weight)
out0 = output[0]
out1 = output[2]

return out0, out1

def replacement(x, weight, residual, rms_norm_weight):
out0, out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(
x,
weight,
residual,
rms_norm_weight,
self.tp_group_name,
self.tp_size,
self.local_rank,
self.eps,
True,
False,
)
return out0, out1

pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)


class LastLayerMatmulAllReduceAddRMSNormPattern:
def __init__(self, vllm_config, eps=1e-6):
self.vllm_config = vllm_config
self.eps = eps
device_group = get_tp_group().device_group
backend = device_group._get_backend(torch.device("npu"))
self.local_rank = torch.distributed.get_rank(group=device_group)
self.tp_group_name = backend.get_hccl_comm_name(self.local_rank)
self.tp_size = get_tensor_model_parallel_world_size()

def get_inputs(self):
batch_size, seq_len = 2, 4
hidden_size = 4096
x = torch.randn(batch_size, seq_len, hidden_size, device="npu")
weight = torch.randn(hidden_size, hidden_size, device="npu")
residual = torch.randn(batch_size, seq_len, hidden_size, device="npu")
rms_norm_weight = torch.randn(hidden_size, device="npu")
return [x, weight, residual, rms_norm_weight]

def register(self, pm_pass: PatternMatcherPass):
def pattern(x, weight, residual, rms_norm_weight):
mm = torch.ops.vllm.unquantized_gemm(x, weight, None)
all_reduce_ = tensor_model_parallel_all_reduce(mm)
output = torch.ops.npu.npu_add_rms_norm(all_reduce_, residual, rms_norm_weight)

return output[0]

def replacement(x, weight, residual, rms_norm_weight):
out0, _ = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(
x,
weight,
residual,
rms_norm_weight,
self.tp_group_name,
self.tp_size,
self.local_rank,
self.eps,
True,
False,
)
return out0

pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)


class MatmulAllReduceAddRMSNormPass(VllmInductorPass):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="allreduce_rmsnorm_fusion_pass")

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass name should change?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest leave pass and pattern name as MatmulAllReduceAddRMSNormPass and xxMatmulAllReduceAddRMSNormPattern respectively. And using comments to explain that the fusion operator actually splits the allreduce into reducescatter and allgather.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have modified it.

MiddleLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register(self.pattern_match_passes)
LastLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register(self.pattern_match_passes)

def __call__(self, graph: torch.fx.Graph):
self.begin()
self.matched_count = self.pattern_match_passes.apply(graph)
pattern_idx = 0
for pattern_entry in self.pattern_match_passes.patterns.values():
for p in pattern_entry:
p_str = PatternPrettyPrinter.run(p.pattern)
logger.debug("Pattern %d: %s", pattern_idx, p_str)
pattern_idx += 1
logger.debug("Replaced %s patterns", self.matched_count)
self.end_and_log()

def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""
applicable = compile_range.start > ALLREDUCE_NORM_FUSE_THREHOLD
return applicable
12 changes: 12 additions & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,15 @@
# 1. `vllm.distributed.parallel_state.GroupCoordinator`
# Why:
# vllm doesn't support all_to_all for GroupCoordinator.
# all_reduce in vLLM not is a customop, which will make MatmulAllReduceAddRMSNorm fusion failure.
# How:
# Add all_to_all implementation for GroupCoordinator.
# make all_reduce as a customop.
# Related PR (if no, explain why):
# No, we should use vlLM all2all manager to support all_to_all for npu.
# Future Plan:
# Remove this patch when the refactor of all2all manager is done.
# Remove this patch when vLLM support all_reduce as customop.
#
# ** 3. File: worker/patch_minicpm.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -276,3 +279,12 @@
# Future Plan:
# Remove this patch when cann fix the gather bug.
#
# ** 13. File: worker/patch_unquantized_gemm.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.layers.utils.default_unquantized_gemm`
# Why:
# unquantized_gemm in vLLM not is a customop, which will make MatmulAllReduceAddRMSNorm fusion failure.
# How:
# make unquantized_gemm as a customop.
# Future Plan:
# Remove this patch when vLLM support the operator as customop.
1 change: 1 addition & 0 deletions vllm_ascend/patch/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

# isort: off
import vllm_ascend.patch.platform.patch_sched_yield # noqa
import vllm_ascend.patch.worker.patch_unquantized_gemm # noqa
import vllm_ascend.patch.worker.patch_bert # noqa
import vllm_ascend.patch.worker.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_deepseek # noqa
Expand Down
5 changes: 5 additions & 0 deletions vllm_ascend/patch/worker/patch_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,10 @@ def all_to_all(self,
gather_dim, scatter_sizes,
gather_sizes)

def all_reduce(self, input_):
Comment thread
wxsIcey marked this conversation as resolved.
if self.world_size == 1:
return input_
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)


vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch
57 changes: 57 additions & 0 deletions vllm_ascend/patch/worker/patch_unquantized_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import vllm.model_executor.layers.utils
from vllm.utils.torch_utils import direct_register_custom_op


def unquantized_gemm(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.nn.functional.linear(x, weight, bias)


def unquantized_gemm_fake(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
output_shape = (x.shape[0], weight.shape[0])
return torch.empty(output_shape, dtype=x.dtype, device=x.device)


direct_register_custom_op(op_name="unquantized_gemm",
op_func=unquantized_gemm,
fake_impl=unquantized_gemm_fake,
mutates_args=[],
dispatch_key="PrivateUse1")

def default_unquantized_gemm(
layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if x.device.type == "npu":
return torch.ops.vllm.unquantized_gemm(x, weight, bias)
else:
return torch.nn.functional.linear(x, weight, bias)


vllm.model_executor.layers.utils.default_unquantized_gemm = default_unquantized_gemm
Comment thread
wxsIcey marked this conversation as resolved.
12 changes: 12 additions & 0 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
else ascend_compilation_config
)

if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THREHOLD

new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THREHOLD)
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
logger.debug(
"set compile_ranges_split_points to "
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
)

elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
if model_config is None:
Expand Down