From 58c14f1aa32a695afd36010b9e53da8dfeba6623 Mon Sep 17 00:00:00 2001 From: whx-sjtu <2952154980@qq.com> Date: Wed, 24 Sep 2025 21:31:51 +0800 Subject: [PATCH 1/4] fix a2 accu problem Signed-off-by: whx-sjtu <2952154980@qq.com> --- vllm_ascend/ops/common_fused_moe.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 554b40e0021..9ca97b97e09 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -333,6 +333,15 @@ def forward( hidden_states: torch.Tensor, router_logits: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + shared_out, fused_out = AscendFusedMoE.forward( + self, + hidden_states=hidden_states, + router_logits=router_logits, + ) + return shared_out, fused_out + + def forward_impl(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): # Make sure the shared experts stream begins after hidden_states are ready. if self.multistream_overlap_shared_expert: self.shared_expert_stream.wait_stream( # type: ignore @@ -347,8 +356,7 @@ def forward( moe_comm_type = forward_context.moe_comm_type if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: shared_out = tensor_model_parallel_all_reduce(shared_out) - - _, fused_out = AscendFusedMoE.forward( + fused_output = AscendFusedMoE.forward_impl( self, hidden_states=hidden_states, router_logits=router_logits, @@ -356,17 +364,7 @@ def forward( # Make sure the default stream waits for the shared experts stream to finish. if self.multistream_overlap_shared_expert: torch.npu.current_stream().wait_stream(self.shared_expert_stream) - return shared_out, fused_out - - def forward_impl(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - shared_output = torch.empty(1) - fused_output = AscendFusedMoE.forward_impl( - self, - hidden_states=hidden_states, - router_logits=router_logits, - ) - return shared_output, fused_output + return shared_out, fused_output UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func From b256ff4c038b6a27de33d2e8a0e41c1ce32d7298 Mon Sep 17 00:00:00 2001 From: whx-sjtu <2952154980@qq.com> Date: Thu, 25 Sep 2025 17:58:33 +0800 Subject: [PATCH 2/4] make maybe_all_reduce_tensor_model_parallel custom op Signed-off-by: whx-sjtu <2952154980@qq.com> --- vllm_ascend/ops/common_fused_moe.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 9ca97b97e09..b7e9a43939a 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -20,6 +20,7 @@ import torch import torch_npu from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.utils import direct_register_custom_op from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce) from vllm.forward_context import get_forward_context @@ -203,12 +204,7 @@ def maybe_all_reduce_tensor_model_parallel( `finalize` function. In `allgathercommimpl`, we still need to all-reduce the outputs since each rank only has partial outputs. """ - forward_context = get_forward_context() - moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: - return final_hidden_states - else: - return tensor_model_parallel_all_reduce(final_hidden_states) + return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(final_hidden_states) def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): @@ -367,6 +363,22 @@ def forward_impl(self, hidden_states: torch.Tensor, return shared_out, fused_output +def _maybe_all_reduce_tensor_model_parallel_impl( + final_hidden_states: torch.Tensor +) -> torch.Tensor: + forward_context = get_forward_context() + moe_comm_type = forward_context.moe_comm_type + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: + return final_hidden_states + else: + return tensor_model_parallel_all_reduce(final_hidden_states) + +direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel", + op_func=_maybe_all_reduce_tensor_model_parallel_impl, + fake_impl=lambda x, label: x, + mutates_args=[], + dispatch_key="PrivateUse1") + UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading UnquantizedFusedMoEMethod.forward_oot = forward_oot From 02da01a9ee193ee7c40341ae8bf70fd37dfaa884 Mon Sep 17 00:00:00 2001 From: whx-sjtu <2952154980@qq.com> Date: Thu, 25 Sep 2025 19:36:38 +0800 Subject: [PATCH 3/4] fix lint Signed-off-by: whx-sjtu <2952154980@qq.com> --- vllm_ascend/ops/common_fused_moe.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index b7e9a43939a..32c7e0579c0 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -20,13 +20,13 @@ import torch import torch_npu from vllm.config import CompilationLevel, get_current_vllm_config -from vllm.utils import direct_register_custom_op from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE +from vllm.utils import direct_register_custom_op from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import MoECommType @@ -204,7 +204,8 @@ def maybe_all_reduce_tensor_model_parallel( `finalize` function. In `allgathercommimpl`, we still need to all-reduce the outputs since each rank only has partial outputs. """ - return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(final_hidden_states) + return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel( + final_hidden_states) def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): @@ -364,8 +365,7 @@ def forward_impl(self, hidden_states: torch.Tensor, def _maybe_all_reduce_tensor_model_parallel_impl( - final_hidden_states: torch.Tensor -) -> torch.Tensor: + final_hidden_states: torch.Tensor) -> torch.Tensor: forward_context = get_forward_context() moe_comm_type = forward_context.moe_comm_type if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: @@ -373,6 +373,7 @@ def _maybe_all_reduce_tensor_model_parallel_impl( else: return tensor_model_parallel_all_reduce(final_hidden_states) + direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel", op_func=_maybe_all_reduce_tensor_model_parallel_impl, fake_impl=lambda x, label: x, From 64d351d5fe31fd1f26a8ce1f9741c606be74d7fe Mon Sep 17 00:00:00 2001 From: whx-sjtu <2952154980@qq.com> Date: Fri, 26 Sep 2025 14:19:36 +0800 Subject: [PATCH 4/4] change register position Signed-off-by: whx-sjtu <2952154980@qq.com> --- vllm_ascend/ops/common_fused_moe.py | 17 ----------------- vllm_ascend/ops/register_custom_ops.py | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 32c7e0579c0..ac22b69bcc7 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE -from vllm.utils import direct_register_custom_op from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import MoECommType @@ -364,22 +363,6 @@ def forward_impl(self, hidden_states: torch.Tensor, return shared_out, fused_output -def _maybe_all_reduce_tensor_model_parallel_impl( - final_hidden_states: torch.Tensor) -> torch.Tensor: - forward_context = get_forward_context() - moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: - return final_hidden_states - else: - return tensor_model_parallel_all_reduce(final_hidden_states) - - -direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel", - op_func=_maybe_all_reduce_tensor_model_parallel_impl, - fake_impl=lambda x, label: x, - mutates_args=[], - dispatch_key="PrivateUse1") - UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading UnquantizedFusedMoEMethod.forward_oot = forward_oot diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index a702b3521d5..438bff1935f 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -10,6 +10,7 @@ from vllm.utils import direct_register_custom_op import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_forward_context import MoECommType def _maybe_chunk_residual_impl(x: torch.Tensor, @@ -147,6 +148,16 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None: return +def _maybe_all_reduce_tensor_model_parallel_impl( + final_hidden_states: torch.Tensor) -> torch.Tensor: + forward_context = get_forward_context() + moe_comm_type = forward_context.moe_comm_type + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: + return final_hidden_states + else: + return tensor_model_parallel_all_reduce(final_hidden_states) + + direct_register_custom_op(op_name="maybe_chunk_residual", op_func=_maybe_chunk_residual_impl, fake_impl=lambda x, residual: residual, @@ -182,3 +193,9 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None: fake_impl=_maybe_wait_prefetch_done_impl_fake, mutates_args=[], dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel", + op_func=_maybe_all_reduce_tensor_model_parallel_impl, + fake_impl=lambda x: x, + mutates_args=[], + dispatch_key="PrivateUse1")