Skip to content

Commit 4d7c007

Browse files
committed
feat(moe): Add MC2 communication method for MoE layers
This method replaces the previous all-gather approach for small numbers of tokens. The key changes include: - A new `AscendFusedMoE` layer that handles token splitting, local computation, and final aggregation via all-gather. - Logic in the model runner to dynamically select between the new MC2 method and the existing all-gather method based on the number of input tokens. - Sharding the MoE communication mask across tensor-parallel ranks. Signed-off-by: Yizhou Liu <[email protected]>
1 parent 67a222c commit 4d7c007

File tree

4 files changed

+143
-6
lines changed

4 files changed

+143
-6
lines changed

vllm_ascend/distributed/moe_comm_method.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import torch
44
import torch_npu
55
from transformers.configuration_utils import PretrainedConfig
6-
from vllm.distributed.parallel_state import get_ep_group, get_tp_group
6+
from vllm.distributed.parallel_state import (
7+
get_ep_group, get_tensor_model_parallel_rank,
8+
get_tensor_model_parallel_world_size, get_tp_group)
79
from vllm.forward_context import ForwardContext, get_forward_context
810
from vllm.utils import direct_register_custom_op
911

@@ -305,6 +307,11 @@ def _pre_process(
305307
self.topk_weights = topk_weights.to(torch.float32)
306308
self.mc2_mask = get_forward_context().mc2_mask
307309

310+
tp_size = get_tensor_model_parallel_world_size()
311+
split_mc2_mask = torch.tensor_split(self.mc2_mask, tp_size, dim=0)
312+
tp_rank = get_tensor_model_parallel_rank()
313+
self.mc2_mask = split_mc2_mask[tp_rank]
314+
308315
dispatch_kwargs = {
309316
"x": hidden_states,
310317
"expert_ids": self.topk_ids,

vllm_ascend/ops/common_fused_moe.py

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,16 @@
1818
from typing import Callable, Optional
1919

2020
import torch
21+
import torch.distributed as dist
22+
from torch import nn
2123
from vllm.config import CompilationLevel, get_current_vllm_config
24+
from vllm.distributed import get_tp_group
2225
from vllm.forward_context import get_forward_context
23-
from vllm.model_executor.layers.fused_moe.layer import \
24-
UnquantizedFusedMoEMethod
26+
from vllm.model_executor.layers.fused_moe.layer import (
27+
FusedMoE, UnquantizedFusedMoEMethod)
2528

2629
from vllm_ascend.ascend_config import get_ascend_config
30+
from vllm_ascend.distributed.moe_comm_method import MC2CommImpl
2731
from vllm_ascend.ops.fused_moe import fused_experts_moge, unified_fused_experts
2832
from vllm_ascend.ops.layers.experts_selector import select_experts
2933
from vllm_ascend.utils import is_310p
@@ -109,5 +113,123 @@ def forward_oot(
109113
)
110114

111115

116+
class AscendFusedMoE(FusedMoE):
117+
118+
def __init__(
119+
self,
120+
num_experts,
121+
top_k,
122+
hidden_size,
123+
intermediate_size,
124+
params_dtype=None,
125+
reduce_results=False,
126+
renormalize=True,
127+
use_grouped_topk=False,
128+
num_expert_group=None,
129+
topk_group=None,
130+
quant_config=None,
131+
tp_size=None,
132+
ep_size=None,
133+
dp_size=None,
134+
prefix="",
135+
custom_routing_function=None,
136+
scoring_func="softmax",
137+
e_score_correction_bias=None,
138+
apply_router_weight_on_input=False,
139+
activation="silu",
140+
enable_eplb=False,
141+
num_redundant_experts=0,
142+
has_bias=False,
143+
):
144+
super().__init__(
145+
num_experts,
146+
top_k,
147+
hidden_size,
148+
intermediate_size,
149+
params_dtype,
150+
reduce_results,
151+
renormalize,
152+
use_grouped_topk,
153+
num_expert_group,
154+
topk_group,
155+
quant_config,
156+
tp_size,
157+
ep_size,
158+
dp_size,
159+
prefix,
160+
custom_routing_function,
161+
scoring_func,
162+
e_score_correction_bias,
163+
apply_router_weight_on_input,
164+
activation,
165+
enable_eplb,
166+
num_redundant_experts,
167+
has_bias,
168+
)
169+
170+
self.tp_group = get_tp_group().device_group
171+
172+
def forward_impl(self, hidden_states: torch.Tensor,
173+
router_logits: torch.Tensor):
174+
assert self.quant_method is not None
175+
176+
num_tokens, _ = hidden_states.shape
177+
forward_context = get_forward_context()
178+
179+
moe_comm_method = forward_context.moe_comm_method
180+
if type(moe_comm_method) is MC2CommImpl:
181+
# NOTE: Pad tensors to make sure they can be evenly split.
182+
if num_tokens % self.ep_size != 0:
183+
pad_size = self.ep_size - (num_tokens % self.ep_size)
184+
hidden_states = nn.functional.pad(hidden_states,
185+
(0, 0, 0, pad_size))
186+
router_logits = nn.functional.pad(router_logits,
187+
(0, 0, 0, pad_size))
188+
189+
split_hidden_states = torch.tensor_split(hidden_states,
190+
self.ep_size,
191+
dim=0)
192+
split_router_logits = torch.tensor_split(router_logits,
193+
self.ep_size,
194+
dim=0)
195+
hidden_states = split_hidden_states[self.ep_rank]
196+
router_logits = split_router_logits[self.ep_rank]
197+
198+
# Matrix multiply.
199+
final_hidden_states = self.quant_method.apply(
200+
layer=self,
201+
x=hidden_states,
202+
router_logits=router_logits,
203+
top_k=self.top_k,
204+
renormalize=self.renormalize,
205+
use_grouped_topk=self.use_grouped_topk,
206+
global_num_experts=self.global_num_experts,
207+
expert_map=self.expert_map,
208+
topk_group=self.topk_group,
209+
num_expert_group=self.num_expert_group,
210+
custom_routing_function=self.custom_routing_function,
211+
scoring_func=self.scoring_func,
212+
e_score_correction_bias=self.e_score_correction_bias,
213+
activation=self.activation,
214+
apply_router_weight_on_input=self.apply_router_weight_on_input,
215+
enable_eplb=self.enable_eplb,
216+
expert_load_view=self.expert_load_view,
217+
logical_to_physical_map=self.logical_to_physical_map,
218+
logical_replica_count=self.logical_replica_count,
219+
)
220+
221+
if type(moe_comm_method) is MC2CommImpl:
222+
dist.all_gather(list(split_hidden_states), hidden_states,
223+
self.tp_group)
224+
final_hidden_states = torch.cat(split_hidden_states, dim=0)
225+
if num_tokens % self.ep_size != 0:
226+
final_hidden_states = final_hidden_states[:num_tokens]
227+
elif self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
228+
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
229+
final_hidden_states)
230+
231+
return final_hidden_states
232+
233+
112234
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
113235
UnquantizedFusedMoEMethod.forward_oot = forward_oot

vllm_ascend/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,9 @@ def register_ascend_customop():
493493
from vllm_ascend.ops.layernorm import AscendRMSNorm
494494
CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm")
495495

496+
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
497+
CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE")
498+
496499
# NOTE: Keep this at last to ensure all custom actions are registered
497500
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
498501

vllm_ascend/worker/model_runner_v1.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
8888
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
8989
DummyCommImpl,
90+
MC2CommImpl,
9091
MoECommMethod)
9192
from vllm_ascend.multistream.ms_split import compute_split_seq_index
9293
from vllm_ascend.platform import NPUPlatform
@@ -360,13 +361,14 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
360361
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
361362
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
362363

