Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 0 additions & 37 deletions vllm_ascend/distributed/moe_comm_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,43 +94,6 @@ def unpermute(self, mlp_output: torch.Tensor,
pass


class DummyCommImpl(MoECommMethod):

def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Dummy prepare method that does nothing."""
return hidden_states, router_logits

def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""Dummy finalize method that does nothing."""
return hidden_states

def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Dummy implementation, make sure the output shapes are correct."""
top_k_num = topk_ids.shape[1]
permuted_hidden_states = hidden_states.repeat_interleave(top_k_num,
dim=0)
expert_tokens = torch.zeros((num_experts, ),
dtype=torch.int64,
device=hidden_states.device)
group_list_type = 0
return permuted_hidden_states, expert_tokens, group_list_type

def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
"""Dummy implementation that does nothing."""
pass


class AllGatherCommImpl(MoECommMethod):
"""This implementation is the same as NativeAllGatherCommImpl,
but uses NPU-specific ops for better performance.
Expand Down
8 changes: 5 additions & 3 deletions vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
DummyCommImpl,
MC2CommImpl,
MoECommMethod)
from vllm_ascend.distributed.parallel_state import get_mc2_group
Expand Down Expand Up @@ -230,7 +229,7 @@ def __init__(
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()

for method in {AllGatherCommImpl, DummyCommImpl, MC2CommImpl}:
for method in {AllGatherCommImpl, MC2CommImpl}:
setattr(
self, method.__name__.lower(),
method(moe_config=self.moe_config)) # type: ignore[abstract]
Expand All @@ -241,8 +240,11 @@ def forward_impl(self, hidden_states: torch.Tensor,

forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
if not self.moe_config.use_ep and moe_comm_method_name != "dummycommimpl":

# TODO: Can we refactor this logic to model_runner?
if not self.moe_config.use_ep:
moe_comm_method_name = "allgathercommimpl"

forward_context.moe_comm_method = getattr(self, moe_comm_method_name)

hidden_states, router_logits = forward_context.moe_comm_method.prepare(
Expand Down
15 changes: 12 additions & 3 deletions vllm_ascend/torchair/torchair_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
register_torchair_model()
torchair_quant_method_register()

def _get_forward_metadata_across_dp_and_pad(
def _sync_metadata_across_dp(
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
"""Override from NPUModelRunner to pad num_tokens"""
Expand All @@ -81,8 +81,17 @@ def _get_forward_metadata_across_dp_and_pad(
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
return num_tokens, None, with_prefill, enable_dbo

num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
num_tokens, with_prefill, enable_dbo)
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
dtype=torch.int32,
device="npu")
num_tokens_across_dp[self.dp_rank] = num_tokens
num_tokens_across_dp[-2] = int(with_prefill)
num_tokens_across_dp[-1] = int(not enable_dbo)
dist.all_reduce(num_tokens_across_dp,
group=get_dp_group().device_group)
with_prefill = bool(num_tokens_across_dp[-2])
enable_dbo = not bool(num_tokens_across_dp[-1])
num_tokens_across_dp = num_tokens_across_dp[:-2]

if not with_prefill:
max_num_token = num_tokens_across_dp.max().item()
Expand Down
Loading
Loading