-
Notifications
You must be signed in to change notification settings - Fork 830
chore: cute dsl nvfp4 moe clean up #2775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
067cfec
7cf0321
4183632
699f086
086226c
66b52e2
f450151
fd50e23
5dcca81
71042bf
476899a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -93,6 +93,8 @@ def run_moe_test(args): | |
| return testTrtllmFp8PerTensorScaleMoe(args) | ||
| elif args.routine == "cutlass_fused_moe": | ||
| return testCutlassFusedMoe(args) | ||
| elif args.routine == "cute_dsl_fp4_block_scale_moe": | ||
| return testCuteDslFp4BlockScaleMoe(args) | ||
| else: | ||
| raise ValueError(f"Unsupported routine: {args.routine}") | ||
|
|
||
|
|
@@ -1144,6 +1146,333 @@ def run_cutlass( | |
| return res | ||
|
|
||
|
|
||
| def _interleave_linear_and_gate( | ||
| x: torch.Tensor, group_size: int = 64, dim: int = -1 | ||
| ) -> torch.Tensor: | ||
| """Interleave linear and gate weights for CuteDSL SwiGLU layout.""" | ||
| sizes = x.size() | ||
| dim = dim % x.dim() | ||
| assert sizes[dim] % (group_size * 2) == 0 | ||
| prev_sizes = sizes[:dim] | ||
| post_sizes = sizes[dim + 1 :] | ||
| x = x.view(*prev_sizes, 2, sizes[dim] // (group_size * 2), group_size, *post_sizes) | ||
| x = x.transpose(dim, dim + 1).contiguous().view(*sizes) | ||
| return x | ||
|
|
||
|
|
||
| def _create_cute_dsl_moe_test_data( | ||
| num_tokens: int, | ||
| hidden_size: int, | ||
| intermediate_size: int, | ||
| num_experts: int, | ||
| num_local_experts: int, | ||
| top_k: int, | ||
| device: torch.device, | ||
| ): | ||
| """Create NVFP4-quantized test data for CuteDSL MoE (Blackwell kernels). | ||
|
|
||
| Routing is computed externally via simple top-k (CuteDslMoEWrapper takes | ||
| pre-computed token_selected_experts and token_final_scales). | ||
|
|
||
| Returns a dict with all tensors needed by CuteDslMoEWrapper.run(). | ||
| """ | ||
| from flashinfer.fp4_quantization import fp4_quantize | ||
| from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout | ||
|
|
||
| sf_vec_size = 16 | ||
|
|
||
| # Input activations | ||
| x_bf16 = ( | ||
| torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=device) / 10 | ||
| ) | ||
|
Comment on lines
+1184
to
+1187
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This benchmark can report the wrong input/weight dtypes. The new CuteDSL route records Also applies to: 1192-1199, 1223-1230, 1292-1293, 1317-1326 π€ Prompt for AI Agents |
||
| a1_gs = torch.tensor([1.0], device=device, dtype=torch.float32) | ||
| x_quantized, x_sf = fp4_quantize( | ||
| x_bf16, global_scale=a1_gs, sf_vec_size=sf_vec_size, is_sf_swizzled_layout=False | ||
| ) | ||
| x_sf = x_sf.unsqueeze(-1) | ||
|
|
||
| # Routing (simple top-k; the wrapper takes pre-computed assignments) | ||
| routing_logits = torch.randn(num_tokens, num_experts, device=device) | ||
| routing_weights, selected_experts = compute_routing(routing_logits, top_k) | ||
| selected_experts = selected_experts.to(torch.int32) | ||
|
|
||
| # GEMM1 weights (gate + up, interleaved for CuteDSL SwiGLU) | ||
| w1_bf16 = ( | ||
| torch.randn( | ||
| num_local_experts, | ||
| 2 * intermediate_size, | ||
| hidden_size, | ||
| dtype=torch.bfloat16, | ||
| device=device, | ||
| ) | ||
| / 10 | ||
| ) | ||
| w1_bf16_interleaved = _interleave_linear_and_gate(w1_bf16, group_size=64, dim=1) | ||
| w1_gs = torch.tensor([1.0], device=device, dtype=torch.float32) | ||
| w1_flat = w1_bf16_interleaved.view( | ||
| num_local_experts * 2 * intermediate_size, hidden_size | ||
| ) | ||
| w1_q_flat, w1_sf_flat = fp4_quantize( | ||
| w1_flat, global_scale=w1_gs, sf_vec_size=sf_vec_size, is_sf_swizzled_layout=True | ||
| ) | ||
| w1_weight = w1_q_flat.view( | ||
| num_local_experts, 2 * intermediate_size, hidden_size // 2 | ||
| ) | ||
| w1_weight_sf = convert_sf_to_mma_layout( | ||
| w1_sf_flat, | ||
| m=2 * intermediate_size, | ||
| k=hidden_size, | ||
| num_groups=num_local_experts, | ||
| sf_vec_size=sf_vec_size, | ||
| ) | ||
| w1_alpha = torch.ones(num_local_experts, device=device, dtype=torch.float32) | ||
|
|
||
| # GEMM2 weights (down projection) | ||
| w2_bf16 = ( | ||
| torch.randn( | ||
| num_local_experts, | ||
| hidden_size, | ||
| intermediate_size, | ||
| dtype=torch.bfloat16, | ||
| device=device, | ||
| ) | ||
| / 10 | ||
| ) | ||
| w2_gs = torch.tensor([1.0], device=device, dtype=torch.float32) | ||
| w2_flat = w2_bf16.view(num_local_experts * hidden_size, intermediate_size) | ||
| w2_q_flat, w2_sf_flat = fp4_quantize( | ||
| w2_flat, global_scale=w2_gs, sf_vec_size=sf_vec_size, is_sf_swizzled_layout=True | ||
| ) | ||
| w2_weight = w2_q_flat.view(num_local_experts, hidden_size, intermediate_size // 2) | ||
| w2_weight_sf = convert_sf_to_mma_layout( | ||
| w2_sf_flat, | ||
| m=hidden_size, | ||
| k=intermediate_size, | ||
| num_groups=num_local_experts, | ||
| sf_vec_size=sf_vec_size, | ||
| ) | ||
| w2_alpha = torch.ones(num_local_experts, device=device, dtype=torch.float32) | ||
|
|
||
| fc2_input_scale = torch.tensor([1.0], device=device, dtype=torch.float32) | ||
|
|
||
| return { | ||
| "x": x_quantized, | ||
| "x_sf": x_sf, | ||
| "token_selected_experts": selected_experts, | ||
| "token_final_scales": routing_weights, | ||
| "w1_weight": w1_weight, | ||
| "w1_weight_sf": w1_weight_sf, | ||
| "w1_alpha": w1_alpha, | ||
| "fc2_input_scale": fc2_input_scale, | ||
| "w2_weight": w2_weight, | ||
| "w2_weight_sf": w2_weight_sf, | ||
| "w2_alpha": w2_alpha, | ||
| } | ||
|
|
||
|
|
||
| def testCuteDslFp4BlockScaleMoe(args): | ||
| """ | ||
| Test cute_dsl_fp4_block_scale_moe (CuteDSL NVFP4 MoE on Blackwell). | ||
|
|
||
| This test: | ||
| 1. Creates NVFP4-quantized weights and inputs for CuteDSL kernels | ||
| 2. Runs MoE via CuteDslMoEWrapper | ||
| 3. Measures performance metrics (TFLOPS, TB/sec) | ||
|
|
||
| Args: | ||
| args: Parsed command line arguments containing test configuration | ||
|
|
||
| Returns: | ||
| dict: List of dictionaries containing performance results | ||
| """ | ||
| if args.verbose >= 1: | ||
| print("[INFO] Running testCuteDslFp4BlockScaleMoe") | ||
| print(f"[INFO] FlashInfer version: {flashinfer.__version__}") | ||
|
|
||
| from flashinfer import CuteDslMoEWrapper | ||
|
|
||
| device = get_device(args) | ||
| if args.generate_repro_command: | ||
| print( | ||
| f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" | ||
| ) | ||
|
|
||
| input_dtype = dtype_str_to_torch_dtype(args.input_dtype) | ||
| weight_dtype = dtype_str_to_torch_dtype(args.weight_dtype) | ||
|
|
||
| num_tokens = args.num_tokens | ||
| hidden_size = args.hidden_size | ||
| intermediate_size = args.intermediate_size | ||
| num_experts = args.num_experts | ||
| top_k = args.top_k | ||
| local_expert_offset = args.local_expert_offset | ||
| local_num_experts = args.local_num_experts or num_experts | ||
| is_cuda_graph_compatible = not args.no_cuda_graph | ||
| res = [] | ||
|
|
||
| backends = ["cute-dsl"] | ||
| backends = filter_backends_by_compute_capability(backends, args.routine, device) | ||
| if len(backends) == 0: | ||
| print("[ERROR] No backends to test. Exiting.") | ||
| return res | ||
|
|
||
| if args.verbose >= 1: | ||
| print( | ||
| f"[INFO] Configuration: tokens={num_tokens}, hidden={hidden_size}, " | ||
| f"intermediate={intermediate_size}, experts={num_experts}, top_k={top_k}" | ||
| ) | ||
|
|
||
| # Create CuteDSL-specific NVFP4 test data | ||
| tensors = _create_cute_dsl_moe_test_data( | ||
| num_tokens=num_tokens, | ||
| hidden_size=hidden_size, | ||
| intermediate_size=intermediate_size, | ||
| num_experts=num_experts, | ||
| num_local_experts=local_num_experts, | ||
| top_k=top_k, | ||
| device=device, | ||
| ) | ||
|
|
||
| if args.verbose >= 2: | ||
| print(f"[VVERBOSE] x.shape = {tensors['x'].shape}") | ||
| print(f"[VVERBOSE] w1_weight.shape = {tensors['w1_weight'].shape}") | ||
| print(f"[VVERBOSE] w2_weight.shape = {tensors['w2_weight'].shape}") | ||
|
|
||
| moe = CuteDslMoEWrapper( | ||
| num_experts=num_experts, | ||
| top_k=top_k, | ||
| hidden_size=hidden_size, | ||
| intermediate_size=intermediate_size, | ||
| use_cuda_graph=is_cuda_graph_compatible, | ||
| max_num_tokens=num_tokens, | ||
| num_local_experts=local_num_experts, | ||
| local_expert_offset=local_expert_offset, | ||
| ) | ||
|
|
||
| def run_cute_dsl_moe( | ||
| x, | ||
| x_sf, | ||
| token_selected_experts, | ||
| token_final_scales, | ||
| w1_weight, | ||
| w1_weight_sf, | ||
| w1_alpha, | ||
| fc2_input_scale, | ||
| w2_weight, | ||
| w2_weight_sf, | ||
| w2_alpha, | ||
| ): | ||
| return moe.run( | ||
| x=x, | ||
| x_sf=x_sf, | ||
| token_selected_experts=token_selected_experts, | ||
| token_final_scales=token_final_scales, | ||
| w1_weight=w1_weight, | ||
| w1_weight_sf=w1_weight_sf, | ||
| w1_alpha=w1_alpha, | ||
| fc2_input_scale=fc2_input_scale, | ||
| w2_weight=w2_weight, | ||
| w2_weight_sf=w2_weight_sf, | ||
| w2_alpha=w2_alpha, | ||
| ) | ||
|
|
||
| input_args = ( | ||
| tensors["x"], | ||
| tensors["x_sf"], | ||
| tensors["token_selected_experts"], | ||
| tensors["token_final_scales"], | ||
| tensors["w1_weight"], | ||
| tensors["w1_weight_sf"], | ||
| tensors["w1_alpha"], | ||
| tensors["fc2_input_scale"], | ||
| tensors["w2_weight"], | ||
| tensors["w2_weight_sf"], | ||
| tensors["w2_alpha"], | ||
| ) | ||
|
|
||
| # Snapshot active expert count before any kernel execution, since | ||
| # autotune tactic exploration may corrupt input tensors. | ||
| num_active_experts = int(tensors["token_selected_experts"].unique().numel()) | ||
|
|
||
| backend = "cute-dsl" | ||
|
|
||
| # Optional autotune warmup. | ||
| # Clone input_args so autotune tactic exploration doesn't corrupt the | ||
| # original tensors used by the subsequent benchmark. | ||
| if getattr(args, "autotune", False): | ||
| warmup_iters = ( | ||
| args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 | ||
| ) | ||
| backend = "cute-dsl_autotune" | ||
| if args.verbose >= 1: | ||
| print(f"[INFO] Autotune warmup for CuteDSL NVFP4 MoE: {warmup_iters} iters") | ||
| autotune_args = tuple( | ||
| t.clone() if isinstance(t, torch.Tensor) else t for t in input_args | ||
| ) | ||
| with autotune(True): | ||
| for _ in range(warmup_iters): | ||
| run_cute_dsl_moe(*autotune_args) | ||
| del autotune_args | ||
|
Comment on lines
+1402
to
+1415
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Autotune cache path not supported. Unlike other test functions ( π§ Proposed fix to add cache support+ cache_path = getattr(args, "autotune_cache", None)
+
# Optional autotune warmup.
# Clone input_args so autotune tactic exploration doesn't corrupt the
# original tensors used by the subsequent benchmark.
if getattr(args, "autotune", False):
warmup_iters = (
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
)
backend = "cute-dsl_autotune"
if args.verbose >= 1:
print(f"[INFO] Autotune warmup for CuteDSL NVFP4 MoE: {warmup_iters} iters")
autotune_args = tuple(
t.clone() if isinstance(t, torch.Tensor) else t for t in input_args
)
- with autotune(True):
+ with autotune(True, cache=cache_path):
for _ in range(warmup_iters):
run_cute_dsl_moe(*autotune_args)
del autotune_args
+ elif cache_path:
+ with autotune(False, cache=cache_path):
+ passπ€ Prompt for AI Agents |
||
|
|
||
| # Benchmark timing | ||
| times = bench_gpu_time( | ||
| fn=run_cute_dsl_moe, | ||
| dry_run_iters=args.dry_run_iters, | ||
| repeat_iters=args.num_iters, | ||
| sleep_after_run=False, | ||
| enable_cupti=args.use_cupti, | ||
| use_cuda_graph=is_cuda_graph_compatible, | ||
| cold_l2_cache=True, | ||
| input_args=input_args, | ||
| ) | ||
|
|
||
| # Compute performance metrics | ||
| median_time = np.median(times) | ||
| std_time = np.std(times) | ||
| tflops = calculate_moe_tflops( | ||
| num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time | ||
| ) | ||
| tb_per_sec = calculate_moe_kernel_bandwidth( | ||
| num_tokens, | ||
| hidden_size, | ||
| intermediate_size, | ||
| num_experts, | ||
| top_k, | ||
| median_time, | ||
| input_dtype, | ||
| weight_dtype, | ||
| input_format="nvfp4", | ||
| weight_format="nvfp4", | ||
| routing_logits_dtype=None, | ||
| active_experts=num_active_experts, | ||
| verbose=args.verbose, | ||
| ) | ||
|
|
||
| print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) | ||
|
|
||
| if args.output_path is not None: | ||
| cur_res = defaultdict(str) | ||
| cur_res["routine"] = args.routine | ||
| cur_res["median_time"] = median_time | ||
| cur_res["std_time"] = std_time | ||
| cur_res["tflops"] = tflops | ||
| cur_res["tb_per_sec"] = tb_per_sec | ||
| cur_res["backend"] = backend | ||
| cur_res["num_tokens"] = num_tokens | ||
| cur_res["hidden_size"] = hidden_size | ||
| cur_res["intermediate_size"] = intermediate_size | ||
| cur_res["num_experts"] = num_experts | ||
| cur_res["top_k"] = top_k | ||
| cur_res["local_expert_offset"] = local_expert_offset | ||
| cur_res["local_num_experts"] = local_num_experts | ||
| cur_res["input_dtype"] = input_dtype | ||
| cur_res["weight_dtype"] = weight_dtype | ||
| cur_res["fp4_mode"] = "nvfp4" | ||
| res.append(cur_res) | ||
|
|
||
| return res | ||
|
|
||
|
|
||
| def testTrtllmFp8BlockScaleMoe(args): | ||
| """ | ||
| Test trtllm_fp8_block_scale_moe API (TensorRT-LLM fused MoE). | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.