@@ -134,6 +134,7 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs):
134134 self .use_aclgraph = (vllm_config .compilation_config .level
135135 == CompilationLevel .PIECEWISE
136136 and not vllm_config .model_config .enforce_eager )
137+ self .transpose = True
137138
138139
139140def forward_oot_v01011 (
@@ -261,13 +262,22 @@ def forward_oot(
261262
262263def process_weights_after_loading (self , layer ):
263264 super (UnquantizedFusedMoEMethod , self ).process_weights_after_loading (layer )
264- w13_data = self ._maybe_pad_weight (layer .w13_weight .data ).transpose (
265- 1 , 2 ).contiguous ()
266- layer .w13_weight = torch .nn .Parameter (w13_data , requires_grad = False )
265+ if self .transpose :
266+ w13_data = self ._maybe_pad_weight (layer .w13_weight .data ).transpose (
267+ 1 , 2 ).contiguous ()
268+ layer .w13_weight = torch .nn .Parameter (w13_data , requires_grad = False )
267269
268- w2_data = self ._maybe_pad_weight (layer .w2_weight .data ).transpose (
269- 1 , 2 ).contiguous ()
270- layer .w2_weight = torch .nn .Parameter (w2_data , requires_grad = False )
270+ w2_data = self ._maybe_pad_weight (layer .w2_weight .data ).transpose (
271+ 1 , 2 ).contiguous ()
272+ layer .w2_weight = torch .nn .Parameter (w2_data , requires_grad = False )
273+
274+ self .transpose = False
275+ else :
276+ w13_data = self ._maybe_pad_weight (layer .w13_weight .data )
277+ layer .w13_weight = torch .nn .Parameter (w13_data , requires_grad = False )
278+
279+ w2_data = self ._maybe_pad_weight (layer .w2_weight .data )
280+ layer .w2_weight = torch .nn .Parameter (w2_data , requires_grad = False )
271281
272282 if not is_310p ():
273283 layer .w13_weight .data = torch_npu .npu_format_cast (
@@ -358,12 +368,11 @@ def __init__(
358368 num_redundant_experts ,
359369 has_bias ,
360370 )
361-
362371 setup_token_dispatchers (self .moe_config .ep_size ,
363372 top_k = self .top_k ,
364373 num_experts = self .global_num_experts ,
365374 num_local_experts = self .local_num_experts )
366-
375+ self . hidden_size = hidden_size
367376 self .moe_config .tp_group = get_tp_group ()
368377 self .moe_config .dp_group = get_dp_group ()
369378 self .moe_config .ep_group = get_ep_group ()
@@ -430,6 +439,61 @@ def forward_impl(self, hidden_states: torch.Tensor,
430439
431440 return final_hidden_states
432441
442+ def transpose_weight (self , loaded_weight , expert_data , shard_dim ):
443+ # Ensure training and inference weight shapes match during RL weight updates
444+ if (
445+ loaded_weight .shape [1 ] != expert_data .shape [1 ] and \
446+ loaded_weight .shape [0 ] != expert_data .shape [0 ]
447+ ):
448+ shard_dim = int (not shard_dim )
449+ loaded_weight = loaded_weight .transpose (0 , 1 ).contiguous ()
450+ return loaded_weight , shard_dim
451+
452+ def _load_w13 (self ,
453+ expert_data : torch .Tensor ,
454+ shard_dim : int ,
455+ shard_id : str ,
456+ loaded_weight : torch .Tensor ,
457+ tp_rank : int ,
458+ load_full : bool = False ):
459+ # Index the loaded weight for tp sharding.
460+ # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
461+ loaded_weight , shard_dim = self .transpose_weight (
462+ loaded_weight , expert_data , shard_dim )
463+ shard_size = expert_data .shape [shard_dim ] // 2
464+ if not load_full :
465+ loaded_weight = loaded_weight .narrow (shard_dim ,
466+ shard_size * tp_rank ,
467+ shard_size )
468+ # Narrow parameter and load.
469+ # w1, gate_proj: Load into first logical weight of w13.
470+ if shard_id == "w1" :
471+ expert_data = expert_data .narrow (shard_dim , 0 , shard_size )
472+ # w3, up_proj: Load into second logical weight of w13.
473+ else :
474+ assert shard_id == "w3"
475+ expert_data = expert_data .narrow (shard_dim , shard_size , shard_size )
476+ expert_data .copy_ (loaded_weight )
477+
478+ def _load_w2 (self ,
479+ expert_data : torch .Tensor ,
480+ shard_dim : int ,
481+ loaded_weight : torch .Tensor ,
482+ tp_rank : int ,
483+ load_full : bool = False ):
484+ # Index the loaded weight for tp sharding.
485+ # down_proj: "RowParallel" so tp sharding on input_dim
486+ # Narrow parameter and load.
487+ loaded_weight , shard_dim = self .transpose_weight (
488+ loaded_weight , expert_data , shard_dim )
489+ shard_size = expert_data .shape [shard_dim ]
490+ if not load_full :
491+ loaded_weight = loaded_weight .narrow (shard_dim ,
492+ shard_size * tp_rank ,
493+ shard_size )
494+ # w2, down_proj: Load into only logical weight of w2.
495+ expert_data .copy_ (loaded_weight )
496+
433497
434498class AscendSharedFusedMoE (AscendFusedMoE ):
435499
0 commit comments