Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,8 @@ def dtype_str_to_torch_dtype(dtype_str):
"9.0": [],
"10.0": ["cute-dsl"],
"10.3": ["cute-dsl"],
"12.0": [],
"12.1": [],
"12.0": ["cute-dsl"],
"12.1": ["cute-dsl"],
},
# NORM
"rmsnorm": {
Expand Down
96 changes: 80 additions & 16 deletions benchmarks/routines/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
cutlass_fused_moe,
fused_topk_deepseek,
)
from flashinfer.fused_moe.core import RoutingMethodType
from flashinfer.tllm_enums import RoutingMethodType
from flashinfer import fp4_quantize, mxfp8_quantize
from flashinfer.testing.utils import (
bench_gpu_time,
Expand Down Expand Up @@ -235,6 +235,18 @@ def parse_moe_args(line, parser):
),
)

# CuTe DSL MoE specific
parser.add_argument(
"--use_functional_api",
action="store_true",
default=False,
help=(
"Use cute_dsl_fused_moe_nvfp4 functional API instead of CuteDslMoEWrapper "
"for cute_dsl_fp4_block_scale_moe benchmark. Useful for verifying that the "
"workspace cache eliminates per-call allocation overhead."
),
)

# CUTLASS fused MoE specific
parser.add_argument(
"--cutlass_variant",
Expand Down Expand Up @@ -1196,7 +1208,9 @@ def _create_cute_dsl_moe_test_data(
routing_weights, selected_experts = compute_routing(routing_logits, top_k)
selected_experts = selected_experts.to(torch.int32)

# GEMM1 weights (gate + up, interleaved for CuteDSL SwiGLU)
# GEMM1 weights (gate + up)
# SM100/103: interleaved in 64-row groups for CuTe DSL SwiGLU epilogue
# SM120/121: non-interleaved [up_0:N, gate_0:N] for b12x fused kernel
w1_bf16 = (
torch.randn(
num_local_experts,
Expand All @@ -1207,9 +1221,13 @@ def _create_cute_dsl_moe_test_data(
)
/ 10
)
w1_bf16_interleaved = _interleave_linear_and_gate(w1_bf16, group_size=64, dim=1)
sm_major = torch.cuda.get_device_capability(device)[0]
if sm_major == 12:
w1_bf16_prepared = w1_bf16 # SM120: non-interleaved
else:
w1_bf16_prepared = _interleave_linear_and_gate(w1_bf16, group_size=64, dim=1)
w1_gs = torch.tensor([1.0], device=device, dtype=torch.float32)
w1_flat = w1_bf16_interleaved.view(
w1_flat = w1_bf16_prepared.view(
num_local_experts * 2 * intermediate_size, hidden_size
)
w1_q_flat, w1_sf_flat = fp4_quantize(
Expand Down Expand Up @@ -1257,6 +1275,7 @@ def _create_cute_dsl_moe_test_data(

return {
"x": x_quantized,
"x_bf16": x_bf16,
"x_sf": x_sf,
"token_selected_experts": selected_experts,
"token_final_scales": routing_weights,
Expand Down Expand Up @@ -1338,16 +1357,61 @@ def testCuteDslFp4BlockScaleMoe(args):
print(f"[VVERBOSE] w1_weight.shape = {tensors['w1_weight'].shape}")
print(f"[VVERBOSE] w2_weight.shape = {tensors['w2_weight'].shape}")

moe = CuteDslMoEWrapper(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
use_cuda_graph=is_cuda_graph_compatible,
max_num_tokens=num_tokens,
num_local_experts=local_num_experts,
local_expert_offset=local_expert_offset,
)
use_functional = getattr(args, "use_functional_api", False)

# SM120 passes bf16 as x (kernel fuses quantization); SM100 passes FP4.
sm_major_bm = torch.cuda.get_device_capability(device)[0]
x_input = tensors["x_bf16"] if sm_major_bm == 12 else tensors["x"]

if use_functional:
from flashinfer import cute_dsl_fused_moe_nvfp4
from functools import partial

if args.verbose >= 1:
print(
"[INFO] Using functional API (cute_dsl_fused_moe_nvfp4) with workspace cache"
)

# Pre-allocate output buffer to avoid per-call allocation
moe_output = torch.empty(
num_tokens, hidden_size, dtype=torch.bfloat16, device=device
)

runner = partial(
cute_dsl_fused_moe_nvfp4,
num_experts=num_experts,
top_k=top_k,
num_local_experts=local_num_experts,
local_expert_offset=local_expert_offset,
moe_output=moe_output,
)

# Warmup call to populate workspace cache before timed region
runner(
x=x_input,
x_sf=tensors["x_sf"],
token_selected_experts=tensors["token_selected_experts"],
token_final_scales=tensors["token_final_scales"],
w1_weight=tensors["w1_weight"],
w1_weight_sf=tensors["w1_weight_sf"],
w1_alpha=tensors["w1_alpha"],
fc2_input_scale=tensors["fc2_input_scale"],
w2_weight=tensors["w2_weight"],
w2_weight_sf=tensors["w2_weight_sf"],
w2_alpha=tensors["w2_alpha"],
)
else:
moe = CuteDslMoEWrapper(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
use_cuda_graph=is_cuda_graph_compatible,
max_num_tokens=num_tokens,
num_local_experts=local_num_experts,
local_expert_offset=local_expert_offset,
)
runner = moe.run

Comment on lines +1360 to 1415
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject EP-style configs on SM120/SM121.

This path now runs the SM12x backend, but it still accepts local_num_experts != num_experts and non-zero local_expert_offset. SM120/SM121 does not support local-expert remapping, so those arguments can produce unsupported benchmark cases or route into experts that are not present in the locally-created weight tensors.

Suggested guard
     use_functional = getattr(args, "use_functional_api", False)

     # SM120 passes bf16 as x (kernel fuses quantization); SM100 passes FP4.
     sm_major_bm = torch.cuda.get_device_capability(device)[0]
+    if sm_major_bm == 12 and (
+        local_num_experts != num_experts or local_expert_offset != 0
+    ):
+        raise ValueError(
+            "cute_dsl_fp4_block_scale_moe on SM120/SM121 does not support "
+            "local expert sharding; use local_num_experts=num_experts and "
+            "local_expert_offset=0."
+        )
     x_input = tensors["x_bf16"] if sm_major_bm == 12 else tensors["x"]

Based on learnings: Expert Parallelism (EP) is unsupported on SM120, and the SM120 dispatch paths intentionally do not forward local_expert_offset because kernel-side remapping is missing.

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/moe.py` around lines 1360 - 1415, Detect SM120/SM121
(sm_major_bm == 12) and explicitly reject Expert-Parallel (EP) configs by
validating local_num_experts == num_experts and local_expert_offset == 0 before
selecting the SM12x backend; if the check fails, raise a clear error (or exit)
so neither cute_dsl_fused_moe_nvfp4 (used when use_functional) nor
CuteDslMoEWrapper::run are invoked with unsupported local-expert remapping
parameters. Ensure the guard references sm_major_bm, local_num_experts,
local_expert_offset, cute_dsl_fused_moe_nvfp4 and CuteDslMoEWrapper so it runs
for both the functional and wrapper code paths.

def run_cute_dsl_moe(
x,
Expand All @@ -1362,7 +1426,7 @@ def run_cute_dsl_moe(
w2_weight_sf,
w2_alpha,
):
return moe.run(
return runner(
x=x,
x_sf=x_sf,
token_selected_experts=token_selected_experts,
Expand All @@ -1377,7 +1441,7 @@ def run_cute_dsl_moe(
)

input_args = (
tensors["x"],
x_input,
tensors["x_sf"],
tensors["token_selected_experts"],
tensors["token_final_scales"],
Expand Down
Loading
Loading