Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions benchmarks/bench_router_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import torch

from flashinfer.testing.utils import bench_gpu_time_with_cudagraph
from flashinfer.dsv3_ops import mm_M1_16_K7168_N128, mm_M1_16_K7168_N256


@torch.compile
def reference_torch(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
):
return torch.nn.functional.linear(x, weight, bias)


def get_data_torch(num_tokens, num_experts, hidden_dim):
mat_a = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16)
mat_b = torch.randn(num_experts, hidden_dim, device="cuda", dtype=torch.bfloat16)
return mat_a, mat_b


def get_data_flashinfer(num_tokens, num_experts, hidden_dim, output_dtype):
mat_a = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16)
mat_b = torch.randn(
num_experts, hidden_dim, device="cuda", dtype=torch.bfloat16
).t()
out = torch.empty(num_tokens, num_experts, device="cuda", dtype=output_dtype)
return mat_a, mat_b, out


def bench_router_gemm(gemm_fn, data, M, N, K, reps=1000, warmup_reps=1000):
measurements = bench_gpu_time_with_cudagraph(
lambda: gemm_fn(*data),
dry_run_time_ms=warmup_reps,
repeat_time_ms=reps,
)
ms = np.median(measurements)
flops = (2 * M * N * K) / ms / 1e9
add_desc = f" launch_with_pdl={data[3]}" if len(data) > 3 else ""
print(
f"Router GEMM function {gemm_fn} | num_tokens={M}, num_experts={N}{add_desc} | Median execution time: {1000 * ms:.3f} us | TFLOPs/s: {flops:.3f}"
)


def main():
hidden_dim = 7168
for num_tokens in [1, 2, 4, 8, 16]:
for num_experts, output_dtype, flashinfer_fn in [
(128, torch.bfloat16, mm_M1_16_K7168_N128),
(256, torch.float32, mm_M1_16_K7168_N256),
]:
data_torch = get_data_torch(
num_tokens=num_tokens, hidden_dim=hidden_dim, num_experts=num_experts
)
bench_router_gemm(
reference_torch, data_torch, num_tokens, num_experts, hidden_dim
)

data_flashinfer = get_data_flashinfer(
num_tokens=num_tokens,
hidden_dim=hidden_dim,
num_experts=num_experts,
output_dtype=output_dtype,
)
for launch_with_pdl in [False, True]:
bench_router_gemm(
flashinfer_fn,
(*data_flashinfer, launch_with_pdl),
num_tokens,
num_experts,
hidden_dim,
)

print()


if __name__ == "__main__":
main()
113 changes: 33 additions & 80 deletions csrc/dsv3_router_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
#include "tvm_ffi_utils.h"

namespace flashinfer::trtllm_dsv3_router_gemm {
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream,

// Note: Explicit template instantiations are not needed here because
// LoopUnroller already forces instantiation of all required specializations.
template <typename Tin, typename Tout, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemm(Tout* output, Tin const* mat_a, Tin const* mat_b, cudaStream_t stream,
bool use_pdl = false) {
constexpr int VPT = 16 / sizeof(T);
constexpr int VPT = 16 / sizeof(Tin);
constexpr int kBlockSize = 128;
cudaLaunchConfig_t config;
config.gridDim = kNumExperts;
Expand All @@ -18,83 +21,20 @@ void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_
config.numAttrs = 1;
config.attrs = attrs;
auto status = cudaLaunchKernelEx(
&config, router_gemm_kernel<T, kBlockSize, VPT, kNumTokens, kNumExperts, kHiddenDim>, output,
mat_a, mat_b);
&config, router_gemm_kernel<Tin, Tout, kBlockSize, VPT, kNumTokens, kNumExperts, kHiddenDim>,
output, mat_a, mat_b);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "cudaLaunchKernelEx failed with error code " << cudaGetErrorString(status);
}

