diff --git a/benchmarks/kernels/bench_ll_a_gemm.py b/benchmarks/kernels/bench_ll_a_gemm.py new file mode 100644 index 000000000000..e003f68f3646 --- /dev/null +++ b/benchmarks/kernels/bench_ll_a_gemm.py @@ -0,0 +1,384 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import os +import sys +import torch +from triton.testing import do_bench_cudagraph + +from vllm import _custom_ops as ops + +q = [0.5, 0.2, 0.8] +_HAS_DSV3 = hasattr(ops, 'dsv3_fused_a_gemm') + +try: + from flashinfer.gemm import tgv_gemm_sm100 + from flashinfer import autotune + _HAS_TGV = True +except ImportError: + _HAS_TGV = False + +try: + from flashinfer.gemm import mm_bf16 + _HAS_MM_BF16 = True +except ImportError: + _HAS_MM_BF16 = False + +parser = argparse.ArgumentParser(description='Benchmark ll_a_gemm kernels') +parser.add_argument('--l2-pollute', action='store_true', + help='Measure with cold L2 via nsys+CG (slower but more realistic)') +parser.add_argument('--shape', type=str, default=None, + help='Filter: K,N (e.g. "7168,2112") or label substring (e.g. "a_proj")') +parser.add_argument('--M', type=str, default=None, + help='Filter M values: comma-separated (e.g. "1,4")') +# Internal: used by nsys subprocess +parser.add_argument('--nsys-kernel', type=str, default=None, help=argparse.SUPPRESS) +parser.add_argument('--nsys-M', type=int, default=1, help=argparse.SUPPRESS) +parser.add_argument('--nsys-K', type=int, default=7168, help=argparse.SUPPRESS) +parser.add_argument('--nsys-N', type=int, default=2112, help=argparse.SUPPRESS) +args = parser.parse_args() + +# Shared: split-K autotuning (used by both nsys and normal mode) +#TODO (roberto): need to autotune under L2 cache pollution for nsys+CG +#TODO (roberto): need to add autotuner to vLLM +_sk_cache = {} + +def _get_best_splitk(a8, b8, M, K_phys, N): + import cutlass.cute as cute + from cutlass.cute.runtime import from_dlpack + from cuda.bindings.driver import CUstream + from torch.cuda import current_stream + + from vllm.model_executor.layers.fused_moe.router._ll_a_gemm_kernels import LLAGemm + + div = 8 + K_view = K_phys // 2 + tiles = K_view // 256 + best_t = float('inf') + best_compiled = None + for sk in [2, 3, 4, 6, 8, 12]: + if tiles % sk != 0: continue + for ns in [2, 3, 4]: + if ns > tiles // sk: continue + ck = (True, sk, ns, K_view, N) + try: + out = torch.empty(N, M, dtype=torch.bfloat16, device=a8.device) + if ck not in _sk_cache: + mA = from_dlpack(b8, assumed_align=16, enable_tvm_ffi=True).mark_layout_dynamic(leading_dim=1).mark_compact_shape_dynamic(mode=1, stride_order=(0,1), divisibility=div) + mB = from_dlpack(a8, assumed_align=16, enable_tvm_ffi=True).mark_layout_dynamic(leading_dim=1).mark_compact_shape_dynamic(mode=1, stride_order=(0,1), divisibility=div) + mC = from_dlpack(out, assumed_align=16, enable_tvm_ffi=True).mark_layout_dynamic(leading_dim=1).mark_compact_shape_dynamic(mode=1, stride_order=(0,1), divisibility=1) + gemm = LLAGemm(tile_n=8, tile_k=256, num_stages=ns, num_dma_warps=4, is_fp8=True, split_k=sk) + _sk_cache[ck] = cute.compile(gemm.call_splitk, mA, mB, mC, CUstream(current_stream().cuda_stream), options="--enable-tvm-ffi") + c = _sk_cache[ck] + t = do_bench_cudagraph( + lambda c=c, a=a8, b=b8, o=out: c(b, a, o, current_stream().cuda_stream, 1.0), + quantiles=q)[0] + if t < best_t: + best_t = t + best_compiled = c + except Exception: + pass + return best_t * 1000 if best_t < float('inf') else float('nan'), best_compiled + +if args.nsys_kernel: + from torch.cuda import current_stream as _cs + M, K, N = args.nsys_M, args.nsys_K, args.nsys_N + kernel_types = args.nsys_kernel.split(',') + + a = torch.randn(M, K, dtype=torch.bfloat16, device='cuda') + b = torch.randn(N, K, dtype=torch.bfloat16, device='cuda') + a8 = a.to(torch.float8_e4m3fn).view(torch.bfloat16) + b8 = b.to(torch.float8_e4m3fn).view(torch.bfloat16) + + # L2 pollution via normal_() + _l2_buf = torch.empty(64 * 1024 * 1024 // 2, dtype=torch.bfloat16, device='cuda') + def _l2p(): _l2_buf.normal_() + + kernels = {} + for kt in kernel_types: + try: + if kt == 'p-bf16': + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm + kernels[kt] = lambda: ll_a_gemm(a, b) + elif kt == 'p-fp8': + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm + kernels[kt] = lambda: ll_a_gemm(a8, b8, is_fp8=True) + elif kt == 't-bf16': + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm_tma import ll_a_gemm_tma + ll_a_gemm_tma(a, b); torch.cuda.synchronize() + kernels[kt] = lambda: ll_a_gemm_tma(a, b) + elif kt == 't-fp8': + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm_tma import ll_a_gemm_tma + ll_a_gemm_tma(a8, b8, is_fp8=True); torch.cuda.synchronize() + kernels[kt] = lambda: ll_a_gemm_tma(a8, b8, is_fp8=True) + elif kt == 'DSV3': + o = torch.empty(M, N, dtype=torch.bfloat16, device='cuda') + kernels[kt] = lambda o=o: ops.dsv3_fused_a_gemm(o, a, b.T) + elif kt == 'TGV': + bias = torch.zeros(N, dtype=torch.bfloat16, device='cuda') + out = torch.empty(M, N, dtype=torch.bfloat16, device='cuda') + with autotune(True): + tgv_gemm_sm100(a, b.T, bias, out=out) + torch.cuda.synchronize() + kernels[kt] = lambda: tgv_gemm_sm100(a, b.T, bias, out=out) + elif kt == 'fi-bf16': + with autotune(True): + mm_bf16(a, b.T, backend='auto') + torch.cuda.synchronize() + kernels[kt] = lambda: mm_bf16(a, b.T, backend='auto') + elif kt == 'sk-fp8': + from cuda.bindings.driver import CUstream as _CUStream + out_sk = torch.empty(N, M, dtype=torch.bfloat16, device='cuda') + _, best_c = _get_best_splitk(a8, b8, M, K, N) + if best_c: + kernels[kt] = lambda c=best_c, o=out_sk: c(b8, a8, o, _CUStream(_cs().cuda_stream), 1.0) + elif kt == 'cuBLAS': + omm = torch.empty(M, N, dtype=torch.bfloat16, device='cuda') + kernels[kt] = lambda: torch.mm(a, b.T, out=omm) + elif kt == 'smm': + a_fp8 = a.to(torch.float8_e4m3fn) + b_fp8 = b.to(torch.float8_e4m3fn) + bt = b_fp8.T.contiguous() + s1 = torch.ones(1, device='cuda', dtype=torch.float32) + kernels[kt] = lambda: torch._scaled_mm(a_fp8, bt, scale_a=s1, scale_b=s1, out_dtype=torch.bfloat16) + except Exception: + pass + + stream = torch.cuda.Stream() + graphs = {} + for kt, kfn in kernels.items(): + with torch.cuda.stream(stream): + _l2p(); kfn(); torch.cuda.synchronize() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=stream): + _l2p(); kfn() + graphs[kt] = g + + for g in graphs.values(): + for _ in range(3): g.replay() + torch.cuda.synchronize() + + # Tag L2 prefix for identification in nsys stats + torch.cuda.nvtx.range_push('BENCH__L2PREFIX') + _l2p(); torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + # Profiled replays with NVTX markers + for kt, g in graphs.items(): + torch.cuda.nvtx.range_push(f'BENCH_{kt}') + for _ in range(20): g.replay() + torch.cuda.nvtx.range_pop() + torch.cuda.synchronize() + + print('NSYS_KERNEL_DONE:' + ','.join(graphs.keys())) + sys.exit(0) + +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +from cuda.bindings.driver import CUstream +from torch.cuda import current_stream +from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm +from vllm.model_executor.layers.fused_moe.router.ll_a_gemm_tma import ll_a_gemm_tma +from vllm.model_executor.layers.fused_moe.router._ll_a_gemm_kernels import LLAGemm + +print(f'Device: {torch.cuda.get_device_name()}') +_mode = 'nsys+CG (cold L2)' if args.l2_pollute else 'do_bench_cudagraph (warm L2)' +print(f'DSV3-A: {_HAS_DSV3} | TGV: {_HAS_TGV} | mm_bf16: {_HAS_MM_BF16} | Mode: {_mode}', flush=True) +print() + +SHAPES = [ + (7168, 2112, "a_proj combined"), + (7168, 576, "kv_a_proj"), + (7168, 1536, "q_a_proj"), + (1536, 24576, "q_b_proj TP1"), + (1536, 3072, "q_b_proj TP8"), + (1536, 6144, "q_b_proj TP4"), + (512, 32768, "kv_b_proj TP1"), + (512, 4096, "kv_b_proj TP8"), + (512, 8192, "kv_b_proj TP4"), + (12288, 3072, "Mistral TP4 Q [FP8]"), + (12288, 256, "Mistral TP4 K/V [FP8]"), + (12288, 1536, "Mistral TP8 Q [FP8]"), + (12288, 128, "Mistral TP8 K/V [FP8]"), +] + +def _bench(fn): + return do_bench_cudagraph(fn, quantiles=q)[0] * 1000 + +def _run_nsys_batch(M, K, N): + import subprocess, shutil + + nsys = shutil.which('nsys') or '/usr/local/bin/nsys' + script = os.path.abspath(sys.argv[0]) + + kt_list = ['p-fp8', 'sk-fp8', 't-fp8', 'smm'] + if K <= 7168: + kt_list = ['p-bf16', 't-bf16', 'cuBLAS', 'fi-bf16'] + kt_list + if K == 7168 and N == 2112: + kt_list.append('DSV3') + if N % 16 == 0 and K <= 7168: + kt_list.append('TGV') + + cmd = [nsys, 'profile', '--stats=true', '-t', 'cuda,nvtx', + '--cuda-graph-trace=node', '-o', '/tmp/_bench_nsys_tmp', '-f', 'true', + sys.executable, script, + '--nsys-kernel', ','.join(kt_list), + '--nsys-M', str(M), '--nsys-K', str(K), '--nsys-N', str(N)] + + results = {} + try: + subprocess.run(cmd, capture_output=True, text=True, timeout=300) + + # Parse nvtx_kern_sum from nsys stats + stats_cmd = [nsys, 'stats', '--force-export=true', + '--report', 'nvtx_kern_sum', + '/tmp/_bench_nsys_tmp.nsys-rep'] + stats_result = subprocess.run(stats_cmd, capture_output=True, text=True, timeout=60) + stats_out = stats_result.stdout + stats_result.stderr + + # Collect L2 prefix kernel names to exclude + l2_prefixes = [] + for line in stats_out.split('\n'): + if ':BENCH__L2PREFIX' in line: + parts = line.split() + if len(parts) > 12: + l2_prefixes.append(' '.join(parts[12:])[:40]) + + # Extract per-kernel times from NVTX-tagged ranges + for line in stats_out.split('\n'): + if ':BENCH_' not in line or ':BENCH__L2PREFIX' in line: + continue + parts = line.split() + if len(parts) < 9: + continue + kt = parts[0].replace(':BENCH_', '') + try: + med_ns = float(parts[8]) + except (ValueError, IndexError): + continue + kernel_name = ' '.join(parts[12:]) if len(parts) > 12 else '' + if any(kernel_name.startswith(p) for p in l2_prefixes): + continue + if kt not in results: + results[kt] = med_ns / 1000.0 + + except (subprocess.TimeoutExpired, FileNotFoundError): + pass + return results + +bf16_keys = ['p-bf16', 't-bf16', 'DSV3', 'TGV', 'fi-bf16', 'cuBLAS'] +fp8_keys = ['p-fp8', 'sk-fp8', 't-fp8', 'smm'] +all_keys = ['p-bf16', 't-bf16', 'DSV3', 'TGV', 'fi-bf16', 'cuBLAS', 'p-fp8', 'sk-fp8', 't-fp8', 'smm'] + +def bench_one(M, K, N, label=""): + fp8_only = '[FP8]' in label + r = {} + + if args.l2_pollute: + r = _run_nsys_batch(M, K, N) + for k in all_keys: + if k not in r: + r[k] = float('nan') + return r + + a = torch.randn(M, K, dtype=torch.bfloat16, device='cuda') + b = torch.randn(N, K, dtype=torch.bfloat16, device='cuda') + a8 = a.to(torch.float8_e4m3fn).view(torch.bfloat16) + b8 = b.to(torch.float8_e4m3fn).view(torch.bfloat16) + + r['p-bf16'] = _bench(lambda: ll_a_gemm(a, b)) if not fp8_only else float('nan') + r['p-fp8'] = _bench(lambda: ll_a_gemm(a8, b8, is_fp8=True)) + r['sk-fp8'], _ = _get_best_splitk(a8, b8, M, K, N) + + if not fp8_only: + try: + ll_a_gemm_tma(a, b); torch.cuda.synchronize() + r['t-bf16'] = _bench(lambda: ll_a_gemm_tma(a, b)) + except: r['t-bf16'] = float('nan') + else: + r['t-bf16'] = float('nan') + + try: + ll_a_gemm_tma(a8, b8, is_fp8=True); torch.cuda.synchronize() + r['t-fp8'] = _bench(lambda: ll_a_gemm_tma(a8, b8, is_fp8=True)) + except: r['t-fp8'] = float('nan') + + if _HAS_DSV3 and K == 7168 and N == 2112 and M <= 16 and not fp8_only: + o = torch.empty(M, N, dtype=torch.bfloat16, device='cuda') + r['DSV3'] = _bench(lambda: ops.dsv3_fused_a_gemm(o, a, b.T)) + else: r['DSV3'] = float('nan') + + if _HAS_TGV and N % 16 == 0 and not fp8_only: + bias = torch.zeros(N, dtype=torch.bfloat16, device='cuda') + out = torch.empty(M, N, dtype=torch.bfloat16, device='cuda') + with autotune(True): + tgv_gemm_sm100(a, b.T, bias, out=out) + torch.cuda.synchronize() + r['TGV'] = _bench(lambda: tgv_gemm_sm100(a, b.T, bias, out=out)) + else: r['TGV'] = float('nan') + + if _HAS_MM_BF16 and not fp8_only: + try: + with autotune(True): + mm_bf16(a, b.T, backend='auto') + torch.cuda.synchronize() + r['fi-bf16'] = _bench(lambda: mm_bf16(a, b.T, backend='auto')) + except: r['fi-bf16'] = float('nan') + else: r['fi-bf16'] = float('nan') + + if not fp8_only: + omm = torch.empty(M, N, dtype=torch.bfloat16, device='cuda') + r['cuBLAS'] = _bench(lambda: torch.mm(a, b.T, out=omm)) + else: r['cuBLAS'] = float('nan') + + a_fp8 = a.to(torch.float8_e4m3fn) + b_fp8 = b.to(torch.float8_e4m3fn) + bt = b_fp8.T.contiguous() + s1 = torch.ones(1, device='cuda', dtype=torch.float32) + r['smm'] = _bench(lambda: torch._scaled_mm(a_fp8, bt, scale_a=s1, scale_b=s1, out_dtype=torch.bfloat16)) + return r + +_M_vals = [int(x) for x in args.M.split(',')] if args.M else [1, 4, 8, 16] + +for K, N, label in SHAPES: + if args.shape: + s = args.shape + if ',' in s: + fk, fn = s.split(',', 1) + if int(fk) != K or int(fn) != N: + continue + elif s.lower() not in label.lower(): + continue + + print(f'=== {label}: K={K}, N={N} ===', flush=True) + hdr = f"{'M':>3} |" + "".join(f" {c:>9}" for c in all_keys) + print(hdr, flush=True) + print('-' * len(hdr), flush=True) + + for M in _M_vals: + r = bench_one(M, K, N, label) + + bf16_base = r['DSV3'] if r['DSV3'] == r['DSV3'] else r['cuBLAS'] + fp8_base = r['smm'] + + parts = [] + for c in all_keys: + v = r[c] + if v != v: + parts.append(f"{'N/A':>9s}") + continue + base = bf16_base if c in bf16_keys else fp8_base + sp = base / v if v > 0 else 0 + if c == 'cuBLAS' or c == 'smm': + parts.append(f" {v:5.1f} ") + else: + parts.append(f" {v:4.1f}({sp:.2f})") + + vb = {k: r[k] for k in bf16_keys if r[k] == r[k]} + vf = {k: r[k] for k in fp8_keys if r[k] == r[k]} + best_b = min(vb, key=vb.get) if vb else '?' + best_f = min(vf, key=vf.get) if vf else '?' + + print(f" {M:2d} |" + "".join(parts) + f" bf16={best_b} fp8={best_f}", flush=True) + print() diff --git a/benchmarks/kernels/bench_ll_router_gemm.py b/benchmarks/kernels/bench_ll_router_gemm.py new file mode 100644 index 000000000000..54d04b467895 --- /dev/null +++ b/benchmarks/kernels/bench_ll_router_gemm.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import os + +import torch +from triton.testing import do_bench_cudagraph as _do_bench_cg + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import ( + ll_router_gemm, +) +from vllm.triton_utils import triton + +_HAS_DSV3 = hasattr(ops, "dsv3_router_gemm") + +_providers = ["ll-router-bf16", "ll-router-fp8", "cublas-bf16"] +if _HAS_DSV3: + _providers.append("dsv3-trtllm") + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M"], + x_vals=[1, 2, 4, 8, 16], + x_log=False, + line_arg="provider", + line_vals=_providers, + line_names=_providers, + ylabel="Latency (us, lower is better)", + plot_name="LL Router GEMM", + args={}, + ) +) +def benchmark(M, provider, N, K): + device = "cuda" + quantiles = [0.5, 0.2, 0.8] + + if provider == "ll-router-bf16": + a = torch.randn(M, K, dtype=torch.bfloat16, device=device) + b = torch.randn(N, K, dtype=torch.bfloat16, device=device) + if args.l2_pollute: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: (torch.mm(_ap, _wp.T, out=_op), ll_router_gemm(a, b)), quantiles=quantiles + ) + ms -= _l2_tp; min_ms -= _l2_tp; max_ms -= _l2_tp + else: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: ll_router_gemm(a, b), quantiles=quantiles + ) + + elif provider == "ll-router-fp8": + a = torch.randn(M, K, device=device).to(torch.float8_e4m3fn) + b = torch.randn(N, K, device=device).to(torch.float8_e4m3fn) + if args.l2_pollute: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: (torch.mm(_ap, _wp.T, out=_op), ll_router_gemm(a, b)), quantiles=quantiles + ) + ms -= _l2_tp; min_ms -= _l2_tp; max_ms -= _l2_tp + else: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: ll_router_gemm(a, b), quantiles=quantiles + ) + + elif provider == "cublas-bf16": + a = torch.randn(M, K, dtype=torch.bfloat16, device=device) + b = torch.randn(N, K, dtype=torch.bfloat16, device=device) + out = torch.empty(M, N, dtype=torch.bfloat16, device=device) + if args.l2_pollute: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: (torch.mm(_ap, _wp.T, out=_op), torch.mm(a, b.T, out=out)), quantiles=quantiles + ) + ms -= _l2_tp; min_ms -= _l2_tp; max_ms -= _l2_tp + else: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.mm(a, b.T, out=out), quantiles=quantiles + ) + + elif provider == "dsv3-trtllm": + # DSV3 only supports N∈{256,384}, K=7168 + if N not in (256, 384) or K != 7168: + return float("nan"), float("nan"), float("nan") + from vllm import _custom_ops as ops + + a = torch.randn(M, K, dtype=torch.bfloat16, device=device) + b = torch.randn(N, K, dtype=torch.bfloat16, device=device) + if args.l2_pollute: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: (torch.mm(_ap, _wp.T, out=_op), ops.dsv3_router_gemm(a, b, torch.float32)), quantiles=quantiles + ) + ms -= _l2_tp; min_ms -= _l2_tp; max_ms -= _l2_tp + else: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: ops.dsv3_router_gemm(a, b, torch.float32), quantiles=quantiles + ) + + # Return latency in us + return ms * 1000, min_ms * 1000, max_ms * 1000 + + +SHAPES = [ + (256, 7168, "DSV3 router"), + (256, 2048, "Small K"), + (128, 5120, "DeepSeek V2"), + (8, 4096, "Mixtral-8x7B"), + (64, 2880, "Non-aligned K"), +] + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--save-path", type=str, default=None) + parser.add_argument("--l2-pollute", action="store_true", + help="Prefix each kernel with a large matmul to pollute L2") + args = parser.parse_args() + + print(f"Device: {torch.cuda.get_device_name()}") + + q_l2 = [0.5, 0.2, 0.8] + if args.l2_pollute: + _wp = torch.randn(2048, 7168, dtype=torch.bfloat16, device="cuda") + _ap = torch.randn(1, 7168, dtype=torch.bfloat16, device="cuda") + _op = torch.empty(1, 2048, dtype=torch.bfloat16, device="cuda") + _l2_tp = _do_bench_cg(lambda: torch.mm(_ap, _wp.T, out=_op), rep=200, quantiles=q_l2)[0] + print(f"L2 pollution: ON (prefix overhead: {_l2_tp*1000:.1f}us)") + else: + _l2_tp = 0 + print("L2 pollution: OFF") + print() + + for N, K, desc in SHAPES: + print(f"{desc}, N={N} K={K}:") + save_dir = args.save_path or f"bench_ll_router_n{N}_k{K}" + os.makedirs(save_dir, exist_ok=True) + benchmark.run( + print_data=True, + show_plots=False, + save_path=save_dir, + N=N, + K=K, + ) + print() diff --git a/tests/kernels/test_ll_a_gemm.py b/tests/kernels/test_ll_a_gemm.py new file mode 100644 index 000000000000..a69feddc4ad3 --- /dev/null +++ b/tests/kernels/test_ll_a_gemm.py @@ -0,0 +1,361 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +import torch.nn.functional as F + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA required" +) + + +@pytest.fixture(autouse=True, scope="module") +def _check_cutedsl(): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import is_available + if not is_available(): + pytest.skip("cuteDSL (CUTLASS Python) not installed") + + +# ===== Helpers ===== + +def _to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + amax = x.abs().amax().clamp(min=1e-12) + scale = finfo.max / amax + return (x * scale).clamp(min=finfo.min, max=finfo.max).to(dtype), scale.float().reciprocal() + + +def _assert_correct(out, ref, min_cos_sim=0.99, context=""): + assert out.device.type == "cuda", f"{context}: output not on CUDA" + assert torch.isfinite(out).all(), f"{context}: output contains NaN/Inf" + cos = F.cosine_similarity( + out.reshape(-1).float(), ref.reshape(-1).float(), dim=0 + ).item() + abs_err = (out.float() - ref.float()).abs().max().item() + msg = (f"{context}: cosine similarity {cos:.4f} < {min_cos_sim} " + f"(abs_err={abs_err:.2e})") + assert cos > min_cos_sim, msg + + +def _ref_bf16(a, b): + return torch.mm(a.float(), b.float().T).to(torch.bfloat16) + + +def _ref_fp8(a_fp8, b_fp8): + s = torch.ones(1, device="cuda", dtype=torch.float32) + return torch._scaled_mm( + a_fp8, b_fp8.T.contiguous(), scale_a=s, scale_b=s, out_dtype=torch.bfloat16 + ) + + +# ===== Shapes ===== + +SHAPES_BF16 = [ + (7168, 2112, "a_proj"), + (7168, 576, "kv_a_proj"), + (7168, 1536, "q_a_proj"), + (1536, 3072, "q_b_TP8"), + (512, 4096, "kv_b_TP8"), + (512, 16, "small_N"), + (256, 8, "tiny"), +] + +SHAPES_FP8 = [ + (7168, 2112, "a_proj"), + (7168, 576, "kv_a_proj"), + (1536, 3072, "q_b_TP8"), + (512, 4096, "kv_b_TP8"), + (12288, 3072, "Mistral_Q"), + (12288, 256, "Mistral_KV"), + (12288, 128, "Mistral_KV8"), +] + + +# ===== bf16 peeled (cp.async) ===== + +@pytest.mark.parametrize("M", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("K,N,desc", SHAPES_BF16, ids=[s[2] for s in SHAPES_BF16]) +def test_bf16_peeled(M, K, N, desc): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm + torch.manual_seed(42) + a = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + b = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + out = ll_a_gemm(a, b) + ref = _ref_bf16(a, b) + assert out.dtype == torch.bfloat16 + assert out.shape == (M, N) + _assert_correct(out, ref, context=f"bf16_peeled M={M} {desc}") + + +# ===== bf16 TMA ===== + +@pytest.mark.parametrize("M", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("K,N,desc", SHAPES_BF16, ids=[s[2] for s in SHAPES_BF16]) +def test_bf16_tma(M, K, N, desc): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm_tma import ll_a_gemm_tma + torch.manual_seed(42) + a = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + b = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + out = ll_a_gemm_tma(a, b) + ref = _ref_bf16(a, b) + assert out.shape == (M, N) + _assert_correct(out, ref, context=f"bf16_tma M={M} {desc}") + + +# ===== FP8 peeled ===== + +@pytest.mark.parametrize("M", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("K,N,desc", SHAPES_FP8, ids=[s[2] for s in SHAPES_FP8]) +def test_fp8_peeled(M, K, N, desc): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm + torch.manual_seed(42) + a_fp8, _ = _to_float8(torch.randn(M, K, device="cuda")) + b_fp8, _ = _to_float8(torch.randn(N, K, device="cuda")) + out = ll_a_gemm(a_fp8.view(torch.bfloat16), b_fp8.view(torch.bfloat16), is_fp8=True) + ref = _ref_fp8(a_fp8, b_fp8) + assert out.shape == (M, N) + _assert_correct(out, ref, min_cos_sim=0.98, context=f"fp8_peeled M={M} {desc}") + + +# ===== FP8 TMA ===== + +@pytest.mark.parametrize("M", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("K,N,desc", SHAPES_FP8, ids=[s[2] for s in SHAPES_FP8]) +def test_fp8_tma(M, K, N, desc): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm_tma import ll_a_gemm_tma + torch.manual_seed(42) + a_fp8, _ = _to_float8(torch.randn(M, K, device="cuda")) + b_fp8, _ = _to_float8(torch.randn(N, K, device="cuda")) + out = ll_a_gemm_tma(a_fp8.view(torch.bfloat16), b_fp8.view(torch.bfloat16), is_fp8=True) + ref = _ref_fp8(a_fp8, b_fp8) + assert out.shape == (M, N) + _assert_correct(out, ref, min_cos_sim=0.98, context=f"fp8_tma M={M} {desc}") + + +# ===== Split-K FP8 ===== + +@pytest.mark.parametrize("M", [1, 4, 8]) +@pytest.mark.parametrize("sk,ns", [(2, 2), (4, 2), (4, 4), (8, 3)], + ids=["sk2_ns2", "sk4_ns2", "sk4_ns4", "sk8_ns3"]) +@pytest.mark.parametrize("K,N,desc", SHAPES_FP8, ids=[s[2] for s in SHAPES_FP8]) +def test_splitk_fp8(M, K, N, desc, sk, ns): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import _get_compiled_splitk + from cuda.bindings.driver import CUstream + from torch.cuda import current_stream + + K_view = K // 2 + tiles = K_view // 256 + if tiles % sk != 0 or ns > tiles // sk: + pytest.skip(f"tiles={tiles} incompatible with sk={sk} ns={ns}") + + torch.manual_seed(42) + a_fp8, _ = _to_float8(torch.randn(M, K, device="cuda")) + b_fp8, _ = _to_float8(torch.randn(N, K, device="cuda")) + a8 = a_fp8.view(torch.bfloat16) + b8 = b_fp8.view(torch.bfloat16) + ref = _ref_fp8(a_fp8, b_fp8) + + out = torch.empty(N*M, dtype=torch.bfloat16, device="cuda").view(N, M) + compiled = _get_compiled_splitk(True, True, b8, a8, out, sk, ns) + compiled(b8, a8, out, CUstream(current_stream().cuda_stream), 1.0) + torch.cuda.synchronize() + + result = out.view(M, N) + assert result.shape == (M, N) + _assert_correct(result, ref, min_cos_sim=0.98, context=f"sk{sk}_ns{ns} M={M} {desc}") + + +# ===== Cross-validation: peeled vs TMA must agree ===== + +@pytest.mark.parametrize("M", [1, 4, 16]) +@pytest.mark.parametrize("K,N", [(7168, 2112), (512, 4096)]) +def test_peeled_vs_tma_bf16(M, K, N): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm_tma import ll_a_gemm_tma + torch.manual_seed(42) + a = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + b = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + out_p = ll_a_gemm(a, b) + out_t = ll_a_gemm_tma(a, b) + _assert_correct(out_p, out_t, min_cos_sim=0.999, context=f"peeled_vs_tma M={M}") + + +# ===== Cross-validation: split-K vs non-split-K FP8 ===== + +@pytest.mark.parametrize("M", [1, 4]) +def test_splitk_vs_nosplitk_fp8(M): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm, _get_compiled_splitk + from cuda.bindings.driver import CUstream + from torch.cuda import current_stream + + K, N = 12288, 256 + torch.manual_seed(42) + a_fp8, _ = _to_float8(torch.randn(M, K, device="cuda")) + b_fp8, _ = _to_float8(torch.randn(N, K, device="cuda")) + a8 = a_fp8.view(torch.bfloat16) + b8 = b_fp8.view(torch.bfloat16) + + out_nosplit = ll_a_gemm(a8, b8, is_fp8=True) + + out_sk = torch.empty(N*M, dtype=torch.bfloat16, device="cuda").view(N, M) + compiled = _get_compiled_splitk(True, True, b8, a8, out_sk, split_k=8, num_stages=3) + compiled(b8, a8, out_sk, CUstream(current_stream().cuda_stream), 1.0) + torch.cuda.synchronize() + + _assert_correct(out_sk.view(M, N), out_nosplit, min_cos_sim=0.999, + context=f"sk_vs_nosk M={M}") + + +# ===== Scale parameter ===== + +@pytest.mark.parametrize("scale", [0.5, 2.0, 0.01]) +def test_scale_bf16(scale): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm + torch.manual_seed(42) + a = torch.randn(4, 7168, dtype=torch.bfloat16, device="cuda") + b = torch.randn(2112, 7168, dtype=torch.bfloat16, device="cuda") + out_scaled = ll_a_gemm(a, b, scale=scale) + out_unscaled = ll_a_gemm(a, b, scale=1.0) + ref = out_unscaled * scale + _assert_correct(out_scaled, ref, min_cos_sim=0.999, context=f"scale={scale}") + + +@pytest.mark.parametrize("scale", [0.5, 2.0]) +def test_scale_fp8(scale): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm + torch.manual_seed(42) + a_fp8, _ = _to_float8(torch.randn(4, 7168, device="cuda")) + b_fp8, _ = _to_float8(torch.randn(2112, 7168, device="cuda")) + a8, b8 = a_fp8.view(torch.bfloat16), b_fp8.view(torch.bfloat16) + out_scaled = ll_a_gemm(a8, b8, is_fp8=True, scale=scale) + out_unscaled = ll_a_gemm(a8, b8, is_fp8=True, scale=1.0) + ref = out_unscaled * scale + _assert_correct(out_scaled, ref, min_cos_sim=0.999, context=f"fp8_scale={scale}") + + +# ===== Numerical edge cases ===== + +def test_large_values(): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm + torch.manual_seed(42) + a = torch.randn(4, 1536, dtype=torch.bfloat16, device="cuda") * 100 + b = torch.randn(256, 1536, dtype=torch.bfloat16, device="cuda") * 100 + out = ll_a_gemm(a, b) + ref = _ref_bf16(a, b) + assert torch.isfinite(out).all(), "Large values produced NaN/Inf" + _assert_correct(out, ref, min_cos_sim=0.99, context="large_values") + + +def test_near_zero_values(): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm + torch.manual_seed(42) + a = torch.randn(4, 1536, dtype=torch.bfloat16, device="cuda") * 1e-4 + b = torch.randn(256, 1536, dtype=torch.bfloat16, device="cuda") * 1e-4 + out = ll_a_gemm(a, b) + assert torch.isfinite(out).all(), "Near-zero values produced NaN/Inf" + assert out.abs().max() < 1.0, "Near-zero inputs should produce near-zero output" + + +# ===== Swapped vs non-swapped boundary ===== + +def test_swap_boundary(): + """M=8 (swapped) and M=9 (non-swapped) should both be correct.""" + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm + torch.manual_seed(42) + b = torch.randn(2112, 7168, dtype=torch.bfloat16, device="cuda") + for M in [8, 9]: + a = torch.randn(M, 7168, dtype=torch.bfloat16, device="cuda") + out = ll_a_gemm(a, b) + ref = _ref_bf16(a, b) + assert out.shape == (M, 2112) + _assert_correct(out, ref, context=f"swap_boundary M={M}") + + +# ===== Deterministic ===== + +@pytest.mark.parametrize("M", [1, 8]) +def test_deterministic(M): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm + torch.manual_seed(42) + a = torch.randn(M, 7168, dtype=torch.bfloat16, device="cuda") + b = torch.randn(2112, 7168, dtype=torch.bfloat16, device="cuda") + out1 = ll_a_gemm(a, b) + out2 = ll_a_gemm(a, b) + torch.testing.assert_close(out1, out2, atol=0, rtol=0) + + +# ===== CUDA graph ===== + +def test_cudagraph_peeled(): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm + torch.manual_seed(42) + a = torch.randn(4, 7168, dtype=torch.bfloat16, device="cuda") + b = torch.randn(2112, 7168, dtype=torch.bfloat16, device="cuda") + ll_a_gemm(a, b); torch.cuda.synchronize() + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + out = ll_a_gemm(a, b) + + for _ in range(5): + g.replay() + torch.cuda.synchronize() + + ref = _ref_bf16(a, b) + _assert_correct(out, ref, context="CG_peeled") + + +def test_cudagraph_splitk(): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import _get_compiled_splitk + from cuda.bindings.driver import CUstream + from torch.cuda import current_stream + + torch.manual_seed(42) + a_fp8, _ = _to_float8(torch.randn(1, 12288, device="cuda")) + b_fp8, _ = _to_float8(torch.randn(256, 12288, device="cuda")) + a8, b8 = a_fp8.view(torch.bfloat16), b_fp8.view(torch.bfloat16) + ref = _ref_fp8(a_fp8, b_fp8) + + out = torch.empty(256*1, dtype=torch.bfloat16, device="cuda").view(256, 1) + compiled = _get_compiled_splitk(True, True, b8, a8, out, split_k=8, num_stages=3) + compiled(b8, a8, out, CUstream(current_stream().cuda_stream), 1.0) + torch.cuda.synchronize() + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + compiled(b8, a8, out, CUstream(current_stream().cuda_stream), 1.0) + + for _ in range(5): + g.replay() + torch.cuda.synchronize() + + result = out.view(1, 256) + _assert_correct(result, ref, min_cos_sim=0.98, context="CG_splitk") + + +def test_cudagraph_repeated_replay(): + """Many replays should produce consistent results (catches counter bugs).""" + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm + torch.manual_seed(42) + a = torch.randn(4, 7168, dtype=torch.bfloat16, device="cuda") + b = torch.randn(2112, 7168, dtype=torch.bfloat16, device="cuda") + ll_a_gemm(a, b); torch.cuda.synchronize() + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + out = ll_a_gemm(a, b) + + results = [] + for _ in range(20): + g.replay() + torch.cuda.synchronize() + results.append(out.clone()) + + for i, r in enumerate(results[1:], 1): + torch.testing.assert_close(results[0], r, atol=0, rtol=0, + msg=f"Replay {i} differs from replay 0") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/kernels/test_ll_router_gemm.py b/tests/kernels/test_ll_router_gemm.py new file mode 100644 index 000000000000..f38abe9d1509 --- /dev/null +++ b/tests/kernels/test_ll_router_gemm.py @@ -0,0 +1,218 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +import torch.nn.functional as F + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA required" +) + + +@pytest.fixture(autouse=True, scope="module") +def _check_cutedsl(): + from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import is_available + if not is_available(): + pytest.skip("cuteDSL (CUTLASS Python) not installed") + + +# ===== Helpers ===== + +def _to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + amax = x.abs().amax().clamp(min=1e-12) + scale = finfo.max / amax + return (x * scale).clamp(min=finfo.min, max=finfo.max).to(dtype), scale.float().reciprocal() + + +def _assert_correct(out, ref, min_cos_sim=0.99, context=""): + assert out.device.type == "cuda", f"{context}: output not on CUDA" + assert torch.isfinite(out).all(), f"{context}: output contains NaN/Inf" + cos = F.cosine_similarity( + out.reshape(-1).float(), ref.reshape(-1).float(), dim=0 + ).item() + abs_err = (out.float() - ref.float()).abs().max().item() + msg = (f"{context}: cosine similarity {cos:.4f} < {min_cos_sim} " + f"(abs_err={abs_err:.2e})") + assert cos > min_cos_sim, msg + + +def _ref(a, b): + return torch.mm(a.float(), b.float().T) + + +# ===== Shapes ===== + +SHAPES = [ + (256, 7168, "DSV3_router"), + (256, 2048, "small_K"), + (128, 5120, "DeepSeek_V2"), + (8, 4096, "Mixtral"), + (64, 2880, "non_aligned_K"), +] + + +# ===== bf16 correctness ===== + +@pytest.mark.parametrize("M", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("N,K,desc", SHAPES, ids=[s[2] for s in SHAPES]) +def test_bf16(M, N, K, desc): + from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import ll_router_gemm + torch.manual_seed(42) + a = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + b = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + out = ll_router_gemm(a, b) + ref = _ref(a, b) + assert out.dtype == torch.float32 + assert out.shape == (M, N) + _assert_correct(out, ref, context=f"bf16 {M}x{N}x{K}") + + +# ===== FP8 correctness ===== + +@pytest.mark.parametrize("M", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("N,K,desc", + [(n, k, d) for n, k, d in SHAPES if k % 2 == 0], + ids=[s[2] for s in SHAPES if s[1] % 2 == 0]) +def test_fp8(M, N, K, desc): + from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import ll_router_gemm + torch.manual_seed(42) + a_fp8, _ = _to_float8(torch.randn(M, K, device="cuda")) + b_fp8, _ = _to_float8(torch.randn(N, K, device="cuda")) + out = ll_router_gemm(a_fp8, b_fp8) + ref = _ref(a_fp8, b_fp8) + assert out.dtype == torch.float32 + assert out.shape == (M, N) + _assert_correct(out, ref, min_cos_sim=0.98, context=f"fp8 {M}x{N}x{K}") + + +# ===== Cross-validation: bf16 vs fp8 on same data ===== + +@pytest.mark.parametrize("M", [1, 4]) +def test_bf16_vs_fp8_agreement(M): + """bf16 and fp8 kernels on the same underlying data should roughly agree.""" + from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import ll_router_gemm + torch.manual_seed(42) + K, N = 4096, 128 + a_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + b_bf16 = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + out_bf16 = ll_router_gemm(a_bf16, b_bf16) + + a_fp8 = a_bf16.to(torch.float8_e4m3fn) + b_fp8 = b_bf16.to(torch.float8_e4m3fn) + out_fp8 = ll_router_gemm(a_fp8, b_fp8) + + _assert_correct(out_fp8, out_bf16, min_cos_sim=0.95, + context=f"bf16_vs_fp8 M={M}") + + +# ===== Arbitrary N ===== + +@pytest.mark.parametrize("N", [1, 3, 7, 17, 64, 256]) +def test_arbitrary_N(N): + from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import ll_router_gemm + torch.manual_seed(42) + a = torch.randn(4, 2048, dtype=torch.bfloat16, device="cuda") + b = torch.randn(N, 2048, dtype=torch.bfloat16, device="cuda") + out = ll_router_gemm(a, b) + ref = _ref(a, b) + assert out.shape == (4, N) + _assert_correct(out, ref, context=f"N={N}") + + +# ===== Arbitrary K ===== + +@pytest.mark.parametrize("K", [64, 128, 256, 512, 1024, 2048, 4096, 7168]) +def test_arbitrary_K(K): + from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import ll_router_gemm + torch.manual_seed(42) + a = torch.randn(4, K, dtype=torch.bfloat16, device="cuda") + b = torch.randn(32, K, dtype=torch.bfloat16, device="cuda") + out = ll_router_gemm(a, b) + ref = _ref(a, b) + _assert_correct(out, ref, context=f"K={K}") + + +# ===== Numerical edge cases ===== + +def test_large_values(): + from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import ll_router_gemm + torch.manual_seed(42) + a = torch.randn(4, 2048, dtype=torch.bfloat16, device="cuda") * 100 + b = torch.randn(64, 2048, dtype=torch.bfloat16, device="cuda") * 100 + out = ll_router_gemm(a, b) + ref = _ref(a, b) + assert torch.isfinite(out).all(), "Large values produced NaN/Inf" + _assert_correct(out, ref, context="large_values") + + +def test_near_zero_values(): + from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import ll_router_gemm + torch.manual_seed(42) + a = torch.randn(4, 2048, dtype=torch.bfloat16, device="cuda") * 1e-4 + b = torch.randn(64, 2048, dtype=torch.bfloat16, device="cuda") * 1e-4 + out = ll_router_gemm(a, b) + assert torch.isfinite(out).all(), "Near-zero values produced NaN/Inf" + assert out.abs().max() < 1.0, "Near-zero inputs should produce near-zero output" + + +# ===== Deterministic ===== + +@pytest.mark.parametrize("M", [1, 16]) +def test_deterministic(M): + from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import ll_router_gemm + torch.manual_seed(42) + a = torch.randn(M, 4096, dtype=torch.bfloat16, device="cuda") + b = torch.randn(128, 4096, dtype=torch.bfloat16, device="cuda") + out1 = ll_router_gemm(a, b) + out2 = ll_router_gemm(a, b) + torch.testing.assert_close(out1, out2, atol=0, rtol=0) + + +# ===== CUDA graph ===== + +def test_cudagraph(): + from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import ll_router_gemm + torch.manual_seed(42) + a = torch.randn(4, 2048, dtype=torch.bfloat16, device="cuda") + b = torch.randn(64, 2048, dtype=torch.bfloat16, device="cuda") + ll_router_gemm(a, b); torch.cuda.synchronize() + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + out = ll_router_gemm(a, b) + + for _ in range(5): + g.replay() + torch.cuda.synchronize() + + ref = _ref(a, b) + _assert_correct(out, ref, context="cudagraph") + + +def test_cudagraph_repeated_replay(): + """Many replays should produce identical results.""" + from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import ll_router_gemm + torch.manual_seed(42) + a = torch.randn(4, 4096, dtype=torch.bfloat16, device="cuda") + b = torch.randn(256, 4096, dtype=torch.bfloat16, device="cuda") + ll_router_gemm(a, b); torch.cuda.synchronize() + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + out = ll_router_gemm(a, b) + + results = [] + for _ in range(20): + g.replay() + torch.cuda.synchronize() + results.append(out.clone()) + + for i, r in enumerate(results[1:], 1): + torch.testing.assert_close(results[0], r, atol=0, rtol=0, + msg=f"Replay {i} differs from replay 0") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 09b9a557fe45..18b03f3a3c19 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -197,6 +197,7 @@ def call_trtllm_fused_allreduce_norm( layout_code=layout_code, use_oneshot=use_oneshot, fp32_acc=fp32_acc, + trigger_completion_at_end=False, ) def call_trtllm_fused_allreduce_norm_fake( diff --git a/vllm/model_executor/kernels/linear/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/kernels/linear/scaled_mm/ScaledMMLinearKernel.py index b9f6f0c8f873..35b9b72d7f01 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/ScaledMMLinearKernel.py @@ -16,6 +16,7 @@ from ..base import MMLinearLayerConfig +from vllm.model_executor.layers.utils import _ll_gemm_fp8_scaled_mm @dataclass class Int8ScaledMMLinearLayerConfig(MMLinearLayerConfig): @@ -144,6 +145,12 @@ def apply_weights( x_s, x_s_ub, ) + + # Use LL FP8 GEMM for small M + ll_out = _ll_gemm_fp8_scaled_mm(x_2d_q, w, x_s, w_s, bias, output_shape) + if ll_out is not None: + return ll_out + return self.apply_scaled_mm( A=x_2d_q, B=w, diff --git a/vllm/model_executor/layers/fused_moe/router/_ll_a_gemm_kernels.py b/vllm/model_executor/layers/fused_moe/router/_ll_a_gemm_kernels.py new file mode 100644 index 000000000000..1827e588202b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/_ll_a_gemm_kernels.py @@ -0,0 +1,1374 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""CuteDSL tiled A GEMM: C[M,N] = A[M,K] @ B[N,K]^T.""" + +import math + +import cutlass +import cutlass.cute as cute +from cuda.bindings.driver import CUstream +from cutlass._mlir import ir as _ir +from cutlass._mlir.dialects import arith as _arith +from cutlass._mlir.dialects import llvm as _llvm +from cutlass.cutlass_dsl import dsl_user_op +from cutlass.pipeline import sm90 as pipeline + + +@dsl_user_op +def bf16x2_pack(lo, hi, *, loc=None, ip=None): + """Pack 2 bf16 → 1 uint32 via vector insert + bitcast.""" + lo_ir = lo.ir_value(loc=loc, ip=ip) + hi_ir = hi.ir_value(loc=loc, ip=ip) + bf16_ty = lo_ir.type + vec_ty = _ir.VectorType.get([2], bf16_ty) + i32 = _ir.IntegerType.get_signless(32) + c0 = _arith.constant(i32, 0, loc=loc, ip=ip) + c1 = _arith.constant(i32, 1, loc=loc, ip=ip) + undef = _llvm.mlir_undef(vec_ty, loc=loc, ip=ip) + v0 = _llvm.insertelement(vec_ty, undef, lo_ir, c0, loc=loc, ip=ip) + v1 = _llvm.insertelement(vec_ty, v0, hi_ir, c1, loc=loc, ip=ip) + packed = _llvm.bitcast(i32, v1, loc=loc, ip=ip) + return cutlass.Uint32(packed) + + +@dsl_user_op +def mma_e4m3(d0, d1, d2, d3, a0, a1, a2, a3, b0, b1, *, loc=None, ip=None): + """mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32""" + f32 = cutlass.Float32.mlir_type + args = [ + a0.ir_value(loc=loc, ip=ip), + a1.ir_value(loc=loc, ip=ip), + a2.ir_value(loc=loc, ip=ip), + a3.ir_value(loc=loc, ip=ip), + b0.ir_value(loc=loc, ip=ip), + b1.ir_value(loc=loc, ip=ip), + d0.ir_value(loc=loc, ip=ip), + d1.ir_value(loc=loc, ip=ip), + d2.ir_value(loc=loc, ip=ip), + d3.ir_value(loc=loc, ip=ip), + ] + res = _llvm.inline_asm( + _llvm.StructType.get_literal([f32, f32, f32, f32]), + args, + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{$0,$1,$2,$3},{$4,$5,$6,$7},{$8,$9},{$10,$11,$12,$13};", + "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", + has_side_effects=True, + loc=loc, + ip=ip, + ) + r0 = _llvm.extractvalue(f32, res, [0], loc=loc, ip=ip) + r1 = _llvm.extractvalue(f32, res, [1], loc=loc, ip=ip) + r2 = _llvm.extractvalue(f32, res, [2], loc=loc, ip=ip) + r3 = _llvm.extractvalue(f32, res, [3], loc=loc, ip=ip) + return ( + cutlass.Float32(r0), + cutlass.Float32(r1), + cutlass.Float32(r2), + cutlass.Float32(r3), + ) + + +def _pack2(lo, hi, *, loc=None, ip=None): + """Pack 2 bf16 → 1 uint32 via vector insert + bitcast. + + Uses LLVM vector ops instead of scalar integer ops. + LLVM instcombine folds insert(extract(vec,0), extract(vec,1))→vec + back to the original i32 register from ldmatrix. + """ + bf16_ty = lo.type + vec_ty = _ir.VectorType.get([2], bf16_ty) + i32 = _ir.IntegerType.get_signless(32) + c0 = _arith.constant(i32, 0, loc=loc, ip=ip) + c1 = _arith.constant(i32, 1, loc=loc, ip=ip) + undef = _llvm.mlir_undef(vec_ty, loc=loc, ip=ip) + v0 = _llvm.insertelement(undef, lo, c0, loc=loc, ip=ip) + v1 = _llvm.insertelement(v0, hi, c1, loc=loc, ip=ip) + return _llvm.bitcast(i32, v1, loc=loc, ip=ip) + + +@dsl_user_op +def fused_fp8_mma_2n( + c0, + c1, + c2, + c3, + c4, + c5, + c6, + c7, + a0_lo, + a0_hi, + a1_lo, + a1_hi, + a2_lo, + a2_hi, + a3_lo, + a3_hi, + b0_lo, + b0_hi, + b1_lo, + b1_hi, + b2_lo, + b2_hi, + b3_lo, + b3_hi, + *, + loc=None, + ip=None, +): + """Fused: pack bf16 pairs + 2x mma.sync.m16n8k32.e4m3 (both N-atoms).""" + f32 = cutlass.Float32.mlir_type + + # Pack A: 8 bf16 → 4 uint32 + a0 = _pack2( + a0_lo.ir_value(loc=loc, ip=ip), a0_hi.ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + a1 = _pack2( + a1_lo.ir_value(loc=loc, ip=ip), a1_hi.ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + a2 = _pack2( + a2_lo.ir_value(loc=loc, ip=ip), a2_hi.ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + a3 = _pack2( + a3_lo.ir_value(loc=loc, ip=ip), a3_hi.ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + + # Pack B N-atom 0: 4 bf16 → 2 uint32 + bn0 = _pack2( + b0_lo.ir_value(loc=loc, ip=ip), b0_hi.ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + bn1 = _pack2( + b1_lo.ir_value(loc=loc, ip=ip), b1_hi.ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + + # MMA N-atom 0 + r0 = _llvm.inline_asm( + _llvm.StructType.get_literal([f32, f32, f32, f32]), + [ + a0, + a1, + a2, + a3, + bn0, + bn1, + c0.ir_value(loc=loc, ip=ip), + c1.ir_value(loc=loc, ip=ip), + c2.ir_value(loc=loc, ip=ip), + c3.ir_value(loc=loc, ip=ip), + ], + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{$0,$1,$2,$3},{$4,$5,$6,$7},{$8,$9},{$10,$11,$12,$13};", + "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", + has_side_effects=True, + loc=loc, + ip=ip, + ) + + # Pack B N-atom 1 + bn2 = _pack2( + b2_lo.ir_value(loc=loc, ip=ip), b2_hi.ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + bn3 = _pack2( + b3_lo.ir_value(loc=loc, ip=ip), b3_hi.ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + + # MMA N-atom 1 + r1 = _llvm.inline_asm( + _llvm.StructType.get_literal([f32, f32, f32, f32]), + [ + a0, + a1, + a2, + a3, + bn2, + bn3, + c4.ir_value(loc=loc, ip=ip), + c5.ir_value(loc=loc, ip=ip), + c6.ir_value(loc=loc, ip=ip), + c7.ir_value(loc=loc, ip=ip), + ], + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{$0,$1,$2,$3},{$4,$5,$6,$7},{$8,$9},{$10,$11,$12,$13};", + "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", + has_side_effects=True, + loc=loc, + ip=ip, + ) + + return ( + cutlass.Float32(_llvm.extractvalue(f32, r0, [0], loc=loc, ip=ip)), + cutlass.Float32(_llvm.extractvalue(f32, r0, [1], loc=loc, ip=ip)), + cutlass.Float32(_llvm.extractvalue(f32, r0, [2], loc=loc, ip=ip)), + cutlass.Float32(_llvm.extractvalue(f32, r0, [3], loc=loc, ip=ip)), + cutlass.Float32(_llvm.extractvalue(f32, r1, [0], loc=loc, ip=ip)), + cutlass.Float32(_llvm.extractvalue(f32, r1, [1], loc=loc, ip=ip)), + cutlass.Float32(_llvm.extractvalue(f32, r1, [2], loc=loc, ip=ip)), + cutlass.Float32(_llvm.extractvalue(f32, r1, [3], loc=loc, ip=ip)), + ) + + + +@dsl_user_op +def cluster_arrive_relaxed(*, loc=None, ip=None): + """barrier.cluster.arrive.aligned""" + i32 = _ir.IntegerType.get_signless(32) + _llvm.inline_asm(i32, [], + "barrier.cluster.arrive.aligned; mov.u32 $0, 0;", + "=r", has_side_effects=True, loc=loc, ip=ip) + + +@dsl_user_op +def cluster_wait(*, loc=None, ip=None): + """barrier.cluster.wait.aligned""" + i32 = _ir.IntegerType.get_signless(32) + _llvm.inline_asm(i32, [], + "barrier.cluster.wait.aligned; mov.u32 $0, 0;", + "=r", has_side_effects=True, loc=loc, ip=ip) + + +@dsl_user_op +def set_block_rank(smem_ptr, peer_rank, *, loc=None, ip=None): + """mapa.shared::cluster.u32""" + i32 = _ir.IntegerType.get_signless(32) + ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + rank_ir = peer_rank.ir_value(loc=loc, ip=ip) + res = _llvm.inline_asm( + i32, [ptr_i32, rank_ir], + "mapa.shared::cluster.u32 $0, $1, $2;", + "=r,r,r", has_side_effects=False, loc=loc, ip=ip) + return cutlass.Int32(res) + + +@dsl_user_op +def st_shared_remote_f32(remote_addr, val, *, loc=None, ip=None): + """st.shared::cluster.f32""" + i32 = _ir.IntegerType.get_signless(32) + addr_ir = remote_addr.ir_value(loc=loc, ip=ip) + val_ir = val.ir_value(loc=loc, ip=ip) + _llvm.inline_asm(i32, [addr_ir, val_ir], + "st.shared::cluster.f32 [$0], $1; mov.u32 $2, 0;", + "r,f,=r", has_side_effects=True, loc=loc, ip=ip) + + +@dsl_user_op +def atom_add_global_f32(addr, val, *, loc=None, ip=None): + """atom.global.add.f32 — global memory fp32 atomic add.""" + i32 = _ir.IntegerType.get_signless(32) + i64 = _ir.IntegerType.get_signless(64) + addr_int = addr.toint(loc=loc, ip=ip) + val_ir = val.ir_value(loc=loc, ip=ip) + _llvm.inline_asm(i32, [addr_int, val_ir], + "atom.global.add.f32 $0, [$1], $2; mov.u32 $3, 0;", + "=f,l,f,=r", has_side_effects=True, loc=loc, ip=ip) + + +class LLAGemm: + """Warp-specialized low-latency A GEMM with PipelineCpAsync.""" + + def __init__( + self, + ab_dtype=cutlass.BFloat16, + acc_dtype=cutlass.Float32, + out_dtype=cutlass.BFloat16, + tile_n: int = 32, + tile_k: int = 512, + num_stages: int = 3, + is_fp8: bool = False, + num_dma_warps: int = 4, + split_k: int = 1, + transpose_output: bool = False, + ): + self.ab_dtype = ab_dtype + self.acc_dtype = acc_dtype + self.out_dtype = out_dtype + self.tile_m = 16 + # min tile_n = 1*8*2 = 16 + self.tile_n = tile_n + self.tile_k = tile_k + self.num_stages = num_stages + self.is_fp8 = is_fp8 + self.split_k = split_k + self.transpose_output = transpose_output + self.mma_shape = (16, 8, 16) + # (1,1,1) = 32 threads per MMA warp + # 4 MMA warps doing k-phase interleaving on tile_n=16 + self.atom_layout = (1, 1, 1) + self.num_mma_warps = 4 + self.num_dma_threads = num_dma_warps * 32 + self.num_mma_threads = self.num_mma_warps * 32 # 128 + self.num_threads = self.num_dma_threads + self.num_mma_threads + + def _make_smem_layout_AB(self, dtype, copy_bits, smem_tiler): + major_size = min(smem_tiler[1], 64) + swizzle_bits = int(math.log2(major_size * dtype.width // copy_bits)) + swizzle_bits = min(swizzle_bits, 3) + layout_atom_outer = cute.make_layout((8, major_size), stride=(major_size, 1)) + layout_atom = cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, 3, 3), 0, layout_atom_outer + ) + return cute.tile_to_shape(layout_atom, smem_tiler, (0, 1, 2)) + + def _make_smem_layout_C(self, dtype, copy_bits, smem_tiler): + return cute.make_layout(smem_tiler, stride=(smem_tiler[1], 1)) + + def _make_gmem_tiled_copy(self, atom_copy, dtype, copy_bits, num_threads): + copy_elems = copy_bits // dtype.width + k_threads = cute.size(self.tile_k) // copy_elems + thread_layout = cute.make_layout( + (num_threads // k_threads, k_threads), stride=(k_threads, 1) + ) + value_layout = cute.make_layout((1, copy_elems)) + return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout) + + @cute.jit + def __call__( + self, mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, stream: CUstream, + scale: float = 1.0, + ): + bM, bN, bK = self.tile_m, self.tile_n, self.tile_k + copy_bits = 128 + + sA_layout = self._make_smem_layout_AB( + mA.element_type, copy_bits, (bM, bK, self.num_stages) + ) + sB_layout = self._make_smem_layout_AB( + mB.element_type, copy_bits, (bN, bK, self.num_stages) + ) + sC_layout = self._make_smem_layout_C(mC.element_type, copy_bits, (bM, bN)) + + atom_g2s = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp( + cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL + ), + mA.element_type, + num_bits_per_copy=copy_bits, + ) + tiled_copy_A = self._make_gmem_tiled_copy( + atom_g2s, mA.element_type, copy_bits, self.num_dma_threads + ) + tiled_copy_B = self._make_gmem_tiled_copy( + atom_g2s, mB.element_type, copy_bits, self.num_dma_threads + ) + + atom_s2g = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), mC.element_type, num_bits_per_copy=copy_bits + ) + c_copy_elems = copy_bits // mC.element_type.width + cn_threads = bN // c_copy_elems + tiled_copy_C = cute.make_tiled_copy_tv( + atom_s2g, + cute.make_layout( + (self.num_mma_threads // cn_threads, cn_threads), stride=(cn_threads, 1) + ), + cute.make_layout((1, c_copy_elems)), + ) + + op = cute.nvgpu.warp.MmaF16BF16Op(self.ab_dtype, self.acc_dtype, self.mma_shape) + perm_mnk = ( + self.atom_layout[0] * self.mma_shape[0], + self.atom_layout[1] * self.mma_shape[1] * (self.tile_n // 8), + self.atom_layout[2] * self.mma_shape[2], + ) + tiled_mma = cute.make_tiled_mma( + op, cute.make_layout(self.atom_layout), permutation_mnk=perm_mnk + ) + + grid_m = cute.ceil_div(cute.size(mC, mode=[0]), bM) + grid_n = cute.ceil_div(cute.size(mC, mode=[1]), bN) + + self.kernel( + mA, + mB, + mC, + scale, + sA_layout, + sB_layout, + sC_layout, + tiled_copy_A, + tiled_copy_B, + tiled_copy_C, + tiled_mma, + ).launch( + grid=[cute.size(grid_m), cute.size(grid_n), 1], + block=[self.num_threads, 1, 1], + stream=stream, + use_pdl=True, + ) + + @cute.kernel + def kernel( + self, + mA, + mB, + mC, + scale: cutlass.Float32, + sA_layout: cute.ComposedLayout, + sB_layout: cute.ComposedLayout, + sC_layout: cute.Layout, + tiled_copy_A: cute.TiledCopy, + tiled_copy_B: cute.TiledCopy, + tiled_copy_C: cute.TiledCopy, + tiled_mma: cute.TiledMma, + ): + bM, bN, bK = self.tile_m, self.tile_n, self.tile_k + num_stages = self.num_stages + tidx, _, _ = cute.arch.thread_idx() + bid_m, bid_n, _ = cute.arch.block_idx() + + warp_idx = tidx // 32 + is_dma = warp_idx < (self.num_dma_threads // 32) + dma_tidx = tidx + mma_tidx = tidx - self.num_dma_threads + + cta_tiler = (bM, bN, bK) + coord = (bid_m, bid_n, None) + gA = cute.local_tile(mA, tiler=cta_tiler, coord=coord, proj=(1, None, 1)) + gB = cute.local_tile(mB, tiler=cta_tiler, coord=coord, proj=(None, 1, 1)) + gC = cute.local_tile(mC, tiler=cta_tiler, coord=coord, proj=(1, 1, None)) + gA = cute.make_tensor(gA.iterator.align(16), gA.layout) + gB = cute.make_tensor(gB.iterator.align(16), gB.layout) + + mcA = cute.make_identity_tensor(mA.layout.shape) + mcB = cute.make_identity_tensor(mB.layout.shape) + cA = cute.local_tile(mcA, tiler=cta_tiler, coord=coord, proj=(1, None, 1)) + cB = cute.local_tile(mcB, tiler=cta_tiler, coord=coord, proj=(None, 1, 1)) + + @cute.struct + class SharedStorage: + a: cute.struct.Align[ + cute.struct.MemRange[mA.element_type, cute.cosize(sA_layout)], 16 + ] + b: cute.struct.Align[ + cute.struct.MemRange[mB.element_type, cute.cosize(sB_layout)], 16 + ] + c: cute.struct.Align[ + cute.struct.MemRange[mC.element_type, cute.cosize(sC_layout)], 16 + ] + mbar: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int64, num_stages * 2], 8 + ] + + smem = cutlass.utils.SmemAllocator() + storage_ptr = smem.allocate(SharedStorage.size_in_bytes(), byte_alignment=16) + storage = SharedStorage(storage_ptr) + sA = storage.a.get_tensor(sA_layout) + sB = storage.b.get_tensor(sB_layout) + sC = storage.c.get_tensor(sC_layout) + + producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.num_dma_threads + ) + consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.num_mma_threads + ) + + mainloop_pipeline = pipeline.PipelineCpAsync.create( + barrier_storage=storage.mbar.data_ptr(), + num_stages=num_stages, + producer_group=producer_group, + consumer_group=consumer_group, + ) + + k_tile_count = cute.size(gA, mode=[2]) + + if is_dma: + cute.arch.setmaxregister_decrease(40) + + thr_A = tiled_copy_A.get_slice(dma_tidx) + thr_B = tiled_copy_B.get_slice(dma_tidx) + tAgA = thr_A.partition_S(gA) + tAsA = thr_A.partition_D(sA) + tBgB = thr_B.partition_S(gB) + tBsB = thr_B.partition_D(sB) + tAcA = thr_A.partition_S(cA) + tBcB = thr_B.partition_S(cB) + + tApA = cute.make_rmem_tensor( + cute.make_layout( + ( + tAgA.shape[0][1], + cute.size(tAgA, mode=[1]), + cute.size(tAgA, mode=[2]), + ), + stride=(cute.size(tAgA, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + for rv in range(tApA.shape[0]): + for m in range(tApA.shape[1]): + tApA[rv, m, 0] = cute.elem_less( + tAcA[(0, rv), m, 0, 0][0], mA.shape[0] + ) + tBpB = cute.make_rmem_tensor( + cute.make_layout( + ( + tBgB.shape[0][1], + cute.size(tBgB, mode=[1]), + cute.size(tBgB, mode=[2]), + ), + stride=(cute.size(tBgB, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + for rv in range(tBpB.shape[0]): + for n in range(tBpB.shape[1]): + tBpB[rv, n, 0] = cute.elem_less( + tBcB[(0, rv), n, 0, 0][0], mB.shape[0] + ) + + producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, num_stages + ) + + mainloop_pipeline.producer_acquire(producer_state) + cute.copy( + tiled_copy_B, + tBgB[None, None, None, 0], + tBsB[None, None, None, producer_state.index], + pred=tBpB, + ) + + cute.arch.griddepcontrol_wait() + cute.arch.griddepcontrol_launch_dependents() + + cute.copy( + tiled_copy_A, + tAgA[None, None, None, 0], + tAsA[None, None, None, producer_state.index], + pred=tApA, + ) + mainloop_pipeline.producer_commit(producer_state) + producer_state.advance() + + for k_tile in range(1, k_tile_count): + mainloop_pipeline.producer_acquire(producer_state) + cute.copy( + tiled_copy_A, + tAgA[None, None, None, k_tile], + tAsA[None, None, None, producer_state.index], + pred=tApA, + ) + cute.copy( + tiled_copy_B, + tBgB[None, None, None, k_tile], + tBsB[None, None, None, producer_state.index], + pred=tBpB, + ) + mainloop_pipeline.producer_commit(producer_state) + producer_state.advance() + + mainloop_pipeline.producer_tail(producer_state) + + else: + # ===== 4 MMA WARPS with k-phase interleaving ===== + cute.arch.setmaxregister_increase(232) + + lane_id = mma_tidx % 32 + mma_warp_idx = mma_tidx // 32 # 0-3 + NUM_MMA_WARPS: cutlass.Constexpr = self.num_mma_warps + + # Each warp uses the same tiled_mma (32 threads) + # All warps partition the SAME smem; they'll index different k_blocks + thr_mma = tiled_mma.get_slice(lane_id) + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCgC = thr_mma.partition_C(gC) + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) + tCrC = tiled_mma.make_fragment_C(tCgC) + tCrC.fill(0.0) + + atom_s2r_A = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(False, 4), mA.element_type + ) + atom_s2r_B = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(False, 4), mB.element_type + ) + tiled_s2r_A = cute.make_tiled_copy_A(atom_s2r_A, tiled_mma) + tiled_s2r_B = cute.make_tiled_copy_B(atom_s2r_B, tiled_mma) + thr_s2r_A = tiled_s2r_A.get_slice(lane_id) + thr_s2r_B = tiled_s2r_B.get_slice(lane_id) + tCsA_v = thr_s2r_A.partition_S(sA) + tCrA_v = thr_s2r_A.retile(tCrA) + tCsB_v = thr_s2r_B.partition_S(sB) + tCrB_v = thr_s2r_B.retile(tCrB) + + num_k_block = cute.size(tCrA, mode=[2]) + K_PER_WARP: cutlass.Constexpr = num_k_block // NUM_MMA_WARPS + + consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, num_stages + ) + + for k_tile in range(k_tile_count): + mainloop_pipeline.consumer_wait(consumer_state) + + tCsA_p = tCsA_v[None, None, None, consumer_state.index] + tCsB_p = tCsB_v[None, None, None, consumer_state.index] + + # K-phase: each warp directly computes its k_block index + # No branch — each warp loops K_PER_WARP times + # Reuse rmem slot 0 (no need to index by k_block) + if not self.is_fp8: + for ki in cutlass.range(K_PER_WARP, unroll_full=True): + k_block = ki * NUM_MMA_WARPS + mma_warp_idx + cute.copy( + tiled_s2r_A, + tCsA_p[None, None, k_block], + tCrA_v[None, None, 0], + ) + cute.copy( + tiled_s2r_B, + tCsB_p[None, None, k_block], + tCrB_v[None, None, 0], + ) + cute.gemm( + tiled_mma, + tCrC, + tCrA[None, None, 0], + tCrB[None, None, 0], + tCrC, + ) + else: + # fp8: keep accumulators as scalars, avoid + # fragment load/store per k_block + c0 = tCrC[0] + c1 = tCrC[1] + c2 = tCrC[2] + c3 = tCrC[3] + c4 = tCrC[4] + c5 = tCrC[5] + c6 = tCrC[6] + c7 = tCrC[7] + a_s = tCrA[None, None, 0] + b_s = tCrB[None, None, 0] + for ki in cutlass.range(K_PER_WARP, unroll_full=True): + k_block = ki * NUM_MMA_WARPS + mma_warp_idx + cute.copy( + tiled_s2r_A, + tCsA_p[None, None, k_block], + tCrA_v[None, None, 0], + ) + cute.copy( + tiled_s2r_B, + tCsB_p[None, None, k_block], + tCrB_v[None, None, 0], + ) + c0, c1, c2, c3, c4, c5, c6, c7 = fused_fp8_mma_2n( + c0, + c1, + c2, + c3, + c4, + c5, + c6, + c7, + a_s[0], + a_s[1], + a_s[2], + a_s[3], + a_s[4], + a_s[5], + a_s[6], + a_s[7], + b_s[0], + b_s[1], + b_s[2], + b_s[3], + b_s[4], + b_s[5], + b_s[6], + b_s[7], + ) + tCrC[0] = c0 + tCrC[1] = c1 + tCrC[2] = c2 + tCrC[3] = c3 + tCrC[4] = c4 + tCrC[5] = c5 + tCrC[6] = c6 + tCrC[7] = c7 + + mainloop_pipeline.consumer_release(consumer_state) + consumer_state.advance() + + # Fused epilogue: reduce + direct global store (1 sync, no sC) + smem_red_ptr = cute.arch.alloc_smem( + cutlass.Float32, bM * bN * NUM_MMA_WARPS, alignment=16 + ) + + # Each warp writes partial C via MMA partition (vectorized) + smem_warp = cute.make_tensor( + smem_red_ptr + mma_warp_idx * bM * bN, + cute.make_layout((bM, bN), stride=(bN, 1)), + ) + tCsC_partial = thr_mma.partition_C(smem_warp) + cute.autovec_copy(tCrC, tCsC_partial) + cute.arch.sync_threads() + + # Reduce + write directly to global (skip sC) + # 128 threads handle 256 elements (2 per thread) + num_elems: cutlass.Constexpr = bM * bN + elems_per_thread: cutlass.Constexpr = num_elems // self.num_mma_threads + N_global = cute.size(mC, mode=[1]) + for ei in cutlass.range_constexpr(elems_per_thread): + idx = ei * self.num_mma_threads + mma_tidx + m = idx // bN + n = idx % bN + global_m = bid_m * bM + m + global_n = bid_n * bN + n + if global_m < cute.size(mC, mode=[0]): + if global_n < N_global: + total = cutlass.Float32(0.0) + for w in cutlass.range_constexpr(NUM_MMA_WARPS): + p = smem_red_ptr + w * bM * bN + idx + t = cute.make_tensor(p, cute.make_layout((1,))) + r = cute.make_rmem_tensor((1,), cutlass.Float32) + cute.autovec_copy(t, r) + total = total + r[0] + # Direct global store via output pointer + out_p = (mC.iterator + global_m * N_global + global_n).align(2) + out_t = cute.make_tensor(out_p, cute.make_layout((1,))) + out_r = cute.make_rmem_tensor((1,), self.out_dtype) + out_r[0] = (total * scale).to(self.out_dtype) + cute.autovec_copy(out_r, out_t) + + cute.arch.sync_threads() + + + @cute.jit + def call_splitk(self, mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, + stream: CUstream, scale: float = 1.0): + bM, bN, bK = self.tile_m, self.tile_n, self.tile_k + copy_bits = 128 + sA_layout = self._make_smem_layout_AB(mA.element_type, copy_bits, (bM, bK, self.num_stages)) + sB_layout = self._make_smem_layout_AB(mB.element_type, copy_bits, (bN, bK, self.num_stages)) + sC_layout = self._make_smem_layout_C(mC.element_type, copy_bits, (bM, bN)) + atom_g2s = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL), mA.element_type, num_bits_per_copy=copy_bits) + tiled_copy_A = self._make_gmem_tiled_copy(atom_g2s, mA.element_type, copy_bits, self.num_dma_threads) + tiled_copy_B = self._make_gmem_tiled_copy(atom_g2s, mB.element_type, copy_bits, self.num_dma_threads) + atom_s2g = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mC.element_type, num_bits_per_copy=copy_bits) + c_copy_elems = copy_bits // mC.element_type.width + cn_threads = bN // c_copy_elems + tiled_copy_C = cute.make_tiled_copy_tv(atom_s2g, cute.make_layout((self.num_mma_threads // cn_threads, cn_threads), stride=(cn_threads, 1)), cute.make_layout((1, c_copy_elems))) + op = cute.nvgpu.warp.MmaF16BF16Op(self.ab_dtype, self.acc_dtype, self.mma_shape) + perm_mnk = (self.atom_layout[0] * self.mma_shape[0], self.atom_layout[1] * self.mma_shape[1] * (self.tile_n // 8), self.atom_layout[2] * self.mma_shape[2]) + tiled_mma = cute.make_tiled_mma(op, cute.make_layout(self.atom_layout), permutation_mnk=perm_mnk) + grid_m = cute.ceil_div(cute.size(mC, mode=[0]), bM) + grid_n = cute.ceil_div(cute.size(mC, mode=[1]), bN) + self.kernel_splitk(mA, mB, mC, scale, sA_layout, sB_layout, sC_layout, + tiled_copy_A, tiled_copy_B, tiled_copy_C, tiled_mma).launch( + grid=[cute.size(grid_m), cute.size(grid_n), self.split_k], + block=[self.num_threads, 1, 1], + cluster=[1, 1, self.split_k], + stream=stream, use_pdl=True) + + @cute.kernel + def kernel_splitk(self, mA, mB, mC, scale: cutlass.Float32, + sA_layout: cute.ComposedLayout, sB_layout: cute.ComposedLayout, sC_layout: cute.Layout, + tiled_copy_A: cute.TiledCopy, tiled_copy_B: cute.TiledCopy, + tiled_copy_C: cute.TiledCopy, tiled_mma: cute.TiledMma): + bM, bN, bK = self.tile_m, self.tile_n, self.tile_k + num_stages = self.num_stages + tidx, _, _ = cute.arch.thread_idx() + bid_m, bid_n, bid_z = cute.arch.block_idx() + warp_idx = tidx // 32 + is_dma = warp_idx < (self.num_dma_threads // 32) + dma_tidx = tidx + mma_tidx = tidx - self.num_dma_threads + N_out = cute.size(mC, mode=[1]) + M_out = cute.size(mC, mode=[0]) + cta_tiler = (bM, bN, bK) + coord = (bid_m, bid_n, None) + gA = cute.local_tile(mA, tiler=cta_tiler, coord=coord, proj=(1, None, 1)) + gB = cute.local_tile(mB, tiler=cta_tiler, coord=coord, proj=(None, 1, 1)) + gC = cute.local_tile(mC, tiler=cta_tiler, coord=coord, proj=(1, 1, None)) + gA = cute.make_tensor(gA.iterator.align(16), gA.layout) + gB = cute.make_tensor(gB.iterator.align(16), gB.layout) + mcA = cute.make_identity_tensor(mA.layout.shape) + mcB = cute.make_identity_tensor(mB.layout.shape) + cA = cute.local_tile(mcA, tiler=cta_tiler, coord=coord, proj=(1, None, 1)) + cB = cute.local_tile(mcB, tiler=cta_tiler, coord=coord, proj=(None, 1, 1)) + + @cute.struct + class SharedStorage: + a: cute.struct.Align[cute.struct.MemRange[mA.element_type, cute.cosize(sA_layout)], 16] + b: cute.struct.Align[cute.struct.MemRange[mB.element_type, cute.cosize(sB_layout)], 16] + c: cute.struct.Align[cute.struct.MemRange[mC.element_type, cute.cosize(sC_layout)], 16] + mbar: cute.struct.Align[cute.struct.MemRange[cutlass.Int64, num_stages * 2], 8] + smem = cutlass.utils.SmemAllocator() + storage_ptr = smem.allocate(SharedStorage.size_in_bytes(), byte_alignment=16) + storage = SharedStorage(storage_ptr) + sA = storage.a.get_tensor(sA_layout) + sB = storage.b.get_tensor(sB_layout) + producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_dma_threads) + consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_mma_threads) + mainloop_pipeline = pipeline.PipelineCpAsync.create(barrier_storage=storage.mbar.data_ptr(), num_stages=num_stages, producer_group=producer_group, consumer_group=consumer_group) + k_tile_count_full = cute.size(gA, mode=[2]) + tiles_per_split = k_tile_count_full // self.split_k + k_start = bid_z * tiles_per_split + + if is_dma: + cute.arch.setmaxregister_decrease(40) + thr_A = tiled_copy_A.get_slice(dma_tidx); thr_B = tiled_copy_B.get_slice(dma_tidx) + tAgA = thr_A.partition_S(gA); tAsA = thr_A.partition_D(sA) + tBgB = thr_B.partition_S(gB); tBsB = thr_B.partition_D(sB) + tAcA = thr_A.partition_S(cA); tBcB = thr_B.partition_S(cB) + tApA = cute.make_rmem_tensor(cute.make_layout((tAgA.shape[0][1], cute.size(tAgA, mode=[1]), cute.size(tAgA, mode=[2])), stride=(cute.size(tAgA, mode=[1]), 1, 0)), cutlass.Boolean) + for rv in range(tApA.shape[0]): + for m in range(tApA.shape[1]): + tApA[rv, m, 0] = cute.elem_less(tAcA[(0, rv), m, 0, 0][0], mA.shape[0]) + tBpB = cute.make_rmem_tensor(cute.make_layout((tBgB.shape[0][1], cute.size(tBgB, mode=[1]), cute.size(tBgB, mode=[2])), stride=(cute.size(tBgB, mode=[1]), 1, 0)), cutlass.Boolean) + for rv in range(tBpB.shape[0]): + for n in range(tBpB.shape[1]): + tBpB[rv, n, 0] = cute.elem_less(tBcB[(0, rv), n, 0, 0][0], mB.shape[0]) + producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, num_stages) + # First tile: PDL overlap — prefetch B (weight) before wait, load A (activation) after + mainloop_pipeline.producer_acquire(producer_state) + cute.copy(tiled_copy_B, tBgB[None, None, None, k_start], tBsB[None, None, None, producer_state.index], pred=tBpB) + cute.arch.griddepcontrol_wait() + cute.copy(tiled_copy_A, tAgA[None, None, None, k_start], tAsA[None, None, None, producer_state.index], pred=tApA) + mainloop_pipeline.producer_commit(producer_state) + producer_state.advance() + + # Remaining tiles: normal + for local_k in range(1, tiles_per_split): + k_tile = k_start + local_k + mainloop_pipeline.producer_acquire(producer_state) + cute.copy(tiled_copy_A, tAgA[None, None, None, k_tile], tAsA[None, None, None, producer_state.index], pred=tApA) + cute.copy(tiled_copy_B, tBgB[None, None, None, k_tile], tBsB[None, None, None, producer_state.index], pred=tBpB) + mainloop_pipeline.producer_commit(producer_state) + producer_state.advance() + mainloop_pipeline.producer_tail(producer_state) + else: + cute.arch.setmaxregister_increase(232) + lane_id = mma_tidx % 32 + mma_warp_idx = mma_tidx // 32 + NUM_MMA_WARPS: cutlass.Constexpr = self.num_mma_warps + thr_mma = tiled_mma.get_slice(lane_id) + tCsA = thr_mma.partition_A(sA); tCsB = thr_mma.partition_B(sB) + tCgC = thr_mma.partition_C(gC) + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) + tCrC = tiled_mma.make_fragment_C(tCgC); tCrC.fill(0.0) + atom_s2r_A = cute.make_copy_atom(cute.nvgpu.warp.LdMatrix8x8x16bOp(False, 4), mA.element_type) + atom_s2r_B = cute.make_copy_atom(cute.nvgpu.warp.LdMatrix8x8x16bOp(False, 4), mB.element_type) + tiled_s2r_A = cute.make_tiled_copy_A(atom_s2r_A, tiled_mma) + tiled_s2r_B = cute.make_tiled_copy_B(atom_s2r_B, tiled_mma) + thr_s2r_A = tiled_s2r_A.get_slice(lane_id); thr_s2r_B = tiled_s2r_B.get_slice(lane_id) + tCsA_v = thr_s2r_A.partition_S(sA); tCrA_v = thr_s2r_A.retile(tCrA) + tCsB_v = thr_s2r_B.partition_S(sB); tCrB_v = thr_s2r_B.retile(tCrB) + num_k_block = cute.size(tCrA, mode=[2]) + K_PER_WARP: cutlass.Constexpr = num_k_block // NUM_MMA_WARPS + consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, num_stages) + for local_k in range(tiles_per_split): + mainloop_pipeline.consumer_wait(consumer_state) + tCsA_p = tCsA_v[None, None, None, consumer_state.index] + tCsB_p = tCsB_v[None, None, None, consumer_state.index] + if not self.is_fp8: + for ki in cutlass.range(K_PER_WARP, unroll_full=True): + k_block = ki * NUM_MMA_WARPS + mma_warp_idx + cute.copy(tiled_s2r_A, tCsA_p[None, None, k_block], tCrA_v[None, None, 0]) + cute.copy(tiled_s2r_B, tCsB_p[None, None, k_block], tCrB_v[None, None, 0]) + cute.gemm(tiled_mma, tCrC, tCrA[None, None, 0], tCrB[None, None, 0], tCrC) + else: + c0=tCrC[0];c1=tCrC[1];c2=tCrC[2];c3=tCrC[3];c4=tCrC[4];c5=tCrC[5];c6=tCrC[6];c7=tCrC[7] + a_s=tCrA[None,None,0]; b_s=tCrB[None,None,0] + for ki in cutlass.range(K_PER_WARP, unroll_full=True): + k_block = ki * NUM_MMA_WARPS + mma_warp_idx + cute.copy(tiled_s2r_A, tCsA_p[None, None, k_block], tCrA_v[None, None, 0]) + cute.copy(tiled_s2r_B, tCsB_p[None, None, k_block], tCrB_v[None, None, 0]) + c0,c1,c2,c3,c4,c5,c6,c7 = fused_fp8_mma_2n(c0,c1,c2,c3,c4,c5,c6,c7,a_s[0],a_s[1],a_s[2],a_s[3],a_s[4],a_s[5],a_s[6],a_s[7],b_s[0],b_s[1],b_s[2],b_s[3],b_s[4],b_s[5],b_s[6],b_s[7]) + tCrC[0]=c0;tCrC[1]=c1;tCrC[2]=c2;tCrC[3]=c3;tCrC[4]=c4;tCrC[5]=c5;tCrC[6]=c6;tCrC[7]=c7 + mainloop_pipeline.consumer_release(consumer_state) + consumer_state.advance() + + # === CLUSTER REDUCTION EPILOGUE === + # Step 1: reduce MMA warps → per-thread fp32 partial in smem + smem_red_ptr = cute.arch.alloc_smem(cutlass.Float32, bM * bN * NUM_MMA_WARPS, alignment=16) + smem_warp = cute.make_tensor(smem_red_ptr + mma_warp_idx * bM * bN, cute.make_layout((bM, bN), stride=(bN, 1))) + tCsC_partial = thr_mma.partition_C(smem_warp) + cute.autovec_copy(tCrC, tCsC_partial) + cute.arch.sync_threads() + + num_elems: cutlass.Constexpr = bM * bN + elems_per_thread: cutlass.Constexpr = num_elems // self.num_mma_threads + + # Reduce across warps, scale, store to smem partials buffer + partial_buf = cute.arch.alloc_smem(cutlass.Float32, bM * bN * self.split_k, alignment=16) + cta_rank = cute.arch.block_idx_in_cluster() + + for ei in cutlass.range_constexpr(elems_per_thread): + idx = ei * self.num_mma_threads + mma_tidx + m2 = idx // bN + n2 = idx % bN + gm2 = bid_m * bM + m2 + gn2 = bid_n * bN + n2 + total = cutlass.Float32(0.0) + if gm2 < M_out: + if gn2 < N_out: + for w in cutlass.range_constexpr(NUM_MMA_WARPS): + p = smem_red_ptr + w * bM * bN + idx + t = cute.make_tensor(p, cute.make_layout((1,))) + r = cute.make_rmem_tensor((1,), cutlass.Float32) + cute.autovec_copy(t, r) + total = total + r[0] + total = total * scale + pb_p = partial_buf + cta_rank * num_elems + idx + pb_t = cute.make_tensor(pb_p, cute.make_layout((1,))) + pb_r = cute.make_rmem_tensor((1,), cutlass.Float32) + pb_r[0] = total + cute.autovec_copy(pb_r, pb_t) + + cute.arch.sync_threads() + + # Step 2: send partials to all peer CTAs via st.shared::cluster + for ei in cutlass.range_constexpr(elems_per_thread): + idx = ei * self.num_mma_threads + mma_tidx + my_slot = partial_buf + cta_rank * num_elems + idx + my_val_t = cute.make_tensor(my_slot, cute.make_layout((1,))) + my_val_r = cute.make_rmem_tensor((1,), cutlass.Float32) + cute.autovec_copy(my_val_t, my_val_r) + for peer in cutlass.range_constexpr(self.split_k): + remote = set_block_rank(my_slot, cutlass.Int32(peer)) + st_shared_remote_f32(remote, my_val_r[0]) + + # Step 3: cluster barrier + cluster_arrive_relaxed() + cluster_wait() + + # Signal next kernel AFTER cluster reduction is complete + if mma_tidx == 0: + cute.arch.griddepcontrol_launch_dependents() + cute.arch.sync_threads() + + # Step 4: local reduction + global output write + for ei in cutlass.range_constexpr(elems_per_thread): + idx = ei * self.num_mma_threads + mma_tidx + m = idx // bN + n = idx % bN + global_m = bid_m * bM + m + global_n = bid_n * bN + n + if global_m < M_out: + if global_n < N_out: + acc = cutlass.Float32(0.0) + for sk in cutlass.range_constexpr(self.split_k): + cb_p = partial_buf + sk * num_elems + idx + cb_t = cute.make_tensor(cb_p, cute.make_layout((1,))) + cb_r = cute.make_rmem_tensor((1,), cutlass.Float32) + cute.autovec_copy(cb_t, cb_r) + acc = acc + cb_r[0] + out_r0 = global_n if self.transpose_output else global_m + out_r1 = global_m if self.transpose_output else global_n + out_s = M_out if self.transpose_output else N_out + out_p = (mC.iterator + out_r0 * out_s + out_r1).align(2) + out_t = cute.make_tensor(out_p, cute.make_layout((1,))) + out_r = cute.make_rmem_tensor((1,), self.out_dtype) + out_r[0] = acc.to(self.out_dtype) + cute.autovec_copy(out_r, out_t) + + cute.arch.sync_threads() + + @cute.jit + def call_mpar(self, mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, + stream: CUstream, scale: float = 1.0): + """M-parallel warp dispatch: 4 MMA warps handle 4 m16 sub-tiles = tile_m=64.""" + bM_total = 16 * self.num_mma_warps # 64 + bN, bK = self.tile_n, self.tile_k + copy_bits = 128 + + # smem layout for the FULL tile_m (64 rows) + sA_layout = self._make_smem_layout_AB(mA.element_type, copy_bits, (bM_total, bK, self.num_stages)) + sB_layout = self._make_smem_layout_AB(mB.element_type, copy_bits, (bN, bK, self.num_stages)) + # Per-warp smem layout (16 rows) for MMA partitioning + sA_warp_layout = self._make_smem_layout_AB(mA.element_type, copy_bits, (16, bK, self.num_stages)) + + atom_g2s = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL), mA.element_type, num_bits_per_copy=copy_bits) + tiled_copy_A = self._make_gmem_tiled_copy(atom_g2s, mA.element_type, copy_bits, self.num_dma_threads) + tiled_copy_B = self._make_gmem_tiled_copy(atom_g2s, mB.element_type, copy_bits, self.num_dma_threads) + atom_s2g = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mC.element_type, num_bits_per_copy=copy_bits) + c_copy_elems = copy_bits // mC.element_type.width + cn_threads = bN // c_copy_elems + tiled_copy_C = cute.make_tiled_copy_tv(atom_s2g, cute.make_layout((self.num_mma_threads // cn_threads, cn_threads), stride=(cn_threads, 1)), cute.make_layout((1, c_copy_elems))) + op = cute.nvgpu.warp.MmaF16BF16Op(self.ab_dtype, self.acc_dtype, self.mma_shape) + perm_mnk = (self.atom_layout[0] * self.mma_shape[0], self.atom_layout[1] * self.mma_shape[1] * (self.tile_n // 8), self.atom_layout[2] * self.mma_shape[2]) + tiled_mma = cute.make_tiled_mma(op, cute.make_layout(self.atom_layout), permutation_mnk=perm_mnk) + + grid_m = cute.ceil_div(cute.size(mC, mode=[0]), bM_total) + grid_n = cute.ceil_div(cute.size(mC, mode=[1]), bN) + scale_val = cutlass.Float32(scale) + + self.kernel_mpar(mA, mB, mC, scale_val, sA_layout, sB_layout, sA_warp_layout, + tiled_copy_A, tiled_copy_B, tiled_copy_C, tiled_mma).launch( + grid=[cute.size(grid_m), cute.size(grid_n), 1], + block=[self.num_threads, 1, 1], + stream=stream, use_pdl=True) + + @cute.kernel + def kernel_mpar(self, mA, mB, mC, scale: cutlass.Float32, + sA_layout: cute.ComposedLayout, sB_layout: cute.ComposedLayout, + sA_warp_layout: cute.ComposedLayout, + tiled_copy_A: cute.TiledCopy, tiled_copy_B: cute.TiledCopy, + tiled_copy_C: cute.TiledCopy, tiled_mma: cute.TiledMma): + """M-parallel warp kernel: 4 MMA warps, each handles own m16 sub-tile.""" + bN, bK = self.tile_n, self.tile_k + NUM_MMA_WARPS: cutlass.Constexpr = self.num_mma_warps + bM_total = 16 * NUM_MMA_WARPS # 64 + num_stages = self.num_stages + tidx, _, _ = cute.arch.thread_idx() + bid_m, bid_n, _ = cute.arch.block_idx() + warp_idx = tidx // 32 + is_dma = warp_idx < (self.num_dma_threads // 32) + dma_tidx = tidx + mma_tidx = tidx - self.num_dma_threads + N_out = cute.size(mC, mode=[1]) + M_out = cute.size(mC, mode=[0]) + + # Global tiles: A uses bM_total (64 rows), B uses bN (8 rows) + cta_tiler_A = (bM_total, bN, bK) + cta_tiler_B = (bM_total, bN, bK) + coord = (bid_m, bid_n, None) + gA = cute.local_tile(mA, tiler=cta_tiler_A, coord=coord, proj=(1, None, 1)) + gB = cute.local_tile(mB, tiler=cta_tiler_B, coord=coord, proj=(None, 1, 1)) + gA = cute.make_tensor(gA.iterator.align(16), gA.layout) + gB = cute.make_tensor(gB.iterator.align(16), gB.layout) + + mcA = cute.make_identity_tensor(mA.layout.shape) + mcB = cute.make_identity_tensor(mB.layout.shape) + cA = cute.local_tile(mcA, tiler=cta_tiler_A, coord=coord, proj=(1, None, 1)) + cB = cute.local_tile(mcB, tiler=cta_tiler_B, coord=coord, proj=(None, 1, 1)) + + @cute.struct + class SharedStorage: + a: cute.struct.Align[cute.struct.MemRange[mA.element_type, cute.cosize(sA_layout)], 16] + b: cute.struct.Align[cute.struct.MemRange[mB.element_type, cute.cosize(sB_layout)], 16] + mbar: cute.struct.Align[cute.struct.MemRange[cutlass.Int64, num_stages * 2], 8] + smem = cutlass.utils.SmemAllocator() + storage_ptr = smem.allocate(SharedStorage.size_in_bytes(), byte_alignment=16) + storage = SharedStorage(storage_ptr) + sA = storage.a.get_tensor(sA_layout) + sB = storage.b.get_tensor(sB_layout) + + producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_dma_threads) + consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_mma_threads) + mainloop_pipeline = pipeline.PipelineCpAsync.create(barrier_storage=storage.mbar.data_ptr(), num_stages=num_stages, producer_group=producer_group, consumer_group=consumer_group) + k_tile_count = cute.size(gA, mode=[2]) + + if is_dma: + cute.arch.setmaxregister_decrease(40) + thr_A = tiled_copy_A.get_slice(dma_tidx) + thr_B = tiled_copy_B.get_slice(dma_tidx) + tAgA = thr_A.partition_S(gA); tAsA = thr_A.partition_D(sA) + tBgB = thr_B.partition_S(gB); tBsB = thr_B.partition_D(sB) + tAcA = thr_A.partition_S(cA); tBcB = thr_B.partition_S(cB) + tApA = cute.make_rmem_tensor(cute.make_layout((tAgA.shape[0][1], cute.size(tAgA, mode=[1]), cute.size(tAgA, mode=[2])), stride=(cute.size(tAgA, mode=[1]), 1, 0)), cutlass.Boolean) + for rv in range(tApA.shape[0]): + for m in range(tApA.shape[1]): + tApA[rv, m, 0] = cute.elem_less(tAcA[(0, rv), m, 0, 0][0], mA.shape[0]) + tBpB = cute.make_rmem_tensor(cute.make_layout((tBgB.shape[0][1], cute.size(tBgB, mode=[1]), cute.size(tBgB, mode=[2])), stride=(cute.size(tBgB, mode=[1]), 1, 0)), cutlass.Boolean) + for rv in range(tBpB.shape[0]): + for n in range(tBpB.shape[1]): + tBpB[rv, n, 0] = cute.elem_less(tBcB[(0, rv), n, 0, 0][0], mB.shape[0]) + + producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, num_stages) + # First tile: PDL overlap + mainloop_pipeline.producer_acquire(producer_state) + cute.copy(tiled_copy_B, tBgB[None, None, None, 0], tBsB[None, None, None, producer_state.index], pred=tBpB) + cute.arch.griddepcontrol_wait() + cute.copy(tiled_copy_A, tAgA[None, None, None, 0], tAsA[None, None, None, producer_state.index], pred=tApA) + mainloop_pipeline.producer_commit(producer_state) + producer_state.advance() + + for k_tile in range(1, k_tile_count): + mainloop_pipeline.producer_acquire(producer_state) + cute.copy(tiled_copy_A, tAgA[None, None, None, k_tile], tAsA[None, None, None, producer_state.index], pred=tApA) + cute.copy(tiled_copy_B, tBgB[None, None, None, k_tile], tBsB[None, None, None, producer_state.index], pred=tBpB) + mainloop_pipeline.producer_commit(producer_state) + producer_state.advance() + mainloop_pipeline.producer_tail(producer_state) + else: + cute.arch.setmaxregister_increase(232) + lane_id = mma_tidx % 32 + mma_warp_idx = mma_tidx // 32 + + # Each warp gets its own 16-row sub-tile of sA via local_tile + sA_warp_tiler = (16, cute.size(sA, mode=[1]), cute.size(sA, mode=[2])) + sA_warp = cute.local_tile(sA, tiler=sA_warp_tiler, coord=(mma_warp_idx, 0, 0)) + + ab_width = mA.element_type.width + sA_warp_view = cute.recast_tensor(sA_warp, cutlass.BFloat16) + sB_view = cute.recast_tensor(sB, cutlass.BFloat16) + + thr_mma = tiled_mma.get_slice(lane_id) + tCsA = thr_mma.partition_A(sA_warp_view) + tCsB = thr_mma.partition_B(sB_view) + + # Output: each warp writes to its own M-sub-tile region + # gC_warp covers output rows [bid_m*64 + warp*16 : bid_m*64 + (warp+1)*16] + cta_tiler_C = (16, bN, bK) + warp_bid_m = bid_m * NUM_MMA_WARPS + mma_warp_idx + coord_c = (warp_bid_m, bid_n, None) + gC_warp = cute.local_tile(mC, tiler=cta_tiler_C, coord=coord_c, proj=(1, 1, None)) + tCgC = thr_mma.partition_C(gC_warp) + + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) + tCrC = tiled_mma.make_fragment_C(tCgC) + tCrC.fill(0.0) + + atom_s2r_A = cute.make_copy_atom(cute.nvgpu.warp.LdMatrix8x8x16bOp(False, 4), cutlass.BFloat16) + atom_s2r_B = cute.make_copy_atom(cute.nvgpu.warp.LdMatrix8x8x16bOp(False, 4), cutlass.BFloat16) + tiled_s2r_A = cute.make_tiled_copy_A(atom_s2r_A, tiled_mma) + tiled_s2r_B = cute.make_tiled_copy_B(atom_s2r_B, tiled_mma) + thr_s2r_A = tiled_s2r_A.get_slice(lane_id) + thr_s2r_B = tiled_s2r_B.get_slice(lane_id) + tCsA_v = thr_s2r_A.partition_S(sA_warp_view) + tCrA_v = thr_s2r_A.retile(tCrA) + tCsB_v = thr_s2r_B.partition_S(sB_view) + tCrB_v = thr_s2r_B.retile(tCrB) + + num_k_block = cute.size(tCrA, mode=[2]) + consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, num_stages) + + # Main K loop: each warp processes ALL k-tiles (no K-parallelism) + for k_tile in range(k_tile_count): + mainloop_pipeline.consumer_wait(consumer_state) + tCsA_p = tCsA_v[None, None, None, consumer_state.index] + tCsB_p = tCsB_v[None, None, None, consumer_state.index] + if not self.is_fp8: + for ki in range(num_k_block): + cute.copy(tiled_s2r_A, tCsA_p[None, None, ki], tCrA_v[None, None, 0]) + cute.copy(tiled_s2r_B, tCsB_p[None, None, ki], tCrB_v[None, None, 0]) + cute.gemm(tiled_mma, tCrC, tCrA[None, None, 0], tCrB[None, None, 0], tCrC) + else: + a_s = tCrA[None, None, 0]; b_s = tCrB[None, None, 0] + c0=tCrC[0];c1=tCrC[1];c2=tCrC[2];c3=tCrC[3];c4=tCrC[4];c5=tCrC[5];c6=tCrC[6];c7=tCrC[7] + for ki in range(num_k_block): + cute.copy(tiled_s2r_A, tCsA_p[None, None, ki], tCrA_v[None, None, 0]) + cute.copy(tiled_s2r_B, tCsB_p[None, None, ki], tCrB_v[None, None, 0]) + c0,c1,c2,c3,c4,c5,c6,c7 = fused_fp8_mma_2n(c0,c1,c2,c3,c4,c5,c6,c7,a_s[0],a_s[1],a_s[2],a_s[3],a_s[4],a_s[5],a_s[6],a_s[7],b_s[0],b_s[1],b_s[2],b_s[3],b_s[4],b_s[5],b_s[6],b_s[7]) + tCrC[0]=c0;tCrC[1]=c1;tCrC[2]=c2;tCrC[3]=c3;tCrC[4]=c4;tCrC[5]=c5;tCrC[6]=c6;tCrC[7]=c7 + mainloop_pipeline.consumer_release(consumer_state) + consumer_state.advance() + + # Epilogue: write accumulators to smem via MMA partition, then to global + # Each warp uses its own smem region for the m16xn8 output + smem_c_ptr = cute.arch.alloc_smem(cutlass.Float32, 16 * bN * NUM_MMA_WARPS, alignment=16) + smem_warp_c = cute.make_tensor(smem_c_ptr + mma_warp_idx * 16 * bN, cute.make_layout((16, bN), stride=(bN, 1))) + tCsC_w = thr_mma.partition_C(smem_warp_c) + cute.autovec_copy(tCrC, tCsC_w) + cute.arch.sync_threads() + + # Scale + write to global — each warp writes its own 16xbN tile + num_warp_elems: cutlass.Constexpr = 16 * bN + warp_elems_per_thread: cutlass.Constexpr = num_warp_elems // 32 + global_m_base = bid_m * bM_total + mma_warp_idx * 16 + for ei in cutlass.range_constexpr(warp_elems_per_thread): + idx = ei * 32 + lane_id + m = idx // bN + n = idx % bN + global_m = global_m_base + m + global_n = bid_n * bN + n + if global_m < M_out: + if global_n < N_out: + s_p = smem_c_ptr + mma_warp_idx * 16 * bN + m * bN + n + s_t = cute.make_tensor(s_p, cute.make_layout((1,))) + s_r = cute.make_rmem_tensor((1,), cutlass.Float32) + cute.autovec_copy(s_t, s_r) + val = s_r[0] * scale + out_p = (mC.iterator + global_m * N_out + global_n).align(2) + out_t = cute.make_tensor(out_p, cute.make_layout((1,))) + out_r = cute.make_rmem_tensor((1,), self.out_dtype) + out_r[0] = val.to(self.out_dtype) + cute.autovec_copy(out_r, out_t) + + if mma_tidx == 0: + cute.arch.griddepcontrol_launch_dependents() + cute.arch.sync_threads() + + @cute.jit + def call_splitk_atomic(self, mA: cute.Tensor, mB: cute.Tensor, + mC: cute.Tensor, stream: CUstream, scale: float = 1.0): + """Split-K with global atomic reduction (no cluster barriers).""" + bM, bN, bK = self.tile_m, self.tile_n, self.tile_k + copy_bits = 128 + sA_layout = self._make_smem_layout_AB(mA.element_type, copy_bits, (bM, bK, self.num_stages)) + sB_layout = self._make_smem_layout_AB(mB.element_type, copy_bits, (bN, bK, self.num_stages)) + atom_g2s = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL), mA.element_type, num_bits_per_copy=copy_bits) + tiled_copy_A = self._make_gmem_tiled_copy(atom_g2s, mA.element_type, copy_bits, self.num_dma_threads) + tiled_copy_B = self._make_gmem_tiled_copy(atom_g2s, mB.element_type, copy_bits, self.num_dma_threads) + op = cute.nvgpu.warp.MmaF16BF16Op(self.ab_dtype, self.acc_dtype, self.mma_shape) + perm_mnk = (self.atom_layout[0] * self.mma_shape[0], self.atom_layout[1] * self.mma_shape[1] * (self.tile_n // 8), self.atom_layout[2] * self.mma_shape[2]) + tiled_mma = cute.make_tiled_mma(op, cute.make_layout(self.atom_layout), permutation_mnk=perm_mnk) + grid_m = cute.ceil_div(cute.size(mC, mode=[0]), bM) + grid_n = cute.ceil_div(cute.size(mC, mode=[1]), bN) + scale_val = cutlass.Float32(scale) + # NO cluster — each CTA is independent + self.kernel_splitk_atomic(mA, mB, mC, scale_val, sA_layout, sB_layout, + tiled_copy_A, tiled_copy_B, tiled_mma).launch( + grid=[cute.size(grid_m), cute.size(grid_n), self.split_k], + block=[self.num_threads, 1, 1], + stream=stream, use_pdl=True) + + @cute.kernel + def kernel_splitk_atomic(self, mA, mB, mC, scale: cutlass.Float32, + sA_layout: cute.ComposedLayout, sB_layout: cute.ComposedLayout, + tiled_copy_A: cute.TiledCopy, tiled_copy_B: cute.TiledCopy, + tiled_mma: cute.TiledMma): + """Split-K kernel with global atomic reduction — no cluster barriers.""" + bM, bN, bK = self.tile_m, self.tile_n, self.tile_k + num_stages = self.num_stages + tidx, _, _ = cute.arch.thread_idx() + bid_m, bid_n, bid_z = cute.arch.block_idx() + warp_idx = tidx // 32 + is_dma = warp_idx < (self.num_dma_threads // 32) + dma_tidx = tidx + mma_tidx = tidx - self.num_dma_threads + N_out = cute.size(mC, mode=[1]) + M_out = cute.size(mC, mode=[0]) + cta_tiler = (bM, bN, bK) + coord = (bid_m, bid_n, None) + gA = cute.local_tile(mA, tiler=cta_tiler, coord=coord, proj=(1, None, 1)) + gB = cute.local_tile(mB, tiler=cta_tiler, coord=coord, proj=(None, 1, 1)) + gC = cute.local_tile(mC, tiler=cta_tiler, coord=coord, proj=(1, 1, None)) + gA = cute.make_tensor(gA.iterator.align(16), gA.layout) + gB = cute.make_tensor(gB.iterator.align(16), gB.layout) + mcA = cute.make_identity_tensor(mA.layout.shape) + mcB = cute.make_identity_tensor(mB.layout.shape) + cA = cute.local_tile(mcA, tiler=cta_tiler, coord=coord, proj=(1, None, 1)) + cB = cute.local_tile(mcB, tiler=cta_tiler, coord=coord, proj=(None, 1, 1)) + + @cute.struct + class SharedStorage: + a: cute.struct.Align[cute.struct.MemRange[mA.element_type, cute.cosize(sA_layout)], 16] + b: cute.struct.Align[cute.struct.MemRange[mB.element_type, cute.cosize(sB_layout)], 16] + mbar: cute.struct.Align[cute.struct.MemRange[cutlass.Int64, num_stages * 2], 8] + smem = cutlass.utils.SmemAllocator() + storage_ptr = smem.allocate(SharedStorage.size_in_bytes(), byte_alignment=16) + storage = SharedStorage(storage_ptr) + sA = storage.a.get_tensor(sA_layout) + sB = storage.b.get_tensor(sB_layout) + producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_dma_threads) + consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_mma_threads) + mainloop_pipeline = pipeline.PipelineCpAsync.create(barrier_storage=storage.mbar.data_ptr(), num_stages=num_stages, producer_group=producer_group, consumer_group=consumer_group) + k_tile_count_full = cute.size(gA, mode=[2]) + tiles_per_split = k_tile_count_full // self.split_k + k_start = bid_z * tiles_per_split + + if is_dma: + cute.arch.setmaxregister_decrease(40) + thr_A = tiled_copy_A.get_slice(dma_tidx); thr_B = tiled_copy_B.get_slice(dma_tidx) + tAgA = thr_A.partition_S(gA); tAsA = thr_A.partition_D(sA) + tBgB = thr_B.partition_S(gB); tBsB = thr_B.partition_D(sB) + tAcA = thr_A.partition_S(cA); tBcB = thr_B.partition_S(cB) + tApA = cute.make_rmem_tensor(cute.make_layout((tAgA.shape[0][1], cute.size(tAgA, mode=[1]), cute.size(tAgA, mode=[2])), stride=(cute.size(tAgA, mode=[1]), 1, 0)), cutlass.Boolean) + for rv in range(tApA.shape[0]): + for m in range(tApA.shape[1]): + tApA[rv, m, 0] = cute.elem_less(tAcA[(0, rv), m, 0, 0][0], mA.shape[0]) + tBpB = cute.make_rmem_tensor(cute.make_layout((tBgB.shape[0][1], cute.size(tBgB, mode=[1]), cute.size(tBgB, mode=[2])), stride=(cute.size(tBgB, mode=[1]), 1, 0)), cutlass.Boolean) + for rv in range(tBpB.shape[0]): + for n in range(tBpB.shape[1]): + tBpB[rv, n, 0] = cute.elem_less(tBcB[(0, rv), n, 0, 0][0], mB.shape[0]) + producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, num_stages) + mainloop_pipeline.producer_acquire(producer_state) + cute.copy(tiled_copy_B, tBgB[None, None, None, k_start], tBsB[None, None, None, producer_state.index], pred=tBpB) + cute.arch.griddepcontrol_wait() + cute.copy(tiled_copy_A, tAgA[None, None, None, k_start], tAsA[None, None, None, producer_state.index], pred=tApA) + mainloop_pipeline.producer_commit(producer_state) + producer_state.advance() + for local_k in range(1, tiles_per_split): + k_tile = k_start + local_k + mainloop_pipeline.producer_acquire(producer_state) + cute.copy(tiled_copy_A, tAgA[None, None, None, k_tile], tAsA[None, None, None, producer_state.index], pred=tApA) + cute.copy(tiled_copy_B, tBgB[None, None, None, k_tile], tBsB[None, None, None, producer_state.index], pred=tBpB) + mainloop_pipeline.producer_commit(producer_state) + producer_state.advance() + mainloop_pipeline.producer_tail(producer_state) + else: + cute.arch.setmaxregister_increase(232) + lane_id = mma_tidx % 32 + mma_warp_idx = mma_tidx // 32 + NUM_MMA_WARPS: cutlass.Constexpr = self.num_mma_warps + thr_mma = tiled_mma.get_slice(lane_id) + sA_view = cute.recast_tensor(sA, cutlass.BFloat16) + sB_view = cute.recast_tensor(sB, cutlass.BFloat16) + tCsA = thr_mma.partition_A(sA_view); tCsB = thr_mma.partition_B(sB_view) + tCgC = thr_mma.partition_C(gC) + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) + tCrC = tiled_mma.make_fragment_C(tCgC); tCrC.fill(0.0) + atom_s2r_A = cute.make_copy_atom(cute.nvgpu.warp.LdMatrix8x8x16bOp(False, 4), cutlass.BFloat16) + atom_s2r_B = cute.make_copy_atom(cute.nvgpu.warp.LdMatrix8x8x16bOp(False, 4), cutlass.BFloat16) + tiled_s2r_A = cute.make_tiled_copy_A(atom_s2r_A, tiled_mma) + tiled_s2r_B = cute.make_tiled_copy_B(atom_s2r_B, tiled_mma) + thr_s2r_A = tiled_s2r_A.get_slice(lane_id); thr_s2r_B = tiled_s2r_B.get_slice(lane_id) + tCsA_v = thr_s2r_A.partition_S(sA_view); tCrA_v = thr_s2r_A.retile(tCrA) + tCsB_v = thr_s2r_B.partition_S(sB_view); tCrB_v = thr_s2r_B.retile(tCrB) + num_k_block = cute.size(tCrA, mode=[2]) + K_PER_WARP: cutlass.Constexpr = num_k_block // NUM_MMA_WARPS + consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, num_stages) + for local_k in range(tiles_per_split): + mainloop_pipeline.consumer_wait(consumer_state) + tCsA_p = tCsA_v[None, None, None, consumer_state.index] + tCsB_p = tCsB_v[None, None, None, consumer_state.index] + if not self.is_fp8: + for ki in cutlass.range(K_PER_WARP, unroll_full=True): + k_block = ki * NUM_MMA_WARPS + mma_warp_idx + cute.copy(tiled_s2r_A, tCsA_p[None, None, k_block], tCrA_v[None, None, 0]) + cute.copy(tiled_s2r_B, tCsB_p[None, None, k_block], tCrB_v[None, None, 0]) + cute.gemm(tiled_mma, tCrC, tCrA[None, None, 0], tCrB[None, None, 0], tCrC) + else: + a_s=tCrA[None,None,0]; b_s=tCrB[None,None,0] + c0=tCrC[0];c1=tCrC[1];c2=tCrC[2];c3=tCrC[3];c4=tCrC[4];c5=tCrC[5];c6=tCrC[6];c7=tCrC[7] + for ki in cutlass.range(K_PER_WARP, unroll_full=True): + k_block = ki * NUM_MMA_WARPS + mma_warp_idx + cute.copy(tiled_s2r_A, tCsA_p[None, None, k_block], tCrA_v[None, None, 0]) + cute.copy(tiled_s2r_B, tCsB_p[None, None, k_block], tCrB_v[None, None, 0]) + c0,c1,c2,c3,c4,c5,c6,c7 = fused_fp8_mma_2n(c0,c1,c2,c3,c4,c5,c6,c7,a_s[0],a_s[1],a_s[2],a_s[3],a_s[4],a_s[5],a_s[6],a_s[7],b_s[0],b_s[1],b_s[2],b_s[3],b_s[4],b_s[5],b_s[6],b_s[7]) + tCrC[0]=c0;tCrC[1]=c1;tCrC[2]=c2;tCrC[3]=c3;tCrC[4]=c4;tCrC[5]=c5;tCrC[6]=c6;tCrC[7]=c7 + mainloop_pipeline.consumer_release(consumer_state) + consumer_state.advance() + + # === WORKSPACE WRITE EPILOGUE (no reduction) === + # Each CTA writes its scaled partial to mC at its bid_z slice + # mC layout: [split_k * M_tiles, N_tiles] — bid_z offsets by M_tiles + smem_red_ptr = cute.arch.alloc_smem(cutlass.Float32, bM * bN * NUM_MMA_WARPS, alignment=16) + smem_warp = cute.make_tensor(smem_red_ptr + mma_warp_idx * bM * bN, cute.make_layout((bM, bN), stride=(bN, 1))) + tCsC_partial = thr_mma.partition_C(smem_warp) + cute.autovec_copy(tCrC, tCsC_partial) + cute.arch.sync_threads() + + num_elems: cutlass.Constexpr = bM * bN + elems_per_thread: cutlass.Constexpr = num_elems // self.num_mma_threads + + for ei in cutlass.range_constexpr(elems_per_thread): + idx = ei * self.num_mma_threads + mma_tidx + m = idx // bN + n = idx % bN + global_m = bid_m * bM + m + global_n = bid_n * bN + n + if global_m < M_out: + if global_n < N_out: + total = cutlass.Float32(0.0) + for w in cutlass.range_constexpr(NUM_MMA_WARPS): + p = smem_red_ptr + w * bM * bN + idx + t = cute.make_tensor(p, cute.make_layout((1,))) + r = cute.make_rmem_tensor((1,), cutlass.Float32) + cute.autovec_copy(t, r) + total = total + r[0] + total = total * scale + # Write to workspace: offset by bid_z * M_logical + M_logical = M_out // self.split_k + ws_m = bid_z * M_logical + global_m + out_p = (mC.iterator + ws_m * N_out + global_n).align(2) + out_t = cute.make_tensor(out_p, cute.make_layout((1,))) + out_r = cute.make_rmem_tensor((1,), self.out_dtype) + out_r[0] = total.to(self.out_dtype) + cute.autovec_copy(out_r, out_t) + + if mma_tidx == 0: + cute.arch.griddepcontrol_launch_dependents() + + cute.arch.sync_threads() + diff --git a/vllm/model_executor/layers/fused_moe/router/_ll_router_gemm_kernels.py b/vllm/model_executor/layers/fused_moe/router/_ll_router_gemm_kernels.py new file mode 100644 index 000000000000..230a96ca1280 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/_ll_router_gemm_kernels.py @@ -0,0 +1,361 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import operator + +import cutlass +import cutlass.cute as cute +from cuda.bindings.driver import CUstream +from cutlass._mlir import ir as _ir +from cutlass._mlir.dialects import arith as _arith +from cutlass._mlir.dialects import llvm as _llvm +from cutlass.cutlass_dsl import dsl_user_op + + +# ===== fp8 pair conversion via PTX ===== + +@dsl_user_op +def fp8x2_cvt(packed_i16, *, loc=None, ip=None): + """Convert packed Int16 (2x fp8 e4m3) -> 2x Float32 via PTX.""" + i16_ir = packed_i16.ir_value(loc=loc, ip=ip) + i32_f16x2 = _llvm.inline_asm( + _ir.IntegerType.get_signless(32), + [i16_ir], + "cvt.rn.f16x2.e4m3x2 $0, $1;", + "=r,h", + has_side_effects=False, + loc=loc, + ip=ip, + ) + lo16 = _arith.trunci(_ir.IntegerType.get_signless(16), i32_f16x2, loc=loc, ip=ip) + f32_lo = _llvm.inline_asm( + cutlass.Float32.mlir_type, + [lo16], + "cvt.f32.f16 $0, $1;", + "=f,h", + has_side_effects=False, + loc=loc, + ip=ip, + ) + hi32 = _arith.shrui( + i32_f16x2, + _arith.constant(_ir.IntegerType.get_signless(32), 16, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + hi16 = _arith.trunci(_ir.IntegerType.get_signless(16), hi32, loc=loc, ip=ip) + f32_hi = _llvm.inline_asm( + cutlass.Float32.mlir_type, + [hi16], + "cvt.f32.f16 $0, $1;", + "=f,h", + has_side_effects=False, + loc=loc, + ip=ip, + ) + return cutlass.Float32(f32_lo), cutlass.Float32(f32_hi) + + +# ===== fp8 kernel ===== +@cute.kernel +def dotprod_fp8( + gA: cute.Tensor, + gB: cute.Tensor, + gC: cute.Tensor, + M: cutlass.Constexpr, + K_pairs: cutlass.Int32, + N_dim: cutlass.Int32, +): + cute.arch.setmaxregister_increase(128) + tidx = cute.arch.thread_idx()[0] + n_idx = cute.arch.block_idx()[0] + + VPT: cutlass.Constexpr = 8 + F_PER_T: cutlass.Constexpr = 16 + BS: cutlass.Constexpr = 128 + KPI: cutlass.Constexpr = VPT * BS + VPT_T: cutlass.Constexpr = 4 + F_PER_T_T: cutlass.Constexpr = 8 + KPT: cutlass.Constexpr = VPT_T * BS + + k_main = K_pairs // KPI + k_rem = K_pairs - k_main * KPI + k_tail = k_rem // KPT + + acc = cute.make_rmem_tensor((M,), cutlass.Float32) + for m in cutlass.range_constexpr(M): + acc[m] = cutlass.Float32(0.0) + + if k_main > 0: + kb0 = tidx * VPT + bp0 = (gB.iterator + (n_idx * K_pairs + kb0)).align(16) + bt0 = cute.make_tensor(bp0, cute.make_layout((VPT,))) + br0 = cute.make_rmem_tensor((VPT,), cutlass.Int16) + cute.autovec_copy(bt0, br0) + + #cute.arch.griddepcontrol_wait() + + bf0 = cute.make_rmem_tensor((F_PER_T,), cutlass.Float32) + for p in cutlass.range_constexpr(VPT): + bf0[p * 2], bf0[p * 2 + 1] = fp8x2_cvt(br0[p]) + for m in cutlass.range_constexpr(M): + ap0 = (gA.iterator + (m * K_pairs + kb0)).align(16) + at0 = cute.make_tensor(ap0, cute.make_layout((VPT,))) + ar0 = cute.make_rmem_tensor((VPT,), cutlass.Int16) + cute.autovec_copy(at0, ar0) + for p in cutlass.range_constexpr(VPT): + a00, a10 = fp8x2_cvt(ar0[p]) + acc[m] = acc[m] + a00 * bf0[p * 2] + a10 * bf0[p * 2 + 1] + + for ki in cutlass.range(k_main - 1, unroll=4): + kb = (ki + 1) * KPI + tidx * VPT + bp = (gB.iterator + (n_idx * K_pairs + kb)).align(16) + bt = cute.make_tensor(bp, cute.make_layout((VPT,))) + br = cute.make_rmem_tensor((VPT,), cutlass.Int16) + cute.autovec_copy(bt, br) + bf = cute.make_rmem_tensor((F_PER_T,), cutlass.Float32) + for p in cutlass.range_constexpr(VPT): + bf[p * 2], bf[p * 2 + 1] = fp8x2_cvt(br[p]) + for m in cutlass.range_constexpr(M): + ap = (gA.iterator + (m * K_pairs + kb)).align(16) + at = cute.make_tensor(ap, cute.make_layout((VPT,))) + ar = cute.make_rmem_tensor((VPT,), cutlass.Int16) + cute.autovec_copy(at, ar) + for p in cutlass.range_constexpr(VPT): + a0, a1 = fp8x2_cvt(ar[p]) + acc[m] = acc[m] + a0 * bf[p * 2] + a1 * bf[p * 2 + 1] + #else: + # cute.arch.griddepcontrol_wait() + + for ti in cutlass.range(k_tail): + kb = k_main * KPI + ti * KPT + tidx * VPT_T + bp = (gB.iterator + (n_idx * K_pairs + kb)).align(8) + bt = cute.make_tensor(bp, cute.make_layout((VPT_T,))) + br = cute.make_rmem_tensor((VPT_T,), cutlass.Int16) + cute.autovec_copy(bt, br) + bf = cute.make_rmem_tensor((F_PER_T_T,), cutlass.Float32) + for p in cutlass.range_constexpr(VPT_T): + bf[p * 2], bf[p * 2 + 1] = fp8x2_cvt(br[p]) + for m in cutlass.range_constexpr(M): + ap = (gA.iterator + (m * K_pairs + kb)).align(8) + at = cute.make_tensor(ap, cute.make_layout((VPT_T,))) + ar = cute.make_rmem_tensor((VPT_T,), cutlass.Int16) + cute.autovec_copy(at, ar) + for p in cutlass.range_constexpr(VPT_T): + a0, a1 = fp8x2_cvt(ar[p]) + acc[m] = acc[m] + a0 * bf[p * 2] + a1 * bf[p * 2 + 1] + + # Scalar tail + kp = k_main * KPI + k_tail * KPT + kr = K_pairs - kp + ks_full = kr // BS + ks_part = kr - ks_full * BS + for si in cutlass.range(ks_full): + ko = kp + si * BS + tidx + bp = (gB.iterator + (n_idx * K_pairs + ko)).align(2) + bt = cute.make_tensor(bp, cute.make_layout((1,))) + br = cute.make_rmem_tensor((1,), cutlass.Int16) + cute.autovec_copy(bt, br) + b0, b1 = fp8x2_cvt(br[0]) + for m in cutlass.range_constexpr(M): + ap = (gA.iterator + (m * K_pairs + ko)).align(2) + at = cute.make_tensor(ap, cute.make_layout((1,))) + ar = cute.make_rmem_tensor((1,), cutlass.Int16) + cute.autovec_copy(at, ar) + a0, a1 = fp8x2_cvt(ar[0]) + acc[m] = acc[m] + a0 * b0 + a1 * b1 + if tidx < ks_part: + ko = kp + ks_full * BS + tidx + bp = (gB.iterator + (n_idx * K_pairs + ko)).align(2) + bt = cute.make_tensor(bp, cute.make_layout((1,))) + br = cute.make_rmem_tensor((1,), cutlass.Int16) + cute.autovec_copy(bt, br) + b0, b1 = fp8x2_cvt(br[0]) + for m in cutlass.range_constexpr(M): + ap = (gA.iterator + (m * K_pairs + ko)).align(2) + at = cute.make_tensor(ap, cute.make_layout((1,))) + ar = cute.make_rmem_tensor((1,), cutlass.Int16) + cute.autovec_copy(at, ar) + a0, a1 = fp8x2_cvt(ar[0]) + acc[m] = acc[m] + a0 * b0 + a1 * b1 + + # Reduction + WS: cutlass.Constexpr = 32 + NW: cutlass.Constexpr = BS // WS + for m in cutlass.range_constexpr(M): + acc[m] = cute.arch.warp_reduction(acc[m], operator.add) + wid = tidx // WS + lid = tidx % WS + sp = cute.arch.alloc_smem(cutlass.Float32, M * NW, alignment=16) + sm = cute.make_tensor(sp, cute.make_layout((M, NW))) + for m in cutlass.range_constexpr(M): + if lid == 0: + sm[m, wid] = acc[m] + cute.arch.sync_threads() + if tidx == 0: + for m in cutlass.range_constexpr(M): + t = cutlass.Float32(0.0) + for w in cutlass.range_constexpr(NW): + t = t + sm[m, w] + gC[m * N_dim + n_idx] = t.to(cutlass.Float32) + #cute.arch.griddepcontrol_launch_dependents() + + +@cute.jit +def host_fp8( + gA: cute.Tensor, + gB: cute.Tensor, + gC: cute.Tensor, + M: cutlass.Constexpr, + K_pairs: cutlass.Int32, + N_dim: cutlass.Int32, + stream: CUstream, +): + dotprod_fp8(gA, gB, gC, M, K_pairs, N_dim).launch( + grid=[N_dim, 1, 1], + block=[128, 1, 1], + smem=M * 4 * 4, + stream=stream, + use_pdl=False, #TODO(roberto): needs investigation. + ) + + +# ===== bf16 kernel ===== + +def make_host_bf16(k_val: int): + """Create bf16 router kernel for a given K.""" + _VPT = 8; _BS = 256; _KPI = _VPT * _BS # 128-bit loads, 256 threads + _k_main = k_val // _KPI # main loop iters + _VPT_T = 4; _KPT = _VPT_T * _BS # 64-bit tail loads + _k_tail = (k_val - _k_main * _KPI) // _KPT + _k_done = _k_main * _KPI + _k_tail * _KPT + _scalar_rem = k_val - _k_done + _ks_full = _scalar_rem // _BS + _ks_part = _scalar_rem % _BS + + @cute.kernel + def dotprod_bf16_lf( + gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor, + M: cutlass.Constexpr, K_dim: cutlass.Constexpr, N_dim: cutlass.Int32, + ): + cute.arch.setmaxregister_increase(128) #TODO(roberto): limit to 64? + tidx = cute.arch.thread_idx()[0] + n_idx = cute.arch.block_idx()[0] # one CTA per expert + VPT: cutlass.Constexpr = _VPT + BS: cutlass.Constexpr = _BS + KPI: cutlass.Constexpr = _KPI + K_MAIN: cutlass.Constexpr = _k_main + elem = gB.element_type + b_base = gB.iterator + n_idx * K_dim # precomputed B row base + tid_off = tidx * VPT + + acc = cute.make_rmem_tensor((M,), cutlass.Float32) + for m in cutlass.range_constexpr(M): + acc[m] = cutlass.Float32(0.0) + + # Main K-loop (fully unrolled via range_constexpr) + for ki in cutlass.range_constexpr(K_MAIN): + kb = ki * KPI + tid_off + # Load B tile + bp = (b_base + kb).align(16) + bt = cute.make_tensor(bp, cute.make_layout((VPT,))) + br = cute.make_rmem_tensor((VPT,), elem) + cute.autovec_copy(bt, br) + # Batch-load all A tokens into registers + ar_all = cute.make_rmem_tensor((M, VPT), elem) + for m in cutlass.range_constexpr(M): + ap = (gA.iterator + (m * K_dim + kb)).align(16) + at = cute.make_tensor(ap, cute.make_layout((VPT,))) + ar = cute.make_rmem_tensor((VPT,), elem) + cute.autovec_copy(at, ar) + for v in cutlass.range_constexpr(VPT): + ar_all[m, v] = ar[v] + # Compute (all data in registers) + for m in cutlass.range_constexpr(M): + for v in cutlass.range_constexpr(VPT): + acc[m] = acc[m] + ar_all[m, v].to(cutlass.Float32) * br[v].to(cutlass.Float32) + + VPT_T: cutlass.Constexpr = _VPT_T + KPT: cutlass.Constexpr = _KPT + K_DONE: cutlass.Constexpr = _k_main * _KPI + tid_off_t = tidx * VPT_T + # Vectorized tail (64-bit loads for K remainder) + for ti in cutlass.range_constexpr(_k_tail): + kb = K_DONE + ti * KPT + tid_off_t + bp = (b_base + kb).align(8) + bt = cute.make_tensor(bp, cute.make_layout((VPT_T,))) + br = cute.make_rmem_tensor((VPT_T,), elem) + cute.autovec_copy(bt, br) + for m in cutlass.range_constexpr(M): + ap = (gA.iterator + (m * K_dim + kb)).align(8) + at = cute.make_tensor(ap, cute.make_layout((VPT_T,))) + ar = cute.make_rmem_tensor((VPT_T,), elem) + cute.autovec_copy(at, ar) + for v in cutlass.range_constexpr(VPT_T): + acc[m] = acc[m] + ar[v].to(cutlass.Float32) * br[v].to(cutlass.Float32) + + # Scalar tail (one element per thread for non-aligned K) + K_DONE_ALL: cutlass.Constexpr = _k_done + for si in cutlass.range_constexpr(_ks_full): + ko = K_DONE_ALL + si * BS + tidx + bp = (b_base + ko).align(2) + bt = cute.make_tensor(bp, cute.make_layout((1,))) + br = cute.make_rmem_tensor((1,), elem) + cute.autovec_copy(bt, br) + bv = br[0].to(cutlass.Float32) + for m in cutlass.range_constexpr(M): + ap = (gA.iterator + (m * K_dim + ko)).align(2) + at = cute.make_tensor(ap, cute.make_layout((1,))) + ar = cute.make_rmem_tensor((1,), elem) + cute.autovec_copy(at, ar) + acc[m] = acc[m] + ar[0].to(cutlass.Float32) * bv + + if _ks_part > 0: + KS_PART: cutlass.Constexpr = _ks_part + ko_p = K_DONE_ALL + _ks_full * BS + tidx + if tidx < KS_PART: + bp2 = (b_base + ko_p).align(2) + bt2 = cute.make_tensor(bp2, cute.make_layout((1,))) + br2 = cute.make_rmem_tensor((1,), elem) + cute.autovec_copy(bt2, br2) + bv2 = br2[0].to(cutlass.Float32) + for m in cutlass.range_constexpr(M): + ap2 = (gA.iterator + (m * K_dim + ko_p)).align(2) + at2 = cute.make_tensor(ap2, cute.make_layout((1,))) + ar2 = cute.make_rmem_tensor((1,), elem) + cute.autovec_copy(at2, ar2) + acc[m] = acc[m] + ar2[0].to(cutlass.Float32) * bv2 + + # Warp + cross-warp reduction + WS: cutlass.Constexpr = 32 + NW: cutlass.Constexpr = BS // WS + for m in cutlass.range_constexpr(M): + acc[m] = cute.arch.warp_reduction(acc[m], operator.add) + wid = tidx // WS + lid = tidx % WS + sp = cute.arch.alloc_smem(cutlass.Float32, M * NW, alignment=16) + sm = cute.make_tensor(sp, cute.make_layout((M, NW))) + for m in cutlass.range_constexpr(M): + if lid == 0: + sm[m, wid] = acc[m] + cute.arch.sync_threads() + if tidx == 0: + for m in cutlass.range_constexpr(M): + t = cutlass.Float32(0.0) + for w in cutlass.range_constexpr(NW): + t = t + sm[m, w] + gC[m * N_dim + n_idx] = t.to(cutlass.Float32) + + @cute.jit + def host_bf16_lf( + gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor, + M: cutlass.Constexpr, K_dim: cutlass.Constexpr, + N_dim: cutlass.Int32, stream: CUstream, + ): + dotprod_bf16_lf(gA, gB, gC, M, K_dim, N_dim).launch( + grid=[N_dim, 1, 1], block=[256, 1, 1], + smem=M * 4 * 8, stream=stream, + use_pdl=False, #TODO(roberto): needs investigation. + ) + + return host_bf16_lf diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index 77d8e756026d..2a1739802b71 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -10,11 +10,12 @@ @PluggableLayer.register("gate_linear") class GateLinear(ReplicatedLinear): - """MoE gate linear layer with three-tier GEMM dispatch: + """MoE gate linear layer with four-tier GEMM dispatch: - 1. DSV3 specialized kernel (SM90+, batch<=16, supported dims) - 2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) - 3. F.linear via ReplicatedLinear (ultimate fallback) + 1. cuteDSL ll_router_gemm (SM90+, batch<=16, fp32 output) + 2. DSV3 specialized kernel (SM90+, batch<=16, supported dims) + 3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) + 4. F.linear via ReplicatedLinear (ultimate fallback) The ``out_dtype`` attribute is mutable and can be set after init (e.g. when the required dtype depends on the expert quantization @@ -43,7 +44,7 @@ def __init__( ) # If fp32 compute is required and no specialized kernel is available, - # store weights in fp32 so Tier 3 computes in fp32 natively. + # store weights in fp32 so Tier 4 computes in fp32 natively. if force_fp32_compute and not can_use_specialized_kernels: params_dtype = torch.float32 @@ -57,6 +58,16 @@ def __init__( ) self.out_dtype = out_dtype + # cuteDSL ll_router_gemm eligibility + self.allow_ll_router_gemm = False + if can_use_specialized_kernels: + try: + from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import is_available + self.allow_ll_router_gemm = False #is_available() #TODO(roberto): needs investigation. No improvements e2e on DSV3 (vs. ptx version). + # can be interesting for other models that uses GateLinear directly. + except ImportError: + pass + # DSV3 specialized kernel eligibility (SM90+, exact dims) self.allow_specialized_router_gemm = can_use_specialized_kernels self.allow_dsv3_router_gemm = ( @@ -94,7 +105,13 @@ def forward( ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: import vllm._custom_ops as ops - # Tier 1: DSV3 specialized kernel + # Tier 1: cuteDSL ll_router_gemm + if self.allow_ll_router_gemm and x.shape[0] <= 16: + from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import ll_router_gemm + output = ll_router_gemm(x, self.weight) + return output, None + + # Tier 2: DSV3 specialized kernel if self.allow_dsv3_router_gemm and x.shape[0] <= 16: output = ops.dsv3_router_gemm( hidden_states=x, @@ -103,12 +120,12 @@ def forward( ) return output, None - # Tier 2: cuBLAS bf16→fp32 + # Tier 3: cuBLAS bf16→fp32 if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16: output = ops.router_gemm_bf16_fp32(x, self.weight) return output, None - # Tier 3: F.linear (ReplicatedLinear) + # Tier 4: F.linear (ReplicatedLinear) if self.out_dtype is not None and x.dtype != self.weight.dtype: x = x.to(self.weight.dtype) output, output_bias = super().forward(x) diff --git a/vllm/model_executor/layers/fused_moe/router/ll_a_gemm.py b/vllm/model_executor/layers/fused_moe/router/ll_a_gemm.py new file mode 100644 index 000000000000..a922dda9e94d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/ll_a_gemm.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import logging + +import torch + +logger = logging.getLogger(__name__) + + +def is_available() -> bool: + try: + import cutlass.cute # noqa: F401 + return True + except ImportError: + return False + + +# Cache: (is_fp8, swapped) -> compiled callable +_compiled_cache: dict[tuple, object] = {} + + +def _get_compiled(is_fp8: bool, swapped: bool, a, b, c): + """Get or compile an A GEMM kernel.""" + import cutlass.cute as cute + from cuda.bindings.driver import CUstream + from cutlass.cute.runtime import from_dlpack + from torch.cuda import current_stream + + from ._ll_a_gemm_kernels import LLAGemm + + K = a.shape[1] + ns = 12 if K >= 4096 else 4 + cache_key = (is_fp8, swapped, ns) + if cache_key in _compiled_cache: + return _compiled_cache[cache_key] + + div = 8 + # For swapped path, output C=[N,M] has small M — relax divisibility + b_div = div # B (activations) K dim always divisible by 8 + c_div = 1 if swapped else div # C mode 1 = M (can be 1-8) + + mA = (from_dlpack(a, assumed_align=16, enable_tvm_ffi=True) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, stride_order=(0, 1), + divisibility=div)) + mB = (from_dlpack(b, assumed_align=16, enable_tvm_ffi=True) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, stride_order=(0, 1), + divisibility=b_div)) + mC = (from_dlpack(c, assumed_align=16, enable_tvm_ffi=True) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, stride_order=(0, 1), + divisibility=c_div)) + + tk = 256 + tn = 8 if swapped else 16 + gemm = LLAGemm(tile_n=tn, tile_k=tk, num_stages=ns, + num_dma_warps=4, is_fp8=is_fp8) + stream = CUstream(current_stream().cuda_stream) + compiled = cute.compile(gemm, mA, mB, mC, stream, + options="--enable-tvm-ffi") + _compiled_cache[cache_key] = compiled + logger.debug("Compiled ll_a_gemm: is_fp8=%s swapped=%s tile_n=%d", + is_fp8, swapped, tn) + return compiled + + +def ll_a_gemm( + hidden_states: torch.Tensor, + weight: torch.Tensor, + is_fp8: bool = False, + scale: float = 1.0, +) -> torch.Tensor: + from cuda.bindings.driver import CUstream + from torch.cuda import current_stream + + M = hidden_states.shape[0] + N = weight.shape[0] + + if M <= 8: + out_NM = torch.empty(N, M, dtype=torch.bfloat16, + device=hidden_states.device) + compiled = _get_compiled(is_fp8, True, weight, hidden_states, out_NM) + stream = CUstream(current_stream().cuda_stream) + compiled(weight, hidden_states, out_NM, stream, scale) + return out_NM.T + else: + output = torch.empty(M, N, dtype=torch.bfloat16, + device=hidden_states.device) + compiled = _get_compiled(is_fp8, False, hidden_states, weight, output) + stream = CUstream(current_stream().cuda_stream) + compiled(hidden_states, weight, output, stream, scale) + return output + + +# Split-K compiled kernel cache +_splitk_cache: dict[tuple, object] = {} + + +def _get_compiled_splitk(is_fp8: bool, swapped: bool, a, b, c, split_k: int, num_stages: int = 0): + """Compile split-K kernel variant.""" + import cutlass.cute as cute + from cuda.bindings.driver import CUstream + from cutlass.cute.runtime import from_dlpack + from torch.cuda import current_stream + + from ._ll_a_gemm_kernels import LLAGemm + + K = a.shape[1] + tiles = K // 256 + ns = num_stages if num_stages > 0 else min(12, tiles // split_k) + cache_key = (is_fp8, swapped, split_k, ns, swapped) # transpose_output=swapped + if cache_key in _splitk_cache: + return _splitk_cache[cache_key] + + div = 8 + b_div = div + c_div = 1 if swapped else div + tn = 8 if swapped else 16 + + mA = (from_dlpack(a, assumed_align=16, enable_tvm_ffi=True) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, stride_order=(0, 1), + divisibility=div)) + mB = (from_dlpack(b, assumed_align=16, enable_tvm_ffi=True) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, stride_order=(0, 1), + divisibility=b_div)) + mC = (from_dlpack(c, assumed_align=16, enable_tvm_ffi=True) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, stride_order=(0, 1), + divisibility=c_div)) + + gemm = LLAGemm(tile_n=tn, tile_k=256, num_stages=ns, + num_dma_warps=4, is_fp8=is_fp8, split_k=split_k, + transpose_output=swapped) + stream = CUstream(current_stream().cuda_stream) + compiled = cute.compile(gemm.call_splitk, mA, mB, mC, stream, + options="--enable-tvm-ffi") + _splitk_cache[cache_key] = compiled + logger.debug("Compiled ll_a_gemm splitk: sk=%d ns=%d swapped=%s", + split_k, ns, swapped) + return compiled + + +def ll_a_gemm_fp8( + hidden_states: torch.Tensor, + weight_fp8_viewed: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, +) -> torch.Tensor: + from cuda.bindings.driver import CUstream + from torch.cuda import current_stream + + M = hidden_states.shape[0] + K_phys = hidden_states.shape[1] + + # Quantize input to FP8, view as bf16 + x_fp8 = (hidden_states / input_scale).to(torch.float8_e4m3fn) + # Force tight strides for M=1 (PyTorch keeps loose stride[0]) + if M == 1: + buf = torch.empty_like(x_fp8) + buf.copy_(x_fp8) + x_fp8 = buf + x8 = x_fp8.view(torch.bfloat16) + + w8 = weight_fp8_viewed # already [N, K/2] bf16-viewed + N = w8.shape[0] + K_view = w8.shape[1] + + # Select split_k + #TODO(roberto): Implement an autotuner. + tiles = K_view // 256 + if tiles >= 12 and N <= 256: + split_k = 12 + elif tiles >= 6 and N <= 1536: + split_k = 6 + elif tiles >= 4: + split_k = 4 + elif tiles >= 2: + split_k = 2 + else: + split_k = 1 + while tiles % split_k != 0 and split_k > 1: + split_k -= 1 + + if split_k == 1: + out = ll_a_gemm(x8, w8, is_fp8=True) + else: + swapped = M <= 8 + if swapped: + out_buf = torch.empty(split_k * N, M, dtype=torch.bfloat16, + device=x8.device) + compiled = _get_compiled_splitk(True, True, w8, x8, out_buf, split_k) + stream = CUstream(current_stream().cuda_stream) + compiled(w8, x8, out_buf, stream) + out = out_buf.view(split_k, N, M).sum(dim=0).T + else: + out_buf = torch.empty(split_k * M, N, dtype=torch.bfloat16, + device=x8.device) + compiled = _get_compiled_splitk(True, False, x8, w8, out_buf, split_k) + stream = CUstream(current_stream().cuda_stream) + compiled(x8, w8, out_buf, stream) + out = out_buf.view(split_k, M, N).sum(dim=0) + + out = out * (input_scale * weight_scale) + return out diff --git a/vllm/model_executor/layers/fused_moe/router/ll_router_gemm.py b/vllm/model_executor/layers/fused_moe/router/ll_router_gemm.py new file mode 100644 index 000000000000..0671f9f6ff5b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/ll_router_gemm.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Low-latency router GEMM via cuteDSL. + +Generalized router GEMM kernel. Supports arbitrary N (num_experts) +and K (hidden_dim) with bf16 and fp8_e4m3fn inputs, +M <= 16 tokens, fp32 output. +""" + +from __future__ import annotations + +import logging + +import torch + +logger = logging.getLogger(__name__) + +_cutedsl_available: bool | None = None + + +def is_available() -> bool: + """Check if cuteDSL backend is available.""" + global _cutedsl_available + if _cutedsl_available is not None: + return _cutedsl_available + try: + import cutlass # noqa: F401 + import cutlass.cute # noqa: F401 + + _cutedsl_available = True + except ImportError: + _cutedsl_available = False + logger.info("cuteDSL (CUTLASS Python) not available, ll_router_gemm disabled") + return _cutedsl_available + + +# Cache: (M, is_fp8) -> compiled callable +_compiled_cache: dict[tuple[int, bool], object] = {} + + +def _get_compiled(M: int, is_fp8: bool, K: int, N: int, a_flat, b_flat, c_flat): + """Get or compile a kernel for the given (M, is_fp8, K) combination.""" + import cutlass.cute as cute + from cuda.bindings.driver import CUstream + from cutlass.cute.runtime import from_dlpack + from torch.cuda import current_stream + + key = (M, is_fp8, K) + if key in _compiled_cache: + return _compiled_cache[key] + + if is_fp8: + from ._ll_router_gemm_kernels import host_fp8 + host_fn = host_fp8 + else: + from ._ll_router_gemm_kernels import make_host_bf16 + host_fn = make_host_bf16(K) + + a_c = from_dlpack(a_flat, assumed_align=32, + enable_tvm_ffi=True).mark_layout_dynamic() + b_c = from_dlpack(b_flat, assumed_align=32, + enable_tvm_ffi=True).mark_layout_dynamic() + c_c = from_dlpack(c_flat, assumed_align=32, + enable_tvm_ffi=True).mark_layout_dynamic() + + K_eff = K // 2 if is_fp8 else K + stream = CUstream(current_stream().cuda_stream) + + compiled = cute.compile(host_fn, a_c, b_c, c_c, M, K_eff, N, stream, + options="--enable-tvm-ffi --ptxas-options -maxrregcount=64") + _compiled_cache[key] = compiled + logger.debug("Compiled ll_router_gemm: M=%d, is_fp8=%s, K=%d", M, is_fp8, K) + return compiled + + +def ll_router_gemm( + hidden_states: torch.Tensor, + router_weight: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Low-latency router GEMM: C[M,N] = A[M,K] @ B[N,K]^T. + + Args: + hidden_states: [M, K] input tensor (bf16 or fp8_e4m3fn), M <= 16. + router_weight: [N, K] weight tensor (same dtype as input). + output_dtype: Output dtype (default float32). + + Returns: + [M, N] output tensor. + """ + from cuda.bindings.driver import CUstream + from torch.cuda import current_stream + + M, K = hidden_states.shape + N = router_weight.shape[0] + is_fp8 = hidden_states.dtype == torch.float8_e4m3fn + + output = torch.empty(M, N, dtype=output_dtype, device=hidden_states.device) + + if is_fp8: + a_flat = hidden_states.view(torch.int16).reshape(-1) + b_flat = router_weight.view(torch.int16).reshape(-1) + else: + a_flat = hidden_states.reshape(-1) + b_flat = router_weight.reshape(-1) + c_flat = output.reshape(-1) + + compiled = _get_compiled(M, is_fp8, K, N, a_flat, b_flat, c_flat) + + # TVM FFI: pass torch tensors directly (no from_dlpack on hot path) + stream = CUstream(current_stream().cuda_stream) + if is_fp8: + K_eff = K // 2 + compiled(a_flat, b_flat, c_flat, K_eff, N, stream) + else: + # K is baked in as Constexpr for bf16 + compiled(a_flat, b_flat, c_flat, N, stream) + + return output diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index e26b511de4ce..27ade29ef855 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -88,6 +88,35 @@ def apply_penalties( logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits +def _check_ll_gemm() -> bool: + try: + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import is_available + return is_available() + except ImportError: + return False + +def _ll_gemm_unquantized_gemm( + layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, +): + num_tokens = x.numel() // x.shape[-1] + N = weight.shape[0] + if ( + num_tokens <= 16 + and x.dtype == torch.bfloat16 + and weight.dtype == torch.bfloat16 + and N <= 4096 #TODO(roberto): larger Ns require other kernel designs. + and bias is None #TODO(roberto): support bias + and x.is_contiguous() + and weight.is_contiguous() + ): + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ll_a_gemm + out_shape = (*x.shape[:-1], N) + result = ll_a_gemm(x.view(num_tokens, -1), weight) + return result.view(out_shape) + return torch.nn.functional.linear(x, weight, bias) def default_unquantized_gemm( layer: torch.nn.Module, @@ -97,6 +126,63 @@ def default_unquantized_gemm( ): return torch.nn.functional.linear(x, weight, bias) +def _ll_gemm_fp8_scaled_mm( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None, + output_shape: list, +) -> torch.Tensor | None: + """Low-latency FP8 split-K GEMM for small-M decode. + Returns output tensor or None to fall through to scaled_mm.""" + if not _check_ll_gemm(): + return None + num_tokens = A.shape[0] + if num_tokens > 8: #TODO(roberto): Try to push this barrier up to 16. + return None + + x8 = A.view(torch.bfloat16) + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import _get_compiled_splitk + from cuda.bindings.driver import CUstream + from torch.cuda import current_stream + + w8 = B.T.view(torch.bfloat16) + N = w8.shape[0] + K_view = w8.shape[1] + + if N > 4096: # Not a LL kernel. Need other kernel designs. + return None + + #TODO(roberto): sk values do not guarantee tiles%sk==0 for all K. + # Should pick sk that divides tiles. + # Longer-term: implement an autotuner. + tiles = K_view // 256 + if N <= 256 and tiles >= 8: + split_k, ns = 8, min(3, tiles // 8) + elif N <= 1536 and tiles >= 4: + split_k, ns = 4, min(4, tiles // 4) + elif N <= 3072 and tiles >= 4: + split_k, ns = 4, min(2, tiles // 4) + else: + return None + + if tiles % split_k != 0: + return None + + combined_scale = (As * Bs).item() + out_buf = torch.empty(N * num_tokens, dtype=torch.bfloat16, device=x8.device) + out_for_kernel = out_buf.view(N, num_tokens) + compiled = _get_compiled_splitk(True, True, w8, x8, out_for_kernel, split_k, ns) + stream = CUstream(current_stream().cuda_stream) + compiled(w8, x8, out_for_kernel, stream, combined_scale) + out = out_buf.view(num_tokens, N) + # TODO(roberto): fuse bias into the kernel + if bias is not None: + out = out + bias + return out.view(*output_shape) + + def use_aiter_triton_gemm(n, m, k, dtype): if ( @@ -311,5 +397,7 @@ def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: return rocm_unquantized_gemm elif current_platform.is_cpu(): return cpu_unquantized_gemm + elif _check_ll_gemm(): + return _ll_gemm_unquantized_gemm else: return default_unquantized_gemm diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index fbbd5da1fd90..64507d4571c9 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -83,6 +83,7 @@ maybe_remap_kv_scale_name, ) from vllm.model_executor.models.utils import sequence_parallel_chunk + from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.torch_utils import direct_register_custom_op @@ -766,6 +767,13 @@ def _min_latency_fused_qkv_a_proj_impl( """ num_tokens = input_.shape[0] if 0 < num_tokens <= 16: + try: + from vllm.model_executor.layers.fused_moe.router.ll_a_gemm import ( + ll_a_gemm, + ) + return ll_a_gemm(input_, weight) + except ImportError: + pass output = torch.empty( num_tokens, weight.shape[0], @@ -824,6 +832,7 @@ def __init__( ) ) + def forward( self, input_,