diff --git a/benchmarks/bench_bgmv_moe.py b/benchmarks/bench_bgmv_moe.py new file mode 100644 index 0000000000..ef8e3e12f3 --- /dev/null +++ b/benchmarks/bench_bgmv_moe.py @@ -0,0 +1,366 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Multi-LoRA MoE BGMV Kernel Benchmark. + +Compares the BGMV MoE CUDA kernel against FlashInfer's grouped_mm_bf16 baseline +across multiple model configurations and token counts. + +Usage: + FLASHINFER_DISABLE_VERSION_CHECK=1 python benchmarks/bench_bgmv_moe.py +""" + +import os + +os.environ.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1") + +import time +from dataclasses import dataclass +from typing import Callable + +import torch + + +@dataclass +class BenchmarkConfig: + """Configuration for a single benchmark run.""" + + name: str + num_tokens: int + hidden_size: int + rank: int + num_experts: int + top_k: int + num_loras: int + num_slices: int + dtype: torch.dtype + + +# Model configurations to benchmark +CONFIGS = [ + # Large MoE (hidden=3072, rank=32, 128 experts) + BenchmarkConfig("Decode-1tok-LargeMoE", 1, 3072, 32, 128, 2, 8, 1, torch.bfloat16), + BenchmarkConfig("Decode-4tok-LargeMoE", 4, 3072, 32, 128, 2, 8, 1, torch.bfloat16), + BenchmarkConfig("Decode-8tok-LargeMoE", 8, 3072, 32, 128, 2, 8, 1, torch.bfloat16), + BenchmarkConfig( + "Decode-32tok-LargeMoE", 32, 3072, 32, 128, 2, 8, 1, torch.bfloat16 + ), + BenchmarkConfig( + "Prefill-256tok-LargeMoE", 256, 3072, 32, 128, 2, 8, 1, torch.bfloat16 + ), + BenchmarkConfig( + "Prefill-512tok-LargeMoE", 512, 3072, 32, 128, 2, 8, 1, torch.bfloat16 + ), + BenchmarkConfig( + "Prefill-1024tok-LargeMoE", 1024, 3072, 32, 128, 2, 8, 1, torch.bfloat16 + ), + # Nemotron-Nano-3-30B-A3B (hidden=2688, rank=32, 128 experts) + BenchmarkConfig("Decode-1tok-Nemotron", 1, 2688, 32, 128, 2, 8, 1, torch.bfloat16), + BenchmarkConfig("Decode-4tok-Nemotron", 4, 2688, 32, 128, 2, 8, 1, torch.bfloat16), + BenchmarkConfig("Decode-8tok-Nemotron", 8, 2688, 32, 128, 2, 8, 1, torch.bfloat16), + BenchmarkConfig( + "Decode-32tok-Nemotron", 32, 2688, 32, 128, 2, 8, 1, torch.bfloat16 + ), + BenchmarkConfig( + "Prefill-256tok-Nemotron", 256, 2688, 32, 128, 2, 8, 1, torch.bfloat16 + ), + BenchmarkConfig( + "Prefill-512tok-Nemotron", 512, 2688, 32, 128, 2, 8, 1, torch.bfloat16 + ), + BenchmarkConfig( + "Prefill-1024tok-Nemotron", 1024, 2688, 32, 128, 2, 8, 1, torch.bfloat16 + ), +] + + +def generate_test_data(config: BenchmarkConfig, device: str = "cuda"): + """Generate random test data for a benchmark configuration.""" + num_tokens = config.num_tokens + hidden_size = config.hidden_size + rank = config.rank + num_experts = config.num_experts + top_k = config.top_k + num_loras = config.num_loras + num_slices = config.num_slices + dtype = config.dtype + num_pairs = num_tokens * top_k + feat_out = hidden_size + + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) * 0.1 + lora_a_weights = [ + torch.randn( + num_loras, num_experts, rank, hidden_size, dtype=dtype, device=device + ) + * 0.01 + for _ in range(num_slices) + ] + lora_b_weights = [ + torch.randn(num_loras, num_experts, feat_out, rank, dtype=dtype, device=device) + * 0.01 + for _ in range(num_slices) + ] + sorted_token_ids = torch.arange( + num_tokens, device=device, dtype=torch.int64 + ).repeat_interleave(top_k) + expert_ids = torch.randint( + 0, num_experts, (num_pairs,), device=device, dtype=torch.int64 + ) + topk_weights = ( + torch.softmax(torch.randn(num_tokens, top_k, device=device), dim=-1) + .view(-1) + .to(torch.float32) + ) + lora_indices = torch.randint( + 0, num_loras, (num_tokens,), device=device, dtype=torch.int64 + ) + + return { + "x": x, + "lora_a_weights": lora_a_weights, + "lora_b_weights": lora_b_weights, + "sorted_token_ids": sorted_token_ids, + "expert_ids": expert_ids, + "topk_weights": topk_weights, + "lora_indices": lora_indices, + "num_pairs": num_pairs, + "feat_out": feat_out, + } + + +def benchmark_fn(fn: Callable, warmup: int = 10, repeat: int = 100) -> float: + """Benchmark a function, return median time in microseconds.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + times = [] + for _ in range(repeat): + torch.cuda.synchronize() + start = time.perf_counter_ns() + fn() + torch.cuda.synchronize() + end = time.perf_counter_ns() + times.append((end - start) / 1000.0) + + times.sort() + return times[len(times) // 2] + + +def run_benchmark(config: BenchmarkConfig): + """Run benchmark for a single configuration.""" + from flashinfer.fused_moe.bgmv_moe import ( + bgmv_moe_shrink, + bgmv_moe_expand, + fill_w_ptr, + ) + + data = generate_test_data(config) + results = {"config": config.name} + + num_tokens = config.num_tokens + num_pairs = data["num_pairs"] + rank = config.rank + num_experts = config.num_experts + num_slices = config.num_slices + num_loras = config.num_loras + feat_out = data["feat_out"] + hidden_size = config.hidden_size + dtype = config.dtype + device = "cuda" + + # === BGMV MoE kernel === + w_ptr_a = torch.zeros(num_slices, num_experts, dtype=torch.int64, device=device) + lora_stride_a = 0 + for s in range(num_slices): + lora_stride_a = fill_w_ptr(w_ptr_a, data["lora_a_weights"][s], num_experts, s) + + w_ptr_b = torch.zeros(num_slices, num_experts, dtype=torch.int64, device=device) + lora_stride_b = 0 + for s in range(num_slices): + lora_stride_b = fill_w_ptr(w_ptr_b, data["lora_b_weights"][s], num_experts, s) + + shrink_out = torch.zeros(num_slices, num_pairs, rank, dtype=dtype, device=device) + slice_start_loc = torch.zeros(num_slices, dtype=torch.int64, device=device) + for s in range(num_slices): + slice_start_loc[s] = s * feat_out + output_slices = [feat_out] * num_slices + y_accum = torch.zeros( + num_tokens, feat_out * num_slices, dtype=torch.float32, device=device + ) + + def cuda_fn(): + shrink_out.zero_() + y_accum.zero_() + bgmv_moe_shrink( + shrink_out, + data["x"], + w_ptr_a, + data["sorted_token_ids"], + data["expert_ids"], + data["lora_indices"], + lora_stride_a, + ) + bgmv_moe_expand( + y_accum, + shrink_out, + w_ptr_b, + data["sorted_token_ids"], + data["expert_ids"], + data["topk_weights"], + data["lora_indices"], + slice_start_loc, + output_slices, + lora_stride_b, + ) + + cuda_time = benchmark_fn(cuda_fn) + results["bgmv_moe_us"] = cuda_time + + # === grouped_mm_bf16 baseline === + try: + from flashinfer.grouped_mm import grouped_mm_bf16 + + num_groups = num_loras * num_experts + lora_ids_expanded = data["lora_indices"][data["sorted_token_ids"]] + group_ids = lora_ids_expanded * num_experts + data["expert_ids"] + group_ids[lora_ids_expanded < 0] = num_groups + + sorted_indices = torch.argsort(group_ids) + sorted_group_ids = group_ids[sorted_indices] + valid_mask = sorted_group_ids < num_groups + num_valid = valid_mask.sum().item() + + sorted_token_indices = data["sorted_token_ids"][sorted_indices[:num_valid]] + g_input = data["x"][sorted_token_indices] + + counts = torch.zeros(num_groups + 1, dtype=torch.int32, device=device) + valid_groups = sorted_group_ids[:num_valid] + for g in range(num_groups): + counts[g + 1] = counts[g] + (valid_groups == g).sum().to(torch.int32) + g_m_indptr = counts + + g_lora_a = data["lora_a_weights"][0].view( + num_loras * num_experts, rank, hidden_size + ) + g_lora_b = data["lora_b_weights"][0].view( + num_loras * num_experts, feat_out, rank + ) + g_shrink_out = torch.zeros(num_valid, rank, dtype=dtype, device=device) + g_expand_out = torch.zeros(num_valid, feat_out, dtype=dtype, device=device) + + # Warmup + grouped_mm_bf16(g_input, g_lora_a, g_m_indptr, out=g_shrink_out) + grouped_mm_bf16(g_shrink_out, g_lora_b, g_m_indptr, out=g_expand_out) + torch.cuda.synchronize() + + # Kernel only + def gg_kernel_fn(): + g_shrink_out.zero_() + g_expand_out.zero_() + grouped_mm_bf16(g_input, g_lora_a, g_m_indptr, out=g_shrink_out) + grouped_mm_bf16(g_shrink_out, g_lora_b, g_m_indptr, out=g_expand_out) + + gg_kernel_time = benchmark_fn(gg_kernel_fn) + results["gg_kernel_us"] = gg_kernel_time + + # Sort + kernel + def gg_full_fn(): + _sorted_indices = torch.argsort(group_ids) + _sorted_token_indices = data["sorted_token_ids"][ + _sorted_indices[:num_valid] + ] + _g_input = data["x"][_sorted_token_indices] + g_shrink_out.zero_() + g_expand_out.zero_() + grouped_mm_bf16(_g_input, g_lora_a, g_m_indptr, out=g_shrink_out) + grouped_mm_bf16(g_shrink_out, g_lora_b, g_m_indptr, out=g_expand_out) + + gg_full_time = benchmark_fn(gg_full_fn) + results["gg_full_us"] = gg_full_time + except (ImportError, RuntimeError) as e: + print(f" [SKIP] grouped_mm_bf16 baseline: {e}") + results["gg_kernel_us"] = float("nan") + results["gg_full_us"] = float("nan") + + return results + + +def main(): + """Run all benchmarks and print results table.""" + if not torch.cuda.is_available(): + print("CUDA not available, skipping benchmarks.") + return + + device_name = torch.cuda.get_device_name(0) + + # Trigger JIT compilation before printing the table + from flashinfer.fused_moe.bgmv_moe import _get_bgmv_moe_module + + _get_bgmv_moe_module() + + print(f"\n{'=' * 100}") + print("Multi-LoRA MoE BGMV Kernel Benchmark") + print(f"Device: {device_name}") + print(f"{'=' * 100}\n") + + # Header + print( + f"{'Config':<28} {'GG-kern (μs)':>13} {'GG-sort+kern (μs)':>18} " + f"{'BGMV MoE (μs)':>14} {'vs GG-kern':>11} {'vs GG-sort+kern':>16}" + ) + print(f"{'-' * 28} {'-' * 13} {'-' * 18} {'-' * 14} {'-' * 11} {'-' * 16}") + + for config in CONFIGS: + results = run_benchmark(config) + + def fmt(v): + return f"{v:.1f}" if v == v else "N/A" + + def fmt_speedup(baseline, kernel): + if baseline != baseline or kernel != kernel or kernel == 0: + return "N/A" + return f"{baseline / kernel:.2f}x" + + bgmv = results["bgmv_moe_us"] + gg_k = results.get("gg_kernel_us", float("nan")) + gg_f = results.get("gg_full_us", float("nan")) + + print( + f"{results['config']:<28} " + f"{fmt(gg_k):>13} " + f"{fmt(gg_f):>18} " + f"{fmt(bgmv):>14} " + f"{fmt_speedup(gg_k, bgmv):>11} " + f"{fmt_speedup(gg_f, bgmv):>16}" + ) + + print(f"\n{'=' * 100}") + print("\nNotes:") + print( + " - 'GG-kern' = FlashInfer grouped_mm_bf16 kernel only (pre-sorted, no sort overhead)" + ) + print( + " - 'GG-sort+kern' = FlashInfer grouped_mm_bf16 with token sorting (sort + kernel)" + ) + print(" - 'BGMV MoE' = BGMV MoE CUDA kernel (this PR)") + print(" - 'LargeMoE' = hidden=3072, rank=32, 128 experts (large MoE model config)") + print( + " - 'Nemotron' = hidden=2688, rank=32, 128 experts (Nemotron-Nano-3-30B-A3B)" + ) + print(" - All times are median of 100 runs after 10 warmup iterations") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 9554e0e6b6..56395a9258 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -200,6 +200,7 @@ "cutlass_fused_moe", "cute_dsl_fp4_block_scale_moe", "b12x_fused_moe", + "bgmv_moe", ], "moe_comm": [ "moe_a2a_dispatch_combine", diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index fb260249ad..075fde70f1 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -98,6 +98,8 @@ def run_moe_test(args): return testCuteDslFp4BlockScaleMoe(args) elif args.routine == "b12x_fused_moe": return testB12xFusedMoe(args) + elif args.routine == "bgmv_moe": + return testBgmvMoe(args) else: raise ValueError(f"Unsupported routine: {args.routine}") @@ -119,7 +121,7 @@ def parse_moe_args(line, parser): "--intermediate_size", type=int, required=True, - help="Intermediate dimension size.", + help="Intermediate dimension size (not used for bgmv_moe).", ) # Note: num_experts/top_k is added by add_common_moe_args parser.add_argument( @@ -296,6 +298,29 @@ def parse_moe_args(line, parser): help="Expert parallel rank for cutlass_fused_moe.", ) + # BGMV MoE specific arguments + parser.add_argument( + "--rank", + type=int, + required=False, + default=32, + help="LoRA rank for bgmv_moe benchmark.", + ) + parser.add_argument( + "--num_loras", + type=int, + required=False, + default=8, + help="Number of LoRA adapters for bgmv_moe benchmark.", + ) + parser.add_argument( + "--num_slices", + type=int, + required=False, + default=1, + help="Number of weight slices for bgmv_moe benchmark (e.g., 2 for gate+up).", + ) + args = parser.parse_args(line) # Normalize routing method (map string to internal int expected by kernels) @@ -2364,3 +2389,366 @@ def run_fp8_per_tensor_moe( res.append(cur_res) return res + + +def testBgmvMoe(args): + """ + Benchmark BGMV MoE kernels for multi-LoRA inference. + + Measures the performance of the fused shrink + expand CUDA kernels + that apply LoRA adapters in MoE models. Compares against FlashInfer's + grouped_mm_bf16 as a baseline. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + list: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testBgmvMoe") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + + num_tokens = args.num_tokens + hidden_size = args.hidden_size + num_experts = args.num_experts + top_k = args.top_k + rank = args.rank + num_loras = args.num_loras + num_slices = getattr(args, "num_slices", 1) + res = [] + + if args.verbose >= 1: + print( + f"[INFO] Configuration: tokens={num_tokens}, hidden={hidden_size}, " + f"rank={rank}, experts={num_experts}, top_k={top_k}, " + f"num_loras={num_loras}, num_slices={num_slices}" + ) + + # Import BGMV MoE kernels + from flashinfer.fused_moe.bgmv_moe import ( + bgmv_moe_shrink, + bgmv_moe_expand, + fill_w_ptr, + ) + + # Generate test data + num_pairs = num_tokens * top_k + feat_out = hidden_size + + x = torch.randn(num_tokens, hidden_size, dtype=input_dtype, device=device) * 0.1 + lora_a_weights = [ + torch.randn( + num_loras, num_experts, rank, hidden_size, dtype=input_dtype, device=device + ) + * 0.01 + for _ in range(num_slices) + ] + lora_b_weights = [ + torch.randn( + num_loras, num_experts, feat_out, rank, dtype=input_dtype, device=device + ) + * 0.01 + for _ in range(num_slices) + ] + + sorted_token_ids = torch.arange( + num_tokens, device=device, dtype=torch.int64 + ).repeat_interleave(top_k) + expert_ids = torch.randint( + 0, num_experts, (num_pairs,), device=device, dtype=torch.int64 + ) + topk_weights = ( + torch.softmax(torch.randn(num_tokens, top_k, device=device), dim=-1) + .view(-1) + .to(torch.float32) + ) + lora_indices = torch.randint( + 0, num_loras, (num_tokens,), device=device, dtype=torch.int64 + ) + + # Pre-allocate buffers for BGMV MoE + w_ptr_a = torch.zeros(num_slices, num_experts, dtype=torch.int64, device=device) + lora_stride_a = 0 + for s in range(num_slices): + lora_stride_a = fill_w_ptr(w_ptr_a, lora_a_weights[s], num_experts, s) + + w_ptr_b = torch.zeros(num_slices, num_experts, dtype=torch.int64, device=device) + lora_stride_b = 0 + for s in range(num_slices): + lora_stride_b = fill_w_ptr(w_ptr_b, lora_b_weights[s], num_experts, s) + + shrink_out = torch.zeros( + num_slices, num_pairs, rank, dtype=input_dtype, device=device + ) + slice_start_loc = torch.zeros(num_slices, dtype=torch.int64, device=device) + for s in range(num_slices): + slice_start_loc[s] = s * feat_out + output_slices = [feat_out] * num_slices + + y_accum = torch.zeros( + num_tokens, feat_out * num_slices, dtype=torch.float32, device=device + ) + + # Define benchmark function (shrink + expand) + def run_bgmv_moe( + shrink_out, + x, + w_ptr_a, + sorted_token_ids, + expert_ids, + lora_indices, + lora_stride_a, + y_accum, + w_ptr_b, + topk_weights, + slice_start_loc, + output_slices, + lora_stride_b, + ): + shrink_out.zero_() + y_accum.zero_() + bgmv_moe_shrink( + shrink_out, + x, + w_ptr_a, + sorted_token_ids, + expert_ids, + lora_indices, + lora_stride_a, + ) + bgmv_moe_expand( + y_accum, + shrink_out, + w_ptr_b, + sorted_token_ids, + expert_ids, + topk_weights, + lora_indices, + slice_start_loc, + output_slices, + lora_stride_b, + ) + + is_cuda_graph_compatible = not args.no_cuda_graph + + # Benchmark BGMV MoE + times = bench_gpu_time( + fn=run_bgmv_moe, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + sleep_after_run=False, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + cold_l2_cache=True, + input_args=( + shrink_out, + x, + w_ptr_a, + sorted_token_ids, + expert_ids, + lora_indices, + lora_stride_a, + y_accum, + w_ptr_b, + topk_weights, + slice_start_loc, + output_slices, + lora_stride_b, + ), + ) + + bgmv_median = np.median(times) + bgmv_std = np.std(times) + + # Benchmark grouped_mm_bf16 as baseline comparison (single-slice only) + gg_kernel_median = float("nan") + gg_full_median = float("nan") + if num_slices != 1: + print( + "[INFO] grouped_mm_bf16 baseline skipped: only supported for num_slices=1" + ) + else: + try: + from flashinfer.grouped_mm import grouped_mm_bf16 + + # For grouped GEMM, sort tokens by (lora_id, expert_id) and build m_indptr + num_groups = num_loras * num_experts + + # Assign each pair a group_id = lora_id * num_experts + expert_id + lora_ids_expanded = lora_indices[sorted_token_ids] + group_ids = lora_ids_expanded * num_experts + expert_ids + group_ids[lora_ids_expanded < 0] = num_groups + + # Sort by group_id + sorted_indices = torch.argsort(group_ids) + sorted_group_ids = group_ids[sorted_indices] + + valid_mask = sorted_group_ids < num_groups + num_valid = valid_mask.sum().item() + + sorted_token_indices = sorted_token_ids[sorted_indices[:num_valid]] + g_input = x[sorted_token_indices] + + # Build m_indptr + counts = torch.zeros(num_groups + 1, dtype=torch.int32, device=device) + valid_groups = sorted_group_ids[:num_valid] + for g in range(num_groups): + counts[g + 1] = counts[g] + (valid_groups == g).sum().to(torch.int32) + g_m_indptr = counts + + # Reshape LoRA weights for grouped GEMM + g_lora_a = lora_a_weights[0].view( + num_loras * num_experts, rank, hidden_size + ) + g_lora_b = lora_b_weights[0].view(num_loras * num_experts, feat_out, rank) + + g_shrink_out = torch.zeros( + num_valid, rank, dtype=input_dtype, device=device + ) + g_expand_out = torch.zeros( + num_valid, feat_out, dtype=input_dtype, device=device + ) + + # Warmup + grouped_mm_bf16(g_input, g_lora_a, g_m_indptr, out=g_shrink_out) + grouped_mm_bf16(g_shrink_out, g_lora_b, g_m_indptr, out=g_expand_out) + torch.cuda.synchronize() + + # Benchmark: kernel only (pre-sorted, no sort overhead) + def run_gg_kernel( + g_input, g_lora_a, g_m_indptr, g_shrink_out, g_lora_b, g_expand_out + ): + g_shrink_out.zero_() + g_expand_out.zero_() + grouped_mm_bf16(g_input, g_lora_a, g_m_indptr, out=g_shrink_out) + grouped_mm_bf16(g_shrink_out, g_lora_b, g_m_indptr, out=g_expand_out) + + gg_kernel_times = bench_gpu_time( + fn=run_gg_kernel, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + sleep_after_run=False, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + cold_l2_cache=True, + input_args=( + g_input, + g_lora_a, + g_m_indptr, + g_shrink_out, + g_lora_b, + g_expand_out, + ), + ) + gg_kernel_median = np.median(gg_kernel_times) + + # Benchmark: sort + kernel (full end-to-end) + def run_gg_full( + x, + sorted_token_ids, + lora_indices, + expert_ids, + num_experts, + num_groups, + g_lora_a, + g_lora_b, + g_m_indptr, + g_shrink_out, + g_expand_out, + ): + _lora_ids = lora_indices[sorted_token_ids] + _group_ids = _lora_ids * num_experts + expert_ids + _group_ids[_lora_ids < 0] = num_groups + _sorted_indices = torch.argsort(_group_ids) + _sorted_token_indices = sorted_token_ids[_sorted_indices[:num_valid]] + _g_input = x[_sorted_token_indices] + g_shrink_out.zero_() + g_expand_out.zero_() + grouped_mm_bf16(_g_input, g_lora_a, g_m_indptr, out=g_shrink_out) + grouped_mm_bf16(g_shrink_out, g_lora_b, g_m_indptr, out=g_expand_out) + + gg_full_times = bench_gpu_time( + fn=run_gg_full, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + sleep_after_run=False, + enable_cupti=args.use_cupti, + use_cuda_graph=False, # sort is not cuda-graph compatible + cold_l2_cache=True, + input_args=( + x, + sorted_token_ids, + lora_indices, + expert_ids, + num_experts, + num_groups, + g_lora_a, + g_lora_b, + g_m_indptr, + g_shrink_out, + g_expand_out, + ), + ) + gg_full_median = np.median(gg_full_times) + + except (ImportError, RuntimeError) as e: + print(f"[INFO] grouped_mm_bf16 baseline skipped: {e}") + + # Print comparison results + def _fmt(v): + return f"{v:.3f}" if v == v else "N/A" + + def _speedup(baseline, kernel): + if baseline != baseline or kernel != kernel or kernel == 0: + return "N/A" + return f"{baseline / kernel:.2f}x" + + print(f"\n{'=' * 70}") + print( + f" BGMV MoE Benchmark: tokens={num_tokens}, hidden={hidden_size}, " + f"rank={rank}, experts={num_experts}, top_k={top_k}" + ) + print(f"{'=' * 70}") + print(f" {'Method':<25} {'Median (ms)':>12} {'Speedup vs BGMV MoE':>20}") + print(f" {'-' * 25} {'-' * 12} {'-' * 20}") + print(f" {'BGMV MoE (this PR)':<25} {_fmt(bgmv_median):>12} {'—':>20}") + print( + f" {'grouped_mm (kernel)':<25} {_fmt(gg_kernel_median):>12} {_speedup(gg_kernel_median, bgmv_median):>20}" + ) + print( + f" {'grouped_mm (sort+kern)':<25} {_fmt(gg_full_median):>12} {_speedup(gg_full_median, bgmv_median):>20}" + ) + print(f"{'=' * 70}\n") + + # Also print in standard format + flops = num_slices * num_pairs * (hidden_size * rank + rank * feat_out) * 2 + tflops = flops / (bgmv_median * 1e-3) / 1e12 + + backend = "bgmv_moe" + print_perf_metrics(backend, bgmv_median, bgmv_std, tflops, 0.0) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = bgmv_median + cur_res["std_time"] = bgmv_std + cur_res["tflops"] = tflops + cur_res["backend"] = backend + cur_res["num_tokens"] = num_tokens + cur_res["hidden_size"] = hidden_size + cur_res["num_experts"] = num_experts + cur_res["top_k"] = top_k + cur_res["rank"] = rank + cur_res["num_loras"] = num_loras + cur_res["num_slices"] = num_slices + cur_res["input_dtype"] = str(input_dtype) + cur_res["gg_kernel_median"] = gg_kernel_median + cur_res["gg_full_median"] = gg_full_median + res.append(cur_res) + + return res diff --git a/csrc/bgmv_moe/kernel_config.h b/csrc/bgmv_moe/kernel_config.h new file mode 100644 index 0000000000..74fa275eb0 --- /dev/null +++ b/csrc/bgmv_moe/kernel_config.h @@ -0,0 +1,40 @@ +#pragma once + +/* + * BGMV MoE kernel tuning parameters. + * + * Target: H100/H200 (sm_90, 228 KB shared memory per SM) + * Also supports sm_70+ (V100, A100) with reduced pipeline depth. + * + * Copyright (c) 2025 by FlashInfer team. + * Licensed under the Apache License, Version 2.0. + */ + +struct MoeShrinkKernelConfig { + static constexpr int tx = 32; // threads per warp (x-dimension) + static constexpr int ty = 4; // number of warps (y-dimension) + static constexpr int vec_size = 8; // elements per vectorized load + static constexpr int rank_tile = 8; // rank elements per block (8x X reuse) + + // Multi-pair decode path: PPB=4 pairs per block for decode, + // PPB=1 for prefill (grid already saturates GPU). + static constexpr int pairs_per_block_prefill = 1; + static constexpr int pairs_per_block_decode = 4; + static constexpr int decode_threshold = 32; + + // Pipeline depth: 3 stages on sm_90 decode (216 KB / 228 KB = 95%), + // 2 stages for prefill (36 KB, leaves room for occupancy). + static constexpr int num_stages_default = 2; + static constexpr int num_stages_extended = 3; + + // Shared memory budget (decode, PPB=4, 3 stages, RANK_TILE=8, fp16): + // X: 3 * 4 * 1024 * 2 = 24 KB + // W: 3 * 4 * 8 * 1024 * 2 = 192 KB + // y: 4 * 8 * 4 * 4 = 512 B + // Total: ~216 KB (fits 228 KB on H100/H200) +}; + +struct MoeExpandKernelConfig { + static constexpr int tz = 4; + static constexpr int vec_size = 8; +}; diff --git a/csrc/bgmv_moe/moe_bgmv_bf16_bf16_bf16.cu b/csrc/bgmv_moe/moe_bgmv_bf16_bf16_bf16.cu new file mode 100644 index 0000000000..deed4c8ebe --- /dev/null +++ b/csrc/bgmv_moe/moe_bgmv_bf16_bf16_bf16.cu @@ -0,0 +1,10 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * Licensed under the Apache License, Version 2.0. + */ + +#include "moe_bgmv_config.h" +#include "moe_bgmv_impl.cuh" + +// Shrink + expand (in_T=out_T=W_T=nv_bfloat16). +FOR_MOE_ALL_WIDE_NARROW(INST_MOE_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) diff --git a/csrc/bgmv_moe/moe_bgmv_bf16_fp32_bf16.cu b/csrc/bgmv_moe/moe_bgmv_bf16_fp32_bf16.cu new file mode 100644 index 0000000000..165770d1f7 --- /dev/null +++ b/csrc/bgmv_moe/moe_bgmv_bf16_fp32_bf16.cu @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * Licensed under the Apache License, Version 2.0. + */ + +#include "moe_bgmv_config.h" +#include "moe_bgmv_impl.cuh" + +// Expand only (in_T=nv_bfloat16, Y=float32, W_T=nv_bfloat16). +// Shrink is covered by moe_bgmv_bf16_bf16_bf16.cu. + +#define INST_MOE_BGMV_EXPAND_ONLY(in_T, out_T, W_T, narrow, wide) \ + INST_MOE_BGMV_EXPAND_SLICED(narrow, wide, in_T, W_T) + +FOR_MOE_ALL_WIDE_NARROW(INST_MOE_BGMV_EXPAND_ONLY, nv_bfloat16, float, nv_bfloat16) diff --git a/csrc/bgmv_moe/moe_bgmv_binding.cu b/csrc/bgmv_moe/moe_bgmv_binding.cu new file mode 100644 index 0000000000..c995300c0f --- /dev/null +++ b/csrc/bgmv_moe/moe_bgmv_binding.cu @@ -0,0 +1,23 @@ +/* + * TVM-FFI binding for BGMV MoE kernels. + * + * Exports two functions: + * - bgmv_moe_shrink + * - bgmv_moe_expand + * + * Copyright (c) 2025 by FlashInfer team. + * Licensed under the Apache License, Version 2.0. + */ + +#include "moe_bgmv_ops.cu" + +// Forward declarations +void bgmv_moe_shrink(TensorView y, TensorView x, TensorView w_ptr, TensorView sorted_token_ids, + TensorView expert_ids, TensorView lora_indices, int64_t lora_stride); + +void bgmv_moe_expand(TensorView y, TensorView x, TensorView w_ptr, TensorView sorted_token_ids, + TensorView expert_ids, TensorView topk_weights, TensorView lora_indices, + TensorView slice_start_loc, int64_t first_feat_out, int64_t lora_stride); + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(bgmv_moe_shrink, bgmv_moe_shrink); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(bgmv_moe_expand, bgmv_moe_expand); diff --git a/csrc/bgmv_moe/moe_bgmv_config.h b/csrc/bgmv_moe/moe_bgmv_config.h new file mode 100644 index 0000000000..79480f4537 --- /dev/null +++ b/csrc/bgmv_moe/moe_bgmv_config.h @@ -0,0 +1,82 @@ +#pragma once + +#include + +/* + * BGMV MoE dimension configuration and forward declarations. + * + * Defines the set of (narrow, wide) dimension pairs that are compiled. + * narrow = LoRA rank (8, 16, 32, 64) + * wide = hidden/intermediate dimension of the model + * + * Models covered: + * Qwen3-30B-A3B: gate_up=(2048,768), down=(768,2048) + * Qwen3.5-35B-A3B: gate_up=(2048,768), down=(768,2048) [256 experts, top-8] + * Gemma-4-26B-A4B: gate_up=(2816,2112), down=(2112,2816) [128 experts, top-8] + * Nemotron-Nano-3-30B-A3B: gate_up=(2688,1856), down=(1856,2688) + * Nemotron-3-Super-120B-A12B: gate_up=(4096,2688), down=(2688,4096) + * Large MoE (128 experts): gate_up=(3072,5888), down=(2944,3072) + * + * TP-sharded dimensions also included. + * All wide values must satisfy: wide % 32 == 0. + * + * Copyright (c) 2025 by FlashInfer team. + * Licensed under the Apache License, Version 2.0. + */ + +// clang-format off + +#define FOR_MOE_ALL_WIDE(f, in_T, out_T, W_T, narrow) \ + f(in_T, out_T, W_T, narrow, 384) \ + f(in_T, out_T, W_T, narrow, 736) \ + f(in_T, out_T, W_T, narrow, 768) \ + f(in_T, out_T, W_T, narrow, 1024) \ + f(in_T, out_T, W_T, narrow, 1344) \ + f(in_T, out_T, W_T, narrow, 1472) \ + f(in_T, out_T, W_T, narrow, 1536) \ + f(in_T, out_T, W_T, narrow, 1856) \ + f(in_T, out_T, W_T, narrow, 2048) \ + f(in_T, out_T, W_T, narrow, 2112) \ + f(in_T, out_T, W_T, narrow, 2688) \ + f(in_T, out_T, W_T, narrow, 2816) \ + f(in_T, out_T, W_T, narrow, 2880) \ + f(in_T, out_T, W_T, narrow, 2944) \ + f(in_T, out_T, W_T, narrow, 3072) \ + f(in_T, out_T, W_T, narrow, 4096) \ + f(in_T, out_T, W_T, narrow, 5120) \ + f(in_T, out_T, W_T, narrow, 5888) \ + f(in_T, out_T, W_T, narrow, 7168) \ + f(in_T, out_T, W_T, narrow, 8192) \ + f(in_T, out_T, W_T, narrow, 10240) \ + f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 16384) \ + f(in_T, out_T, W_T, narrow, 28672) + +#define FOR_MOE_ALL_WIDE_NARROW(f, in_T, out_T, W_T) \ + FOR_MOE_ALL_WIDE(f, in_T, out_T, W_T, 8) \ + FOR_MOE_ALL_WIDE(f, in_T, out_T, W_T, 16) \ + FOR_MOE_ALL_WIDE(f, in_T, out_T, W_T, 32) \ + FOR_MOE_ALL_WIDE(f, in_T, out_T, W_T, 64) + +// clang-format on + +// ===== Forward declarations ===== + +template +void moe_bgmv_shrink_sliced(out_T* __restrict__ Y, const in_T* __restrict__ X, + W_T** __restrict__ w_ptr, const int64_t* __restrict__ sorted_token_ids, + const int64_t* __restrict__ expert_ids, + const int64_t* __restrict__ lora_indices, int64_t num_pairs, + int64_t num_slices, int64_t num_experts, int64_t num_tokens, + int64_t lora_stride, float scale); + +template +void moe_bgmv_expand_sliced(float* __restrict__ Y, const in_T* __restrict__ X, + W_T** __restrict__ w_ptr, const int64_t* __restrict__ sorted_token_ids, + const int64_t* __restrict__ expert_ids, + const int64_t* __restrict__ lora_indices, + const float* __restrict__ topk_weights, + const int64_t* __restrict__ slice_start_loc, int64_t num_pairs, + int64_t num_slices, int64_t num_experts, int64_t total_feat_out, + int32_t current_feat_out, int64_t num_tokens, int64_t lora_stride, + float scale); diff --git a/csrc/bgmv_moe/moe_bgmv_fp16_fp16_fp16.cu b/csrc/bgmv_moe/moe_bgmv_fp16_fp16_fp16.cu new file mode 100644 index 0000000000..c1f30de2e3 --- /dev/null +++ b/csrc/bgmv_moe/moe_bgmv_fp16_fp16_fp16.cu @@ -0,0 +1,10 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * Licensed under the Apache License, Version 2.0. + */ + +#include "moe_bgmv_config.h" +#include "moe_bgmv_impl.cuh" + +// Shrink + expand (in_T=out_T=W_T=nv_half). +FOR_MOE_ALL_WIDE_NARROW(INST_MOE_BGMV_TWOSIDE, nv_half, nv_half, nv_half) diff --git a/csrc/bgmv_moe/moe_bgmv_fp16_fp32_fp16.cu b/csrc/bgmv_moe/moe_bgmv_fp16_fp32_fp16.cu new file mode 100644 index 0000000000..6682c099fa --- /dev/null +++ b/csrc/bgmv_moe/moe_bgmv_fp16_fp32_fp16.cu @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * Licensed under the Apache License, Version 2.0. + */ + +#include "moe_bgmv_config.h" +#include "moe_bgmv_impl.cuh" + +// Expand only (in_T=nv_half, Y=float32, W_T=nv_half). +// Shrink is covered by moe_bgmv_fp16_fp16_fp16.cu. + +#define INST_MOE_BGMV_EXPAND_ONLY(in_T, out_T, W_T, narrow, wide) \ + INST_MOE_BGMV_EXPAND_SLICED(narrow, wide, in_T, W_T) + +FOR_MOE_ALL_WIDE_NARROW(INST_MOE_BGMV_EXPAND_ONLY, nv_half, float, nv_half) diff --git a/csrc/bgmv_moe/moe_bgmv_fp32_bf16_bf16.cu b/csrc/bgmv_moe/moe_bgmv_fp32_bf16_bf16.cu new file mode 100644 index 0000000000..4ff5171949 --- /dev/null +++ b/csrc/bgmv_moe/moe_bgmv_fp32_bf16_bf16.cu @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * Licensed under the Apache License, Version 2.0. + */ + +#include "moe_bgmv_config.h" +#include "moe_bgmv_impl.cuh" + +// Shrink only with mixed precision (in_T=float, out_T=nv_bfloat16, W_T=nv_bfloat16). +// This handles the case where X is accumulated in fp32 but LoRA weights are bf16. + +#define INST_MOE_BGMV_SHRINK_ONLY(in_T, out_T, W_T, narrow, wide) \ + INST_MOE_BGMV_SHRINK_SLICED(wide, narrow, in_T, out_T, W_T) + +FOR_MOE_ALL_WIDE_NARROW(INST_MOE_BGMV_SHRINK_ONLY, float, nv_bfloat16, nv_bfloat16) diff --git a/csrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cu b/csrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cu new file mode 100644 index 0000000000..bf49e6acea --- /dev/null +++ b/csrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cu @@ -0,0 +1,14 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * Licensed under the Apache License, Version 2.0. + */ + +#include "moe_bgmv_config.h" +#include "moe_bgmv_impl.cuh" + +// Shrink only with mixed precision (in_T=float, out_T=nv_half, W_T=nv_half). + +#define INST_MOE_BGMV_SHRINK_ONLY(in_T, out_T, W_T, narrow, wide) \ + INST_MOE_BGMV_SHRINK_SLICED(wide, narrow, in_T, out_T, W_T) + +FOR_MOE_ALL_WIDE_NARROW(INST_MOE_BGMV_SHRINK_ONLY, float, nv_half, nv_half) diff --git a/csrc/bgmv_moe/moe_bgmv_impl.cuh b/csrc/bgmv_moe/moe_bgmv_impl.cuh new file mode 100644 index 0000000000..fb14256bbe --- /dev/null +++ b/csrc/bgmv_moe/moe_bgmv_impl.cuh @@ -0,0 +1,431 @@ +#pragma once + +/* + * Multi-LoRA MoE BGMV CUDA Kernel Implementation. + * + * Two kernels: + * 1. Shrink: y[slice, pair, rank] += x[token] @ lora_a[expert, lora_id] + * - Compute-bound, uses async pipeline, RANK_TILE tiling, multi-pair blocking + * 2. Expand: y[token, feat] += shrink_out[pair, rank] @ lora_b[expert, lora_id] * topk_weight + * - Memory-bound, uses warp-level reduction + * + * Grid (shrink): (ceil(num_pairs/PPB), ceil(feat_out/RANK_TILE), num_slices) + * Grid (expand): (num_pairs, feat_out/(ty*tz), num_slices) + * + * Copyright (c) 2025 by FlashInfer team. + * Licensed under the Apache License, Version 2.0. + */ + +#include +#include + +#include + +// Get the current CUDA stream. In TVM-FFI context, the stream is set by the +// framework before kernel dispatch. We use the default stream (0) which maps +// to the current stream in per-thread default stream mode. +#define BGMV_MOE_GET_STREAM() 0 + +#include + +#include "kernel_config.h" + +using namespace flashinfer; + +namespace cg = cooperative_groups; + +// ============================================================ +// MoE BGMV Shrink Sliced Kernel +// +// Optimizations: +// 1. RANK_TILE tiling — reuse X tile across RANK_TILE weight rows +// 2. Multi-pair — PPB pairs per block (PPB=4 decode, PPB=1 prefill) +// 3. Deep pipeline — NUM_STAGES async pipeline stages (3 decode, 2 prefill) +// +// Uses dynamic shared memory so that large configurations compile for all archs. +// The host wrapper calls cudaFuncSetAttribute on sm_80+ to raise the limit. +// ============================================================ +template +__global__ void moe_bgmv_shrink_sliced_kernel( + out_T* __restrict__ Y, const in_T* __restrict__ X, W_T** __restrict__ w_ptr, + const int64_t* __restrict__ sorted_token_ids, const int64_t* __restrict__ expert_ids, + const int64_t* __restrict__ lora_indices, int64_t num_pairs, int64_t num_experts, + int64_t num_tokens, int64_t lora_stride, float scale) { + const int slice_id = blockIdx.z; + const int pair_block_idx = blockIdx.x; + const int rank_tile_idx = blockIdx.y; + const int j0 = rank_tile_idx * RANK_TILE; + const int p0 = pair_block_idx * PAIRS_PER_BLOCK; + + auto block = cg::this_thread_block(); + constexpr size_t tile_size = tx * ty * vec_size; + constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size; + + // Per-pair metadata + const in_T* X_tok[PAIRS_PER_BLOCK]; + const W_T* W_base[PAIRS_PER_BLOCK]; + bool pair_valid[PAIRS_PER_BLOCK]; + +#pragma unroll + for (int pp = 0; pp < PAIRS_PER_BLOCK; ++pp) { + const int pair_idx = p0 + pp; + if (pair_idx < num_pairs) { + const int64_t token_idx = sorted_token_ids[pair_idx]; + if (token_idx >= 0 && token_idx < num_tokens) { + const int64_t eid = expert_ids[pair_idx]; + const int64_t lid = lora_indices[token_idx]; + if (lid >= 0) { + X_tok[pp] = X + token_idx * feat_in; + W_base[pp] = w_ptr[slice_id * num_experts + eid] + lid * lora_stride + j0 * feat_in; + pair_valid[pp] = true; + continue; + } + } + } + X_tok[pp] = nullptr; + W_base[pp] = nullptr; + pair_valid[pp] = false; + } + + bool any_valid = false; +#pragma unroll + for (int pp = 0; pp < PAIRS_PER_BLOCK; ++pp) any_valid |= pair_valid[pp]; + if (!any_valid) return; + + // Dynamic shared memory layout + extern __shared__ char smem[]; + constexpr size_t x_elems = NUM_STAGES * PAIRS_PER_BLOCK * tile_size; + constexpr size_t w_elems = NUM_STAGES * PAIRS_PER_BLOCK * RANK_TILE * tile_size; + in_T* X_shared = reinterpret_cast(smem); + W_T* W_shared = reinterpret_cast(smem + x_elems * sizeof(in_T)); + float* y_warpwise = + reinterpret_cast(smem + x_elems * sizeof(in_T) + w_elems * sizeof(W_T)); + + auto pipe = cuda::make_pipeline(); + const size_t toff = (threadIdx.y * tx + threadIdx.x) * vec_size; + + float y_acc[PAIRS_PER_BLOCK][RANK_TILE]; +#pragma unroll + for (int pp = 0; pp < PAIRS_PER_BLOCK; ++pp) +#pragma unroll + for (int r = 0; r < RANK_TILE; ++r) y_acc[pp][r] = 0.f; + + vec_t x_vec; + vec_t w_vec; + + // Prologue: fill pipeline + constexpr size_t pro = (num_tiles < NUM_STAGES) ? num_tiles : NUM_STAGES; +#pragma unroll + for (size_t t = 0; t < pro; ++t) { + const size_t s = t % NUM_STAGES; + const size_t tb = t * tile_size; + pipe.producer_acquire(); +#pragma unroll + for (int pp = 0; pp < PAIRS_PER_BLOCK; ++pp) { + if (pair_valid[pp] && tb + toff < feat_in) { + cuda::memcpy_async(X_shared + (s * PAIRS_PER_BLOCK + pp) * tile_size + toff, + X_tok[pp] + tb + toff, cuda::aligned_size_t(X_copy_size), + pipe); +#pragma unroll + for (int r = 0; r < RANK_TILE; ++r) + if (j0 + r < feat_out) + cuda::memcpy_async( + W_shared + ((s * PAIRS_PER_BLOCK + pp) * RANK_TILE + r) * tile_size + toff, + W_base[pp] + r * feat_in + tb + toff, + cuda::aligned_size_t(W_copy_size), pipe); + } + } + pipe.producer_commit(); + } + + // Main loop + for (size_t t = pro; t < num_tiles; ++t) { + const size_t cs = (t - pro) % NUM_STAGES; + pipe.consumer_wait(); + block.sync(); +#pragma unroll + for (int pp = 0; pp < PAIRS_PER_BLOCK; ++pp) { + if (!pair_valid[pp]) continue; + x_vec.load(X_shared + (cs * PAIRS_PER_BLOCK + pp) * tile_size + toff); +#pragma unroll + for (int r = 0; r < RANK_TILE; ++r) { + if (j0 + r < feat_out) { + w_vec.load(W_shared + ((cs * PAIRS_PER_BLOCK + pp) * RANK_TILE + r) * tile_size + toff); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) sum += float(w_vec[i]) * float(x_vec[i]) * scale; +#pragma unroll + for (size_t off = tx / 2; off > 0; off /= 2) + sum += __shfl_down_sync(0xffffffff, sum, off); + if (threadIdx.x == 0) y_warpwise[pp * RANK_TILE * ty + r * ty + threadIdx.y] = sum; + } + } + } + block.sync(); +#pragma unroll + for (int pp = 0; pp < PAIRS_PER_BLOCK; ++pp) { + if (!pair_valid[pp]) continue; +#pragma unroll + for (int r = 0; r < RANK_TILE; ++r) + if (j0 + r < feat_out) { + float v = 0.f; + for (int w = 0; w < ty; ++w) v += y_warpwise[pp * RANK_TILE * ty + r * ty + w]; + y_acc[pp][r] += v; + } + } + block.sync(); + pipe.consumer_release(); + + // Load next tile + const size_t ls = t % NUM_STAGES; + const size_t tb = t * tile_size; + pipe.producer_acquire(); +#pragma unroll + for (int pp = 0; pp < PAIRS_PER_BLOCK; ++pp) { + if (pair_valid[pp] && tb + toff < feat_in) { + cuda::memcpy_async(X_shared + (ls * PAIRS_PER_BLOCK + pp) * tile_size + toff, + X_tok[pp] + tb + toff, cuda::aligned_size_t(X_copy_size), + pipe); +#pragma unroll + for (int r = 0; r < RANK_TILE; ++r) + if (j0 + r < feat_out) + cuda::memcpy_async( + W_shared + ((ls * PAIRS_PER_BLOCK + pp) * RANK_TILE + r) * tile_size + toff, + W_base[pp] + r * feat_in + tb + toff, + cuda::aligned_size_t(W_copy_size), pipe); + } + } + pipe.producer_commit(); + } + + // Epilogue: drain remaining pipeline stages + for (size_t t = (num_tiles > pro ? num_tiles - pro : 0); t < num_tiles; ++t) { + const size_t cs = t % NUM_STAGES; + const size_t ts = t * tile_size; + pipe.consumer_wait(); + block.sync(); +#pragma unroll + for (int pp = 0; pp < PAIRS_PER_BLOCK; ++pp) { + if (!pair_valid[pp]) continue; + x_vec.load(X_shared + (cs * PAIRS_PER_BLOCK + pp) * tile_size + toff); +#pragma unroll + for (int r = 0; r < RANK_TILE; ++r) { + if (j0 + r < feat_out) { + w_vec.load(W_shared + ((cs * PAIRS_PER_BLOCK + pp) * RANK_TILE + r) * tile_size + toff); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) sum += float(w_vec[i]) * float(x_vec[i]) * scale; +#pragma unroll + for (size_t off = tx / 2; off > 0; off /= 2) + sum += __shfl_down_sync(0xffffffff, sum, off); + if (threadIdx.x == 0) { + if (t == num_tiles - 1) sum = (ts + threadIdx.y * tx * vec_size < feat_in) ? sum : 0.f; + y_warpwise[pp * RANK_TILE * ty + r * ty + threadIdx.y] = sum; + } + } + } + } + block.sync(); +#pragma unroll + for (int pp = 0; pp < PAIRS_PER_BLOCK; ++pp) { + if (!pair_valid[pp]) continue; +#pragma unroll + for (int r = 0; r < RANK_TILE; ++r) + if (j0 + r < feat_out) { + float v = 0.f; + for (int w = 0; w < ty; ++w) v += y_warpwise[pp * RANK_TILE * ty + r * ty + w]; + y_acc[pp][r] += v; + } + } + block.sync(); + pipe.consumer_release(); + } + + // Write results + if (block.thread_rank() == 0) { +#pragma unroll + for (int pp = 0; pp < PAIRS_PER_BLOCK; ++pp) { + if (!pair_valid[pp]) continue; +#pragma unroll + for (int r = 0; r < RANK_TILE; ++r) + if (j0 + r < feat_out) + Y[slice_id * num_pairs * feat_out + (p0 + pp) * feat_out + j0 + r] += + static_cast(y_acc[pp][r]); + } + } +} + +// ============================================================ +// MoE BGMV Expand Sliced Kernel +// ============================================================ +template +__global__ void moe_bgmv_expand_sliced_kernel( + float* __restrict__ Y, const in_T* __restrict__ X, W_T** __restrict__ w_ptr, + const int64_t* __restrict__ sorted_token_ids, const int64_t* __restrict__ expert_ids, + const int64_t* __restrict__ lora_indices, const float* __restrict__ topk_weights, + const int64_t* __restrict__ slice_start_loc, int64_t num_pairs, int64_t num_experts, + int64_t total_feat_out, int32_t current_feat_out, int64_t num_tokens, int64_t lora_stride, + float scale) { + size_t pair_idx = blockIdx.x; + size_t tile_idx = blockIdx.y; + int64_t token_idx = sorted_token_ids[pair_idx]; + if (token_idx < 0 || token_idx >= num_tokens) return; + int64_t lora_id = lora_indices[token_idx]; + if (lora_id < 0) return; + int slice_id = blockIdx.z; + int64_t expert_id = expert_ids[pair_idx]; + float topk_w = topk_weights[pair_idx]; + int64_t col_offset = slice_start_loc[slice_id]; + const W_T* W = w_ptr[slice_id * num_experts + expert_id] + lora_id * lora_stride; + auto block = cg::this_thread_block(); + vec_t x_vec; + x_vec.load(X + slice_id * num_pairs * feat_in + pair_idx * feat_in + threadIdx.x * vec_size); + vec_t w_vec; + w_vec.load(W + (tile_idx * tz * ty) * feat_in + block.thread_rank() * vec_size); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) sum += float(w_vec[i]) * float(x_vec[i]) * scale; + cg::thread_block_tile g = cg::tiled_partition(block); +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) sum += g.shfl_down(sum, offset); + sum = g.shfl(sum, 0); + if (threadIdx.x == 0) { + int out_col = col_offset + tile_idx * (tz * ty) + threadIdx.z * ty + threadIdx.y; + atomicAdd(Y + token_idx * total_feat_out + out_col, sum * topk_w); + } +} + +// ============================================================ +// Host-side dispatch: Shrink +// ============================================================ + +template +void moe_bgmv_shrink_sliced(out_T* __restrict__ Y, const in_T* __restrict__ X, + W_T** __restrict__ w_ptr, const int64_t* sorted_token_ids, + const int64_t* expert_ids, const int64_t* lora_indices, + int64_t num_pairs, int64_t num_slices, int64_t num_experts, + int64_t num_tokens, int64_t lora_stride, float scale) { + // Use the current CUDA stream + const cudaStream_t stream = BGMV_MOE_GET_STREAM(); + + constexpr int cfg_tx = MoeShrinkKernelConfig::tx; + constexpr int cfg_ty = MoeShrinkKernelConfig::ty; + constexpr int RT = MoeShrinkKernelConfig::rank_tile; + constexpr int gy = (feat_out + RT - 1) / RT; + constexpr size_t fvs = MoeShrinkKernelConfig::vec_size; + + // Runtime: detect sm_80+ for extended shared memory + int dev; + cudaGetDevice(&dev); + int sm_major = 0; + cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, dev); + const bool extended = (sm_major >= 9); + const bool decode = (num_pairs <= MoeShrinkKernelConfig::decode_threshold); + + const int ppb = (extended && decode) ? MoeShrinkKernelConfig::pairs_per_block_decode + : MoeShrinkKernelConfig::pairs_per_block_prefill; + const int nstg = (extended && decode) ? MoeShrinkKernelConfig::num_stages_extended + : MoeShrinkKernelConfig::num_stages_default; + +#define LAUNCH(PPB, NSTG, VS) \ + do { \ + constexpr size_t ts = cfg_tx * cfg_ty * (VS); \ + constexpr size_t shmem = (NSTG) * (PPB) * ts * sizeof(in_T) + \ + (NSTG) * (PPB) * RT * ts * sizeof(W_T) + \ + (PPB) * RT * cfg_ty * sizeof(float); \ + auto kfn = &moe_bgmv_shrink_sliced_kernel; \ + if constexpr (shmem > 48 * 1024) \ + cudaFuncSetAttribute(kfn, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)shmem); \ + dim3 g((int)((num_pairs + (PPB) - 1) / (PPB)), gy, num_slices); \ + kfn<<>>(Y, X, w_ptr, sorted_token_ids, expert_ids, \ + lora_indices, num_pairs, num_experts, \ + num_tokens, lora_stride, scale); \ + } while (0) + +#define DISPATCH(VS) \ + do { \ + if (ppb == 4 && nstg == 3) { \ + LAUNCH(4, 3, VS); \ + } else { \ + LAUNCH(1, 2, VS); \ + } \ + } while (0) + + if constexpr (feat_in % (fvs * cfg_tx) == 0) { + DISPATCH(fvs); + } else if constexpr (feat_in % (fvs / 2 * cfg_tx) == 0) { + DISPATCH(fvs / 2); + } else if constexpr (feat_in % (fvs / 4 * cfg_tx) == 0) { + DISPATCH(fvs / 4); + } else if constexpr (feat_in % cfg_tx == 0) { + DISPATCH(1); + } + +#undef DISPATCH +#undef LAUNCH +} + +// ============================================================ +// Host-side dispatch: Expand +// ============================================================ + +template +void moe_bgmv_expand_sliced(float* __restrict__ Y, const in_T* __restrict__ X, + W_T** __restrict__ w_ptr, const int64_t* sorted_token_ids, + const int64_t* expert_ids, const int64_t* lora_indices, + const float* topk_weights, const int64_t* slice_start_loc, + int64_t num_pairs, int64_t num_slices, int64_t num_experts, + int64_t total_feat_out, int32_t current_feat_out, int64_t num_tokens, + int64_t lora_stride, float scale) { + const cudaStream_t stream = BGMV_MOE_GET_STREAM(); // current CUDA stream + + constexpr size_t vec_size = MoeExpandKernelConfig::vec_size; + constexpr int tz = MoeExpandKernelConfig::tz; + static_assert(feat_in % vec_size == 0); + constexpr int tx = feat_in / vec_size; + + if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) { + constexpr int ty = 32 / tx; + moe_bgmv_expand_sliced_kernel + <<>>( + Y, X, w_ptr, sorted_token_ids, expert_ids, lora_indices, topk_weights, slice_start_loc, + num_pairs, num_experts, total_feat_out, current_feat_out, num_tokens, lora_stride, + scale); + } else if constexpr (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) { + constexpr int ty = 16 / tx; + moe_bgmv_expand_sliced_kernel + <<>>( + Y, X, w_ptr, sorted_token_ids, expert_ids, lora_indices, topk_weights, slice_start_loc, + num_pairs, num_experts, total_feat_out, current_feat_out, num_tokens, lora_stride, + scale); + } else if constexpr (8 % tx == 0 && feat_out % (8 / tx * tz) == 0) { + constexpr int ty = 8 / tx; + moe_bgmv_expand_sliced_kernel + <<>>( + Y, X, w_ptr, sorted_token_ids, expert_ids, lora_indices, topk_weights, slice_start_loc, + num_pairs, num_experts, total_feat_out, current_feat_out, num_tokens, lora_stride, + scale); + } +} + +// ============================================================ +// Instantiation macros +// ============================================================ +#define INST_MOE_BGMV_SHRINK_SLICED(feat_in, feat_out, in_T, out_T, W_T) \ + template void moe_bgmv_shrink_sliced( \ + out_T*, const in_T*, W_T**, const int64_t*, const int64_t*, const int64_t*, int64_t, \ + int64_t, int64_t, int64_t, int64_t, float); + +#define INST_MOE_BGMV_EXPAND_SLICED(feat_in, feat_out, in_T, W_T) \ + template void moe_bgmv_expand_sliced( \ + float*, const in_T*, W_T**, const int64_t*, const int64_t*, const int64_t*, const float*, \ + const int64_t*, int64_t, int64_t, int64_t, int64_t, int32_t, int64_t, int64_t, float); + +#define INST_MOE_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ + INST_MOE_BGMV_SHRINK_SLICED(wide, narrow, in_T, out_T, W_T) \ + INST_MOE_BGMV_EXPAND_SLICED(narrow, wide, in_T, W_T) diff --git a/csrc/bgmv_moe/moe_bgmv_ops.cu b/csrc/bgmv_moe/moe_bgmv_ops.cu new file mode 100644 index 0000000000..a26d3028ad --- /dev/null +++ b/csrc/bgmv_moe/moe_bgmv_ops.cu @@ -0,0 +1,174 @@ +/* + * Dispatch logic for BGMV MoE kernels. + * Routes to the correct template instantiation based on tensor dtypes and dimensions. + * + * Copyright (c) 2025 by FlashInfer team. + * Licensed under the Apache License, Version 2.0. + */ + +#include +#include + +#include + +#include "moe_bgmv_config.h" +#include "tvm_ffi_utils.h" + +// ====== Utils ====== + +inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) { + return (uint64_t(a) << 32) | uint64_t(b); +} + +// ====== MoE BGMV Shrink Launcher ====== + +template +inline bool launch_moe_shrink_sliced_kernel(T* Y, const T* X, T** w_ptr, + const int64_t* sorted_token_ids, + const int64_t* expert_ids, const int64_t* lora_indices, + uint32_t feat_in, uint32_t feat_out, int64_t num_pairs, + int64_t num_slices, int64_t num_experts, + int64_t num_tokens, int64_t lora_stride) { + switch (pack_u32(feat_in, feat_out)) { +#define CASE_MOE_SHRINK(in_T, out_T, W_T, narrow, wide) \ + case pack_u32(wide, narrow): \ + moe_bgmv_shrink_sliced( \ + Y, X, w_ptr, sorted_token_ids, expert_ids, lora_indices, num_pairs, num_slices, \ + num_experts, num_tokens, lora_stride, 1.0f); \ + return true; + FOR_MOE_ALL_WIDE_NARROW(CASE_MOE_SHRINK, T, T, T) +#undef CASE_MOE_SHRINK + default: + return false; + } +} + +// ====== MoE BGMV Expand Launcher ====== + +template +inline bool launch_moe_expand_sliced_kernel( + float* Y, const T* X, T** w_ptr, const int64_t* sorted_token_ids, const int64_t* expert_ids, + const int64_t* lora_indices, const float* topk_weights, const int64_t* slice_start_loc, + uint32_t feat_in, uint32_t feat_out, int64_t num_pairs, int64_t num_slices, int64_t num_experts, + int64_t total_feat_out, int64_t num_tokens, int64_t lora_stride) { + switch (pack_u32(feat_in, feat_out)) { +#define CASE_MOE_EXPAND(in_T, out_T, W_T, narrow, wide) \ + case pack_u32(narrow, wide): \ + moe_bgmv_expand_sliced( \ + Y, X, w_ptr, sorted_token_ids, expert_ids, lora_indices, topk_weights, slice_start_loc, \ + num_pairs, num_slices, num_experts, total_feat_out, wide, num_tokens, lora_stride, 1.0f); \ + return true; + FOR_MOE_ALL_WIDE_NARROW(CASE_MOE_EXPAND, T, T, T) +#undef CASE_MOE_EXPAND + default: + return false; + } +} + +// ====== TVM-FFI dispatch: MoE Shrink ====== + +void bgmv_moe_shrink(TensorView y, TensorView x, TensorView w_ptr, TensorView sorted_token_ids, + TensorView expert_ids, TensorView lora_indices, int64_t lora_stride) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w_ptr); + CHECK_INPUT(sorted_token_ids); + CHECK_INPUT(expert_ids); + CHECK_INPUT(lora_indices); + CHECK_DIM(3, y); + CHECK_DIM(2, x); + CHECK_DIM(2, w_ptr); + CHECK_DIM(1, sorted_token_ids); + CHECK_DIM(1, expert_ids); + CHECK_DIM(1, lora_indices); + + int64_t num_slices = y.size(0); + int64_t num_pairs = sorted_token_ids.size(0); + int64_t num_tokens = lora_indices.size(0); + int64_t feat_in = x.size(1); + int64_t feat_out = y.size(2); + int64_t num_experts = w_ptr.size(1); + + TVM_FFI_ICHECK_EQ(w_ptr.size(0), num_slices) << "w_ptr slice dim mismatch"; + TVM_FFI_ICHECK(sorted_token_ids.dtype() == dl_int64) << "sorted_token_ids must be int64"; + TVM_FFI_ICHECK(expert_ids.dtype() == dl_int64) << "expert_ids must be int64"; + TVM_FFI_ICHECK(w_ptr.dtype() == dl_int64) << "w_ptr must be int64"; + TVM_FFI_ICHECK(lora_indices.dtype() == dl_int64) << "lora_indices must be int64"; + + ffi::CUDADeviceGuard guard(x.device().device_id); + bool ok = false; + + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(x.dtype(), DType, [&] { + ok = launch_moe_shrink_sliced_kernel( + static_cast(y.data_ptr()), static_cast(x.data_ptr()), + reinterpret_cast(static_cast(w_ptr.data_ptr())), + static_cast(sorted_token_ids.data_ptr()), + static_cast(expert_ids.data_ptr()), + static_cast(lora_indices.data_ptr()), feat_in, feat_out, num_pairs, num_slices, + num_experts, num_tokens, lora_stride); + return true; + }); + + TVM_FFI_ICHECK(ok) << "BGMV MoE shrink failed. feat_in=" << feat_in << " feat_out=" << feat_out + << ". Dimension pair not compiled."; +} + +// ====== TVM-FFI dispatch: MoE Expand ====== + +void bgmv_moe_expand(TensorView y, TensorView x, TensorView w_ptr, TensorView sorted_token_ids, + TensorView expert_ids, TensorView topk_weights, TensorView lora_indices, + TensorView slice_start_loc, int64_t first_feat_out, int64_t lora_stride) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w_ptr); + CHECK_INPUT(sorted_token_ids); + CHECK_INPUT(expert_ids); + CHECK_INPUT(topk_weights); + CHECK_INPUT(lora_indices); + CHECK_INPUT(slice_start_loc); + CHECK_DIM(2, y); + CHECK_DIM(3, x); + CHECK_DIM(2, w_ptr); + CHECK_DIM(1, sorted_token_ids); + CHECK_DIM(1, expert_ids); + CHECK_DIM(1, topk_weights); + CHECK_DIM(1, lora_indices); + CHECK_DIM(1, slice_start_loc); + + int64_t num_slices = x.size(0); + int64_t num_pairs = sorted_token_ids.size(0); + int64_t num_tokens = lora_indices.size(0); + int64_t feat_in = x.size(2); + int64_t total_feat_out = y.size(1); + int64_t num_experts = w_ptr.size(1); + + TVM_FFI_ICHECK_EQ(w_ptr.size(0), num_slices) << "w_ptr slice dim mismatch"; + TVM_FFI_ICHECK(sorted_token_ids.dtype() == dl_int64) << "sorted_token_ids must be int64"; + TVM_FFI_ICHECK(expert_ids.dtype() == dl_int64) << "expert_ids must be int64"; + TVM_FFI_ICHECK(w_ptr.dtype() == dl_int64) << "w_ptr must be int64"; + TVM_FFI_ICHECK(lora_indices.dtype() == dl_int64) << "lora_indices must be int64"; + TVM_FFI_ICHECK(slice_start_loc.dtype() == dl_int64) << "slice_start_loc must be int64"; + TVM_FFI_ICHECK(topk_weights.dtype() == dl_float32) << "topk_weights must be float32"; + TVM_FFI_ICHECK(y.dtype() == dl_float32) << "y must be float32 accumulation buffer"; + TVM_FFI_ICHECK(first_feat_out > 0) << "first_feat_out must be positive"; + + ffi::CUDADeviceGuard guard(x.device().device_id); + bool ok = false; + + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(x.dtype(), DType, [&] { + ok = launch_moe_expand_sliced_kernel( + static_cast(y.data_ptr()), static_cast(x.data_ptr()), + reinterpret_cast(static_cast(w_ptr.data_ptr())), + static_cast(sorted_token_ids.data_ptr()), + static_cast(expert_ids.data_ptr()), + static_cast(lora_indices.data_ptr()), + static_cast(topk_weights.data_ptr()), + static_cast(slice_start_loc.data_ptr()), feat_in, + static_cast(first_feat_out), num_pairs, num_slices, num_experts, total_feat_out, + num_tokens, lora_stride); + return true; + }); + + TVM_FFI_ICHECK(ok) << "BGMV MoE expand failed. feat_in=" << feat_in + << " feat_out=" << first_feat_out << ". Dimension pair not compiled."; +} diff --git a/csrc/bgmv_moe/moe_bgmv_ops.h b/csrc/bgmv_moe/moe_bgmv_ops.h new file mode 100644 index 0000000000..ef54d55849 --- /dev/null +++ b/csrc/bgmv_moe/moe_bgmv_ops.h @@ -0,0 +1,20 @@ +#pragma once + +/* + * Public C++ interface for BGMV MoE kernels (TVM-FFI). + * + * Copyright (c) 2025 by FlashInfer team. + * Licensed under the Apache License, Version 2.0. + */ + +#include + +// Forward declarations for TVM-FFI dispatch functions. +// These are defined in moe_bgmv_ops.cu and exported via moe_bgmv_binding.cu. + +void bgmv_moe_shrink(TensorView y, TensorView x, TensorView w_ptr, TensorView sorted_token_ids, + TensorView expert_ids, TensorView lora_indices, int64_t lora_stride); + +void bgmv_moe_expand(TensorView y, TensorView x, TensorView w_ptr, TensorView sorted_token_ids, + TensorView expert_ids, TensorView topk_weights, TensorView lora_indices, + TensorView slice_start_loc, int64_t first_feat_out, int64_t lora_stride); diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 6d75420d66..455cd5321b 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -67,6 +67,7 @@ gen_cutlass_fused_moe_sm120_module, gen_trtllm_gen_fused_moe_sm100_module, ) +from .jit.bgmv_moe import gen_bgmv_moe_module from .jit.gdn import gen_gdn_prefill_sm90_module from .jit.gemm import ( gen_fp8_blockscale_gemm_sm90_module, @@ -493,6 +494,8 @@ def gen_all_modules( if add_moe: jit_specs.append(gen_gemm_module()) + # Multi-LoRA MoE BGMV kernel + jit_specs.append(gen_bgmv_moe_module()) if has_sm90: jit_specs.append(gen_gemm_sm90_module()) # fp8 blockscale GEMM (SM90) diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index 6b3279ef64..a5737ad1a7 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -47,6 +47,14 @@ fused_topk_deepseek as fused_topk_deepseek, ) +from .bgmv_moe import ( # noqa: F401 + bgmv_moe as bgmv_moe, + bgmv_moe_shrink as bgmv_moe_shrink, + bgmv_moe_expand as bgmv_moe_expand, + fill_w_ptr as fill_w_ptr, + has_bgmv_moe as has_bgmv_moe, +) + # CuteDSL MoE APIs (conditionally imported if cute_dsl available) try: from .cute_dsl import ( @@ -85,6 +93,11 @@ "trtllm_mxint4_block_scale_moe", "trtllm_mxint4_block_scale_routed_moe", "fused_topk_deepseek", + "bgmv_moe", + "bgmv_moe_shrink", + "bgmv_moe_expand", + "fill_w_ptr", + "has_bgmv_moe", ] # Add CuteDSL exports if available diff --git a/flashinfer/fused_moe/bgmv_moe.py b/flashinfer/fused_moe/bgmv_moe.py new file mode 100644 index 0000000000..3793ba9bff --- /dev/null +++ b/flashinfer/fused_moe/bgmv_moe.py @@ -0,0 +1,260 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import functools +from typing import List, Optional + +import torch + +from ..api_logging import flashinfer_api + + +@functools.cache +def _get_bgmv_moe_module(): + """Lazily load the BGMV MoE CUDA extension. + + Tries in order: + Loads via FlashInfer's JIT compilation system (TVM-FFI). + """ + try: + from ..jit.bgmv_moe import load_bgmv_moe_module + + return load_bgmv_moe_module() + except (ImportError, FileNotFoundError, RuntimeError) as e: + raise ImportError( + f"Failed to load BGMV MoE CUDA extension via JIT. " + f"Ensure CUDA toolkit is available and csrc/bgmv_moe/ sources exist.\n" + f"Error: {e}" + ) from e + + +@functools.cache +def has_bgmv_moe() -> bool: + """Return True if the BGMV MoE CUDA extension is available.""" + try: + _get_bgmv_moe_module() + return True + except ImportError: + return False + + +@flashinfer_api +def bgmv_moe_shrink( + y: torch.Tensor, + x: torch.Tensor, + w_ptr: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + lora_indices: torch.Tensor, + lora_stride: int, +) -> None: + """ + MoE LoRA shrink operation: project input through LoRA-A matrices. + + For each (token, expert) pair, computes: + y[slice, pair, rank] += x[token] @ lora_a[expert, lora_id, :, :] + + Args: + y: Output tensor [num_slices, num_pairs, rank]. Accumulated in-place. + x: Input activations [num_tokens, hidden_dim]. + w_ptr: Pointer table [num_slices, num_experts] of int64. + Each entry points to the start of lora_a weights for (slice, expert). + The kernel uses lora_stride to index different LoRA adapters. + sorted_token_ids: Token indices for each pair [num_pairs]. + expert_ids: Expert indices for each pair [num_pairs]. + lora_indices: LoRA adapter ID for each token [num_tokens]. + -1 means no LoRA (pair is skipped). + lora_stride: Stride (in elements) between consecutive LoRA adapters + in the weight tensor. For layout [max_loras, num_experts, rank, feat], + this is num_experts * rank * feat. + """ + mod = _get_bgmv_moe_module() + mod.bgmv_moe_shrink( + y, x, w_ptr, sorted_token_ids, expert_ids, lora_indices, lora_stride + ) + + +@flashinfer_api +def bgmv_moe_expand( + y: torch.Tensor, + x: torch.Tensor, + w_ptr: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + topk_weights: torch.Tensor, + lora_indices: torch.Tensor, + slice_start_loc: torch.Tensor, + output_slices: List[int], + lora_stride: int, +) -> None: + """ + MoE LoRA expand operation: project through LoRA-B matrices with routing weights. + + For each (token, expert) pair, computes: + y[token, col_offset:col_offset+feat] += topk_weight * (x[slice, pair, :] @ lora_b[expert, lora_id, :, :]) + + Args: + y: Output tensor [num_tokens, total_feat_out]. Float32 accumulation buffer. + x: Shrink output [num_slices, num_pairs, rank]. + w_ptr: Pointer table [num_slices, num_experts] of int64. + sorted_token_ids: Token indices for each pair [num_pairs]. + expert_ids: Expert indices for each pair [num_pairs]. + topk_weights: Routing weights for each pair [num_pairs]. Float32. + lora_indices: LoRA adapter ID for each token [num_tokens]. + slice_start_loc: Column offset for each slice [num_slices]. Int64. + output_slices: Output feature dimension for each slice. + lora_stride: Stride between LoRA adapters in weight tensor. + """ + mod = _get_bgmv_moe_module() + mod.bgmv_moe_expand( + y, + x, + w_ptr, + sorted_token_ids, + expert_ids, + topk_weights, + lora_indices, + slice_start_loc, + output_slices[0], + lora_stride, + ) + + +def fill_w_ptr( + w_ptr: torch.Tensor, + weights: torch.Tensor, + num_experts: int, + slice_id: int, +) -> int: + """ + Fill the weight pointer table for a given slice. + + Populates w_ptr[slice_id, 0:num_experts] with data pointers for each expert. + Works with weight layout [max_loras, num_experts, rank, feat]. + + Args: + w_ptr: Pointer table [num_slices, num_experts] of int64. + weights: LoRA weight tensor [max_loras, num_experts, rank, feat]. + num_experts: Number of experts. + slice_id: Which slice to populate. + + Returns: + lora_stride: The stride (in elements) between LoRA adapters. + """ + # w shape: [max_loras, num_experts, rank, feat] + base_ptr = weights.data_ptr() + expert_stride_bytes = weights.stride(1) * weights.element_size() + + arange = torch.arange(num_experts, dtype=torch.int64, device=weights.device) + w_ptr[slice_id, :num_experts] = arange * expert_stride_bytes + base_ptr + + # lora_stride = stride along dim 0 (in elements) + return weights.stride(0) + + +@flashinfer_api +def bgmv_moe( + x: torch.Tensor, + lora_a_weights: List[torch.Tensor], + lora_b_weights: List[torch.Tensor], + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + lora_indices: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + output_dim: Optional[int] = None, +) -> torch.Tensor: + """ + High-level multi-LoRA MoE BGMV: shrink + expand in one call. + + Computes the LoRA delta for MoE: + delta[token] = Σ_expert (topk_weight * x[token] @ lora_a[expert, lora_id] @ lora_b[expert, lora_id]) + + Args: + x: Input activations [num_tokens, hidden_dim]. + lora_a_weights: List of LoRA-A weight tensors, one per slice. + Each has shape [max_loras, num_experts, rank, hidden_dim]. + lora_b_weights: List of LoRA-B weight tensors, one per slice. + Each has shape [max_loras, num_experts, feat_out, rank]. + sorted_token_ids: Token indices for each pair [num_pairs]. + expert_ids: Expert indices for each pair [num_pairs]. + lora_indices: LoRA adapter ID for each token [num_tokens]. + topk_weights: Routing weights for each pair [num_pairs]. + num_experts: Number of experts. + output_dim: Total output dimension. If None, inferred from lora_b_weights. + + Returns: + Output tensor [num_tokens, total_feat_out] with LoRA deltas. + """ + num_slices = len(lora_a_weights) + num_tokens = x.size(0) + num_pairs = sorted_token_ids.size(0) + rank = lora_a_weights[0].size(2) + device = x.device + dtype = x.dtype + + # Infer output dimension + feat_out_per_slice = [lora_b_weights[s].size(2) for s in range(num_slices)] + total_feat_out = output_dim if output_dim is not None else sum(feat_out_per_slice) + + # Build w_ptr for shrink (lora_a) + w_ptr_a = torch.zeros(num_slices, num_experts, dtype=torch.int64, device=device) + lora_stride_a = 0 + for s in range(num_slices): + lora_stride_a = fill_w_ptr(w_ptr_a, lora_a_weights[s], num_experts, s) + + # Shrink: x @ lora_a -> [num_slices, num_pairs, rank] + shrink_out = torch.zeros(num_slices, num_pairs, rank, dtype=dtype, device=device) + bgmv_moe_shrink( + shrink_out, + x, + w_ptr_a, + sorted_token_ids, + expert_ids, + lora_indices, + lora_stride_a, + ) + + # Build w_ptr for expand (lora_b) + w_ptr_b = torch.zeros(num_slices, num_experts, dtype=torch.int64, device=device) + lora_stride_b = 0 + for s in range(num_slices): + lora_stride_b = fill_w_ptr(w_ptr_b, lora_b_weights[s], num_experts, s) + + # Slice start locations (build on CPU, transfer once to avoid per-element sync) + slice_start_loc_cpu = torch.zeros(num_slices, dtype=torch.int64) + loc = 0 + for s in range(num_slices): + slice_start_loc_cpu[s] = loc + loc += feat_out_per_slice[s] + slice_start_loc = slice_start_loc_cpu.to(device=device) + + # Expand: shrink_out @ lora_b -> [num_tokens, total_feat_out] + y = torch.zeros(num_tokens, total_feat_out, dtype=torch.float32, device=device) + bgmv_moe_expand( + y, + shrink_out, + w_ptr_b, + sorted_token_ids, + expert_ids, + topk_weights, + lora_indices, + slice_start_loc, + feat_out_per_slice, + lora_stride_b, + ) + + return y.to(dtype) diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 8378e0ab74..fe6de6db74 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -97,6 +97,8 @@ from .fp4_kv_quantization import ( gen_fp4_kv_quantization_module as gen_fp4_kv_quantization_module, ) +from .bgmv_moe import gen_bgmv_moe_module as gen_bgmv_moe_module +from .bgmv_moe import load_bgmv_moe_module as load_bgmv_moe_module cuda_lib_path = os.environ.get( diff --git a/flashinfer/jit/bgmv_moe.py b/flashinfer/jit/bgmv_moe.py new file mode 100644 index 0000000000..3d848c9de1 --- /dev/null +++ b/flashinfer/jit/bgmv_moe.py @@ -0,0 +1,144 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import functools +import os +import shutil +from pathlib import Path + +from . import env as jit_env +from .core import gen_jit_spec, logger, current_compilation_context + + +def _get_bgmv_moe_csrc_dir() -> Path: + """Get the path to the BGMV MoE CUDA source directory. + + Handles both installed package (data/csrc/bgmv_moe) and + development checkout (../../csrc/bgmv_moe relative to this file). + """ + # Standard path via FlashInfer's data directory + standard_path = jit_env.FLASHINFER_CSRC_DIR / "bgmv_moe" + if standard_path.exists(): + return standard_path + + # Development fallback: relative to this file + dev_path = Path(__file__).parent.parent.parent / "csrc" / "bgmv_moe" + if dev_path.exists(): + return dev_path + + raise FileNotFoundError( + f"BGMV MoE CUDA sources not found. Checked:\n" + f" - {standard_path}\n" + f" - {dev_path}\n" + f"Please ensure the csrc/bgmv_moe/ directory exists." + ) + + +def get_bgmv_moe_uri() -> str: + """Generate unique identifier for the BGMV MoE module.""" + return "bgmv_moe" + + +@functools.cache +def gen_bgmv_moe_module(): + """ + Generate the JIT compilation spec for the BGMV MoE CUDA kernels. + + This compiles the multi-LoRA MoE BGMV shrink/expand kernel pair. + Supports SM70+ (V100, A100, H100, B200). + + Returns: + JitSpec that can be built and loaded. + """ + csrc_dir = _get_bgmv_moe_csrc_dir() + uri = get_bgmv_moe_uri() + + # Create generation directory + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri + os.makedirs(gen_directory, exist_ok=True) + + # Source files to copy + source_files = [ + "moe_bgmv_binding.cu", + "moe_bgmv_bf16_bf16_bf16.cu", + "moe_bgmv_bf16_fp32_bf16.cu", + "moe_bgmv_fp16_fp16_fp16.cu", + "moe_bgmv_fp16_fp32_fp16.cu", + "moe_bgmv_fp32_bf16_bf16.cu", + "moe_bgmv_fp32_fp16_fp16.cu", + ] + + # Header files to copy (includes moe_bgmv_ops.cu which is #included by binding) + header_files = [ + "moe_bgmv_impl.cuh", + "moe_bgmv_config.h", + "moe_bgmv_ops.h", + "moe_bgmv_ops.cu", + "kernel_config.h", + ] + + # Copy sources to gen directory + sources = [] + for fname in source_files: + src_path = csrc_dir / fname + if not src_path.exists(): + raise FileNotFoundError(f"BGMV MoE source file not found: {src_path}") + dest_path = gen_directory / fname + shutil.copy(src_path, dest_path) + sources.append(dest_path) + + # Copy headers to gen directory + for fname in header_files: + src_path = csrc_dir / fname + if not src_path.exists(): + raise FileNotFoundError(f"BGMV MoE header file not found: {src_path}") + shutil.copy(src_path, gen_directory / fname) + + # Get nvcc flags for supported architectures (SM70+) + nvcc_flags = current_compilation_context.get_nvcc_flags_list( + supported_major_versions=[9, 10, 11, 12] # SM90+ (H100, B200) + ) + + spec = gen_jit_spec( + name=uri, + sources=sources, + extra_cuda_cflags=nvcc_flags + + [ + "-DFLASHINFER_ENABLE_BF16", + "-DFLASHINFER_ENABLE_F16", + ], + extra_include_paths=[ + str(gen_directory), + str(jit_env.FLASHINFER_INCLUDE_DIR), + str(jit_env.FLASHINFER_CSRC_DIR), + ], + ) + + logger.info(f"Generated BGMV MoE JIT spec: {spec.name}") + return spec + + +@functools.cache +def load_bgmv_moe_module(): + """ + Build and load the BGMV MoE CUDA extension via FlashInfer's JIT system. + + Returns the loaded module with `bgmv_moe_shrink` and `bgmv_moe_expand` functions. + """ + spec = gen_bgmv_moe_module() + module = spec.build_and_load() + logger.info("BGMV MoE module loaded successfully") + return module diff --git a/scripts/task_jit_run_tests_part5.sh b/scripts/task_jit_run_tests_part5.sh index e6726a0112..4066207ae8 100755 --- a/scripts/task_jit_run_tests_part5.sh +++ b/scripts/task_jit_run_tests_part5.sh @@ -17,3 +17,4 @@ fi pytest -s tests/utils/test_logits_processor.py pytest -s tests/cli/test_cli_cmds.py pytest -s tests/cli/test_cli_cmds_gpu.py +pytest -s tests/moe/test_bgmv_moe.py diff --git a/tests/moe/test_bgmv_moe.py b/tests/moe/test_bgmv_moe.py new file mode 100644 index 0000000000..9506d4c83e --- /dev/null +++ b/tests/moe/test_bgmv_moe.py @@ -0,0 +1,598 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os + +os.environ.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1") + +import pytest +import torch + + +# ============================================================ +# PyTorch Reference Implementation +# ============================================================ + + +def reference_moe_bgmv_shrink( + x: torch.Tensor, # [num_tokens, hidden_dim] + lora_a_weights: list, # list of [max_loras, num_experts, rank, hidden_dim] + sorted_token_ids: torch.Tensor, # [num_pairs] + expert_ids: torch.Tensor, # [num_pairs] + lora_indices: torch.Tensor, # [num_tokens] +) -> torch.Tensor: + """ + Reference shrink: for each (token, expert) pair, compute x @ lora_a^T. + + Returns: [num_slices, num_pairs, rank] + """ + num_slices = len(lora_a_weights) + num_pairs = sorted_token_ids.size(0) + rank = lora_a_weights[0].size(2) + device = x.device + dtype = x.dtype + + y = torch.zeros(num_slices, num_pairs, rank, dtype=dtype, device=device) + + for pair_idx in range(num_pairs): + token_idx = sorted_token_ids[pair_idx].item() + if token_idx < 0 or token_idx >= x.size(0): + continue + lora_id = lora_indices[token_idx].item() + if lora_id < 0: + continue + expert_id = expert_ids[pair_idx].item() + x_tok = x[token_idx] # [hidden_dim] + + for s in range(num_slices): + # lora_a shape: [max_loras, num_experts, rank, hidden_dim] + w_a = lora_a_weights[s][lora_id, expert_id] # [rank, hidden_dim] + # y[s, pair, :] = x_tok @ w_a^T + y[s, pair_idx] = x_tok @ w_a.t() + + return y + + +def reference_moe_bgmv_expand( + shrink_out: torch.Tensor, # [num_slices, num_pairs, rank] + lora_b_weights: list, # list of [max_loras, num_experts, feat_out, rank] + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + lora_indices: torch.Tensor, + topk_weights: torch.Tensor, + num_tokens: int, +) -> torch.Tensor: + """ + Reference expand: for each (token, expert) pair, compute shrink_out @ lora_b^T * topk_weight. + + Returns: [num_tokens, total_feat_out] + """ + num_slices = len(lora_b_weights) + num_pairs = sorted_token_ids.size(0) + device = shrink_out.device + + feat_out_per_slice = [lora_b_weights[s].size(2) for s in range(num_slices)] + total_feat_out = sum(feat_out_per_slice) + + y = torch.zeros(num_tokens, total_feat_out, dtype=torch.float32, device=device) + + for pair_idx in range(num_pairs): + token_idx = sorted_token_ids[pair_idx].item() + if token_idx < 0 or token_idx >= num_tokens: + continue + lora_id = lora_indices[token_idx].item() + if lora_id < 0: + continue + expert_id = expert_ids[pair_idx].item() + topk_w = topk_weights[pair_idx].item() + + col_offset = 0 + for s in range(num_slices): + # lora_b shape: [max_loras, num_experts, feat_out, rank] + w_b = lora_b_weights[s][lora_id, expert_id] # [feat_out, rank] + x_s = shrink_out[s, pair_idx] # [rank] + # y[token, col_offset:col_offset+feat_out] += topk_w * (x_s @ w_b^T) + y[token_idx, col_offset : col_offset + feat_out_per_slice[s]] += topk_w * ( + x_s @ w_b.t() + ) + col_offset += feat_out_per_slice[s] + + return y + + +def reference_bgmv_moe( + x: torch.Tensor, + lora_a_weights: list, + lora_b_weights: list, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + lora_indices: torch.Tensor, + topk_weights: torch.Tensor, +) -> torch.Tensor: + """Full reference: shrink + expand.""" + shrink_out = reference_moe_bgmv_shrink( + x, lora_a_weights, sorted_token_ids, expert_ids, lora_indices + ) + y = reference_moe_bgmv_expand( + shrink_out, + lora_b_weights, + sorted_token_ids, + expert_ids, + lora_indices, + topk_weights, + x.size(0), + ) + return y + + +# ============================================================ +# Test Fixtures +# ============================================================ + + +def generate_test_data( + num_tokens: int, + hidden_size: int, + rank: int, + num_experts: int, + top_k: int, + num_loras: int, + num_slices: int, + dtype: torch.dtype, + device: str = "cuda", +): + """Generate random test data for BGMV MoE kernels.""" + # Input activations + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) * 0.1 + + # LoRA weights: [max_loras, num_experts, rank, hidden_dim] for A + # [max_loras, num_experts, feat_out, rank] for B + max_loras = num_loras + feat_out = hidden_size # For simplicity, same as hidden_size + + lora_a_weights = [ + torch.randn( + max_loras, num_experts, rank, hidden_size, dtype=dtype, device=device + ) + * 0.01 + for _ in range(num_slices) + ] + lora_b_weights = [ + torch.randn(max_loras, num_experts, feat_out, rank, dtype=dtype, device=device) + * 0.01 + for _ in range(num_slices) + ] + + # Routing: each token is routed to top_k experts + # sorted_token_ids: flattened token indices (repeated for each expert) + # expert_ids: which expert each pair goes to + num_pairs = num_tokens * top_k + + # Simple routing: token i goes to experts [i*top_k % num_experts, ...] + sorted_token_ids = torch.arange( + num_tokens, device=device, dtype=torch.int64 + ).repeat_interleave(top_k) + expert_ids = torch.randint( + 0, num_experts, (num_pairs,), device=device, dtype=torch.int64 + ) + topk_weights = ( + torch.softmax(torch.randn(num_tokens, top_k, device=device), dim=-1) + .view(-1) + .to(torch.float32) + ) + + # LoRA indices: assign each token a random LoRA adapter (some may be -1 = no LoRA) + lora_indices = torch.randint( + -1, max_loras, (num_tokens,), device=device, dtype=torch.int64 + ) + # Ensure at least some tokens have valid LoRA + lora_indices[: num_tokens // 2] = torch.randint( + 0, max_loras, (num_tokens // 2,), device=device, dtype=torch.int64 + ) + + return { + "x": x, + "lora_a_weights": lora_a_weights, + "lora_b_weights": lora_b_weights, + "sorted_token_ids": sorted_token_ids, + "expert_ids": expert_ids, + "topk_weights": topk_weights, + "lora_indices": lora_indices, + "num_experts": num_experts, + "num_pairs": num_pairs, + "rank": rank, + "num_slices": num_slices, + "feat_out": hidden_size, + } + + +# ============================================================ +# Tests +# ============================================================ + +# BGMV MoE kernels are tested on SM80-SM90 (A100, H100, B200). +# Skip on consumer Blackwell GPUs (SM120, e.g., RTX 5090, RTX Pro 6000) +# where extended shared memory behavior may differ. +_SUPPORTED_SM = {90, 100, 103} + + +def _skip_if_unsupported_sm(): + """Skip test if current GPU SM version is not in the supported set.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + capability = torch.cuda.get_device_capability() + sm = capability[0] * 10 + capability[1] + if sm not in _SUPPORTED_SM: + pytest.skip( + f"BGMV MoE kernel not validated on SM{sm} " + f"(device: {torch.cuda.get_device_name()}). " + f"Supported: {sorted(_SUPPORTED_SM)}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestBgmvMoeShrink: + """Test the shrink kernel against reference.""" + + def setup_method(self): + _skip_if_unsupported_sm() + + @pytest.mark.parametrize("num_tokens", [1, 4, 32]) + @pytest.mark.parametrize("hidden_size", [768, 2048]) + @pytest.mark.parametrize("rank", [16, 32]) + @pytest.mark.parametrize("num_experts", [8, 64]) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_shrink_correctness( + self, num_tokens, hidden_size, rank, num_experts, dtype + ): + from flashinfer.fused_moe.bgmv_moe import bgmv_moe_shrink, fill_w_ptr + + top_k = 2 + num_loras = 4 + num_slices = 1 + + data = generate_test_data( + num_tokens, + hidden_size, + rank, + num_experts, + top_k, + num_loras, + num_slices, + dtype, + ) + + # Reference + ref_out = reference_moe_bgmv_shrink( + data["x"], + data["lora_a_weights"], + data["sorted_token_ids"], + data["expert_ids"], + data["lora_indices"], + ) + + # CUDA kernel + num_pairs = data["num_pairs"] + w_ptr = torch.zeros(num_slices, num_experts, dtype=torch.int64, device="cuda") + lora_stride = fill_w_ptr(w_ptr, data["lora_a_weights"][0], num_experts, 0) + + cuda_out = torch.zeros(num_slices, num_pairs, rank, dtype=dtype, device="cuda") + bgmv_moe_shrink( + cuda_out, + data["x"], + w_ptr, + data["sorted_token_ids"], + data["expert_ids"], + data["lora_indices"], + lora_stride, + ) + + # Compare + torch.testing.assert_close( + cuda_out.float(), + ref_out.float(), + atol=1e-2, + rtol=1e-2, + msg=f"Shrink mismatch: tokens={num_tokens}, hidden={hidden_size}, " + f"rank={rank}, experts={num_experts}, dtype={dtype}", + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestBgmvMoeExpand: + """Test the expand kernel against reference.""" + + def setup_method(self): + _skip_if_unsupported_sm() + + @pytest.mark.parametrize("num_tokens", [1, 4, 32]) + @pytest.mark.parametrize("hidden_size", [768, 2048]) + @pytest.mark.parametrize("rank", [16, 32]) + @pytest.mark.parametrize("num_experts", [8, 64]) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_expand_correctness( + self, num_tokens, hidden_size, rank, num_experts, dtype + ): + from flashinfer.fused_moe.bgmv_moe import bgmv_moe_expand, fill_w_ptr + + top_k = 2 + num_loras = 4 + num_slices = 1 + + data = generate_test_data( + num_tokens, + hidden_size, + rank, + num_experts, + top_k, + num_loras, + num_slices, + dtype, + ) + + # Generate shrink output (use reference for isolation) + shrink_out = reference_moe_bgmv_shrink( + data["x"], + data["lora_a_weights"], + data["sorted_token_ids"], + data["expert_ids"], + data["lora_indices"], + ) + + # Reference expand + ref_out = reference_moe_bgmv_expand( + shrink_out, + data["lora_b_weights"], + data["sorted_token_ids"], + data["expert_ids"], + data["lora_indices"], + data["topk_weights"], + num_tokens, + ) + + # CUDA kernel + w_ptr = torch.zeros(num_slices, num_experts, dtype=torch.int64, device="cuda") + lora_stride = fill_w_ptr(w_ptr, data["lora_b_weights"][0], num_experts, 0) + + slice_start_loc = torch.zeros(num_slices, dtype=torch.int64, device="cuda") + feat_out = hidden_size + output_slices = [feat_out] * num_slices + + cuda_out = torch.zeros(num_tokens, feat_out, dtype=torch.float32, device="cuda") + bgmv_moe_expand( + cuda_out, + shrink_out, + w_ptr, + data["sorted_token_ids"], + data["expert_ids"], + data["topk_weights"], + data["lora_indices"], + slice_start_loc, + output_slices, + lora_stride, + ) + + # Compare + torch.testing.assert_close( + cuda_out, + ref_out, + atol=1e-2, + rtol=1e-2, + msg=f"Expand mismatch: tokens={num_tokens}, hidden={hidden_size}, " + f"rank={rank}, experts={num_experts}, dtype={dtype}", + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestBgmvMoeEndToEnd: + """End-to-end test: shrink + expand combined.""" + + def setup_method(self): + _skip_if_unsupported_sm() + + @pytest.mark.parametrize("num_tokens", [1, 8, 32]) + @pytest.mark.parametrize("hidden_size", [768, 2048]) + @pytest.mark.parametrize("rank", [16, 32]) + @pytest.mark.parametrize("num_experts", [8, 64]) + @pytest.mark.parametrize("top_k", [2]) + @pytest.mark.parametrize("dtype", [torch.bfloat16]) + def test_end_to_end(self, num_tokens, hidden_size, rank, num_experts, top_k, dtype): + from flashinfer.fused_moe.bgmv_moe import bgmv_moe + + num_loras = 4 + num_slices = 1 + + data = generate_test_data( + num_tokens, + hidden_size, + rank, + num_experts, + top_k, + num_loras, + num_slices, + dtype, + ) + + # Reference + ref_out = reference_bgmv_moe( + data["x"], + data["lora_a_weights"], + data["lora_b_weights"], + data["sorted_token_ids"], + data["expert_ids"], + data["lora_indices"], + data["topk_weights"], + ) + + # CUDA kernel (high-level API) + cuda_out = bgmv_moe( + data["x"], + data["lora_a_weights"], + data["lora_b_weights"], + data["sorted_token_ids"], + data["expert_ids"], + data["lora_indices"], + data["topk_weights"], + num_experts, + ) + + # Compare + torch.testing.assert_close( + cuda_out.float(), + ref_out.float(), + atol=5e-2, + rtol=5e-2, + msg=f"E2E mismatch: tokens={num_tokens}, hidden={hidden_size}, " + f"rank={rank}, experts={num_experts}, top_k={top_k}", + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestBgmvMoeEdgeCases: + """Edge case tests.""" + + def setup_method(self): + _skip_if_unsupported_sm() + + def test_all_tokens_no_lora(self): + """All tokens have lora_id=-1, output should be zero.""" + from flashinfer.fused_moe.bgmv_moe import bgmv_moe + + num_tokens, hidden_size, rank, num_experts, top_k = 16, 768, 16, 8, 2 + dtype = torch.bfloat16 + num_loras = 4 + num_slices = 1 + + data = generate_test_data( + num_tokens, + hidden_size, + rank, + num_experts, + top_k, + num_loras, + num_slices, + dtype, + ) + # Set all lora_indices to -1 + data["lora_indices"].fill_(-1) + + out = bgmv_moe( + data["x"], + data["lora_a_weights"], + data["lora_b_weights"], + data["sorted_token_ids"], + data["expert_ids"], + data["lora_indices"], + data["topk_weights"], + num_experts, + ) + + assert torch.all(out == 0), "Output should be zero when no LoRA is active" + + def test_single_token_single_expert(self): + """Minimal case: 1 token, 1 expert, 1 LoRA.""" + from flashinfer.fused_moe.bgmv_moe import bgmv_moe + + num_tokens, hidden_size, rank, num_experts, top_k = 1, 768, 8, 1, 1 + dtype = torch.bfloat16 + num_loras = 1 + num_slices = 1 + + data = generate_test_data( + num_tokens, + hidden_size, + rank, + num_experts, + top_k, + num_loras, + num_slices, + dtype, + ) + data["lora_indices"][0] = 0 # Ensure valid LoRA + + ref_out = reference_bgmv_moe( + data["x"], + data["lora_a_weights"], + data["lora_b_weights"], + data["sorted_token_ids"], + data["expert_ids"], + data["lora_indices"], + data["topk_weights"], + ) + + cuda_out = bgmv_moe( + data["x"], + data["lora_a_weights"], + data["lora_b_weights"], + data["sorted_token_ids"], + data["expert_ids"], + data["lora_indices"], + data["topk_weights"], + num_experts, + ) + + torch.testing.assert_close( + cuda_out.float(), ref_out.float(), atol=1e-2, rtol=1e-2 + ) + + def test_multi_slice_w13(self): + """Test with 2 slices (simulating gate+up projection).""" + from flashinfer.fused_moe.bgmv_moe import bgmv_moe + + num_tokens, hidden_size, rank, num_experts, top_k = 8, 2048, 16, 8, 2 + dtype = torch.bfloat16 + num_loras = 4 + num_slices = 2 + + data = generate_test_data( + num_tokens, + hidden_size, + rank, + num_experts, + top_k, + num_loras, + num_slices, + dtype, + ) + + ref_out = reference_bgmv_moe( + data["x"], + data["lora_a_weights"], + data["lora_b_weights"], + data["sorted_token_ids"], + data["expert_ids"], + data["lora_indices"], + data["topk_weights"], + ) + + cuda_out = bgmv_moe( + data["x"], + data["lora_a_weights"], + data["lora_b_weights"], + data["sorted_token_ids"], + data["expert_ids"], + data["lora_indices"], + data["topk_weights"], + num_experts, + ) + + torch.testing.assert_close( + cuda_out.float(), ref_out.float(), atol=5e-2, rtol=5e-2 + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"])