Skip to content

Commit f453685

Browse files
committed
feat(moe): Refactor MoE communication method framework
This commit refactors the MoE communication method framework to improve modularity, clarity, and extensibility. Key changes include: - **Revised `MoECommMethod` Interface:** - Renamed `_pre_process` to `permute` and `_post_process` to `unpermute` for better clarity. - Introduced `prepare` and `finalize` methods to encapsulate logic that happens before and after the core MoE computation, such as tensor padding/splitting for MC2 and the final AllReduce. - **Simplified `AscendFusedMoE`:** - The `forward_impl` is significantly simplified by delegating pre- and post-processing logic (padding, splitting, reduction) to the specific `MoECommMethod` implementation. - `AscendFusedMoE` now instantiates all communication method objects at initialization and selects the appropriate one at runtime based on a string identifier. - **Centralized Expert Logic:** - Removed `unified_fused_experts` and introduced a new `fused_experts` function in `common_fused_moe.py`. - This new function utilizes the `permute`/`unpermute` methods from the `MoECommMethod` abstraction, decoupling the core expert logic from specific communication implementations. - **Configuration and Invocation:** - The communication method is now selected and passed around as a string (e.g., "mc2", "allgather") instead of a class type, simplifying the invocation in the model runner. These changes result in a cleaner separation of concerns, making the MoE implementation easier to understand, maintain, and extend with new communication strategies. Signed-off-by: Yizhou Liu <[email protected]>
1 parent 6b1d54f commit f453685

File tree

6 files changed

+232
-176
lines changed

6 files changed

+232
-176
lines changed

tests/e2e/multicard/moe/test_moe_comm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,11 @@ def test_all_gather_comm_impl(
103103
native_permuted_hidden,
104104
native_expert_tokens,
105105
_,
106-
) = native_impl._pre_process(hidden_states, topk_ids, topk_weights,
107-
expert_map, num_experts)
106+
) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map,
107+
num_experts)
108108
# Simulate MLP output
109109
native_mlp_output = torch.randn_like(native_permuted_hidden)
110-
native_impl._post_process(native_mlp_output, native_hidden_states_out)
110+
native_impl.unpermute(native_mlp_output, native_hidden_states_out)
111111

112112
# --- Run AllGather Implementation ---
113113
all_gather_hidden_states_out = hidden_states.clone()

vllm_ascend/ascend_forward_context.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
set_forward_context)
1212

1313
import vllm_ascend.envs as envs_ascend
14-
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
1514

1615

1716
class FusedMoEState(Enum):
@@ -57,7 +56,7 @@ def set_ascend_forward_context(
5756
with_prefill: bool = True,
5857
in_profile_run: bool = False,
5958
reserved_mc2_mask: Optional[torch.Tensor] = None,
60-
moe_comm_method: Optional[MoECommMethod] = None,
59+
moe_comm_method: str = "",
6160
num_actual_tokens: Optional[int] = None,
6261
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
6362
batch_descriptor: Optional[BatchDescriptor] = None):
@@ -75,7 +74,7 @@ def set_ascend_forward_context(
7574
batch_descriptor=batch_descriptor,
7675
):
7776
forward_context = get_forward_context()
78-
forward_context.moe_comm_method = moe_comm_method
77+
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
7978
forward_context.with_prefill = with_prefill
8079
ep_size = (get_ep_group().world_size if
8180
vllm_config.parallel_config.enable_expert_parallel else 1)

vllm_ascend/distributed/moe_comm_method.py

Lines changed: 142 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from abc import ABC, abstractmethod
2+
from typing import Optional
23

34
import torch
5+
import torch.distributed as dist
6+
import torch.nn as nn
47
import torch_npu
5-
from transformers.configuration_utils import PretrainedConfig
6-
from vllm.distributed.parallel_state import (
7-
get_ep_group, get_tensor_model_parallel_rank,
8-
get_tensor_model_parallel_world_size, get_tp_group)
8+
from vllm.distributed import tensor_model_parallel_all_reduce
9+
from vllm.distributed.parallel_state import get_ep_group, get_tp_group
910
from vllm.forward_context import ForwardContext, get_forward_context
11+
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
1012
from vllm.utils import direct_register_custom_op
1113

1214
from vllm_ascend.distributed.parallel_state import get_mc2_group
@@ -16,26 +18,36 @@
1618
class MoECommMethod(ABC):
1719
"""Base class for MoE communication methods."""
1820

