Skip to content
Draft
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
347c0ed
generalized low-latency router GEMM via cuteDSL
LopezCastroRoberto Apr 15, 2026
7b9fa1f
generalized low-latency router + A GEMM via cuteDSL (bf16/fp8)
LopezCastroRoberto Apr 17, 2026
df92ff2
add a_gemm API
LopezCastroRoberto Apr 17, 2026
ef0182b
update
LopezCastroRoberto Apr 17, 2026
6eda5a5
add TinyGEMM to the benchmark
LopezCastroRoberto Apr 22, 2026
18bc6b6
update
LopezCastroRoberto Apr 22, 2026
964dac4
add tma version
LopezCastroRoberto Apr 22, 2026
7c1546d
update benchmakrs
LopezCastroRoberto Apr 22, 2026
1d3a2e8
update
LopezCastroRoberto Apr 22, 2026
88f444b
update
LopezCastroRoberto Apr 22, 2026
eb7bffb
update
LopezCastroRoberto Apr 22, 2026
737eef7
update
LopezCastroRoberto Apr 22, 2026
4c68740
Merge branch 'main' into feature/ll_gemm_pdl
LopezCastroRoberto Apr 28, 2026
707d385
update
LopezCastroRoberto Apr 28, 2026
cd802fe
add a2a flashinfer
LopezCastroRoberto Apr 28, 2026
b3a512e
update
LopezCastroRoberto Apr 28, 2026
14728c8
update linear
LopezCastroRoberto May 4, 2026
b7d839b
update linear
LopezCastroRoberto May 5, 2026
efd8055
update
LopezCastroRoberto May 6, 2026
5111c65
update
LopezCastroRoberto May 7, 2026
d3d5316
update
LopezCastroRoberto May 9, 2026
4069ed1
update
LopezCastroRoberto May 9, 2026
d6ed40c
add benchmark + cleanup
LopezCastroRoberto May 11, 2026
27a7311
benchmark cleanup
LopezCastroRoberto May 11, 2026
236c314
remove pdl overlap bench
LopezCastroRoberto May 11, 2026
d60f987
improve test coverage
LopezCastroRoberto May 11, 2026
47fc229
cleanup fp8 path
LopezCastroRoberto May 11, 2026
f2e9e61
cleanup bf16 path
LopezCastroRoberto May 11, 2026
fefad72
remove space
LopezCastroRoberto May 11, 2026
7910943
simplify bf16 qkv_a_proj_impl integration+
LopezCastroRoberto May 11, 2026
49da9c2
cleanup up router
LopezCastroRoberto May 12, 2026
00426fd
rm tma ll gemm
LopezCastroRoberto May 12, 2026
29d3cd6
rm tma ll gemm
LopezCastroRoberto May 12, 2026
6d32d42
cleanup ll_a_gemm
LopezCastroRoberto May 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
384 changes: 384 additions & 0 deletions benchmarks/kernels/bench_ll_a_gemm.py

Large diffs are not rendered by default.

142 changes: 142 additions & 0 deletions benchmarks/kernels/bench_ll_router_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse
import os

import torch
from triton.testing import do_bench_cudagraph as _do_bench_cg

from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.router.ll_router_gemm import (
ll_router_gemm,
)
from vllm.triton_utils import triton

_HAS_DSV3 = hasattr(ops, "dsv3_router_gemm")

_providers = ["ll-router-bf16", "ll-router-fp8", "cublas-bf16"]
if _HAS_DSV3:
_providers.append("dsv3-trtllm")


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["M"],
x_vals=[1, 2, 4, 8, 16],
x_log=False,
line_arg="provider",
line_vals=_providers,
line_names=_providers,
ylabel="Latency (us, lower is better)",
plot_name="LL Router GEMM",
args={},
)
)
def benchmark(M, provider, N, K):
device = "cuda"
quantiles = [0.5, 0.2, 0.8]

if provider == "ll-router-bf16":
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N, K, dtype=torch.bfloat16, device=device)
if args.l2_pollute:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: (torch.mm(_ap, _wp.T, out=_op), ll_router_gemm(a, b)), quantiles=quantiles
)
ms -= _l2_tp; min_ms -= _l2_tp; max_ms -= _l2_tp
else:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: ll_router_gemm(a, b), quantiles=quantiles
)

elif provider == "ll-router-fp8":
a = torch.randn(M, K, device=device).to(torch.float8_e4m3fn)
b = torch.randn(N, K, device=device).to(torch.float8_e4m3fn)
if args.l2_pollute:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: (torch.mm(_ap, _wp.T, out=_op), ll_router_gemm(a, b)), quantiles=quantiles
)
ms -= _l2_tp; min_ms -= _l2_tp; max_ms -= _l2_tp
else:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: ll_router_gemm(a, b), quantiles=quantiles
)

elif provider == "cublas-bf16":
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N, K, dtype=torch.bfloat16, device=device)
out = torch.empty(M, N, dtype=torch.bfloat16, device=device)
if args.l2_pollute:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: (torch.mm(_ap, _wp.T, out=_op), torch.mm(a, b.T, out=out)), quantiles=quantiles
)
ms -= _l2_tp; min_ms -= _l2_tp; max_ms -= _l2_tp
else:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.mm(a, b.T, out=out), quantiles=quantiles
)

elif provider == "dsv3-trtllm":
# DSV3 only supports N∈{256,384}, K=7168
if N not in (256, 384) or K != 7168:
return float("nan"), float("nan"), float("nan")
from vllm import _custom_ops as ops

a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N, K, dtype=torch.bfloat16, device=device)
if args.l2_pollute:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: (torch.mm(_ap, _wp.T, out=_op), ops.dsv3_router_gemm(a, b, torch.float32)), quantiles=quantiles
)
ms -= _l2_tp; min_ms -= _l2_tp; max_ms -= _l2_tp
else:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: ops.dsv3_router_gemm(a, b, torch.float32), quantiles=quantiles
)

# Return latency in us
return ms * 1000, min_ms * 1000, max_ms * 1000


SHAPES = [
(256, 7168, "DSV3 router"),
(256, 2048, "Small K"),
(128, 5120, "DeepSeek V2"),
(8, 4096, "Mixtral-8x7B"),
(64, 2880, "Non-aligned K"),
]

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--save-path", type=str, default=None)
parser.add_argument("--l2-pollute", action="store_true",
help="Prefix each kernel with a large matmul to pollute L2")
args = parser.parse_args()

print(f"Device: {torch.cuda.get_device_name()}")

q_l2 = [0.5, 0.2, 0.8]
if args.l2_pollute:
_wp = torch.randn(2048, 7168, dtype=torch.bfloat16, device="cuda")
_ap = torch.randn(1, 7168, dtype=torch.bfloat16, device="cuda")
_op = torch.empty(1, 2048, dtype=torch.bfloat16, device="cuda")
_l2_tp = _do_bench_cg(lambda: torch.mm(_ap, _wp.T, out=_op), rep=200, quantiles=q_l2)[0]
print(f"L2 pollution: ON (prefix overhead: {_l2_tp*1000:.1f}us)")
else:
_l2_tp = 0
print("L2 pollution: OFF")
print()

for N, K, desc in SHAPES:
print(f"{desc}, N={N} K={K}:")
save_dir = args.save_path or f"bench_ll_router_n{N}_k{K}"
os.makedirs(save_dir, exist_ok=True)
benchmark.run(
print_data=True,
show_plots=False,
save_path=save_dir,
N=N,
K=K,
)
print()
Loading