Skip to content

Commit 594196e

Browse files
committed
Upd
Signed-off-by: Shu Wang. <[email protected]>
1 parent b90f347 commit 594196e

File tree

4 files changed

+109
-58
lines changed

4 files changed

+109
-58
lines changed

tests/kernels/moe/test_cutedsl_moe.py

Lines changed: 99 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,94 @@ def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
212212
).sum(dim=1)
213213

214214

215+
def grouped_gemm_ref(
216+
hidden_states_expanded: torch.Tensor,
217+
hidden_states_3d: torch.Tensor,
218+
weights: torch.Tensor,
219+
topk_idx: torch.Tensor,
220+
masked_m: torch.Tensor,
221+
B: int,
222+
topk: int,
223+
num_experts: int,
224+
*,
225+
block_size: int = 16,
226+
) -> torch.Tensor:
227+
"""
228+
Computes the reference grouped GEMM (fp4 quantized per-expert loop),
229+
computes flashinfer grouped GEMM (for scale consistency),
230+
and returns ONLY the repacked reference output: out_ref.
231+
232+
Returns:
233+
out_ref: Tensor [num_experts, max_m, n_out]
234+
"""
235+
device_hs = hidden_states_expanded.device
236+
device_w = weights.device
237+
out_dtype = weights.dtype
238+
n_out = weights.shape[1]
239+
240+
# Flattened reference output (B*topk, n_out)
241+
out = torch.zeros((B * topk, n_out), dtype=out_dtype, device=device_w)
242+
243+
# Per-expert reference compute loop
244+
for i in range(num_experts):
245+
mask = topk_idx.view(-1) == i
246+
if mask.any():
247+
lhs = hidden_states_expanded[mask]
248+
rhs = weights[i]
249+
250+
a_amax = lhs.abs().max().to(torch.float32).to(device_hs)
251+
b_amax = rhs.abs().max().to(torch.float32).to(device_w)
252+
253+
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
254+
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
255+
256+
lhsq, lhsq_sf = fp4_quantize(lhs, a_gs)
257+
rhsq, rhsq_sf = fp4_quantize(rhs, b_gs)
258+
259+
lhs_in_dtype = dequantize_nvfp4_to_dtype(
260+
lhsq,
261+
lhsq_sf,
262+
a_gs,
263+
dtype=lhs.dtype,
264+
device=device_hs,
265+
block_size=block_size,
266+
)
267+
rhs_in_dtype = dequantize_nvfp4_to_dtype(
268+
rhsq,
269+
rhsq_sf,
270+
b_gs,
271+
dtype=rhs.dtype,
272+
device=device_w,
273+
block_size=block_size,
274+
)
275+
276+
out[mask] = lhs_in_dtype @ rhs_in_dtype.t()
277+
278+
# Determine per-expert max_m
279+
max_m_val = int(masked_m.max().item())
280+
281+
# Repack into [num_experts, max_m, n_out]
282+
out_ref = torch.zeros(
283+
(num_experts, max_m_val, n_out),
284+
dtype=out.dtype,
285+
device=out.device,
286+
)
287+
expert_slot = [0] * num_experts
288+
289+
for i, expert_id in enumerate(topk_idx.view(-1).tolist()):
290+
slot = expert_slot[expert_id]
291+
if slot < max_m_val:
292+
out_ref[expert_id, slot, :] = out[i]
293+
expert_slot[expert_id] += 1
294+
else:
295+
raise IndexError(
296+
f"Expert {expert_id} exceeded max slots ({max_m_val}). "
297+
"Increase max_m or check masked_m."
298+
)
299+
300+
return out_ref
301+
302+
215303
def flashinfer_cutedsl_grouped_gemm_nt_masked(
216304
hidden_states: torch.Tensor, # 3d
217305
input_global_scale: torch.Tensor, # (l,)
@@ -419,7 +507,7 @@ def test_flashinfer_cutedsl_moe_masked(
419507
out.device
420508
).unsqueeze(-1)
421509
torch.testing.assert_close(
422-
out_weighted.cpu(), ref_output.cpu(), atol=1e-1, rtol=1e-1
510+
out_weighted.cpu(), ref_output.cpu(), atol=2e-1, rtol=2e-1
423511
)
424512

425513

@@ -449,48 +537,6 @@ def test_grouped_gemm_nt_masked(
449537
hidden_states_expanded, router_logits, num_experts, topk
450538
)
451539

452-
# reference
453-
out = torch.zeros(
454-
(B * topk, weights.shape[1]), dtype=weights.dtype, device=weights.device
455-
)
456-
for i in range(num_experts):
457-
mask = topk_idx.view(-1) == i
458-
if mask.sum():
459-
lhs = hidden_states_expanded[mask]
460-
rhs = weights[i]
461-
a_amax = lhs.abs().max().to(torch.float32).to(hidden_states.device)
462-
b_amax = rhs.abs().max().to(torch.float32).to(weights.device)
463-
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
464-
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
465-
466-
lhsq, lhsq_sf = fp4_quantize(
467-
lhs,
468-
a_gs,
469-
)
470-
rhsq, rhsq_sf = fp4_quantize(
471-
rhs,
472-
b_gs,
473-
)
474-
475-
lhs_in_dtype = dequantize_nvfp4_to_dtype(
476-
lhsq,
477-
lhsq_sf,
478-
a_gs,
479-
dtype=hidden_states.dtype,
480-
device=hidden_states.device,
481-
block_size=16,
482-
)
483-
484-
rhs_in_dtype = dequantize_nvfp4_to_dtype(
485-
rhsq,
486-
rhsq_sf,
487-
b_gs,
488-
dtype=hidden_states.dtype,
489-
device=hidden_states.device,
490-
block_size=16,
491-
)
492-
out[mask] = lhs_in_dtype @ rhs_in_dtype.t()
493-
494540
a_amax = (
495541
hidden_states_3d.abs()
496542
.amax(dim=(1, 2))
@@ -503,16 +549,17 @@ def test_grouped_gemm_nt_masked(
503549
out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
504550
hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
505551
)
506-
507-
# re-pack out into [num_experts, max_m, n]
508-
out_ref = torch.zeros(
509-
(num_experts, max(masked_m), weights.shape[1]), dtype=out.dtype
552+
# reference
553+
out_ref = grouped_gemm_ref(
554+
hidden_states_expanded=hidden_states_expanded,
555+
hidden_states_3d=hidden_states_3d,
556+
weights=weights,
557+
topk_idx=topk_idx,
558+
masked_m=masked_m,
559+
B=B,
560+
topk=topk,
561+
num_experts=num_experts,
510562
)
511-
expert_slot = [0] * num_experts
512-
for i, expert_id in enumerate(topk_idx.view(-1).tolist()):
513-
out_ref[expert_id, expert_slot[expert_id], :] = out[i]
514-
expert_slot[expert_id] += 1
515-
516563
# Note: just to compare the masked position due to cutedsl may write nan
517564
# into unmasked position.
518565
for i in range(num_experts):

vllm/envs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@
156156
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
157157
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
158158
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
159-
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "cutedsl"] = "latency"
159+
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = (
160+
"latency"
161+
)
160162
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
161163
VLLM_XGRAMMAR_CACHE_MB: int = 0
162164
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
@@ -1228,7 +1230,9 @@ def get_vllm_port() -> int | None:
12281230
# - "latency":
12291231
# Uses TensorRT-LLM kernels optimized for low-latency inference.
12301232
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
1231-
"VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency", "cutedsl"]
1233+
"VLLM_FLASHINFER_MOE_BACKEND",
1234+
"latency",
1235+
["throughput", "latency", "masked_gemm"],
12321236
),
12331237
# Control the workspace buffer size for the FlashInfer backend.
12341238
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int(

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,10 @@ def _do_quant(
143143
x = x.view((-1, hidden_dim))
144144
q_dtype = quant_config.quant_dtype
145145

146-
if envs.VLLM_FLASHINFER_MOE_BACKEND == "cutedsl":
146+
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
147147
logger.info_once(
148-
"Skip quantization when using FlashInfer CUTEDSL for "
149-
"ModelOptNvFp4FusedMoE."
148+
"Skip quantization when using FlashInfer CUTEDSL(masked_gemm) "
149+
"for ModelOptNvFp4FusedMoE."
150150
)
151151
q_dtype = None
152152

vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def select_nvfp4_gemm_impl(
8989
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
9090

9191
if allow_flashinfer:
92-
if envs.VLLM_FLASHINFER_MOE_BACKEND == "cutedsl":
92+
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
9393
return FlashInferCuteDSLExperts(
9494
out_dtype=moe.in_dtype,
9595
quant_config=moe_quant_config,

0 commit comments

Comments
 (0)