diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index ed81f364fe..b3b59cbd75 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -184,6 +184,7 @@ "trtllm_fp8_per_tensor_scale_moe", "cutlass_fused_moe", "cute_dsl_fp4_block_scale_moe", + "b12x_fused_moe", ], "moe_comm": [ "moe_a2a_dispatch_combine", @@ -451,8 +452,19 @@ def dtype_str_to_torch_dtype(dtype_str): "9.0": [], "10.0": ["cute-dsl"], "10.3": ["cute-dsl"], - "12.0": ["cute-dsl"], - "12.1": ["cute-dsl"], + "12.0": [], + "12.1": [], + }, + "b12x_fused_moe": { + "7.5": [], + "8.0": [], + "8.6": [], + "8.9": [], + "9.0": [], + "10.0": [], + "10.3": [], + "12.0": ["b12x"], + "12.1": ["b12x"], }, # NORM "rmsnorm": { diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index 23260d5a35..26f4e45389 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -95,6 +95,8 @@ def run_moe_test(args): return testCutlassFusedMoe(args) elif args.routine == "cute_dsl_fp4_block_scale_moe": return testCuteDslFp4BlockScaleMoe(args) + elif args.routine == "b12x_fused_moe": + return testB12xFusedMoe(args) else: raise ValueError(f"Unsupported routine: {args.routine}") @@ -235,15 +237,17 @@ def parse_moe_args(line, parser): ), ) - # CuTe DSL MoE specific + # CuTe DSL / b12x MoE specific parser.add_argument( "--use_functional_api", action="store_true", default=False, help=( - "Use cute_dsl_fused_moe_nvfp4 functional API instead of CuteDslMoEWrapper " - "for cute_dsl_fp4_block_scale_moe benchmark. Useful for verifying that the " - "workspace cache eliminates per-call allocation overhead." + "Use the functional MoE API instead of the wrapper class: " + "cute_dsl_fused_moe_nvfp4 vs CuteDslMoEWrapper for " + "cute_dsl_fp4_block_scale_moe, and b12x_fused_moe vs B12xMoEWrapper for " + "b12x_fused_moe. Useful for verifying that the wrapper's workspace cache " + "eliminates per-call allocation overhead." ), ) @@ -1197,7 +1201,7 @@ def _interleave_linear_and_gate( return x -def _create_cute_dsl_moe_test_data( +def _create_nvfp4_moe_test_data( num_tokens: int, hidden_size: int, intermediate_size: int, @@ -1205,21 +1209,35 @@ def _create_cute_dsl_moe_test_data( num_local_experts: int, top_k: int, device: torch.device, + backend: str, is_gated: bool = True, ): - """Create NVFP4-quantized test data for CuteDSL MoE (Blackwell kernels). + """Create NVFP4-quantized test data for CuTe-DSL-family MoE kernels. - Routing is computed externally via simple top-k (CuteDslMoEWrapper takes - pre-computed token_selected_experts and token_final_scales). + Supports two backends that share the weight-quantization recipe but differ + in the FC1 weight layout and in whether input activations are pre-quantized: - Returns a dict with all tensors needed by CuteDslMoEWrapper.run(). + - ``backend="cute-dsl"`` (SM100/SM103 ``CuteDslMoEWrapper``): gated SwiGLU only, + FC1 weights interleaved via ``_interleave_linear_and_gate``, fp4-quantized + input activations + scale factors. + - ``backend="b12x"`` (SM120/SM121 ``B12xMoEWrapper``): either gated (SwiGLU, + ``[E, 2n, k]``) or non-gated (ReLU2, ``[E, n, k]``) FC1, no interleave, + bf16 input activations (the kernel fuses quantization internally). + + Routing is computed externally via simple top-k (the wrappers take + pre-computed ``token_selected_experts`` and ``token_final_scales``). """ from flashinfer.fp4_quantization import fp4_quantize from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout + if backend not in ("cute-dsl", "b12x"): + raise ValueError(f"Unsupported backend: {backend!r}") + if backend == "cute-dsl" and not is_gated: + raise ValueError("cute-dsl backend only supports gated (SwiGLU) activation") + sf_vec_size = 16 - # Input activations + # Input activations: fp4-quantized for cute-dsl, bf16 for b12x x_bf16 = ( torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=device) / 10 ) @@ -1235,8 +1253,8 @@ def _create_cute_dsl_moe_test_data( selected_experts = selected_experts.to(torch.int32) # GEMM1 weights - # Gated (SiLU/SwiGLU): [E, 2*n, k] — gate + up fused - # Non-gated (ReLU2): [E, n, k] — single FC1 matrix + # Gated (SwiGLU): [E, 2*n, k] — gate + up fused + # Non-gated (ReLU2, b12x only): [E, n, k] — single FC1 matrix w1_rows = (2 if is_gated else 1) * intermediate_size w1_bf16 = ( torch.randn( @@ -1248,11 +1266,11 @@ def _create_cute_dsl_moe_test_data( ) / 10 ) - sm_major = torch.cuda.get_device_capability(device)[0] - if sm_major == 12 or not is_gated: - w1_bf16_prepared = w1_bf16 # SM120 or non-gated: no interleave - else: + # cute-dsl (SM100) expects gate/up interleaved; b12x (SM120) does not. + if backend == "cute-dsl": w1_bf16_prepared = _interleave_linear_and_gate(w1_bf16, group_size=64, dim=1) + else: + w1_bf16_prepared = w1_bf16 w1_gs = torch.tensor([1.0], device=device, dtype=torch.float32) w1_flat = w1_bf16_prepared.view(num_local_experts * w1_rows, hidden_size) w1_q_flat, w1_sf_flat = fp4_quantize( @@ -1314,13 +1332,17 @@ def _create_cute_dsl_moe_test_data( def testCuteDslFp4BlockScaleMoe(args): """ - Test cute_dsl_fp4_block_scale_moe (CuteDSL NVFP4 MoE on Blackwell). + Test cute_dsl_fp4_block_scale_moe (CuTe DSL NVFP4 MoE on SM100/SM103). This test: - 1. Creates NVFP4-quantized weights and inputs for CuteDSL kernels - 2. Runs MoE via CuteDslMoEWrapper + 1. Creates NVFP4-quantized weights and fp4-quantized inputs for CuTe DSL kernels + 2. Runs MoE via CuteDslMoEWrapper (or cute_dsl_fused_moe_nvfp4 when + ``--use_functional_api`` is set). SwiGLU only. 3. Measures performance metrics (TFLOPS, TB/sec) + Note: SM120/SM121 has moved to the dedicated ``b12x_fused_moe`` routine + (see ``testB12xFusedMoe``). + Args: args: Parsed command line arguments containing test configuration @@ -1364,19 +1386,16 @@ def testCuteDslFp4BlockScaleMoe(args): f"intermediate={intermediate_size}, experts={num_experts}, top_k={top_k}" ) - # Map ActivationType enum to string for SM120 CuTe DSL API + # CuteDslMoEWrapper is gated SwiGLU only. activation_type = args.activation_type - _ACT_STR = {ActivationType.Swiglu: "silu", ActivationType.Relu2: "relu2"} - if activation_type not in _ACT_STR: + if activation_type != ActivationType.Swiglu: raise ValueError( - f"CuTe DSL MoE only supports Swiglu and Relu2 activations, " - f"got {activation_type.name}" + f"cute_dsl_fp4_block_scale_moe only supports Swiglu activation, " + f"got {activation_type.name}. Use --routine b12x_fused_moe for ReLU2." ) - activation_str = _ACT_STR[activation_type] - is_gated = activation_type == ActivationType.Swiglu - # Create CuteDSL-specific NVFP4 test data - tensors = _create_cute_dsl_moe_test_data( + # Create CuteDSL-specific NVFP4 test data (gated, fp4 input) + tensors = _create_nvfp4_moe_test_data( num_tokens=num_tokens, hidden_size=hidden_size, intermediate_size=intermediate_size, @@ -1384,7 +1403,8 @@ def testCuteDslFp4BlockScaleMoe(args): num_local_experts=local_num_experts, top_k=top_k, device=device, - is_gated=is_gated, + backend="cute-dsl", + is_gated=True, ) if args.verbose >= 2: @@ -1394,181 +1414,93 @@ def testCuteDslFp4BlockScaleMoe(args): use_functional = getattr(args, "use_functional_api", False) - sm_major_bm = torch.cuda.get_device_capability(device)[0] - is_sm120 = sm_major_bm == 12 - x_input = tensors["x_bf16"] if is_sm120 else tensors["x"] - if use_functional: from functools import partial + from flashinfer import cute_dsl_fused_moe_nvfp4 # Pre-allocate output buffer to avoid per-call allocation moe_output = torch.empty( num_tokens, hidden_size, dtype=torch.bfloat16, device=device ) - if is_sm120: - from flashinfer import b12x_fused_moe - - if args.verbose >= 1: - print("[INFO] Using b12x functional API (b12x_fused_moe)") - runner = partial( - b12x_fused_moe, - num_experts=num_experts, - top_k=top_k, - num_local_experts=local_num_experts, - output=moe_output, - activation=activation_str, - ) - else: - from flashinfer import cute_dsl_fused_moe_nvfp4 - - if args.verbose >= 1: - print("[INFO] Using CuTe DSL functional API (cute_dsl_fused_moe_nvfp4)") - runner = partial( - cute_dsl_fused_moe_nvfp4, - num_experts=num_experts, - top_k=top_k, - num_local_experts=local_num_experts, - local_expert_offset=local_expert_offset, - moe_output=moe_output, - ) + if args.verbose >= 1: + print("[INFO] Using CuTe DSL functional API (cute_dsl_fused_moe_nvfp4)") + runner = partial( + cute_dsl_fused_moe_nvfp4, + num_experts=num_experts, + top_k=top_k, + num_local_experts=local_num_experts, + local_expert_offset=local_expert_offset, + moe_output=moe_output, + ) # Warmup call to populate workspace cache before timed region - if is_sm120: - runner( - x=x_input, - w1_weight=tensors["w1_weight"], - w1_weight_sf=tensors["w1_weight_sf"], - w1_alpha=tensors["w1_alpha"], - fc2_input_scale=tensors["fc2_input_scale"], - w2_weight=tensors["w2_weight"], - w2_weight_sf=tensors["w2_weight_sf"], - w2_alpha=tensors["w2_alpha"], - token_selected_experts=tensors["token_selected_experts"], - token_final_scales=tensors["token_final_scales"], - ) - else: - runner( - x=x_input, - x_sf=tensors["x_sf"], - token_selected_experts=tensors["token_selected_experts"], - token_final_scales=tensors["token_final_scales"], - w1_weight=tensors["w1_weight"], - w1_weight_sf=tensors["w1_weight_sf"], - w1_alpha=tensors["w1_alpha"], - fc2_input_scale=tensors["fc2_input_scale"], - w2_weight=tensors["w2_weight"], - w2_weight_sf=tensors["w2_weight_sf"], - w2_alpha=tensors["w2_alpha"], - ) + runner( + x=tensors["x"], + x_sf=tensors["x_sf"], + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + w1_weight=tensors["w1_weight"], + w1_weight_sf=tensors["w1_weight_sf"], + w1_alpha=tensors["w1_alpha"], + fc2_input_scale=tensors["fc2_input_scale"], + w2_weight=tensors["w2_weight"], + w2_weight_sf=tensors["w2_weight_sf"], + w2_alpha=tensors["w2_alpha"], + ) else: - if is_sm120: - from flashinfer import B12xMoEWrapper - - moe = B12xMoEWrapper( - num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - use_cuda_graph=is_cuda_graph_compatible, - max_num_tokens=num_tokens, - num_local_experts=local_num_experts, - activation=activation_str, - ) - else: - moe = CuteDslMoEWrapper( - num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - use_cuda_graph=is_cuda_graph_compatible, - max_num_tokens=num_tokens, - num_local_experts=local_num_experts, - local_expert_offset=local_expert_offset, - ) + moe = CuteDslMoEWrapper( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + use_cuda_graph=is_cuda_graph_compatible, + max_num_tokens=num_tokens, + num_local_experts=local_num_experts, + local_expert_offset=local_expert_offset, + ) runner = moe.run - if is_sm120: - - def run_cute_dsl_moe( - x, - w1_weight, - w1_weight_sf, - w1_alpha, - fc2_input_scale, - w2_weight, - w2_weight_sf, - w2_alpha, - token_selected_experts, - token_final_scales, - ): - return runner( - x=x, - w1_weight=w1_weight, - w1_weight_sf=w1_weight_sf, - w1_alpha=w1_alpha, - fc2_input_scale=fc2_input_scale, - w2_weight=w2_weight, - w2_weight_sf=w2_weight_sf, - w2_alpha=w2_alpha, - token_selected_experts=token_selected_experts, - token_final_scales=token_final_scales, - ) - - input_args = ( - x_input, - tensors["w1_weight"], - tensors["w1_weight_sf"], - tensors["w1_alpha"], - tensors["fc2_input_scale"], - tensors["w2_weight"], - tensors["w2_weight_sf"], - tensors["w2_alpha"], - tensors["token_selected_experts"], - tensors["token_final_scales"], + def run_cute_dsl_moe( + x, + x_sf, + token_selected_experts, + token_final_scales, + w1_weight, + w1_weight_sf, + w1_alpha, + fc2_input_scale, + w2_weight, + w2_weight_sf, + w2_alpha, + ): + return runner( + x=x, + x_sf=x_sf, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + w1_weight=w1_weight, + w1_weight_sf=w1_weight_sf, + w1_alpha=w1_alpha, + fc2_input_scale=fc2_input_scale, + w2_weight=w2_weight, + w2_weight_sf=w2_weight_sf, + w2_alpha=w2_alpha, ) - else: - - def run_cute_dsl_moe( - x, - x_sf, - token_selected_experts, - token_final_scales, - w1_weight, - w1_weight_sf, - w1_alpha, - fc2_input_scale, - w2_weight, - w2_weight_sf, - w2_alpha, - ): - return runner( - x=x, - x_sf=x_sf, - token_selected_experts=token_selected_experts, - token_final_scales=token_final_scales, - w1_weight=w1_weight, - w1_weight_sf=w1_weight_sf, - w1_alpha=w1_alpha, - fc2_input_scale=fc2_input_scale, - w2_weight=w2_weight, - w2_weight_sf=w2_weight_sf, - w2_alpha=w2_alpha, - ) - input_args = ( - x_input, - tensors["x_sf"], - tensors["token_selected_experts"], - tensors["token_final_scales"], - tensors["w1_weight"], - tensors["w1_weight_sf"], - tensors["w1_alpha"], - tensors["fc2_input_scale"], - tensors["w2_weight"], - tensors["w2_weight_sf"], - tensors["w2_alpha"], - ) + input_args = ( + tensors["x"], + tensors["x_sf"], + tensors["token_selected_experts"], + tensors["token_final_scales"], + tensors["w1_weight"], + tensors["w1_weight_sf"], + tensors["w1_alpha"], + tensors["fc2_input_scale"], + tensors["w2_weight"], + tensors["w2_weight_sf"], + tensors["w2_alpha"], + ) # Snapshot active expert count before any kernel execution, since # autotune tactic exploration may corrupt input tensors. @@ -1616,7 +1548,7 @@ def run_cute_dsl_moe( num_experts, top_k, median_time, - is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu), + is_gated=True, ) tb_per_sec = calculate_moe_kernel_bandwidth( num_tokens, @@ -1632,7 +1564,7 @@ def run_cute_dsl_moe( routing_logits_dtype=None, active_experts=num_active_experts, verbose=args.verbose, - is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu), + is_gated=True, ) print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) @@ -1655,6 +1587,268 @@ def run_cute_dsl_moe( cur_res["input_dtype"] = input_dtype cur_res["weight_dtype"] = weight_dtype cur_res["fp4_mode"] = "nvfp4" + cur_res["activation_type"] = activation_type.name + res.append(cur_res) + + return res + + +def testB12xFusedMoe(args): + """ + Test b12x_fused_moe (SM120/SM121 CuTe DSL NVFP4 MoE). + + The b12x MoE takes **bf16** hidden states (the kernel fuses the + quantization internally) and NVFP4-quantized weights. Supports both + SwiGLU (gated) and ReLU2 (non-gated) activations. + + This test: + 1. Creates NVFP4-quantized weights and bf16 inputs for b12x kernels + 2. Runs MoE via B12xMoEWrapper (or b12x_fused_moe when + ``--use_functional_api`` is set) + 3. Measures performance metrics (TFLOPS, TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testB12xFusedMoe") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + from flashinfer import B12xMoEWrapper + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + weight_dtype = dtype_str_to_torch_dtype(args.weight_dtype) + + num_tokens = args.num_tokens + hidden_size = args.hidden_size + intermediate_size = args.intermediate_size + num_experts = args.num_experts + top_k = args.top_k + local_num_experts = args.local_num_experts or num_experts + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = ["b12x"] + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + if args.verbose >= 1: + print( + f"[INFO] Configuration: tokens={num_tokens}, hidden={hidden_size}, " + f"intermediate={intermediate_size}, experts={num_experts}, top_k={top_k}" + ) + + # b12x supports SwiGLU (gated) and ReLU2 (non-gated) + activation_type = args.activation_type + _ACT_STR = {ActivationType.Swiglu: "silu", ActivationType.Relu2: "relu2"} + if activation_type not in _ACT_STR: + raise ValueError( + f"b12x_fused_moe only supports Swiglu and Relu2 activations, " + f"got {activation_type.name}" + ) + activation_str = _ACT_STR[activation_type] + is_gated = activation_type == ActivationType.Swiglu + + # Create b12x-specific NVFP4 test data (weights quantized, input stays bf16) + tensors = _create_nvfp4_moe_test_data( + num_tokens=num_tokens, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_experts=num_experts, + num_local_experts=local_num_experts, + top_k=top_k, + device=device, + backend="b12x", + is_gated=is_gated, + ) + + if args.verbose >= 2: + print(f"[VVERBOSE] x_bf16.shape = {tensors['x_bf16'].shape}") + print(f"[VVERBOSE] w1_weight.shape = {tensors['w1_weight'].shape}") + print(f"[VVERBOSE] w2_weight.shape = {tensors['w2_weight'].shape}") + + use_functional = getattr(args, "use_functional_api", False) + x_input = tensors["x_bf16"] + + if use_functional: + from functools import partial + from flashinfer import b12x_fused_moe + + # Pre-allocate output buffer to avoid per-call allocation + moe_output = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=device + ) + + if args.verbose >= 1: + print("[INFO] Using b12x functional API (b12x_fused_moe)") + runner = partial( + b12x_fused_moe, + num_experts=num_experts, + top_k=top_k, + num_local_experts=local_num_experts, + output=moe_output, + activation=activation_str, + ) + + # Warmup call to populate workspace cache before timed region + runner( + x=x_input, + w1_weight=tensors["w1_weight"], + w1_weight_sf=tensors["w1_weight_sf"], + w1_alpha=tensors["w1_alpha"], + fc2_input_scale=tensors["fc2_input_scale"], + w2_weight=tensors["w2_weight"], + w2_weight_sf=tensors["w2_weight_sf"], + w2_alpha=tensors["w2_alpha"], + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + ) + else: + moe = B12xMoEWrapper( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + use_cuda_graph=is_cuda_graph_compatible, + max_num_tokens=num_tokens, + num_local_experts=local_num_experts, + activation=activation_str, + ) + runner = moe.run + + def run_b12x_moe( + x, + w1_weight, + w1_weight_sf, + w1_alpha, + fc2_input_scale, + w2_weight, + w2_weight_sf, + w2_alpha, + token_selected_experts, + token_final_scales, + ): + return runner( + x=x, + w1_weight=w1_weight, + w1_weight_sf=w1_weight_sf, + w1_alpha=w1_alpha, + fc2_input_scale=fc2_input_scale, + w2_weight=w2_weight, + w2_weight_sf=w2_weight_sf, + w2_alpha=w2_alpha, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + ) + + input_args = ( + x_input, + tensors["w1_weight"], + tensors["w1_weight_sf"], + tensors["w1_alpha"], + tensors["fc2_input_scale"], + tensors["w2_weight"], + tensors["w2_weight_sf"], + tensors["w2_alpha"], + tensors["token_selected_experts"], + tensors["token_final_scales"], + ) + + # Snapshot active expert count before any kernel execution, since + # autotune tactic exploration may corrupt input tensors. + num_active_experts = int(tensors["token_selected_experts"].unique().numel()) + + backend = "b12x" + + # Optional autotune warmup. + if getattr(args, "autotune", False): + warmup_iters = ( + args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 + ) + backend = "b12x_autotune" + if args.verbose >= 1: + print(f"[INFO] Autotune warmup for b12x NVFP4 MoE: {warmup_iters} iters") + autotune_args = tuple( + t.clone() if isinstance(t, torch.Tensor) else t for t in input_args + ) + with autotune(True): + for _ in range(warmup_iters): + run_b12x_moe(*autotune_args) + del autotune_args + + # Benchmark timing + times = bench_gpu_time( + fn=run_b12x_moe, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + sleep_after_run=False, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + cold_l2_cache=True, + input_args=input_args, + ) + + # Compute performance metrics + median_time = np.median(times) + std_time = np.std(times) + tflops = calculate_moe_tflops( + num_tokens, + hidden_size, + intermediate_size, + num_experts, + top_k, + median_time, + is_gated=is_gated, + ) + # Input format is bf16 for b12x (kernel fuses quantization), weights are nvfp4. + tb_per_sec = calculate_moe_kernel_bandwidth( + num_tokens, + hidden_size, + intermediate_size, + num_experts, + top_k, + median_time, + input_dtype, + weight_dtype, + input_format="bf16", + weight_format="nvfp4", + routing_logits_dtype=None, + active_experts=num_active_experts, + verbose=args.verbose, + is_gated=is_gated, + ) + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["backend"] = backend + cur_res["num_tokens"] = num_tokens + cur_res["hidden_size"] = hidden_size + cur_res["intermediate_size"] = intermediate_size + cur_res["num_experts"] = num_experts + cur_res["top_k"] = top_k + cur_res["local_num_experts"] = local_num_experts + cur_res["input_dtype"] = input_dtype + cur_res["weight_dtype"] = weight_dtype + cur_res["fp4_mode"] = "nvfp4" + cur_res["activation_type"] = activation_type.name res.append(cur_res) return res diff --git a/benchmarks/samples/sample_testlist.txt b/benchmarks/samples/sample_testlist.txt index 3c4aab7620..8737751821 100644 --- a/benchmarks/samples/sample_testlist.txt +++ b/benchmarks/samples/sample_testlist.txt @@ -41,12 +41,16 @@ --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant nvfp4 --quantized_input --input_dtype float16 -vv --generate_repro_command --case_tag "cutlass_moe_nvfp4_weights_quantized" --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 8 --top_k 2 --cutlass_variant base --input_dtype float16 --tp_size 2 --tp_rank 0 --ep_size 4 --ep_rank 0 -vv --generate_repro_command --case_tag "cutlass_moe_nvfp4_ep_tp" -# CuteDSL NVFP4 MoE (Blackwell SM10.0+ only) +# CuteDSL NVFP4 MoE (Blackwell SM10.0/SM10.3 only; SwiGLU only) --routine cute_dsl_fp4_block_scale_moe --num_tokens 1024 --hidden_size 7168 --intermediate_size 2048 --num_experts 256 --top_k 8 -vv --generate_repro_command --case_tag "cute_dsl_moe_large" --routine cute_dsl_fp4_block_scale_moe --num_tokens 256 --hidden_size 1024 --intermediate_size 512 --num_experts 256 --top_k 2 -vv --generate_repro_command --case_tag "cute_dsl_moe_small" --routine cute_dsl_fp4_block_scale_moe --num_tokens 1024 --hidden_size 7168 --intermediate_size 2048 --num_experts 256 --top_k 8 --autotune -vv --generate_repro_command --case_tag "cute_dsl_moe_autotune" --routine cute_dsl_fp4_block_scale_moe --num_tokens 1024 --hidden_size 7168 --intermediate_size 2048 --num_experts 256 --top_k 8 --local_expert_offset 0 --local_num_experts 32 -vv --generate_repro_command --case_tag "cute_dsl_moe_ep8" +# b12x NVFP4 MoE (Blackwell SM12.0/SM12.1 only; SwiGLU or ReLU2, bf16 input) +--routine b12x_fused_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 2688 --num_experts 512 --top_k 22 -vv --generate_repro_command --case_tag "b12x_moe_swiglu" +--routine b12x_fused_moe --num_tokens 1 --hidden_size 1024 --intermediate_size 2688 --num_experts 512 --top_k 22 --activation-type Relu2 -vv --generate_repro_command --case_tag "b12x_moe_relu2_decode" + ## MoE Communication (requires mpirun, e.g.: mpirun -np 8 python benchmarks/flashinfer_benchmark.py ...) # Basic A2A dispatch+combine without quantization #--routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 -vv --generate_repro_command --case_tag "moe_a2a_basic" diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py index f6cf1b67bb..e266cb771a 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py @@ -89,8 +89,8 @@ st_global_u64, scatter_add_bf16x2, ) -from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120 import ( - Sm120BlockScaledDenseGemmKernel as DenseGemmKernel, +from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120_b12x import ( + Sm120B12xBlockScaledDenseGemmKernel as DenseGemmKernel, ) diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py index e7fdae9293..670b3ad812 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py @@ -122,8 +122,8 @@ st_global_u64, scatter_add_bf16x2, ) -from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120 import ( - Sm120BlockScaledDenseGemmKernel as DenseGemmKernel, +from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120_b12x import ( + Sm120B12xBlockScaledDenseGemmKernel as DenseGemmKernel, ) diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.py index 116ba23d3f..475fa51ae7 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.py @@ -120,8 +120,8 @@ st_global_u64, scatter_add_bf16x2, ) -from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120 import ( - Sm120BlockScaledDenseGemmKernel as DenseGemmKernel, +from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120_b12x import ( + Sm120B12xBlockScaledDenseGemmKernel as DenseGemmKernel, ) diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index a7795beb61..def82216a2 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -61,11 +61,11 @@ from flashinfer.cute_dsl.utils import is_cute_dsl_available if is_cute_dsl_available(): - from .kernels.dense_blockscaled_gemm_sm120 import ( - Sm120BlockScaledDenseGemmKernel as Sm120BlockScaledDenseGemmKernel, + from .kernels.dense_blockscaled_gemm_sm120_b12x import ( + Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel, ) - _cute_dsl_kernels.append("Sm120BlockScaledDenseGemmKernel") + _cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel") except ImportError: pass diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index decc213e6f..6e3c8aa44b 100755 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -4858,8 +4858,8 @@ def _b12x_gemm_fp4_runner( """ import cutlass - from .kernels.dense_blockscaled_gemm_sm120 import ( - Sm120BlockScaledDenseGemmKernel, + from .kernels.dense_blockscaled_gemm_sm120_b12x import ( + Sm120B12xBlockScaledDenseGemmKernel, ) cutlass_dtype_attr = _TORCH_TO_CUTLASS_DTYPE_ATTR.get(out_dtype) @@ -4905,7 +4905,7 @@ def get_valid_tactics( ] swap_ab = False for mma_tiler_mn in sm120_mma_tiler_candidates: - if not Sm120BlockScaledDenseGemmKernel.can_implement( + if not Sm120B12xBlockScaledDenseGemmKernel.can_implement( ab_dtype, sf_dtype, sf_vec_size, @@ -4945,11 +4945,10 @@ def forward( batch_size = 1 if tactic is None or tactic == -1: - _sm_count = torch.cuda.get_device_properties( - a.device - ).multi_processor_count tactic = ( - _select_default_sm120_mma_tiler(m, n, _sm_count), + _select_default_sm120_mma_tiler( + m, n, get_device_sm_count(a.device) + ), (1, 1), False, False, @@ -4987,7 +4986,7 @@ def forward( out_dtype, ) - make_kernel = lambda: Sm120BlockScaledDenseGemmKernel( + make_kernel = lambda: Sm120B12xBlockScaledDenseGemmKernel( sf_vec_size, mma_tiler_mn, cluster_shape_mn, diff --git a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py similarity index 99% rename from flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py rename to flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py index c49bc81586..6eee27a709 100644 --- a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py +++ b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py @@ -1550,7 +1550,7 @@ def wrapper( # Alias for FlashInfer integration -Sm120BlockScaledDenseGemmKernel = DenseGemmKernel +Sm120B12xBlockScaledDenseGemmKernel = DenseGemmKernel class _DenseGemmLaunch: diff --git a/tests/moe/test_b12x_fused_moe.py b/tests/moe/test_b12x_fused_moe.py index bb9c561e72..087426f25e 100644 --- a/tests/moe/test_b12x_fused_moe.py +++ b/tests/moe/test_b12x_fused_moe.py @@ -1,5 +1,5 @@ """ -Copyright (c) 2025 by FlashInfer team. +Copyright (c) 2026 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.