Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,8 @@ def maybe_all_reduce_tensor_model_parallel(
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
outputs since each rank only has partial outputs.
"""
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)

def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
Expand Down Expand Up @@ -333,6 +329,15 @@ def forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out, fused_out = AscendFusedMoE.forward(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
return shared_out, fused_out
Comment on lines +332 to +337
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to AscendFusedMoE.forward is incorrect. The MRO for AscendSharedFusedMoE will cause this to resolve to vllm.model_executor.layers.fused_moe.layer.FusedMoE.forward, which returns a single tensor. Attempting to unpack this single tensor into two variables, shared_out and fused_out, will raise a ValueError at runtime.

Given that the implementation logic has been moved into forward_impl, which correctly returns a tuple of two tensors, the forward method should likely just call self.forward_impl.

        return self.forward_impl(
            hidden_states=hidden_states,
            router_logits=router_logits,
        )


def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
# Make sure the shared experts stream begins after hidden_states are ready.
if self.multistream_overlap_shared_expert:
self.shared_expert_stream.wait_stream( # type: ignore
Expand All @@ -347,26 +352,15 @@ def forward(
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
shared_out = tensor_model_parallel_all_reduce(shared_out)

_, fused_out = AscendFusedMoE.forward(
fused_output = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
# Make sure the default stream waits for the shared experts stream to finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
return shared_out, fused_out

def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
shared_output = torch.empty(1)
fused_output = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
return shared_output, fused_output
return shared_out, fused_output


UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
Expand Down
17 changes: 17 additions & 0 deletions vllm_ascend/ops/register_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.utils import direct_register_custom_op

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType


def _maybe_chunk_residual_impl(x: torch.Tensor,
Expand Down Expand Up @@ -147,6 +148,16 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
return


def _maybe_all_reduce_tensor_model_parallel_impl(
final_hidden_states: torch.Tensor) -> torch.Tensor:
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)


direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: residual,
Expand Down Expand Up @@ -182,3 +193,9 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
fake_impl=_maybe_wait_prefetch_done_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")

direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
fake_impl=lambda x: x,
mutates_args=[],
dispatch_key="PrivateUse1")
Loading