364+
self.mc2_tokens_capacity = 512 * self.parallel_config.tensor_parallel_size
363365
self.reserved_mc2_mask = torch.zeros(
364-
512,
366+
self.mc2_tokens_capacity,
365367
dtype=torch.bool,
366368
device=self.device,
367369
)
368370

369-
self.moe_comm_method = AllGatherCommImpl
371+
self.moe_comm_method = MC2CommImpl
370372

371373
def _use_aclgraph(self) -> bool:
372374
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
@@ -1557,6 +1559,9 @@ def execute_model(
15571559
intermediate_tensors) = (self._prepare_inputs(
15581560
scheduler_output, intermediate_tensors))
15591561

1562+
moe_comm_method = (self.moe_comm_method if num_input_tokens
1563+
<= self.mc2_tokens_capacity else AllGatherCommImpl)
1564+
15601565
# Run forward pass
15611566
with ProfileExecuteDuration().capture_async("forward"):
15621567
with set_ascend_forward_context(
@@ -1566,7 +1571,7 @@ def execute_model(
15661571
num_tokens_across_dp=num_tokens_across_dp,
15671572
with_prefill=self.with_prefill,
15681573
reserved_mc2_mask=self.reserved_mc2_mask,
1569-
moe_comm_method=self.moe_comm_method(
1574+
moe_comm_method=moe_comm_method(
15701575
self.device, self.dtype, self.model_config.hf_config),
15711576
num_actual_tokens=scheduler_output.
15721577
total_num_scheduled_tokens):

0 commit comments

Comments
 (0)