diff --git a/benchmarks/bench_moe_deepseek.py b/benchmarks/bench_moe_deepseek.py new file mode 100644 index 0000000000..149394a614 --- /dev/null +++ b/benchmarks/bench_moe_deepseek.py @@ -0,0 +1,1074 @@ +#!/usr/bin/env python3 +"""DeepSeek-V3 MoE Performance Benchmark - CuteDSL vs CUTLASS vs TRTLLM. + +Compares three NVFP4 MoE backends on DeepSeek-V3 configuration: +- CuteDSL: FlashInfer's CuteDSL-based implementation +- CUTLASS: NVIDIA CUTLASS-based implementation +- TRTLLM: TensorRT-LLM's implementation + +Usage: + # Throughput benchmark (large batches: 128-4096 tokens) + python bench_moe_deepseek.py + + # Generation phase benchmark (small batches: 1-128 tokens) + python bench_moe_deepseek.py --gen-phase + + # With Expert Parallelism simulation + python bench_moe_deepseek.py --ep 1 # 256 local experts (no parallelism) + python bench_moe_deepseek.py --ep 8 # 32 local experts (8-way EP) + python bench_moe_deepseek.py --ep 16 # 16 local experts (16-way EP) + + # Custom token counts + python bench_moe_deepseek.py --num-tokens 64,128,256 + + # Disable CUDA graph (useful for debugging or profiling) + python bench_moe_deepseek.py --no-cuda-graph + + # Disable CUPTI (use CUDA events for timing instead) + python bench_moe_deepseek.py --no-cupti + +Metrics: + - ms: Latency in milliseconds + - TFLOPS: Computational throughput + - Speedup: CuteDSL latency / other backend latency (>1 = CuteDSL faster) +""" + +import argparse +from dataclasses import dataclass +import numpy as np +import torch + + +@dataclass +class DeepSeekConfig: + hidden_size: int = 7168 + intermediate_size: int = 2048 + num_experts: int = 256 + n_group: int = 8 + topk_group: int = 4 + top_k: int = 8 + routed_scaling_factor: float = 2.5 + + +CFG = DeepSeekConfig() +TOKEN_COUNTS = [128, 256, 512, 1024, 2048, 4096] + +# Generation phase token counts (small batches typical in decode) +GEN_PHASE_TOKENS = [1, 2, 4, 8, 16, 32, 64, 128] + +# Expert Parallelism configurations +# EP=1: all 256 experts on single GPU +# EP=8: 32 experts per GPU (256/8) +# EP=16: 16 experts per GPU (256/16) +EP_CONFIGS = { + 1: {"num_local_experts": 256, "local_expert_offset": 0}, + 8: {"num_local_experts": 32, "local_expert_offset": 0}, + 16: {"num_local_experts": 16, "local_expert_offset": 0}, +} + + +def is_sm100_family(): + """Check for SM100 family (Blackwell: SM100, SM103, SM110). + + CuteDSL MoE NVFP4 kernels are optimized for SM100 architecture. + SM120+ (Rubin) may have different shared memory/TMEM configurations. + """ + if not torch.cuda.is_available(): + return False + props = torch.cuda.get_device_properties(0) + return props.major == 10 + + +def calc_tflops(n, ms, num_local_experts=None): + """Calculate TFLOPS for MoE computation. + + With EP, only tokens routed to local experts are computed. + Assumes uniform routing distribution across experts. + """ + if num_local_experts is None: + num_local_experts = CFG.num_experts + + # Fraction of work done locally (assuming uniform distribution) + local_fraction = num_local_experts / CFG.num_experts + + flops = ( + n + * CFG.top_k + * local_fraction # Only local expert pairs are computed + * ( + 2 * CFG.hidden_size * 2 * CFG.intermediate_size + + 2 * CFG.intermediate_size * CFG.hidden_size + ) + ) + return flops / (ms * 1e-3) / 1e12 + + +def interleave(x, gs=64): + M, K = x.shape[-2], x.shape[-1] + return ( + x.view(*x.shape[:-2], 2, M // (gs * 2), gs, K) + .transpose(-4, -3) + .contiguous() + .view(*x.shape) + ) + + +def create_inputs(n, dev="cuda"): + """Create inputs for all backends (CuteDSL, CUTLASS, TRTLLM).""" + from flashinfer.fp4_quantization import fp4_quantize + + torch.manual_seed(42) + sv = 16 + FP8M = torch.finfo(torch.float8_e4m3fn).max + FP4M = 6.0 + + # Router logits and bias + rl = torch.randn(n, CFG.num_experts, device=dev, dtype=torch.float32) + rb = torch.randn(CFG.num_experts, device=dev, dtype=torch.bfloat16) + + # Hidden states + hb = torch.randn(n, CFG.hidden_size, device=dev, dtype=torch.bfloat16) / 10 + hg = FP8M * FP4M / hb.abs().max().float() + + # Weights (BF16) + w1b = ( + torch.randn( + CFG.num_experts, + 2 * CFG.intermediate_size, + CFG.hidden_size, + device=dev, + dtype=torch.bfloat16, + ) + / 10 + ) + w2b = ( + torch.randn( + CFG.num_experts, + CFG.hidden_size, + CFG.intermediate_size, + device=dev, + dtype=torch.bfloat16, + ) + / 10 + ) + + # Compute per-expert global scales + w1_gs_list, w2_gs_list = [], [] + for e in range(CFG.num_experts): + w1_gs_list.append(FP8M * FP4M / w1b[e].abs().max().float()) + w2_gs_list.append(FP8M * FP4M / w2b[e].abs().max().float()) + w1_gs = torch.tensor(w1_gs_list, device=dev) + w2_gs = torch.tensor(w2_gs_list, device=dev) + + # CUTLASS format: quantize with swizzled scale factors + w1_fp4_list, w1_sf_list = [], [] + w2_fp4_list, w2_sf_list = [], [] + for e in range(CFG.num_experts): + q1, s1 = fp4_quantize(w1b[e], w1_gs[e], sv, False, True) # swizzled + w1_fp4_list.append(q1) + w1_sf_list.append(s1) + q2, s2 = fp4_quantize(w2b[e], w2_gs[e], sv, False, True) # swizzled + w2_fp4_list.append(q2) + w2_sf_list.append(s2) + + return { + "router_logits": rl, + "routing_bias": rb, + "hidden_bf16": hb, + "hidden_gs": hg, + "w1_bf16": w1b, + "w1_gs": w1_gs, + "w2_bf16": w2b, + "w2_gs": w2_gs, + # CUTLASS specific + "w1_fp4": torch.stack(w1_fp4_list), + "w1_sf": torch.stack(w1_sf_list), + "w2_fp4": torch.stack(w2_fp4_list), + "w2_sf": torch.stack(w2_sf_list), + } + + +# ============================================================================= +# Benchmark Functions +# ============================================================================= + + +def bench_cute_dsl( + inputs, + warmup=10, + iters=100, + num_local_experts=None, + local_expert_offset=0, + use_cuda_graph=True, + use_cupti=True, + use_wrapper=False, +): + """Benchmark CuteDSL MoE. + + Args: + use_wrapper: If True, use CuteDslMoEWrapper API (recommended for CUDA graph). + If False, use cute_dsl_fused_moe_nvfp4 functional API. + """ + from flashinfer.fused_moe import fused_topk_deepseek + from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout + from flashinfer.fp4_quantization import fp4_quantize + from flashinfer.testing.utils import bench_gpu_time + + if num_local_experts is None: + num_local_experts = CFG.num_experts + + n, sv, dev = inputs["router_logits"].shape[0], 16, "cuda" + gs1 = torch.tensor([1.0], device=dev) + + tv = torch.empty(n, CFG.top_k, dtype=torch.float32, device=dev) + ti = torch.empty(n, CFG.top_k, dtype=torch.int32, device=dev) + + xf, xs = fp4_quantize(inputs["hidden_bf16"], gs1, sv, False, False) + xs = xs.unsqueeze(-1) + + # Expert range for this EP partition + expert_start = local_expert_offset + expert_end = local_expert_offset + num_local_experts + + # Slice weights to LOCAL experts only + w1_local = inputs["w1_bf16"][expert_start:expert_end] + w2_local = inputs["w2_bf16"][expert_start:expert_end] + + w1i = interleave(w1_local, 64) + w1f = w1i.view(num_local_experts * 2 * CFG.intermediate_size, CFG.hidden_size) + w1q, w1s = fp4_quantize(w1f, gs1, sv, False, True) + w1q = w1q.view(num_local_experts, 2 * CFG.intermediate_size, CFG.hidden_size // 2) + w1s = convert_sf_to_mma_layout( + w1s, 2 * CFG.intermediate_size, CFG.hidden_size, num_local_experts, sv + ) + + w2f = w2_local.view(num_local_experts * CFG.hidden_size, CFG.intermediate_size) + w2q, w2s = fp4_quantize(w2f, gs1, sv, False, True) + w2q = w2q.view(num_local_experts, CFG.hidden_size, CFG.intermediate_size // 2) + w2s = convert_sf_to_mma_layout( + w2s, CFG.hidden_size, CFG.intermediate_size, num_local_experts, sv + ) + + # Alpha sized for LOCAL experts only + alpha, fc2sc = ( + torch.ones(num_local_experts, device=dev), + torch.tensor([1.0], device=dev), + ) + + # Pre-convert routing bias to float32 + routing_bias_f32 = inputs["routing_bias"].float() + + if use_wrapper: + # Use CuteDslMoEWrapper (recommended for CUDA graph) + from flashinfer import CuteDslMoEWrapper + + moe = CuteDslMoEWrapper( + num_experts=CFG.num_experts, + top_k=CFG.top_k, + hidden_size=CFG.hidden_size, + intermediate_size=CFG.intermediate_size, + use_cuda_graph=use_cuda_graph, + max_num_tokens=n, + num_local_experts=num_local_experts, + local_expert_offset=local_expert_offset, + ) + + def run(x, x_sf, router_logits, routing_bias, topk_values, topk_indices): + fused_topk_deepseek( + scores=router_logits, + bias=routing_bias, + n_group=CFG.n_group, + topk_group=CFG.topk_group, + topk=CFG.top_k, + routed_scaling_factor=CFG.routed_scaling_factor, + topk_values=topk_values, + topk_indices=topk_indices, + ) + return moe.run( + x=x, + x_sf=x_sf, + token_selected_experts=topk_indices, + token_final_scales=topk_values, + w1_weight=w1q, + w1_weight_sf=w1s, + w1_alpha=alpha, + fc2_input_scale=fc2sc, + w2_weight=w2q, + w2_weight_sf=w2s, + w2_alpha=alpha, + ) + else: + # Use functional API + from flashinfer import cute_dsl_fused_moe_nvfp4 + + def run(x, x_sf, router_logits, routing_bias, topk_values, topk_indices): + fused_topk_deepseek( + scores=router_logits, + bias=routing_bias, + n_group=CFG.n_group, + topk_group=CFG.topk_group, + topk=CFG.top_k, + routed_scaling_factor=CFG.routed_scaling_factor, + topk_values=topk_values, + topk_indices=topk_indices, + ) + return cute_dsl_fused_moe_nvfp4( + x=x, + x_sf=x_sf, + token_selected_experts=topk_indices, + token_final_scales=topk_values, + w1_weight=w1q, + w1_weight_sf=w1s, + w1_alpha=alpha, + fc2_input_scale=fc2sc, + w2_weight=w2q, + w2_weight_sf=w2s, + w2_alpha=alpha, + num_experts=CFG.num_experts, + top_k=CFG.top_k, + num_local_experts=num_local_experts, + local_expert_offset=local_expert_offset, + ) + + # Pass input tensors via input_kwargs for cold L2 cache rotation + input_kwargs = { + "x": xf, + "x_sf": xs, + "router_logits": inputs["router_logits"], + "routing_bias": routing_bias_f32, + "topk_values": tv, + "topk_indices": ti, + } + + times = bench_gpu_time( + run, + dry_run_iters=warmup, + repeat_iters=iters, + cold_l2_cache=True, + enable_cupti=use_cupti, + use_cuda_graph=use_cuda_graph, + input_kwargs=input_kwargs, + ) + return np.median(times) + + +def bench_cutlass( + inputs, + warmup=10, + iters=100, + num_local_experts=None, + local_expert_offset=0, + use_cuda_graph=True, + use_cupti=True, +): + from flashinfer.fused_moe import fused_topk_deepseek, cutlass_fused_moe + from flashinfer.fp4_quantization import fp4_quantize + from flashinfer.testing.utils import bench_gpu_time + + if num_local_experts is None: + num_local_experts = CFG.num_experts + + n, sv, dev = inputs["router_logits"].shape[0], 16, "cuda" + + tv = torch.empty(n, CFG.top_k, dtype=torch.float32, device=dev) + ti = torch.empty(n, CFG.top_k, dtype=torch.int32, device=dev) + + # Expert range for this EP partition + expert_start = local_expert_offset + expert_end = local_expert_offset + num_local_experts + + # Slice weights to LOCAL experts only (for fair EP comparison) + w1_fp4_local = inputs["w1_fp4"][expert_start:expert_end] + w1_sf_local = inputs["w1_sf"][expert_start:expert_end] + w1_gs_local = inputs["w1_gs"][expert_start:expert_end] + w2_fp4_local = inputs["w2_fp4"][expert_start:expert_end] + w2_sf_local = inputs["w2_sf"][expert_start:expert_end] + w2_gs_local = inputs["w2_gs"][expert_start:expert_end] + + # Prepare CUTLASS inputs + a1_gs = torch.tensor(1.0, device=dev, dtype=torch.float32) + a2_gs = torch.tensor(1.0, device=dev, dtype=torch.float32) + + quant_scales = [ + a1_gs, + w1_sf_local.view(torch.int32), + 1.0 / (a1_gs * w1_gs_local), + a2_gs, + w2_sf_local.view(torch.int32), + 1.0 / (a2_gs * w2_gs_local), + ] + + hidden_fp4, input_sf = fp4_quantize(inputs["hidden_bf16"], a1_gs, sv, False, True) + output = torch.empty(n, CFG.hidden_size, dtype=torch.bfloat16, device=dev) + + # Pre-convert routing bias to float32 + routing_bias_f32 = inputs["routing_bias"].float() + + # Pre-compute values that need conversion + w1_fp4_view = w1_fp4_local.contiguous().view(torch.long) + w2_fp4_view = w2_fp4_local.contiguous().view(torch.long) + + # Compute EP size from config + ep_size = CFG.num_experts // num_local_experts + + def run(hidden, sf, router_logits, routing_bias, topk_values, topk_indices): + # Routing (included in timing for fair comparison with TRTLLM) + fused_topk_deepseek( + scores=router_logits, + bias=routing_bias, + n_group=CFG.n_group, + topk_group=CFG.topk_group, + topk=CFG.top_k, + routed_scaling_factor=CFG.routed_scaling_factor, + topk_values=topk_values, + topk_indices=topk_indices, + ) + cutlass_fused_moe( + hidden, + topk_indices.to(torch.int), + topk_values, + w1_fp4_view, + w2_fp4_view, + torch.bfloat16, + quant_scales=quant_scales, + input_sf=sf, + output=output, + ep_size=ep_size, + ep_rank=0, # Simulating rank 0 of EP + ) + return output + + input_kwargs = { + "hidden": hidden_fp4, + "sf": input_sf, + "router_logits": inputs["router_logits"], + "routing_bias": routing_bias_f32, + "topk_values": tv, + "topk_indices": ti, + } + + times = bench_gpu_time( + run, + dry_run_iters=warmup, + repeat_iters=iters, + cold_l2_cache=True, + enable_cupti=use_cupti, + use_cuda_graph=use_cuda_graph, + input_kwargs=input_kwargs, + ) + return np.median(times) + + +def bench_trtllm( + inputs, + warmup=10, + iters=100, + num_local_experts=None, + local_expert_offset=0, + use_cuda_graph=True, + use_cupti=True, +): + from flashinfer.fused_moe import trtllm_fp4_block_scale_moe + from flashinfer.fused_moe.core import ( + RoutingMethodType, + _maybe_get_cached_w3_w1_permute_indices, + get_w2_permute_indices_with_cache, + ) + from flashinfer.fp4_quantization import fp4_quantize, block_scale_interleave + from flashinfer.testing.utils import bench_gpu_time + + if num_local_experts is None: + num_local_experts = CFG.num_experts + + n, dev = inputs["router_logits"].shape[0], inputs["router_logits"].device + sv, etm, cache = 16, 128, {} + + # Expert range for this EP partition + expert_start = local_expert_offset + expert_end = local_expert_offset + num_local_experts + + hg = inputs["hidden_gs"] + hfp, hsf = fp4_quantize(inputs["hidden_bf16"], hg, sv, False, True) + hfp = hfp.view(torch.uint8).reshape(n, CFG.hidden_size // 2) + hsc = ( + hsf.view(torch.float8_e4m3fn) + .flatten()[: n * CFG.hidden_size // sv] + .reshape(n, CFG.hidden_size // sv) + ) + + def prep(bf16, gs, M, K): + """Prepare weights for LOCAL experts only.""" + fl, sl = [], [] + for e in range(expert_start, expert_end): + q, s = fp4_quantize(bf16[e], gs[e], sv, False, False) + fl.append(q.view(torch.uint8).reshape(M, K // 2)) + sl.append(s.view(torch.float8_e4m3fn).reshape(M, K // sv)) + return torch.stack(fl), torch.stack(sl) + + w1f, w1s = prep( + inputs["w1_bf16"], inputs["w1_gs"], 2 * CFG.intermediate_size, CFG.hidden_size + ) + w2f, w2s = prep( + inputs["w2_bf16"], inputs["w2_gs"], CFG.hidden_size, CFG.intermediate_size + ) + + def shuf(fp4, sf, perm_fn): + """Shuffle weights for LOCAL experts only.""" + fsh, ssh = [], [] + for i in range(num_local_experts): + p = perm_fn(cache, fp4[i], etm) + fsh.append(fp4[i][p.to(dev)].contiguous()) + ps = perm_fn(cache, sf[i].view(torch.uint8), etm, sv) + ssh.append( + block_scale_interleave(sf[i].view(torch.uint8)[ps.to(dev)].contiguous()) + ) + return torch.stack(fsh), torch.stack(ssh) + + w1f, w1s = shuf(w1f, w1s, _maybe_get_cached_w3_w1_permute_indices) + w2f, w2s = shuf(w2f, w2s, get_w2_permute_indices_with_cache) + w1s = w1s.view(torch.float8_e4m3fn).reshape( + num_local_experts, 2 * CFG.intermediate_size, CFG.hidden_size // sv + ) + w2s = w2s.view(torch.float8_e4m3fn).reshape( + num_local_experts, CFG.hidden_size, CFG.intermediate_size // sv + ) + + # Scale tensors sized for LOCAL experts only + sc = torch.ones(num_local_experts, device=dev, dtype=torch.float32) + + def run(routing_logits, routing_bias, hidden_states, hidden_states_scale): + return trtllm_fp4_block_scale_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + gemm1_weights=w1f, + gemm1_weights_scale=w1s, + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=w2f, + gemm2_weights_scale=w2s, + gemm2_bias=None, + output1_scale_scalar=sc, + output1_scale_gate_scalar=sc, + output2_scale_scalar=sc, + num_experts=CFG.num_experts, + top_k=CFG.top_k, + n_group=CFG.n_group, + topk_group=CFG.topk_group, + intermediate_size=CFG.intermediate_size, + local_expert_offset=local_expert_offset, + local_num_experts=num_local_experts, + routed_scaling_factor=CFG.routed_scaling_factor, + routing_method_type=RoutingMethodType.DeepSeekV3, + do_finalize=True, + ) + + input_kwargs = { + "routing_logits": inputs["router_logits"], + "routing_bias": inputs["routing_bias"], + "hidden_states": hfp, + "hidden_states_scale": hsc, + } + + times = bench_gpu_time( + run, + dry_run_iters=warmup, + repeat_iters=iters, + cold_l2_cache=True, + enable_cupti=use_cupti, + use_cuda_graph=use_cuda_graph, + input_kwargs=input_kwargs, + ) + return np.median(times) + + +# ============================================================================= +# Autotune +# ============================================================================= + + +def run_autotune(inputs, verbose=True): + from flashinfer.fused_moe import ( + fused_topk_deepseek, + cutlass_fused_moe, + trtllm_fp4_block_scale_moe, + ) + from flashinfer.fused_moe.core import ( + RoutingMethodType, + _maybe_get_cached_w3_w1_permute_indices, + get_w2_permute_indices_with_cache, + ) + from flashinfer import cute_dsl_fused_moe_nvfp4 + from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout + from flashinfer.fp4_quantization import fp4_quantize, block_scale_interleave + from flashinfer.autotuner import autotune + + if verbose: + print("\nRunning autotune warmup for all backends...") + print("-" * 80) + + n, sv, dev = inputs["router_logits"].shape[0], 16, "cuda" + gs1 = torch.tensor([1.0], device=dev) + + tv = torch.empty(n, CFG.top_k, dtype=torch.float32, device=dev) + ti = torch.empty(n, CFG.top_k, dtype=torch.int32, device=dev) + fused_topk_deepseek( + scores=inputs["router_logits"], + bias=inputs["routing_bias"].float(), + n_group=CFG.n_group, + topk_group=CFG.topk_group, + topk=CFG.top_k, + routed_scaling_factor=CFG.routed_scaling_factor, + topk_values=tv, + topk_indices=ti, + ) + + # ------------------------------------------------------------------------- + # CuteDSL autotune + # ------------------------------------------------------------------------- + if verbose: + print("Autotuning CuteDSL...") + + xf, xs = fp4_quantize(inputs["hidden_bf16"], gs1, sv, False, False) + xs = xs.unsqueeze(-1) + + w1i = interleave(inputs["w1_bf16"], 64) + w1f = w1i.view(CFG.num_experts * 2 * CFG.intermediate_size, CFG.hidden_size) + w1q, w1s = fp4_quantize(w1f, gs1, sv, False, True) + w1q = w1q.view(CFG.num_experts, 2 * CFG.intermediate_size, CFG.hidden_size // 2) + w1s = convert_sf_to_mma_layout( + w1s, 2 * CFG.intermediate_size, CFG.hidden_size, CFG.num_experts, sv + ) + + w2f = inputs["w2_bf16"].view( + CFG.num_experts * CFG.hidden_size, CFG.intermediate_size + ) + w2q, w2s = fp4_quantize(w2f, gs1, sv, False, True) + w2q = w2q.view(CFG.num_experts, CFG.hidden_size, CFG.intermediate_size // 2) + w2s = convert_sf_to_mma_layout( + w2s, CFG.hidden_size, CFG.intermediate_size, CFG.num_experts, sv + ) + + alpha, fc2sc = ( + torch.ones(CFG.num_experts, device=dev), + torch.tensor([1.0], device=dev), + ) + + with autotune(True): + for _ in range(10): + cute_dsl_fused_moe_nvfp4( + x=xf, + x_sf=xs, + token_selected_experts=ti, + token_final_scales=tv, + w1_weight=w1q, + w1_weight_sf=w1s, + w1_alpha=alpha, + fc2_input_scale=fc2sc, + w2_weight=w2q, + w2_weight_sf=w2s, + w2_alpha=alpha, + num_experts=CFG.num_experts, + top_k=CFG.top_k, + num_local_experts=CFG.num_experts, + local_expert_offset=0, + ) + torch.cuda.synchronize() + + # ------------------------------------------------------------------------- + # CUTLASS autotune + # ------------------------------------------------------------------------- + if verbose: + print("Autotuning CUTLASS...") + + a1_gs = torch.tensor(1.0, device=dev, dtype=torch.float32) + a2_gs = torch.tensor(1.0, device=dev, dtype=torch.float32) + quant_scales = [ + a1_gs, + inputs["w1_sf"].view(torch.int32), + 1.0 / (a1_gs * inputs["w1_gs"]), + a2_gs, + inputs["w2_sf"].view(torch.int32), + 1.0 / (a2_gs * inputs["w2_gs"]), + ] + hidden_fp4, input_sf = fp4_quantize(inputs["hidden_bf16"], a1_gs, sv, False, True) + output_cutlass = torch.empty(n, CFG.hidden_size, dtype=torch.bfloat16, device=dev) + + with autotune(True): + for _ in range(10): + cutlass_fused_moe( + hidden_fp4, + ti.to(torch.int), + tv, + inputs["w1_fp4"].contiguous().view(torch.long), + inputs["w2_fp4"].contiguous().view(torch.long), + torch.bfloat16, + quant_scales=quant_scales, + input_sf=input_sf, + output=output_cutlass, + ) + torch.cuda.synchronize() + + # ------------------------------------------------------------------------- + # TRTLLM Gen autotune + # ------------------------------------------------------------------------- + if verbose: + print("Autotuning TRTLLM Gen...") + + etm, cache = 128, {} + hg = inputs["hidden_gs"] + hfp, hsf = fp4_quantize(inputs["hidden_bf16"], hg, sv, False, True) + hfp = hfp.view(torch.uint8).reshape(n, CFG.hidden_size // 2) + hsc = ( + hsf.view(torch.float8_e4m3fn) + .flatten()[: n * CFG.hidden_size // sv] + .reshape(n, CFG.hidden_size // sv) + ) + + def prep(bf16, gs, M, K): + fl, sl = [], [] + for e in range(CFG.num_experts): + q, s = fp4_quantize(bf16[e], gs[e], sv, False, False) + fl.append(q.view(torch.uint8).reshape(M, K // 2)) + sl.append(s.view(torch.float8_e4m3fn).reshape(M, K // sv)) + return torch.stack(fl), torch.stack(sl) + + w1f_trt, w1s_trt = prep( + inputs["w1_bf16"], inputs["w1_gs"], 2 * CFG.intermediate_size, CFG.hidden_size + ) + w2f_trt, w2s_trt = prep( + inputs["w2_bf16"], inputs["w2_gs"], CFG.hidden_size, CFG.intermediate_size + ) + + def shuf(fp4, sf, perm_fn): + fsh, ssh = [], [] + for i in range(CFG.num_experts): + p = perm_fn(cache, fp4[i], etm) + fsh.append(fp4[i][p.to(dev)].contiguous()) + ps = perm_fn(cache, sf[i].view(torch.uint8), etm, sv) + ssh.append( + block_scale_interleave(sf[i].view(torch.uint8)[ps.to(dev)].contiguous()) + ) + return torch.stack(fsh), torch.stack(ssh) + + w1f_trt, w1s_trt = shuf(w1f_trt, w1s_trt, _maybe_get_cached_w3_w1_permute_indices) + w2f_trt, w2s_trt = shuf(w2f_trt, w2s_trt, get_w2_permute_indices_with_cache) + w1s_trt = w1s_trt.view(torch.float8_e4m3fn).reshape( + CFG.num_experts, 2 * CFG.intermediate_size, CFG.hidden_size // sv + ) + w2s_trt = w2s_trt.view(torch.float8_e4m3fn).reshape( + CFG.num_experts, CFG.hidden_size, CFG.intermediate_size // sv + ) + + sc = torch.ones(CFG.num_experts, device=dev, dtype=torch.float32) + + with autotune(True): + for _ in range(10): + trtllm_fp4_block_scale_moe( + routing_logits=inputs["router_logits"], + routing_bias=inputs["routing_bias"], + hidden_states=hfp, + hidden_states_scale=hsc, + gemm1_weights=w1f_trt, + gemm1_weights_scale=w1s_trt, + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=w2f_trt, + gemm2_weights_scale=w2s_trt, + gemm2_bias=None, + output1_scale_scalar=sc, + output1_scale_gate_scalar=sc, + output2_scale_scalar=sc, + num_experts=CFG.num_experts, + top_k=CFG.top_k, + n_group=CFG.n_group, + topk_group=CFG.topk_group, + intermediate_size=CFG.intermediate_size, + local_expert_offset=0, + local_num_experts=CFG.num_experts, + routed_scaling_factor=CFG.routed_scaling_factor, + routing_method_type=RoutingMethodType.DeepSeekV3, + do_finalize=True, + ) + torch.cuda.synchronize() + + if verbose: + print("-" * 80) + print("Autotune complete for all backends.\n") + + +# ============================================================================= +# Main Benchmark +# ============================================================================= + + +@dataclass +class BenchResult: + """Single benchmark result for one backend at one token count.""" + + backend: str + tokens: int + latency_ms: float + tflops: float + + +def run_benchmark( + token_counts, + warmup=10, + iters=100, + ep_config=1, + do_autotune=True, + verbose=True, + use_cuda_graph=True, + use_cupti=True, + use_wrapper=True, +): + """ + Unified benchmark for DeepSeek-V3 MoE backends. + + Args: + token_counts: List of token counts to benchmark + warmup: Warmup iterations + iters: Benchmark iterations + ep_config: Expert Parallelism config (1, 8, or 16) + do_autotune: Whether to run autotune before benchmarking + verbose: Print results to stdout + use_cuda_graph: Whether to use CUDA graph for benchmarking + use_cupti: Whether to use CUPTI for accurate GPU timing + use_wrapper: Whether to use CuteDslMoEWrapper API (recommended) + + Returns: + List of BenchResult objects + """ + # Get EP configuration + ep_cfg = EP_CONFIGS.get(ep_config, EP_CONFIGS[1]) + num_local = ep_cfg["num_local_experts"] + local_offset = ep_cfg["local_expert_offset"] + + # Run autotune if requested (BEFORE printing header to avoid interleaved output) + if do_autotune: + run_autotune(create_inputs(max(token_counts)), verbose=verbose) + + # Print header AFTER autotune completes + if verbose: + _print_header(ep_config, num_local, use_cuda_graph, use_cupti) + + # Run benchmarks + results = [] + for n in token_counts: + row = _benchmark_single( + n, + warmup, + iters, + num_local, + local_offset, + use_cuda_graph, + use_cupti, + use_wrapper=use_wrapper, + ) + results.extend(row) + if verbose: + _print_row(row) + + # Print footer + if verbose: + _print_footer(ep_config, num_local) + + return results + + +def _benchmark_single( + n, + warmup, + iters, + num_local, + local_offset, + use_cuda_graph, + use_cupti, + use_wrapper=True, +): + """Benchmark all backends for a single token count. + + Args: + use_wrapper: If True, use CuteDslMoEWrapper API for CuteDSL. + """ + inputs = create_inputs(n) + + # Run all three backends + lat = { + "CuteDSL": bench_cute_dsl( + inputs, + warmup, + iters, + num_local, + local_offset, + use_cuda_graph, + use_cupti, + use_wrapper=use_wrapper, + ), + "CUTLASS": bench_cutlass( + inputs, warmup, iters, num_local, local_offset, use_cuda_graph, use_cupti + ), + "TRTLLM": bench_trtllm( + inputs, warmup, iters, num_local, local_offset, use_cuda_graph, use_cupti + ), + } + + # Build results + results = [] + for backend, latency in lat.items(): + results.append( + BenchResult( + backend=backend, + tokens=n, + latency_ms=latency, + tflops=calc_tflops(n, latency, num_local), + ) + ) + return results + + +def _print_header(ep_config, num_local, use_cuda_graph, use_cupti): + """Print benchmark header.""" + print("\n" + "=" * 100) + print(f"DeepSeek-V3 MoE Benchmark: CuteDSL vs CUTLASS vs TRTLLM (EP={ep_config})") + print("=" * 100) + print( + f"Model: hidden={CFG.hidden_size}, intermediate={CFG.intermediate_size}, " + f"experts={CFG.num_experts}, top_k={CFG.top_k}" + ) + print( + f"EP Config: {num_local} local experts (simulating {CFG.num_experts // num_local}-way parallelism)" + ) + print( + f"CUDA Graph: {'enabled' if use_cuda_graph else 'disabled'}, CUPTI: {'enabled' if use_cupti else 'disabled'}" + ) + print("-" * 100) + print( + f"{'Tokens':>6} | " + f"{'CuteDSL':^15} | " + f"{'CUTLASS':^15} | " + f"{'TRTLLM':^15} | " + f"{'Speedup (CuteDSL/X)':^18} | " + f"{'Winner':^8}" + ) + print( + f"{'':>6} | " + f"{'ms':>7} {'TFLOPS':>7} | " + f"{'ms':>7} {'TFLOPS':>7} | " + f"{'ms':>7} {'TFLOPS':>7} | " + f"{'CUTLASS':>8} {'TRTLLM':>8} |" + ) + print("-" * 100) + + +def _print_row(results): + """Print a single row of benchmark results.""" + # Extract values by backend + r = {r.backend: r for r in results} + cute, cutlass, trtllm = r["CuteDSL"], r["CUTLASS"], r["TRTLLM"] + + # Calculate speedups (> 1.0 means CuteDSL is faster) + speedup_cutlass = cutlass.latency_ms / cute.latency_ms + speedup_trtllm = trtllm.latency_ms / cute.latency_ms + + # Find winner + winner = min(r.values(), key=lambda x: x.latency_ms).backend + + print( + f"{cute.tokens:>6} | " + f"{cute.latency_ms:>7.3f} {cute.tflops:>7.1f} | " + f"{cutlass.latency_ms:>7.3f} {cutlass.tflops:>7.1f} | " + f"{trtllm.latency_ms:>7.3f} {trtllm.tflops:>7.1f} | " + f"{speedup_cutlass:>7.2f}x {speedup_trtllm:>7.2f}x | " + f"{winner:^8}" + ) + + +def _print_footer(ep_config, num_local): + """Print benchmark footer.""" + print("-" * 100) + print("Speedup > 1.0 means CuteDSL is faster than that backend") + + +def main(): + parser = argparse.ArgumentParser( + description="DeepSeek-V3 MoE Performance Benchmark" + ) + parser.add_argument( + "--num-tokens", + type=str, + default=None, + help="Comma-separated token counts (default: 128-4096 for throughput, 1-128 for gen-phase)", + ) + parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations") + parser.add_argument("--iters", type=int, default=100, help="Benchmark iterations") + parser.add_argument("--no-autotune", action="store_true", help="Disable autotune") + parser.add_argument("--quiet", action="store_true", help="Minimal output") + parser.add_argument( + "--gen-phase", + action="store_true", + help="Use generation phase token counts (1-128 instead of 128-4096)", + ) + parser.add_argument( + "--ep", + type=int, + default=1, + choices=[1, 8, 16], + help="Expert Parallelism: 1 (256 local), 8 (32 local), 16 (16 local)", + ) + parser.add_argument( + "--no-cuda-graph", + action="store_true", + help="Disable CUDA graph for benchmarking (enabled by default)", + ) + parser.add_argument( + "--no-cupti", + action="store_true", + help="Disable CUPTI for GPU timing (enabled by default)", + ) + parser.add_argument( + "--functional-api", + action="store_true", + help="Use functional API instead of CuteDslMoEWrapper for CuteDSL benchmark", + ) + args = parser.parse_args() + + if not is_sm100_family(): + print("ERROR: Requires SM100 family GPU (Blackwell: SM100, SM103, SM110)") + return 1 + + # Determine token counts + if args.num_tokens: + tokens = [int(x) for x in args.num_tokens.split(",")] + elif args.gen_phase: + tokens = GEN_PHASE_TOKENS # [1, 2, 4, 8, 16, 32, 64, 128] + else: + tokens = TOKEN_COUNTS # [128, 256, 512, 1024, 2048, 4096] + + print("\nDeepSeek-V3 MoE Performance Benchmark") + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"CuteDSL API: {'Functional' if args.functional_api else 'Wrapper'}") + + run_benchmark( + token_counts=tokens, + warmup=args.warmup, + iters=args.iters, + ep_config=args.ep, + do_autotune=not args.no_autotune, + verbose=not args.quiet, + use_cuda_graph=not args.no_cuda_graph, + use_cupti=not args.no_cupti, + use_wrapper=not args.functional_api, + ) + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/csrc/moe_utils_binding.cu b/csrc/moe_utils_binding.cu new file mode 100644 index 0000000000..ced63cb71b --- /dev/null +++ b/csrc/moe_utils_binding.cu @@ -0,0 +1,346 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#ifdef ENABLE_FP8 +#include +#endif +#ifdef ENABLE_FP4 +#include +#endif + +#include "flashinfer/trtllm/fused_moe/RoutingKernel.h" +#include "tensorrt_llm/kernels/cuteDslKernels/moeUtils.h" +#include "tvm_ffi_utils.h" + +using namespace tensorrt_llm::kernels::cute_dsl; + +namespace { +// Helper function to compute log2 of a value (returns -1 if not power of 2) +inline int32_t computeLog2(int32_t val) { + int32_t n = val; + int32_t out = 0; + while (n >>= 1) { + ++out; + } + if ((1 << out) != val) { + out = -1; + } + return out; +} +} // namespace + +// ============================ moePermute bindings ============================ + +void moe_permute_fp16(int64_t input_ptr, int64_t permuted_output_ptr, int64_t input_sf_ptr, + int64_t permuted_sf_ptr, int64_t tile_idx_to_mn_limit_ptr, + int64_t permuted_idx_to_expanded_idx_ptr, int64_t num_non_exiting_tiles_ptr, + int32_t max_num_permuted_tokens, int32_t hidden_size, int32_t top_k, + int32_t tile_size, bool enable_pdl) { + moePermute( + reinterpret_cast(input_ptr), reinterpret_cast(permuted_output_ptr), + reinterpret_cast(input_sf_ptr), reinterpret_cast(permuted_sf_ptr), + reinterpret_cast(tile_idx_to_mn_limit_ptr), + reinterpret_cast(permuted_idx_to_expanded_idx_ptr), + reinterpret_cast(num_non_exiting_tiles_ptr), max_num_permuted_tokens, + hidden_size, top_k, tile_size, enable_pdl, get_current_stream()); +} + +#ifdef ENABLE_BF16 +void moe_permute_bf16(int64_t input_ptr, int64_t permuted_output_ptr, int64_t input_sf_ptr, + int64_t permuted_sf_ptr, int64_t tile_idx_to_mn_limit_ptr, + int64_t permuted_idx_to_expanded_idx_ptr, int64_t num_non_exiting_tiles_ptr, + int32_t max_num_permuted_tokens, int32_t hidden_size, int32_t top_k, + int32_t tile_size, bool enable_pdl) { + moePermute<__nv_bfloat16, uint8_t>( + reinterpret_cast<__nv_bfloat16 const*>(input_ptr), + reinterpret_cast<__nv_bfloat16*>(permuted_output_ptr), + reinterpret_cast(input_sf_ptr), reinterpret_cast(permuted_sf_ptr), + reinterpret_cast(tile_idx_to_mn_limit_ptr), + reinterpret_cast(permuted_idx_to_expanded_idx_ptr), + reinterpret_cast(num_non_exiting_tiles_ptr), max_num_permuted_tokens, + hidden_size, top_k, tile_size, enable_pdl, get_current_stream()); +} +#endif + +#ifdef ENABLE_FP8 +void moe_permute_fp8(int64_t input_ptr, int64_t permuted_output_ptr, int64_t input_sf_ptr, + int64_t permuted_sf_ptr, int64_t tile_idx_to_mn_limit_ptr, + int64_t permuted_idx_to_expanded_idx_ptr, int64_t num_non_exiting_tiles_ptr, + int32_t max_num_permuted_tokens, int32_t hidden_size, int32_t top_k, + int32_t tile_size, bool enable_pdl) { + moePermute<__nv_fp8_e4m3, uint8_t>( + reinterpret_cast<__nv_fp8_e4m3 const*>(input_ptr), + reinterpret_cast<__nv_fp8_e4m3*>(permuted_output_ptr), + reinterpret_cast(input_sf_ptr), reinterpret_cast(permuted_sf_ptr), + reinterpret_cast(tile_idx_to_mn_limit_ptr), + reinterpret_cast(permuted_idx_to_expanded_idx_ptr), + reinterpret_cast(num_non_exiting_tiles_ptr), max_num_permuted_tokens, + hidden_size, top_k, tile_size, enable_pdl, get_current_stream()); +} +#endif + +#ifdef ENABLE_FP4 +void moe_permute_fp4(int64_t input_ptr, int64_t permuted_output_ptr, int64_t input_sf_ptr, + int64_t permuted_sf_ptr, int64_t tile_idx_to_mn_limit_ptr, + int64_t permuted_idx_to_expanded_idx_ptr, int64_t num_non_exiting_tiles_ptr, + int32_t max_num_permuted_tokens, int32_t hidden_size, int32_t top_k, + int32_t tile_size, bool enable_pdl) { + moePermute<__nv_fp4_e2m1, uint8_t>( + reinterpret_cast<__nv_fp4_e2m1 const*>(input_ptr), + reinterpret_cast<__nv_fp4_e2m1*>(permuted_output_ptr), + reinterpret_cast(input_sf_ptr), reinterpret_cast(permuted_sf_ptr), + reinterpret_cast(tile_idx_to_mn_limit_ptr), + reinterpret_cast(permuted_idx_to_expanded_idx_ptr), + reinterpret_cast(num_non_exiting_tiles_ptr), max_num_permuted_tokens, + hidden_size, top_k, tile_size, enable_pdl, get_current_stream()); +} +#endif + +// ============================ moeUnpermute bindings ============================ + +void moe_unpermute_fp16_float_scale(int64_t permuted_input_ptr, int64_t output_ptr, + int64_t expanded_idx_to_permuted_idx_ptr, + int64_t topk_scales_ptr, int32_t num_tokens, + int32_t hidden_size, int32_t top_k, bool enable_pdl) { + moeUnpermute(reinterpret_cast(permuted_input_ptr), + reinterpret_cast(output_ptr), + reinterpret_cast(expanded_idx_to_permuted_idx_ptr), + reinterpret_cast(topk_scales_ptr), num_tokens, + hidden_size, top_k, enable_pdl, get_current_stream()); +} + +void moe_unpermute_fp16_half_scale(int64_t permuted_input_ptr, int64_t output_ptr, + int64_t expanded_idx_to_permuted_idx_ptr, + int64_t topk_scales_ptr, int32_t num_tokens, int32_t hidden_size, + int32_t top_k, bool enable_pdl) { + moeUnpermute(reinterpret_cast(permuted_input_ptr), + reinterpret_cast(output_ptr), + reinterpret_cast(expanded_idx_to_permuted_idx_ptr), + reinterpret_cast(topk_scales_ptr), num_tokens, hidden_size, + top_k, enable_pdl, get_current_stream()); +} + +#ifdef ENABLE_BF16 +void moe_unpermute_bf16_float_scale(int64_t permuted_input_ptr, int64_t output_ptr, + int64_t expanded_idx_to_permuted_idx_ptr, + int64_t topk_scales_ptr, int32_t num_tokens, + int32_t hidden_size, int32_t top_k, bool enable_pdl) { + moeUnpermute<__nv_bfloat16, float>( + reinterpret_cast<__nv_bfloat16 const*>(permuted_input_ptr), + reinterpret_cast<__nv_bfloat16*>(output_ptr), + reinterpret_cast(expanded_idx_to_permuted_idx_ptr), + reinterpret_cast(topk_scales_ptr), num_tokens, hidden_size, top_k, enable_pdl, + get_current_stream()); +} + +void moe_unpermute_bf16_bf16_scale(int64_t permuted_input_ptr, int64_t output_ptr, + int64_t expanded_idx_to_permuted_idx_ptr, + int64_t topk_scales_ptr, int32_t num_tokens, int32_t hidden_size, + int32_t top_k, bool enable_pdl) { + moeUnpermute<__nv_bfloat16, __nv_bfloat16>( + reinterpret_cast<__nv_bfloat16 const*>(permuted_input_ptr), + reinterpret_cast<__nv_bfloat16*>(output_ptr), + reinterpret_cast(expanded_idx_to_permuted_idx_ptr), + reinterpret_cast<__nv_bfloat16 const*>(topk_scales_ptr), num_tokens, hidden_size, top_k, + enable_pdl, get_current_stream()); +} +#endif + +// ============================ moeOutputMemset bindings ============================ + +void moe_output_memset_fp16(int64_t input_ptr, int64_t tile_idx_to_mn_limit_ptr, + int64_t expanded_idx_to_permuted_idx_ptr, + int64_t permuted_idx_to_expanded_idx_ptr, + int64_t num_non_exiting_tiles_ptr, int32_t max_num_permuted_tokens, + int32_t hidden_size, int32_t top_k, int32_t tile_size, + bool enable_pdl) { + moeOutputMemset(reinterpret_cast(input_ptr), + reinterpret_cast(tile_idx_to_mn_limit_ptr), + reinterpret_cast(expanded_idx_to_permuted_idx_ptr), + reinterpret_cast(permuted_idx_to_expanded_idx_ptr), + reinterpret_cast(num_non_exiting_tiles_ptr), + max_num_permuted_tokens, hidden_size, top_k, tile_size, enable_pdl, + get_current_stream()); +} + +#ifdef ENABLE_BF16 +void moe_output_memset_bf16(int64_t input_ptr, int64_t tile_idx_to_mn_limit_ptr, + int64_t expanded_idx_to_permuted_idx_ptr, + int64_t permuted_idx_to_expanded_idx_ptr, + int64_t num_non_exiting_tiles_ptr, int32_t max_num_permuted_tokens, + int32_t hidden_size, int32_t top_k, int32_t tile_size, + bool enable_pdl) { + moeOutputMemset<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16*>(input_ptr), + reinterpret_cast(tile_idx_to_mn_limit_ptr), + reinterpret_cast(expanded_idx_to_permuted_idx_ptr), + reinterpret_cast(permuted_idx_to_expanded_idx_ptr), + reinterpret_cast(num_non_exiting_tiles_ptr), + max_num_permuted_tokens, hidden_size, top_k, tile_size, enable_pdl, + get_current_stream()); +} +#endif + +// ============================ moeActivation bindings ============================ + +void moe_activation_fp16(int64_t input_ptr, int64_t output_ptr, int64_t tile_idx_to_mn_limit_ptr, + int64_t num_non_exiting_tiles_ptr, int32_t activation_type, + int32_t max_num_permuted_tokens, int32_t interm_size, int32_t tile_size, + bool enable_pdl) { + moeActivation(reinterpret_cast(input_ptr), reinterpret_cast(output_ptr), + reinterpret_cast(tile_idx_to_mn_limit_ptr), + reinterpret_cast(num_non_exiting_tiles_ptr), + static_cast(activation_type), max_num_permuted_tokens, + interm_size, tile_size, enable_pdl, get_current_stream()); +} + +#ifdef ENABLE_BF16 +void moe_activation_bf16(int64_t input_ptr, int64_t output_ptr, int64_t tile_idx_to_mn_limit_ptr, + int64_t num_non_exiting_tiles_ptr, int32_t activation_type, + int32_t max_num_permuted_tokens, int32_t interm_size, int32_t tile_size, + bool enable_pdl) { + moeActivation<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16 const*>(input_ptr), + reinterpret_cast<__nv_bfloat16*>(output_ptr), + reinterpret_cast(tile_idx_to_mn_limit_ptr), + reinterpret_cast(num_non_exiting_tiles_ptr), + static_cast(activation_type), + max_num_permuted_tokens, interm_size, tile_size, enable_pdl, + get_current_stream()); +} +#endif + +// ============================ TVM FFI Registration ============================ + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_permute_fp16, moe_permute_fp16); +#ifdef ENABLE_BF16 +TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_permute_bf16, moe_permute_bf16); +#endif +#ifdef ENABLE_FP8 +TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_permute_fp8, moe_permute_fp8); +#endif +#ifdef ENABLE_FP4 +TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_permute_fp4, moe_permute_fp4); +#endif + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_unpermute_fp16_float_scale, + moe_unpermute_fp16_float_scale); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_unpermute_fp16_half_scale, + moe_unpermute_fp16_half_scale); +#ifdef ENABLE_BF16 +TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_unpermute_bf16_float_scale, + moe_unpermute_bf16_float_scale); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_unpermute_bf16_bf16_scale, + moe_unpermute_bf16_bf16_scale); +#endif + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_output_memset_fp16, moe_output_memset_fp16); +#ifdef ENABLE_BF16 +TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_output_memset_bf16, moe_output_memset_bf16); +#endif + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_activation_fp16, moe_activation_fp16); +#ifdef ENABLE_BF16 +TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_activation_bf16, moe_activation_bf16); +#endif + +// ============================ moeSort bindings ============================ +// moe_sort - Sort tokens by expert assignment and generate mapping tensors +// This uses DeepSeekV3 routing method with pre-computed expert selections +// +// Returns via output pointers: +// - tile_idx_to_expert_idx: [max_num_tiles], mapping from tile to local expert index +// - tile_idx_to_mn_limit: [max_num_tiles], M/N limit for each tile +// - expanded_idx_to_permuted_idx: [num_tokens, top_k], mapping from expanded to permuted index +// - permuted_idx_to_expanded_idx: [max_num_permuted_tokens], mapping from permuted to expanded +// - total_num_padded_tokens: [1], total number of padded tokens +// - num_non_exiting_tiles: [1], number of non-exiting tiles + +void moe_sort( + // Inputs + int64_t token_selected_experts_ptr, // [num_tokens, top_k], int32 + int64_t token_final_scales_ptr, // [num_tokens, top_k], float32 or bf16 + int32_t num_tokens, int32_t num_experts, int32_t top_k, int32_t local_expert_offset, + int32_t num_local_experts, int32_t tile_tokens_dim, bool use_pdl, + // Outputs (pre-allocated buffers) + int64_t tile_idx_to_expert_idx_ptr, int64_t tile_idx_to_mn_limit_ptr, + int64_t expanded_idx_to_permuted_idx_ptr, int64_t permuted_idx_to_expanded_idx_ptr, + int64_t total_num_padded_tokens_ptr, int64_t num_non_exiting_tiles_ptr, + // Optional: expert counts buffer for large token counts (>1024) + // Should be size 2 * num_experts, int32 + int64_t expert_counts_ptr, + // Optional: explicit CUDA stream pointer for CUDA graph compatibility + // If 0, uses TVM FFI's current stream + int64_t cuda_stream_ptr) { + // Set up the routing data structure + moe::dev::routing::routingDeepSeek::Data routingData; + + // Configure dtypes + routingData.mDtypeExpW = batchedGemm::trtllm::gen::Dtype::Bfloat16; + routingData.mDtypeBias = batchedGemm::trtllm::gen::Dtype::Bfloat16; + routingData.mDtypeScore = batchedGemm::trtllm::gen::Dtype::Fp32; + routingData.mUsePdl = use_pdl; + + // Input tensors (pre-computed expert selections) + routingData.mPtrTopKIds = reinterpret_cast(token_selected_experts_ptr); + routingData.mPtrTopKWeights = reinterpret_cast(token_final_scales_ptr); + routingData.mPtrScores = nullptr; // Not using routing logits + routingData.mPtrRoutingBias = nullptr; // Not using bias + + // Output tensors + routingData.mPtrCtaIdxXyToBatchIdx = reinterpret_cast(tile_idx_to_expert_idx_ptr); + routingData.mPtrCtaIdxXyToMnLimit = reinterpret_cast(tile_idx_to_mn_limit_ptr); + routingData.mPtrExpandedIdxToPermutedIdx = + reinterpret_cast(expanded_idx_to_permuted_idx_ptr); + routingData.mPtrPermutedIdxToTokenIdx = + reinterpret_cast(permuted_idx_to_expanded_idx_ptr); + routingData.mPtrPermutedIdxSize = reinterpret_cast(total_num_padded_tokens_ptr); + routingData.mPtrNumNonExitingCtas = reinterpret_cast(num_non_exiting_tiles_ptr); + + // Not using packed format since we have explicit TopK IDs + routingData.mPtrTopKPacked = nullptr; + + // Expert counts buffer: required when num_tokens > 1024 + // The kernel will set this to nullptr internally for small token counts + routingData.mPtrExpertCounts = reinterpret_cast(expert_counts_ptr); + + // Metadata + routingData.mNumTokens = num_tokens; + routingData.mNumExperts = num_experts; + routingData.mTopK = top_k; + routingData.mPaddingLog2 = computeLog2(tile_tokens_dim); + routingData.mTileTokensDim = tile_tokens_dim; + routingData.mLocalExpertsStartIdx = local_expert_offset; + routingData.mLocalExpertsStrideLog2 = 0; + routingData.mNumLocalExperts = num_local_experts; + + // DeepSeekV3 specific parameters + // For moe_sort, we use n_group=1, topk_group=1 since experts are already selected + routingData.mNumExpertGroups = 1; + routingData.mNumLimitedGroups = 1; + routingData.mRouteScale = 1.0f; + routingData.mUseRoutingSoftmax = false; + + // Run the routing kernel + // Use explicit stream if provided (for CUDA graph compatibility), otherwise fall back to TVM FFI + // stream + cudaStream_t stream = + cuda_stream_ptr != 0 ? reinterpret_cast(cuda_stream_ptr) : get_current_stream(); + moe::dev::routing::routingDeepSeek::run(routingData, stream); +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(flashinfer_moe_sort, moe_sort); diff --git a/csrc/nv_internal/include/tensorrt_llm/common/config.h b/csrc/nv_internal/include/tensorrt_llm/common/config.h new file mode 100644 index 0000000000..cb157f6140 --- /dev/null +++ b/csrc/nv_internal/include/tensorrt_llm/common/config.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#ifndef TRTLLM_CONFIG_H +#define TRTLLM_CONFIG_H + +/** + * \def TRTLLM_ABI_NAMESPACE + * This macro is used to open an implicitly inline namespace block for the ABI version. + * This macro can be overridden to change the ABI version. + * The default ABI version is _v1. + */ +#ifndef TRTLLM_ABI_NAMESPACE +#define TRTLLM_ABI_NAMESPACE _v1 +#endif + +#ifndef TRTLLM_ABI_NAMESPACE_BEGIN +#define TRTLLM_ABI_NAMESPACE_BEGIN inline namespace TRTLLM_ABI_NAMESPACE { +#endif + +#ifndef TRTLLM_ABI_NAMESPACE_END +#define TRTLLM_ABI_NAMESPACE_END } +#endif + +/** + * \def TRTLLM_NAMESPACE_BEGIN + * This macro is used to open a `tensorrt_llm::` namespace block, along with any + * enclosing namespaces requested by TRTLLM_WRAPPED_NAMESPACE, etc. + * This macro is defined by TensorRT-LLM and may not be overridden. + */ +#define TRTLLM_NAMESPACE_BEGIN \ + namespace tensorrt_llm { \ + TRTLLM_ABI_NAMESPACE_BEGIN + +/** + * \def TRTLLM_NAMESPACE_END + * This macro is used to close a `tensorrt_llm::` namespace block, along with any + * enclosing namespaces requested by TRTLLM_WRAPPED_NAMESPACE, etc. + * This macro is defined by TensorRT-LLM and may not be overridden. + */ +#define TRTLLM_NAMESPACE_END \ + TRTLLM_ABI_NAMESPACE_END \ + } /* end namespace tensorrt_llm */ + +#endif // TRTLLM_CONFIG_H diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu b/csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu new file mode 100644 index 0000000000..0144d0885c --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu @@ -0,0 +1,485 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/config.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/cuteDslKernels/moeUtils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh" + +#ifdef ENABLE_FP4 +#include +#endif +#include + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels::cute_dsl { +namespace { +using ElemCopyType = uint4; +using SFCopyType = uint32_t; + +template +auto constexpr bitsPerElem() { +#ifdef ENABLE_FP4 + return std::is_same_v ? 4 : cute::sizeof_bits_v; +#else + return cute::sizeof_bits_v; +#endif +} + +template +auto constexpr elemPerCopy() { + return bitsPerElem() / bitsPerElem(); +} + +template +auto constexpr sfElemPerCopy() { + return bitsPerElem() / bitsPerElem(); +} + +// Helper to get max active blocks per SM +template +int32_t getMaxActiveBlocksPerSM(KernelFunc kernel, int32_t threadsPerBlock, + size_t dynamicSmemSize) { + int numBlocks = 0; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks, kernel, threadsPerBlock, + dynamicSmemSize); + return numBlocks; +} + +} // namespace + +template +__global__ void moePermuteKernel(InputType const* input, InputType* permuted_output, + SFType const* input_sf, SFType* permuted_sf, + int32_t const* tile_idx_to_mn_limit, + int32_t const* permuted_idx_to_expanded_idx, + int32_t const* num_non_exiting_tiles, int32_t const hidden_size, + int32_t const top_k, int32_t const tile_size) { + int32_t constexpr kElemPerCopy = elemPerCopy(); + [[maybe_unused]] int32_t constexpr kSFElemPerCopy = sfElemPerCopy(); + // Need int64_t to prevent overflow when computing pointer offsets. + int64_t const kCopyPerToken = hidden_size / kElemPerCopy; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + + int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size; + for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x) { + int32_t const tile_idx = permuted_idx / tile_size; + if (permuted_idx >= tile_idx_to_mn_limit[tile_idx]) { + continue; + } + int32_t const expanded_idx = permuted_idx_to_expanded_idx[permuted_idx]; + int32_t const token_idx = expanded_idx / top_k; + + auto const* src_ptr = reinterpret_cast(input) + token_idx * kCopyPerToken; + auto* dst_ptr = reinterpret_cast(permuted_output) + permuted_idx * kCopyPerToken; + for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock) { + dst_ptr[i] = src_ptr[i]; + } + + // Note: FP4 scale factor handling is deferred to Phase 3 + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template +void moePermute(InputType const* input, InputType* permuted_output, SFType const* input_sf, + SFType* permuted_sf, int32_t const* tile_idx_to_mn_limit, + int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles, + int32_t const max_num_permuted_tokens, int32_t const hidden_size, + int32_t const top_k, int32_t const tile_size, bool enable_pdl, + cudaStream_t stream) { + int32_t constexpr kThreadsPerBlock = 256; + int32_t constexpr kSFVecSize = 16; + int32_t constexpr kElemPerCopy = elemPerCopy(); + TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", + kElemPerCopy); + + auto kernel = &moePermuteKernel; + static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); + int32_t const maxBlocksPerSM = getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0); + int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens); + int32_t const threads = kThreadsPerBlock; + + cudaLaunchConfig_t config; + config.gridDim = blocks; + config.blockDim = threads; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel, input, permuted_output, input_sf, permuted_sf, + tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, + hidden_size, top_k, tile_size); +} + +#define INSTANTIATE_MOE_PERMUTE(InputType, SFType) \ + template void moePermute( \ + InputType const* input, InputType* permuted_output, SFType const* input_sf, \ + SFType* permuted_sf, int32_t const* tile_idx_to_mn_limit, \ + int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles, \ + int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, \ + int32_t const tile_size, bool enable_pdl, cudaStream_t stream) + +INSTANTIATE_MOE_PERMUTE(half, uint8_t); +#ifdef ENABLE_BF16 +INSTANTIATE_MOE_PERMUTE(__nv_bfloat16, uint8_t); +#endif +#ifdef ENABLE_FP8 +INSTANTIATE_MOE_PERMUTE(__nv_fp8_e4m3, uint8_t); +#endif +#ifdef ENABLE_FP4 +INSTANTIATE_MOE_PERMUTE(__nv_fp4_e2m1, uint8_t); +#endif +#undef INSTANTIATE_MOE_PERMUTE + +template +__global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* output, + int32_t const* expanded_idx_to_permuted_idx, + TopKScaleType const* topk_scales, int32_t const hidden_size, + int32_t const top_k) { + using AccumType = float; + int32_t constexpr kElemPerCopy = elemPerCopy(); + // Need int64_t to prevent overflow when computing pointer offsets. + int64_t const kCopyPerToken = hidden_size / kElemPerCopy; + InputType rmem[kElemPerCopy]; + AccumType rmemAccum[kElemPerCopy]; + + int32_t const token_idx = blockIdx.x; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + + auto* dst_ptr = reinterpret_cast(output) + token_idx * kCopyPerToken; + for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock) { +#pragma unroll + for (int32_t j = 0; j < kElemPerCopy; j++) { + rmemAccum[j] = 0; + } + for (int32_t k = 0; k < top_k; k++) { + int32_t const permuted_idx = expanded_idx_to_permuted_idx[token_idx * top_k + k]; + if (permuted_idx < 0) { + continue; + } + auto const* src_ptr = + reinterpret_cast(permuted_input) + permuted_idx * kCopyPerToken; + *reinterpret_cast(rmem) = src_ptr[i]; + TopKScaleType const scale = topk_scales[token_idx * top_k + k]; + +#pragma unroll + for (int32_t j = 0; j < kElemPerCopy; j++) { + rmemAccum[j] += static_cast(rmem[j]) * static_cast(scale); + } + } +#pragma unroll + for (int32_t j = 0; j < kElemPerCopy; j++) { + rmem[j] = static_cast(rmemAccum[j]); + } + dst_ptr[i] = *reinterpret_cast(rmem); + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template +void moeUnpermute(InputType const* permuted_input, InputType* output, + int32_t const* expanded_idx_to_permuted_idx, TopKScaleType const* topk_scales, + int32_t const num_tokens, int32_t const hidden_size, int32_t const top_k, + bool enable_pdl, cudaStream_t stream) { + int32_t constexpr kThreadsPerBlock = 256; + int32_t constexpr kElemPerCopy = elemPerCopy(); + TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", + kElemPerCopy); + + int32_t const blocks = num_tokens; + int32_t const threads = kThreadsPerBlock; + + auto kernel = &moeUnpermuteKernel; + + cudaLaunchConfig_t config; + config.gridDim = blocks; + config.blockDim = threads; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel, permuted_input, output, expanded_idx_to_permuted_idx, + topk_scales, hidden_size, top_k); +} + +#define INSTANTIATE_MOE_UNPERMUTE(InputType, TopKScaleType) \ + template void moeUnpermute(InputType const* permuted_input, InputType* output, \ + int32_t const* expanded_idx_to_permuted_idx, \ + TopKScaleType const* topk_scales, \ + int32_t const num_tokens, int32_t const hidden_size, \ + int32_t const top_k, bool enable_pdl, cudaStream_t stream) + +INSTANTIATE_MOE_UNPERMUTE(half, float); +INSTANTIATE_MOE_UNPERMUTE(half, half); +#ifdef ENABLE_BF16 +INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, float); +INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, __nv_bfloat16); +#endif +#undef INSTANTIATE_MOE_UNPERMUTE + +template +__global__ void moeOutputMemsetKernel(InputType* input, int32_t const* tile_idx_to_mn_limit, + int32_t const* expanded_idx_to_permuted_idx, + int32_t const* permuted_idx_to_expanded_idx, + int32_t const* num_non_exiting_tiles, + int32_t const hidden_size, int32_t const top_k, + int32_t const tile_size) { + int32_t constexpr kElemPerCopy = elemPerCopy(); + int64_t const kCopyPerToken = hidden_size / kElemPerCopy; + + InputType rmem[kElemPerCopy]; +#pragma unroll + for (int32_t j = 0; j < kElemPerCopy; j++) { + rmem[j] = 0; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + + int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size; + for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x) { + int32_t const tile_idx = permuted_idx / tile_size; + if (permuted_idx >= tile_idx_to_mn_limit[tile_idx]) { + continue; + } + int32_t const expanded_idx = permuted_idx_to_expanded_idx[permuted_idx]; + int32_t const token_idx = expanded_idx / top_k; + int32_t const topk_idx = expanded_idx % top_k; + + bool is_first_in_topk = true; + for (int32_t k = 0; k < topk_idx; k++) { + if (expanded_idx_to_permuted_idx[token_idx * top_k + k] >= 0) { + is_first_in_topk = false; + break; + } + } + if (!is_first_in_topk) { + continue; + } + + auto* dst_ptr = reinterpret_cast(input) + token_idx * kCopyPerToken; + for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock) { + dst_ptr[i] = *reinterpret_cast(rmem); + } + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template +void moeOutputMemset(InputType* input, int32_t const* tile_idx_to_mn_limit, + int32_t const* expanded_idx_to_permuted_idx, + int32_t const* permuted_idx_to_expanded_idx, + int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, + int32_t const hidden_size, int32_t const top_k, int32_t const tile_size, + bool enable_pdl, cudaStream_t stream) { + int32_t constexpr kThreadsPerBlock = 256; + int32_t constexpr kElemPerCopy = elemPerCopy(); + TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", + kElemPerCopy); + + auto kernel = &moeOutputMemsetKernel; + static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); + int32_t const maxBlocksPerSM = getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0); + int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens); + int32_t const threads = kThreadsPerBlock; + + cudaLaunchConfig_t config; + config.gridDim = blocks; + config.blockDim = threads; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel, input, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, + permuted_idx_to_expanded_idx, num_non_exiting_tiles, hidden_size, top_k, + tile_size); +} + +#define INSTANTIATE_MOE_OUTPUT_MEMSET(InputType) \ + template void moeOutputMemset( \ + InputType * input, int32_t const* tile_idx_to_mn_limit, \ + int32_t const* expanded_idx_to_permuted_idx, int32_t const* permuted_idx_to_expanded_idx, \ + int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, \ + int32_t const hidden_size, int32_t const top_k, int32_t const tile_size, bool enable_pdl, \ + cudaStream_t stream) + +INSTANTIATE_MOE_OUTPUT_MEMSET(half); +#ifdef ENABLE_BF16 +INSTANTIATE_MOE_OUTPUT_MEMSET(__nv_bfloat16); +#endif +#undef INSTANTIATE_MOE_OUTPUT_MEMSET + +// ============================== Activation Kernels ============================== + +template +__global__ void moeActivationKernel(InputType const* input, InputType* output, + int32_t const* tile_idx_to_mn_limit, + int32_t const* num_non_exiting_tiles, int32_t const interm_size, + int32_t const tile_size) { + using ComputeType = float; + int32_t constexpr kElemPerCopy = elemPerCopy(); + // Need int64_t to prevent overflow when computing pointer offsets. + int64_t const kCopyPerToken = interm_size / kElemPerCopy; + InputType rmem[kElemPerCopy]; + InputType rmemGate[kElemPerCopy]; + ActFn act{}; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + + int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size; + for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x) { + int32_t const tile_idx = permuted_idx / tile_size; + if (permuted_idx >= tile_idx_to_mn_limit[tile_idx]) { + continue; + } + auto const* src_ptr = reinterpret_cast(input) + + permuted_idx * kCopyPerToken * (ActFn::IS_GLU ? 2 : 1); + auto* dst_ptr = reinterpret_cast(output) + permuted_idx * kCopyPerToken; + for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock) { + *reinterpret_cast(rmem) = src_ptr[i]; + if constexpr (ActFn::IS_GLU) { + *reinterpret_cast(rmemGate) = src_ptr[i + kCopyPerToken]; +#pragma unroll + for (int32_t j = 0; j < kElemPerCopy; j++) { + rmem[j] = static_cast( + act(static_cast(rmemGate[j]), static_cast(rmem[j]))); + } + } else { +#pragma unroll + for (int32_t j = 0; j < kElemPerCopy; j++) { + rmem[j] = static_cast(act(static_cast(rmem[j]))); + } + } + + dst_ptr[i] = *reinterpret_cast(rmem); + } + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template +void moeActivation(InputType const* input, InputType* output, int32_t const* tile_idx_to_mn_limit, + int32_t const* num_non_exiting_tiles, MoeActivationType activation_type, + int32_t const max_num_permuted_tokens, int32_t const interm_size, + int32_t const tile_size, bool enable_pdl, cudaStream_t stream) { + int32_t constexpr kThreadsPerBlock = 256; + int32_t constexpr kElemPerCopy = elemPerCopy(); + TLLM_CHECK_WITH_INFO(interm_size % kElemPerCopy == 0, "interm_size must be divisible by %d.", + kElemPerCopy); + + using namespace cutlass_kernels; + + auto get_act_kernel = [](MoeActivationType act_type) -> void (*)(InputType const*, InputType*, + int32_t const*, int32_t const*, + int32_t const, int32_t const) { + switch (act_type) { + case MoeActivationType::Identity: + return &moeActivationKernel, + kThreadsPerBlock>; + case MoeActivationType::Gelu: + return &moeActivationKernel, + kThreadsPerBlock>; + case MoeActivationType::Geglu: + return &moeActivationKernel, + kThreadsPerBlock>; + case MoeActivationType::Relu: + return &moeActivationKernel, + kThreadsPerBlock>; + case MoeActivationType::Silu: + return &moeActivationKernel, + kThreadsPerBlock>; + case MoeActivationType::Swiglu: + return &moeActivationKernel, + kThreadsPerBlock>; + default: + TLLM_CHECK_WITH_INFO(false, "Unsupported activation type: %d", static_cast(act_type)); + return nullptr; + } + }; + + auto kernel = get_act_kernel(activation_type); + + static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); + int32_t const maxBlocksPerSM = getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0); + int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens); + int32_t const threads = kThreadsPerBlock; + + cudaLaunchConfig_t config; + config.gridDim = blocks; + config.blockDim = threads; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel, input, output, tile_idx_to_mn_limit, num_non_exiting_tiles, + interm_size, tile_size); +} + +#define INSTANTIATE_MOE_ACTIVATION(InputType) \ + template void moeActivation( \ + InputType const* input, InputType* output, int32_t const* tile_idx_to_mn_limit, \ + int32_t const* num_non_exiting_tiles, MoeActivationType activation_type, \ + int32_t const max_num_permuted_tokens, int32_t const interm_size, int32_t const tile_size, \ + bool enable_pdl, cudaStream_t stream) + +INSTANTIATE_MOE_ACTIVATION(half); +#ifdef ENABLE_BF16 +INSTANTIATE_MOE_ACTIVATION(__nv_bfloat16); +#endif +#undef INSTANTIATE_MOE_ACTIVATION + +// Note: moeActivationQuantize (fused activation + FP4 quantization) will be added later +// when NVFP4 output support is needed. + +} // namespace kernels::cute_dsl + +TRTLLM_NAMESPACE_END diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h b/csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h new file mode 100644 index 0000000000..0ea18a5bf7 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#include + +#include "tensorrt_llm/common/config.h" + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels::cute_dsl { + +// Activation type enum for standalone moeActivation kernel +// Note: Matches ActivationType in cutlass_kernels/include/common.h +enum class MoeActivationType { + Gelu = 0, + Relu = 1, + Silu = 2, + Swiglu = 3, + Geglu = 4, + Identity = 5, +}; + +template +void moePermute(InputType const* input, InputType* permuted_output, SFType const* input_sf, + SFType* permuted_sf, int32_t const* tile_idx_to_mn_limit, + int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles, + int32_t const max_num_permuted_tokens, int32_t const hidden_size, + int32_t const top_k, int32_t const tile_size, bool enable_pdl, cudaStream_t stream); + +template +void moeUnpermute(InputType const* permuted_input, InputType* output, + int32_t const* expanded_idx_to_permuted_idx, TopKScaleType const* topk_scales, + int32_t const num_tokens, int32_t const hidden_size, int32_t const top_k, + bool enable_pdl, cudaStream_t stream); + +template +void moeOutputMemset(InputType* input, int32_t const* tile_idx_to_mn_limit, + int32_t const* expanded_idx_to_permuted_idx, + int32_t const* permuted_idx_to_expanded_idx, + int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, + int32_t const hidden_size, int32_t const top_k, int32_t const tile_size, + bool enable_pdl, cudaStream_t stream); + +// ============================== Activation Kernels ============================== + +/** + * @brief Apply activation function to MoE intermediate outputs. + * + * For GLU activations (Swiglu, Geglu), input shape is (num_tokens, 2 * interm_size) + * where first half is linear projection and second half is gate. + * Output shape is (num_tokens, interm_size). + * + * For non-GLU activations (Gelu, Relu, Silu, Identity), input and output shape + * are both (num_tokens, interm_size). + * + * @param input Input tensor + * @param output Output tensor (same dtype as input for non-FP4 output) + * @param tile_idx_to_mn_limit Valid token count per tile + * @param num_non_exiting_tiles Number of valid tiles (scalar on device) + * @param activation_type Type of activation to apply + * @param max_num_permuted_tokens Maximum number of permuted tokens + * @param interm_size Intermediate size (output hidden dimension) + * @param tile_size Tile size for scheduling + * @param enable_pdl Enable Programmatic Dependent Launch + * @param stream CUDA stream + */ +template +void moeActivation(InputType const* input, InputType* output, int32_t const* tile_idx_to_mn_limit, + int32_t const* num_non_exiting_tiles, MoeActivationType activation_type, + int32_t const max_num_permuted_tokens, int32_t const interm_size, + int32_t const tile_size, bool enable_pdl, cudaStream_t stream); + +/** + * @brief Fused activation with NVFP4 dynamic quantization. + * + * Combines activation function with per-block NVFP4 quantization in a single kernel pass. + * Output is packed FP4 with swizzled scale factors. + * + * @param input Input tensor (bf16/fp16) + * @param output Output tensor (packed FP4, uint8) + * @param global_sf Global scale factor for quantization + * @param output_sf Per-block scale factors (FP8 E4M3, swizzled layout) + * @param tile_idx_to_mn_limit Valid token count per tile + * @param num_non_exiting_tiles Number of valid tiles + * @param activation_type Type of activation to apply + * @param max_num_permuted_tokens Maximum number of permuted tokens + * @param interm_size Intermediate size + * @param tile_size Tile size for scheduling + * @param enable_pdl Enable Programmatic Dependent Launch + * @param stream CUDA stream + */ +template +void moeActivationQuantize(InputType const* input, OutputType* output, float const* global_sf, + SFType* output_sf, int32_t const* tile_idx_to_mn_limit, + int32_t const* num_non_exiting_tiles, MoeActivationType activation_type, + int32_t const max_num_permuted_tokens, int32_t const interm_size, + int32_t const tile_size, bool enable_pdl, cudaStream_t stream); + +} // namespace kernels::cute_dsl + +TRTLLM_NAMESPACE_END diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh new file mode 100644 index 0000000000..8c96d64808 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#include "cutlass/epilogue/thread/activation.h" +#include "tensorrt_llm/common/config.h" + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels::cutlass_kernels { +// ============================== Activation Adaptors ================================= + +template