@@ -939,23 +939,26 @@ def grouped_topk(hidden_states: torch.Tensor,
939939 else :
940940 raise ValueError (f"Unsupported scoring function: { scoring_func } " )
941941
942+ num_token = scores .shape [0 ]
942943 if e_score_correction_bias is not None :
943944 # Store original scores before applying correction bias. We use biased
944945 # scores for expert selection but original scores for routing weights
945946 original_scores = scores
946947 scores = scores + e_score_correction_bias .unsqueeze (0 )
947-
948- num_token = scores .shape [0 ]
949- group_scores = scores .view (num_token , num_expert_group ,
950- - 1 ).max (dim = - 1 ).values # [n, n_group]
948+ group_scores = (scores .view (num_token , num_expert_group ,
949+ - 1 ).topk (2 , dim = - 1 )[0 ].sum (dim = - 1 ))
950+ else :
951+ group_scores = scores .view (num_token , num_expert_group ,
952+ - 1 ).max (dim = - 1 ).values # [n, n_group]
951953 group_idx = torch .topk (group_scores , k = topk_group , dim = - 1 ,
952954 sorted = False )[1 ] # [n, top_k_group]
953955 group_mask = torch .zeros_like (group_scores ) # [n, n_group]
954956 group_mask .scatter_ (1 , group_idx , 1 ) # [n, n_group]
955957 score_mask = group_mask .unsqueeze (- 1 ).expand (
956958 num_token , num_expert_group ,
957959 scores .shape [- 1 ] // num_expert_group ).reshape (num_token , - 1 ) # [n, e]
958- tmp_scores = scores .masked_fill (~ score_mask .bool (), 0.0 ) # [n, e]
960+ tmp_scores = scores .masked_fill (~ score_mask .bool (),
961+ float ("-inf" )) # [n, e]
959962
960963 if e_score_correction_bias is not None :
961964 topk_ids = torch .topk (tmp_scores , k = topk , dim = - 1 , sorted = False )[1 ]
0 commit comments