@@ -1297,30 +1297,24 @@ def fused_experts_impl(hidden_states: torch.Tensor,
12971297            qintermediate_cache2  =  intermediate_cache2 
12981298            a2q_scale  =  a2_scale 
12991299
1300-         invoke_fused_moe_kernel (
1301-             qintermediate_cache2 ,
1302-             w2 ,
1303-             intermediate_cache3 ,
1304-             a2q_scale ,
1305-             w2_scale ,
1306-             w2_zp ,
1307-             curr_topk_weights ,
1308-             sorted_token_ids ,
1309-             expert_ids ,
1310-             num_tokens_post_padded ,
1311-             False ,  #True, 
1312-             1 ,
1313-             config ,
1314-             compute_type = compute_type ,
1315-             use_fp8_w8a8 = use_fp8_w8a8 ,
1316-             use_int8_w8a16 = use_int8_w8a16 ,
1317-             use_int4_w4a16 = use_int4_w4a16 ,
1318-             block_shape = block_shape )
1319- 
1320-         if  True :
1321-             intermediate_cache3  =  intermediate_cache3 .view (- 1 , top_k_num , K )
1322-             intermediate_cache3 .mul_ (
1323-                 curr_topk_weights .view (tokens_in_chunk , - 1 , 1 ))
1300+         invoke_fused_moe_kernel (qintermediate_cache2 ,
1301+                                 w2 ,
1302+                                 intermediate_cache3 ,
1303+                                 a2q_scale ,
1304+                                 w2_scale ,
1305+                                 w2_zp ,
1306+                                 curr_topk_weights ,
1307+                                 sorted_token_ids ,
1308+                                 expert_ids ,
1309+                                 num_tokens_post_padded ,
1310+                                 True ,
1311+                                 1 ,
1312+                                 config ,
1313+                                 compute_type = compute_type ,
1314+                                 use_fp8_w8a8 = use_fp8_w8a8 ,
1315+                                 use_int8_w8a16 = use_int8_w8a16 ,
1316+                                 use_int4_w4a16 = use_int4_w4a16 ,
1317+                                 block_shape = block_shape )
13241318
13251319        ops .moe_sum (intermediate_cache3 .view (* intermediate_cache3 .shape ),
13261320                    out_hidden_states [begin_chunk_idx :end_chunk_idx ])
0 commit comments