19-
def __init__(
20-
self,
21-
device: torch.device,
22-
dtype: torch.dtype,
23-
hf_config: PretrainedConfig,
24-
):
25-
self.device = device
26-
self.dtype = dtype
27-
self.top_k_num = getattr(hf_config, "num_experts_per_tok", 0)
28-
# global_num_experts may be called num_experts or n_routed_experts in different models.
29-
possible_keys = ["num_experts", "n_routed_experts"]
30-
for key in possible_keys:
31-
if hasattr(hf_config, key):
32-
self.global_num_experts = getattr(hf_config, key)
33-
break
34-
else:
35-
self.global_num_experts = 0
21+
moe_config: FusedMoEConfig = None
22+
23+
def __init__(self, moe_config: Optional[FusedMoEConfig]):
24+
self.moe_config = moe_config
25+
26+
@abstractmethod
27+
def prepare(
28+
self, hidden_states: torch.Tensor,
29+
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
30+
"""Prepare the MoE communication method.
31+
32+
This method is called before quant_method.apply to prepare the
33+
communication method. It can be used to initialize any necessary
34+
resources or configurations.
35+
"""
36+
pass
37+
38+
@abstractmethod
39+
def finalize(self, hidden_states: torch.Tensor,
40+
reduce_results: bool) -> torch.Tensor:
41+
"""Finalize the MoE communication method.
42+
43+
This method is called after quant_method.apply to finalize the
44+
communication method. It can be used to clean up any resources or
45+
configurations.
46+
"""
47+
pass
3648

3749
@abstractmethod
38-
def _pre_process(
50+
def permute(
3951
self,
4052
hidden_states: torch.Tensor,
4153
topk_ids: torch.Tensor,
@@ -69,8 +81,8 @@ def _pre_process(
6981
pass
7082

7183
@abstractmethod
72-
def _post_process(self, mlp_output: torch.Tensor,
73-
hidden_states: torch.Tensor) -> None:
84+
def unpermute(self, mlp_output: torch.Tensor,
85+
hidden_states: torch.Tensor) -> None:
7486
"""Post-process after MLP.
7587
7688
Args:
@@ -84,7 +96,18 @@ def _post_process(self, mlp_output: torch.Tensor,
8496

8597
class DummyCommImpl(MoECommMethod):
8698

87-
def _pre_process(
99+
def prepare(
100+
self, hidden_states: torch.Tensor,
101+
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
102+
"""Dummy prepare method that does nothing."""
103+
return hidden_states, router_logits
104+
105+
def finalize(self, hidden_states: torch.Tensor,
106+
reduce_results: bool) -> torch.Tensor:
107+
"""Dummy finalize method that does nothing."""
108+
return hidden_states
109+
110+
def permute(
88111
self,
89112
hidden_states: torch.Tensor,
90113
topk_ids: torch.Tensor,
@@ -96,8 +119,8 @@ def _pre_process(
96119
return moe_comm_pre_process_fake(hidden_states, topk_ids, topk_weights,
97120
expert_map, num_experts)
98121

99-
def _post_process(self, mlp_output: torch.Tensor,
100-
hidden_states: torch.Tensor) -> None:
122+
def unpermute(self, mlp_output: torch.Tensor,
123+
hidden_states: torch.Tensor) -> None:
101124
"""Dummy implementation that does nothing."""
102125
pass
103126

@@ -110,7 +133,22 @@ class NativeAllGatherCommImpl(MoECommMethod):
110133
But it is a good fallback for scenarios where NPU-specific ops are not available.
111134
"""
112135

113-
def _pre_process(
136+
def prepare(
137+
self, hidden_states: torch.Tensor,
138+
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
139+
"""Dummy prepare method that does nothing."""
140+
return hidden_states, router_logits
141+
142+
def finalize(self, hidden_states: torch.Tensor,
143+
reduce_results: bool) -> torch.Tensor:
144+
if reduce_results and (self.moe_config.tp_size > 1
145+
or self.moe_config.ep_size > 1):
146+
final_hidden_states = tensor_model_parallel_all_reduce(
147+
hidden_states)
148+
149+
return final_hidden_states
150+
151+
def permute(
114152
self,
115153
hidden_states: torch.Tensor,
116154
topk_ids: torch.Tensor,
@@ -122,10 +160,10 @@ def _pre_process(
122160

123161
# Generate token indices and flatten
124162
token_indices = torch.arange(num_tokens,
125-
device=self.device,
163+
device=hidden_states.device,
126164
dtype=torch.int64)
127165
token_indices = (token_indices.unsqueeze(1).expand(
128-
-1, self.top_k_num).reshape(-1))
166+
-1, self.moe_config.experts_per_token).reshape(-1))
129167

130168
# Flatten token-to-expert mappings and map to local experts
131169
weights_flat = topk_weights.view(-1)
@@ -140,7 +178,7 @@ def _pre_process(
140178
# This is a workaround and should be removed after the issue is fixed
141179
filtered_weights = torch.where(mask, weights_flat,
142180
torch.zeros_like(weights_flat)).to(
143-
self.dtype)
181+
topk_weights.dtype)
144182
filtered_experts = torch.where(
145183
mask,
146184
local_experts_flat,
@@ -156,7 +194,7 @@ def _pre_process(
156194
# This is equivalent to but faster than:
157195
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
158196
token_counts = torch.zeros(num_experts + 1,
159-
device=self.device,
197+
device=hidden_states.device,
160198
dtype=torch.int64)
161199
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
162200
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
@@ -169,8 +207,8 @@ def _pre_process(
169207

170208
return permuted_hidden_states, expert_tokens, group_list_type
171209

172-
def _post_process(self, mlp_output: torch.Tensor,
173-
hidden_states: torch.Tensor) -> None:
210+
def unpermute(self, mlp_output: torch.Tensor,
211+
hidden_states: torch.Tensor) -> None:
174212
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
175213

176214
final_hidden_states = torch.zeros_like(hidden_states)
@@ -199,7 +237,22 @@ class AllGatherCommImpl(MoECommMethod):
199237
This is a workaround and should be removed after the issue is fixed.
200238
"""
201239

202-
def _pre_process(
240+
def prepare(
241+
self, hidden_states: torch.Tensor,
242+
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
243+
"""Dummy prepare method that does nothing."""
244+
return hidden_states, router_logits
245+
246+
def finalize(self, hidden_states: torch.Tensor,
247+
reduce_results: bool) -> torch.Tensor:
248+
if reduce_results and (self.moe_config.tp_size > 1
249+
or self.moe_config.ep_size > 1):
250+
final_hidden_states = tensor_model_parallel_all_reduce(
251+
hidden_states)
252+
253+
return final_hidden_states
254+
255+
def permute(
203256
self,
204257
hidden_states: torch.Tensor,
205258
topk_ids: torch.Tensor,
@@ -229,8 +282,8 @@ def _pre_process(
229282
torch_npu.npu_moe_init_routing_v2(
230283
hidden_states,
231284
topk_ids,
232-
active_num=num_tokens * self.top_k_num,
233-
expert_num=self.global_num_experts,
285+
active_num=num_tokens * self.moe_config.experts_per_token,
286+
expert_num=self.moe_config.num_experts,
234287
expert_tokens_num_type=1, # Only support `count` mode now
235288
expert_tokens_num_flag=True, # Output `expert_tokens`
236289
active_expert_range=[first_expert_idx, last_expert_idx],
@@ -243,8 +296,8 @@ def _pre_process(
243296

244297
return permuted_hidden_states, expert_tokens, group_list_type
245298

246-
def _post_process(self, mlp_output: torch.Tensor,
247-
hidden_states: torch.Tensor) -> None:
299+
def unpermute(self, mlp_output: torch.Tensor,
300+
hidden_states: torch.Tensor) -> None:
248301
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
249302
permuted_tokens=mlp_output,
250303
sorted_indices=self.expanded_row_idx,
@@ -261,19 +314,13 @@ class MC2CommImpl(MoECommMethod):
261314
Communication and Computation parallelism on Ascend devices.
262315
"""
263316

264-
def __init__(
265-
self,
266-
device: torch.device,
267-
dtype: torch.dtype,
268-
hf_config: PretrainedConfig,
269-
):
270-
super().__init__(device, dtype, hf_config)
317+
def __init__(self, moe_config: Optional[FusedMoEConfig]):
318+
super().__init__(moe_config)
271319

272320
# Shared communication configurations
273321
ep_group = get_mc2_group()
274322
self.ep_rank_id = ep_group.rank_in_group
275323
self.ep_world_size = ep_group.world_size
276-
self.tp_world_size = get_tp_group().world_size
277324

278325
device_group = ep_group.device_group
279326
local_rank = torch.distributed.get_rank(group=device_group)
@@ -286,15 +333,51 @@ def __init__(
286333
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
287334
self.need_extra_args = self.is_ascend_a3 # or is_torchair
288335

289-
# Intermediate tensors to be passed from pre_process to post_process
290-
self.topk_ids = None
291-
self.topk_weights = None
292-
self.mc2_mask = None
293-
self.assist_info_for_combine = None
294-
self.ep_recv_counts = None
295-
self.tp_recv_counts = None
336+
def prepare(
337+
self, hidden_states: torch.Tensor,
338+
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
339+
num_tokens, _ = hidden_states.shape
340+
self.mc2_mask = get_forward_context().mc2_mask
296341

297-
def _pre_process(
342+
if num_tokens % self.moe_config.ep_size != 0:
343+
pad_size = self.moe_config.ep_size - (num_tokens %
344+
self.moe_config.ep_size)
345+
hidden_states = nn.functional.pad(hidden_states,
346+
(0, 0, 0, pad_size))
347+
router_logits = nn.functional.pad(router_logits,
348+
(0, 0, 0, pad_size))
349+
350+
split_hidden_states = torch.tensor_split(hidden_states,
351+
self.moe_config.ep_size,
352+
dim=0)
353+
split_router_logits = torch.tensor_split(router_logits,
354+
self.moe_config.ep_size,
355+
dim=0)
356+
split_mc2_mask = torch.tensor_split(self.mc2_mask,
357+
self.moe_config.ep_size,
358+
dim=0)
359+
self.num_tokens = num_tokens
360+
self.split_hidden_states = split_hidden_states
361+
362+
hidden_states = split_hidden_states[self.moe_config.ep_rank]
363+
router_logits = split_router_logits[self.moe_config.ep_rank]
364+
self.mc2_mask = split_mc2_mask[self.moe_config.ep_rank]
365+
366+
return hidden_states, router_logits
367+
368+
def finalize(self, hidden_states: torch.Tensor,
369+
reduce_results: bool) -> torch.Tensor:
370+
"""Dummy finalize method that does nothing."""
371+
tp_group = get_tp_group().device_group
372+
dist.all_gather(list(self.split_hidden_states), hidden_states,
373+
tp_group)
374+
final_hidden_states = torch.cat(self.split_hidden_states, dim=0)
375+
if self.num_tokens % self.moe_config.ep_size != 0:
376+
final_hidden_states = final_hidden_states[:self.num_tokens]
377+
378+
return final_hidden_states
379+
380+
def permute(
298381
self,
299382
hidden_states: torch.Tensor,
300383
topk_ids: torch.Tensor,
@@ -305,19 +388,13 @@ def _pre_process(
305388
# Store tensors needed for post_process
306389
self.topk_ids = topk_ids
307390
self.topk_weights = topk_weights.to(torch.float32)
308-
self.mc2_mask = get_forward_context().mc2_mask
309-
310-
tp_size = get_tensor_model_parallel_world_size()
311-
split_mc2_mask = torch.tensor_split(self.mc2_mask, tp_size, dim=0)
312-
tp_rank = get_tensor_model_parallel_rank()
313-
self.mc2_mask = split_mc2_mask[tp_rank]
314391

315392
dispatch_kwargs = {
316393
"x": hidden_states,
317394
"expert_ids": self.topk_ids,
318395
"expert_shard_type": 0,
319396
"shared_expert_rank_num": 0,
320-
"moe_expert_num": self.global_num_experts,
397+
"moe_expert_num": self.moe_config.num_experts,
321398
"global_bs": 0,
322399
"scales": None,
323400
"quant_mode": 0,
@@ -352,15 +429,15 @@ def _pre_process(
352429

353430
return permuted_hidden_states, expert_tokens, group_list_type
354431

355-
def _post_process(self, mlp_output: torch.Tensor,
356-
hidden_states: torch.Tensor) -> None:
432+
def unpermute(self, mlp_output: torch.Tensor,
433+
hidden_states: torch.Tensor) -> None:
357434
combine_kwargs = {
358435
"expand_x": mlp_output,
359436
"expert_ids": self.topk_ids,
360437
"expert_scales": self.topk_weights,
361438
"expert_shard_type": 0,
362439
"shared_expert_rank_num": 0,
363-
"moe_expert_num": self.global_num_experts,
440+
"moe_expert_num": self.moe_config.num_experts,
364441
"global_bs": 0,
365442
"ep_send_counts": self.ep_recv_counts,
366443
"group_ep": self.moe_all_to_all_group_name,

0 commit comments

Comments
 (0)