@@ -212,6 +212,94 @@ def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
212212 ).sum (dim = 1 )
213213
214214
215+ def grouped_gemm_ref (
216+ hidden_states_expanded : torch .Tensor ,
217+ hidden_states_3d : torch .Tensor ,
218+ weights : torch .Tensor ,
219+ topk_idx : torch .Tensor ,
220+ masked_m : torch .Tensor ,
221+ B : int ,
222+ topk : int ,
223+ num_experts : int ,
224+ * ,
225+ block_size : int = 16 ,
226+ ) -> torch .Tensor :
227+ """
228+ Computes the reference grouped GEMM (fp4 quantized per-expert loop),
229+ computes flashinfer grouped GEMM (for scale consistency),
230+ and returns ONLY the repacked reference output: out_ref.
231+
232+ Returns:
233+ out_ref: Tensor [num_experts, max_m, n_out]
234+ """
235+ device_hs = hidden_states_expanded .device
236+ device_w = weights .device
237+ out_dtype = weights .dtype
238+ n_out = weights .shape [1 ]
239+
240+ # Flattened reference output (B*topk, n_out)
241+ out = torch .zeros ((B * topk , n_out ), dtype = out_dtype , device = device_w )
242+
243+ # Per-expert reference compute loop
244+ for i in range (num_experts ):
245+ mask = topk_idx .view (- 1 ) == i
246+ if mask .any ():
247+ lhs = hidden_states_expanded [mask ]
248+ rhs = weights [i ]
249+
250+ a_amax = lhs .abs ().max ().to (torch .float32 ).to (device_hs )
251+ b_amax = rhs .abs ().max ().to (torch .float32 ).to (device_w )
252+
253+ a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
254+ b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
255+
256+ lhsq , lhsq_sf = fp4_quantize (lhs , a_gs )
257+ rhsq , rhsq_sf = fp4_quantize (rhs , b_gs )
258+
259+ lhs_in_dtype = dequantize_nvfp4_to_dtype (
260+ lhsq ,
261+ lhsq_sf ,
262+ a_gs ,
263+ dtype = lhs .dtype ,
264+ device = device_hs ,
265+ block_size = block_size ,
266+ )
267+ rhs_in_dtype = dequantize_nvfp4_to_dtype (
268+ rhsq ,
269+ rhsq_sf ,
270+ b_gs ,
271+ dtype = rhs .dtype ,
272+ device = device_w ,
273+ block_size = block_size ,
274+ )
275+
276+ out [mask ] = lhs_in_dtype @ rhs_in_dtype .t ()
277+
278+ # Determine per-expert max_m
279+ max_m_val = int (masked_m .max ().item ())
280+
281+ # Repack into [num_experts, max_m, n_out]
282+ out_ref = torch .zeros (
283+ (num_experts , max_m_val , n_out ),
284+ dtype = out .dtype ,
285+ device = out .device ,
286+ )
287+ expert_slot = [0 ] * num_experts
288+
289+ for i , expert_id in enumerate (topk_idx .view (- 1 ).tolist ()):
290+ slot = expert_slot [expert_id ]
291+ if slot < max_m_val :
292+ out_ref [expert_id , slot , :] = out [i ]
293+ expert_slot [expert_id ] += 1
294+ else :
295+ raise IndexError (
296+ f"Expert { expert_id } exceeded max slots ({ max_m_val } ). "
297+ "Increase max_m or check masked_m."
298+ )
299+
300+ return out_ref
301+
302+
215303def flashinfer_cutedsl_grouped_gemm_nt_masked (
216304 hidden_states : torch .Tensor , # 3d
217305 input_global_scale : torch .Tensor , # (l,)
@@ -419,7 +507,7 @@ def test_flashinfer_cutedsl_moe_masked(
419507 out .device
420508 ).unsqueeze (- 1 )
421509 torch .testing .assert_close (
422- out_weighted .cpu (), ref_output .cpu (), atol = 1e -1 , rtol = 1e -1
510+ out_weighted .cpu (), ref_output .cpu (), atol = 2e -1 , rtol = 2e -1
423511 )
424512
425513
@@ -449,48 +537,6 @@ def test_grouped_gemm_nt_masked(
449537 hidden_states_expanded , router_logits , num_experts , topk
450538 )
451539
452- # reference
453- out = torch .zeros (
454- (B * topk , weights .shape [1 ]), dtype = weights .dtype , device = weights .device
455- )
456- for i in range (num_experts ):
457- mask = topk_idx .view (- 1 ) == i
458- if mask .sum ():
459- lhs = hidden_states_expanded [mask ]
460- rhs = weights [i ]
461- a_amax = lhs .abs ().max ().to (torch .float32 ).to (hidden_states .device )
462- b_amax = rhs .abs ().max ().to (torch .float32 ).to (weights .device )
463- a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
464- b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
465-
466- lhsq , lhsq_sf = fp4_quantize (
467- lhs ,
468- a_gs ,
469- )
470- rhsq , rhsq_sf = fp4_quantize (
471- rhs ,
472- b_gs ,
473- )
474-
475- lhs_in_dtype = dequantize_nvfp4_to_dtype (
476- lhsq ,
477- lhsq_sf ,
478- a_gs ,
479- dtype = hidden_states .dtype ,
480- device = hidden_states .device ,
481- block_size = 16 ,
482- )
483-
484- rhs_in_dtype = dequantize_nvfp4_to_dtype (
485- rhsq ,
486- rhsq_sf ,
487- b_gs ,
488- dtype = hidden_states .dtype ,
489- device = hidden_states .device ,
490- block_size = 16 ,
491- )
492- out [mask ] = lhs_in_dtype @ rhs_in_dtype .t ()
493-
494540 a_amax = (
495541 hidden_states_3d .abs ()
496542 .amax (dim = (1 , 2 ))
@@ -503,16 +549,17 @@ def test_grouped_gemm_nt_masked(
503549 out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked (
504550 hidden_states_3d .to (hidden_states .device ), a_gs , weights , b_gs , masked_m
505551 )
506-
507- # re-pack out into [num_experts, max_m, n]
508- out_ref = torch .zeros (
509- (num_experts , max (masked_m ), weights .shape [1 ]), dtype = out .dtype
552+ # reference
553+ out_ref = grouped_gemm_ref (
554+ hidden_states_expanded = hidden_states_expanded ,
555+ hidden_states_3d = hidden_states_3d ,
556+ weights = weights ,
557+ topk_idx = topk_idx ,
558+ masked_m = masked_m ,
559+ B = B ,
560+ topk = topk ,
561+ num_experts = num_experts ,
510562 )
511- expert_slot = [0 ] * num_experts
512- for i , expert_id in enumerate (topk_idx .view (- 1 ).tolist ()):
513- out_ref [expert_id , expert_slot [expert_id ], :] = out [i ]
514- expert_slot [expert_id ] += 1
515-
516563 # Note: just to compare the masked position due to cutedsl may write nan
517564 # into unmasked position.
518565 for i in range (num_experts ):
0 commit comments