11from abc import ABC , abstractmethod
2+ from typing import Optional
23
34import torch
5+ import torch .distributed as dist
6+ import torch .nn as nn
47import torch_npu
5- from transformers .configuration_utils import PretrainedConfig
6- from vllm .distributed .parallel_state import (
7- get_ep_group , get_tensor_model_parallel_rank ,
8- get_tensor_model_parallel_world_size , get_tp_group )
8+ from vllm .distributed import tensor_model_parallel_all_reduce
9+ from vllm .distributed .parallel_state import get_ep_group , get_tp_group
910from vllm .forward_context import ForwardContext , get_forward_context
11+ from vllm .model_executor .layers .fused_moe import FusedMoEConfig
1012from vllm .utils import direct_register_custom_op
1113
1214from vllm_ascend .distributed .parallel_state import get_mc2_group
1618class MoECommMethod (ABC ):
1719 """Base class for MoE communication methods."""
1820
19- def __init__ (
20- self ,
21- device : torch .device ,
22- dtype : torch .dtype ,
23- hf_config : PretrainedConfig ,
24- ):
25- self .device = device
26- self .dtype = dtype
27- self .top_k_num = getattr (hf_config , "num_experts_per_tok" , 0 )
28- # global_num_experts may be called num_experts or n_routed_experts in different models.
29- possible_keys = ["num_experts" , "n_routed_experts" ]
30- for key in possible_keys :
31- if hasattr (hf_config , key ):
32- self .global_num_experts = getattr (hf_config , key )
33- break
34- else :
35- self .global_num_experts = 0
21+ moe_config : FusedMoEConfig = None
22+
23+ def __init__ (self , moe_config : Optional [FusedMoEConfig ]):
24+ self .moe_config = moe_config
25+
26+ @abstractmethod
27+ def prepare (
28+ self , hidden_states : torch .Tensor ,
29+ router_logits : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
30+ """Prepare the MoE communication method.
31+
32+ This method is called before quant_method.apply to prepare the
33+ communication method. It can be used to initialize any necessary
34+ resources or configurations.
35+ """
36+ pass
37+
38+ @abstractmethod
39+ def finalize (self , hidden_states : torch .Tensor ,
40+ reduce_results : bool ) -> torch .Tensor :
41+ """Finalize the MoE communication method.
42+
43+ This method is called after quant_method.apply to finalize the
44+ communication method. It can be used to clean up any resources or
45+ configurations.
46+ """
47+ pass
3648
3749 @abstractmethod
38- def _pre_process (
50+ def permute (
3951 self ,
4052 hidden_states : torch .Tensor ,
4153 topk_ids : torch .Tensor ,
@@ -69,8 +81,8 @@ def _pre_process(
6981 pass
7082
7183 @abstractmethod
72- def _post_process (self , mlp_output : torch .Tensor ,
73- hidden_states : torch .Tensor ) -> None :
84+ def unpermute (self , mlp_output : torch .Tensor ,
85+ hidden_states : torch .Tensor ) -> None :
7486 """Post-process after MLP.
7587
7688 Args:
@@ -84,7 +96,18 @@ def _post_process(self, mlp_output: torch.Tensor,
8496
8597class DummyCommImpl (MoECommMethod ):
8698
87- def _pre_process (
99+ def prepare (
100+ self , hidden_states : torch .Tensor ,
101+ router_logits : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
102+ """Dummy prepare method that does nothing."""
103+ return hidden_states , router_logits
104+
105+ def finalize (self , hidden_states : torch .Tensor ,
106+ reduce_results : bool ) -> torch .Tensor :
107+ """Dummy finalize method that does nothing."""
108+ return hidden_states
109+
110+ def permute (
88111 self ,
89112 hidden_states : torch .Tensor ,
90113 topk_ids : torch .Tensor ,
@@ -96,8 +119,8 @@ def _pre_process(
96119 return moe_comm_pre_process_fake (hidden_states , topk_ids , topk_weights ,
97120 expert_map , num_experts )
98121
99- def _post_process (self , mlp_output : torch .Tensor ,
100- hidden_states : torch .Tensor ) -> None :
122+ def unpermute (self , mlp_output : torch .Tensor ,
123+ hidden_states : torch .Tensor ) -> None :
101124 """Dummy implementation that does nothing."""
102125 pass
103126
@@ -110,7 +133,22 @@ class NativeAllGatherCommImpl(MoECommMethod):
110133 But it is a good fallback for scenarios where NPU-specific ops are not available.
111134 """
112135
113- def _pre_process (
136+ def prepare (
137+ self , hidden_states : torch .Tensor ,
138+ router_logits : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
139+ """Dummy prepare method that does nothing."""
140+ return hidden_states , router_logits
141+
142+ def finalize (self , hidden_states : torch .Tensor ,
143+ reduce_results : bool ) -> torch .Tensor :
144+ if reduce_results and (self .moe_config .tp_size > 1
145+ or self .moe_config .ep_size > 1 ):
146+ final_hidden_states = tensor_model_parallel_all_reduce (
147+ hidden_states )
148+
149+ return final_hidden_states
150+
151+ def permute (
114152 self ,
115153 hidden_states : torch .Tensor ,
116154 topk_ids : torch .Tensor ,
@@ -122,10 +160,10 @@ def _pre_process(
122160
123161 # Generate token indices and flatten
124162 token_indices = torch .arange (num_tokens ,
125- device = self .device ,
163+ device = hidden_states .device ,
126164 dtype = torch .int64 )
127165 token_indices = (token_indices .unsqueeze (1 ).expand (
128- - 1 , self .top_k_num ).reshape (- 1 ))
166+ - 1 , self .moe_config . experts_per_token ).reshape (- 1 ))
129167
130168 # Flatten token-to-expert mappings and map to local experts
131169 weights_flat = topk_weights .view (- 1 )
@@ -140,7 +178,7 @@ def _pre_process(
140178 # This is a workaround and should be removed after the issue is fixed
141179 filtered_weights = torch .where (mask , weights_flat ,
142180 torch .zeros_like (weights_flat )).to (
143- self .dtype )
181+ topk_weights .dtype )
144182 filtered_experts = torch .where (
145183 mask ,
146184 local_experts_flat ,
@@ -156,7 +194,7 @@ def _pre_process(
156194 # This is equivalent to but faster than:
157195 # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
158196 token_counts = torch .zeros (num_experts + 1 ,
159- device = self .device ,
197+ device = hidden_states .device ,
160198 dtype = torch .int64 )
161199 ones = torch .ones_like (filtered_experts , dtype = torch .int64 )
162200 token_counts .scatter_add_ (0 , filtered_experts .to (torch .int64 ), ones )
@@ -169,8 +207,8 @@ def _pre_process(
169207
170208 return permuted_hidden_states , expert_tokens , group_list_type
171209
172- def _post_process (self , mlp_output : torch .Tensor ,
173- hidden_states : torch .Tensor ) -> None :
210+ def unpermute (self , mlp_output : torch .Tensor ,
211+ hidden_states : torch .Tensor ) -> None :
174212 mlp_output = mlp_output * self .sorted_weights .unsqueeze (1 )
175213
176214 final_hidden_states = torch .zeros_like (hidden_states )
@@ -199,7 +237,22 @@ class AllGatherCommImpl(MoECommMethod):
199237 This is a workaround and should be removed after the issue is fixed.
200238 """
201239
202- def _pre_process (
240+ def prepare (
241+ self , hidden_states : torch .Tensor ,
242+ router_logits : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
243+ """Dummy prepare method that does nothing."""
244+ return hidden_states , router_logits
245+
246+ def finalize (self , hidden_states : torch .Tensor ,
247+ reduce_results : bool ) -> torch .Tensor :
248+ if reduce_results and (self .moe_config .tp_size > 1
249+ or self .moe_config .ep_size > 1 ):
250+ final_hidden_states = tensor_model_parallel_all_reduce (
251+ hidden_states )
252+
253+ return final_hidden_states
254+
255+ def permute (
203256 self ,
204257 hidden_states : torch .Tensor ,
205258 topk_ids : torch .Tensor ,
@@ -229,8 +282,8 @@ def _pre_process(
229282 torch_npu .npu_moe_init_routing_v2 (
230283 hidden_states ,
231284 topk_ids ,
232- active_num = num_tokens * self .top_k_num ,
233- expert_num = self .global_num_experts ,
285+ active_num = num_tokens * self .moe_config . experts_per_token ,
286+ expert_num = self .moe_config . num_experts ,
234287 expert_tokens_num_type = 1 , # Only support `count` mode now
235288 expert_tokens_num_flag = True , # Output `expert_tokens`
236289 active_expert_range = [first_expert_idx , last_expert_idx ],
@@ -243,8 +296,8 @@ def _pre_process(
243296
244297 return permuted_hidden_states , expert_tokens , group_list_type
245298
246- def _post_process (self , mlp_output : torch .Tensor ,
247- hidden_states : torch .Tensor ) -> None :
299+ def unpermute (self , mlp_output : torch .Tensor ,
300+ hidden_states : torch .Tensor ) -> None :
248301 hidden_states [:] = torch_npu .npu_moe_token_unpermute (
249302 permuted_tokens = mlp_output ,
250303 sorted_indices = self .expanded_row_idx ,
@@ -261,19 +314,13 @@ class MC2CommImpl(MoECommMethod):
261314 Communication and Computation parallelism on Ascend devices.
262315 """
263316
264- def __init__ (
265- self ,
266- device : torch .device ,
267- dtype : torch .dtype ,
268- hf_config : PretrainedConfig ,
269- ):
270- super ().__init__ (device , dtype , hf_config )
317+ def __init__ (self , moe_config : Optional [FusedMoEConfig ]):
318+ super ().__init__ (moe_config )
271319
272320 # Shared communication configurations
273321 ep_group = get_mc2_group ()
274322 self .ep_rank_id = ep_group .rank_in_group
275323 self .ep_world_size = ep_group .world_size
276- self .tp_world_size = get_tp_group ().world_size
277324
278325 device_group = ep_group .device_group
279326 local_rank = torch .distributed .get_rank (group = device_group )
@@ -286,15 +333,51 @@ def __init__(
286333 self .is_ascend_a3 = get_ascend_soc_version () == AscendSocVersion .A3
287334 self .need_extra_args = self .is_ascend_a3 # or is_torchair
288335
289- # Intermediate tensors to be passed from pre_process to post_process
290- self .topk_ids = None
291- self .topk_weights = None
292- self .mc2_mask = None
293- self .assist_info_for_combine = None
294- self .ep_recv_counts = None
295- self .tp_recv_counts = None
336+ def prepare (
337+ self , hidden_states : torch .Tensor ,
338+ router_logits : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
339+ num_tokens , _ = hidden_states .shape
340+ self .mc2_mask = get_forward_context ().mc2_mask
296341
297- def _pre_process (
342+ if num_tokens % self .moe_config .ep_size != 0 :
343+ pad_size = self .moe_config .ep_size - (num_tokens %
344+ self .moe_config .ep_size )
345+ hidden_states = nn .functional .pad (hidden_states ,
346+ (0 , 0 , 0 , pad_size ))
347+ router_logits = nn .functional .pad (router_logits ,
348+ (0 , 0 , 0 , pad_size ))
349+
350+ split_hidden_states = torch .tensor_split (hidden_states ,
351+ self .moe_config .ep_size ,
352+ dim = 0 )
353+ split_router_logits = torch .tensor_split (router_logits ,
354+ self .moe_config .ep_size ,
355+ dim = 0 )
356+ split_mc2_mask = torch .tensor_split (self .mc2_mask ,
357+ self .moe_config .ep_size ,
358+ dim = 0 )
359+ self .num_tokens = num_tokens
360+ self .split_hidden_states = split_hidden_states
361+
362+ hidden_states = split_hidden_states [self .moe_config .ep_rank ]
363+ router_logits = split_router_logits [self .moe_config .ep_rank ]
364+ self .mc2_mask = split_mc2_mask [self .moe_config .ep_rank ]
365+
366+ return hidden_states , router_logits
367+
368+ def finalize (self , hidden_states : torch .Tensor ,
369+ reduce_results : bool ) -> torch .Tensor :
370+ """Dummy finalize method that does nothing."""
371+ tp_group = get_tp_group ().device_group
372+ dist .all_gather (list (self .split_hidden_states ), hidden_states ,
373+ tp_group )
374+ final_hidden_states = torch .cat (self .split_hidden_states , dim = 0 )
375+ if self .num_tokens % self .moe_config .ep_size != 0 :
376+ final_hidden_states = final_hidden_states [:self .num_tokens ]
377+
378+ return final_hidden_states
379+
380+ def permute (
298381 self ,
299382 hidden_states : torch .Tensor ,
300383 topk_ids : torch .Tensor ,
@@ -305,19 +388,13 @@ def _pre_process(
305388 # Store tensors needed for post_process
306389 self .topk_ids = topk_ids
307390 self .topk_weights = topk_weights .to (torch .float32 )
308- self .mc2_mask = get_forward_context ().mc2_mask
309-
310- tp_size = get_tensor_model_parallel_world_size ()
311- split_mc2_mask = torch .tensor_split (self .mc2_mask , tp_size , dim = 0 )
312- tp_rank = get_tensor_model_parallel_rank ()
313- self .mc2_mask = split_mc2_mask [tp_rank ]
314391
315392 dispatch_kwargs = {
316393 "x" : hidden_states ,
317394 "expert_ids" : self .topk_ids ,
318395 "expert_shard_type" : 0 ,
319396 "shared_expert_rank_num" : 0 ,
320- "moe_expert_num" : self .global_num_experts ,
397+ "moe_expert_num" : self .moe_config . num_experts ,
321398 "global_bs" : 0 ,
322399 "scales" : None ,
323400 "quant_mode" : 0 ,
@@ -352,15 +429,15 @@ def _pre_process(
352429
353430 return permuted_hidden_states , expert_tokens , group_list_type
354431
355- def _post_process (self , mlp_output : torch .Tensor ,
356- hidden_states : torch .Tensor ) -> None :
432+ def unpermute (self , mlp_output : torch .Tensor ,
433+ hidden_states : torch .Tensor ) -> None :
357434 combine_kwargs = {
358435 "expand_x" : mlp_output ,
359436 "expert_ids" : self .topk_ids ,
360437 "expert_scales" : self .topk_weights ,
361438 "expert_shard_type" : 0 ,
362439 "shared_expert_rank_num" : 0 ,
363- "moe_expert_num" : self .global_num_experts ,
440+ "moe_expert_num" : self .moe_config . num_experts ,
364441 "global_bs" : 0 ,
365442 "ep_send_counts" : self .ep_recv_counts ,
366443 "group_ep" : self .moe_all_to_all_group_name ,
0 commit comments