template void invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template <int kBegin, int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller {
static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input,
template <typename Tout>
static void unroll(int num_tokens, Tout* output, __nv_bfloat16 const* input,
__nv_bfloat16 const* weights, cudaStream_t stream, bool launch_with_pdl) {
if (num_tokens == kBegin) {
invokeRouterGemm<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights,
stream, launch_with_pdl);
invokeRouterGemm<__nv_bfloat16, Tout, kBegin, kNumExperts, kHiddenDim>(
output, input, weights, stream, launch_with_pdl);
} else {
LoopUnroller<kBegin + 1, kEnd, kNumExperts, kHiddenDim>::unroll(
num_tokens, output, input, weights, stream, launch_with_pdl);
Expand All @@ -104,24 +44,26 @@ struct LoopUnroller {

template <int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller<kEnd, kEnd, kNumExperts, kHiddenDim> {
static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input,
template <typename Tout>
static void unroll(int num_tokens, Tout* output, __nv_bfloat16 const* input,
__nv_bfloat16 const* weights, cudaStream_t stream, bool launch_with_pdl) {
if (num_tokens == kEnd) {
invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream,
launch_with_pdl);
invokeRouterGemm<__nv_bfloat16, Tout, kEnd, kNumExperts, kHiddenDim>(output, input, weights,
stream, launch_with_pdl);
} else {
throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
}
}
};

void dsv3_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out, bool launch_with_pdl) {
template <typename Tout, int64_t tout_code, int kNumExperts, int kBegin, int kEnd>
void generic_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out,
bool launch_with_pdl) {
int const num_tokens = mat_a.sizes()[0];
int const num_experts = mat_b.sizes()[1];
int const hidden_dim = mat_a.sizes()[1];
auto const out_dtype_ = out.dtype();
auto const data_type = mat_a.dtype();
constexpr int kNumExperts = 256;
constexpr int kHiddenDim = 7168;
std::vector<int64_t> output_size = {mat_a.sizes()[0], mat_b.sizes()[1]};
TVM_FFI_ICHECK(mat_a.dim() == 2 && mat_b.dim() == 2) << "mat_a and mat_b must be 2D tensors";
Expand All @@ -132,21 +74,32 @@ void dsv3_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out, boo
bool use_custom_kernel = false;
if (num_tokens >= 1 && num_tokens <= 16 && num_experts == kNumExperts &&
hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code &&
encode_dlpack_dtype(out_dtype_) == float32_code) {
encode_dlpack_dtype(out_dtype_) == tout_code) {
use_custom_kernel = true;
}

if (use_custom_kernel) {
LoopUnroller<1, 16, kNumExperts, kHiddenDim>::unroll(
num_tokens, reinterpret_cast<float*>(out.data_ptr()),
LoopUnroller<kBegin, kEnd, kNumExperts, kHiddenDim>::unroll(
num_tokens, reinterpret_cast<Tout*>(out.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream, launch_with_pdl);
} else {
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input tensor size";
}
}

void dsv3_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out, bool launch_with_pdl) {
generic_router_gemm_op<float, float32_code, 256, 1, 16>(mat_a, mat_b, out, launch_with_pdl);
}

void ml3_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out, bool launch_with_pdl) {
generic_router_gemm_op<__nv_bfloat16, bfloat16_code, 128, 1, 16>(mat_a, mat_b, out,
launch_with_pdl);
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(dsv3_router_gemm_op,
flashinfer::trtllm_dsv3_router_gemm::dsv3_router_gemm_op);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(ml3_router_gemm_op,
flashinfer::trtllm_dsv3_router_gemm::ml3_router_gemm_op);

} // namespace flashinfer::trtllm_dsv3_router_gemm
3 changes: 2 additions & 1 deletion flashinfer/dsv3_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from flashinfer.gemm import mm_M1_16_K7168_N256
from flashinfer.gemm import mm_M1_16_K7168_N128, mm_M1_16_K7168_N256
from flashinfer.fused_moe import fused_topk_deepseek
from flashinfer.concat_ops import concat_mla_k

__all__ = [
"mm_M1_16_K7168_N128",
"mm_M1_16_K7168_N256",
"fused_topk_deepseek",
"concat_mla_k",
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .gemm_base import fp8_blockscale_gemm_sm90 as fp8_blockscale_gemm_sm90

from .routergemm_dsv3 import (
mm_M1_16_K7168_N128 as mm_M1_16_K7168_N128,
mm_M1_16_K7168_N256 as mm_M1_16_K7168_N256,
)

Expand All @@ -38,5 +39,6 @@
"gemm_fp8_nt_groupwise",
"group_gemm_fp8_nt_groupwise",
"fp8_blockscale_gemm_sm90",
"mm_M1_16_K7168_N128",
"mm_M1_16_K7168_N256",
]
Loading