Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ set(SOURCES
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
"csrc/gemm/dsv3_router_gemm.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
Expand Down
56 changes: 56 additions & 0 deletions sgl-kernel/benchmark/bench_dsv3_router_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import argparse

import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import dsv3_router_gemm


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_log=False,
line_arg="impl",
line_vals=["torch", "sgl-kernel"],
line_names=["torch", "dsv3_router_gemm"],
styles=[("blue", "-"), ("orange", "-")],
ylabel="TFLOPs",
plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
args={},
)
)
def benchmark(num_tokens, impl):
# M: num_tokens, K: hidden_dim, N: num_experts
M, K, N = num_tokens, 7168, 256

mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()

quantiles = [0.5, 0.2, 0.8]

if impl == "torch":

def runner():
F.linear(mat_a, mat_b).to(torch.float32)

elif impl == "sgl-kernel":

def runner():
dsv3_router_gemm(mat_a, mat_b)

ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)

def tflops(t_ms):
flops = 2 * M * K * N
return flops / (t_ms * 1e-3) / 1e12

return tflops(ms), tflops(max_ms), tflops(min_ms)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()

benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm")
3 changes: 3 additions & 0 deletions sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor expert_offsets, Tensor sf_offsets) -> ()");
m.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm);

m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);

/*
* From csrc/moe
*/
Expand Down
286 changes: 286 additions & 0 deletions sgl-kernel/csrc/gemm/dsv3_router_gemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include "cuda_bf16.h"
#include "cuda_runtime.h"
#include "utils.h"

// Custom FMA implementation using PTX assembly instructions
__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c) {
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n"
: "=l"(reinterpret_cast<uint64_t&>(d))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b)),
"l"(reinterpret_cast<uint64_t const&>(c)));
}

// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion
template <int VPT>
__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) {
__nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast<uint4*>(&vec));

#pragma unroll
for (int i = 0; i < VPT; i++) {
dst[i] = __bfloat162float(bf16_ptr[i]);
}
}

template <typename T, int kBlockSize, int VPT, int kNumTokens, int kNumExperts, int kHiddenDim>
__global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const* mat_a, T const* mat_b) {
// Each block handles one expert column
int const n_idx = blockIdx.x;
int const tid = threadIdx.x;
constexpr int kWarpSize = 32;
constexpr int kNumWarps = kBlockSize / kWarpSize;
// Constants for this kernel
constexpr int k_elems_per_k_iteration = VPT * kBlockSize;
constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration; // Total K iterations

// Initialize accumulators for all M rows
float acc[kNumTokens] = {};

// Shared memory for warp-level reduction
__shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps

// B matrix is in column-major order, so we can directly load a column for the n_idx expert
T const* b_col = mat_b + n_idx * kHiddenDim;

// Pre-compute k_base values for each iteration to help compiler optimize
// int k_bases[k_iterations];
int k_bases[k_iterations];
#pragma unroll
for (int ki = 0; ki < k_iterations; ki++) {
k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT;
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif

// Process the GEMM in chunks
for (int ki = 0; ki < k_iterations; ki++) {
int const k_base = k_bases[ki];

// Load B matrix values using vector load (8 bf16 values)
uint4 b_vec = *reinterpret_cast<uint4 const*>(b_col + k_base);

// Convert B values to float
float b_float[VPT];
bf16_uint4_to_float8<VPT>(b_vec, b_float);

// Process each token
#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
// Load both rows of A matrix using vector loads
uint4 a_vec = *reinterpret_cast<uint4 const*>(mat_a + (m_idx * kHiddenDim) + k_base);

// Convert A values to float
float a_float[VPT];
bf16_uint4_to_float8<VPT>(a_vec, a_float);

// Process elements in this chunk
#pragma unroll
for (int k = 0; k < VPT; k++) {
float a = a_float[k];
float b = b_float[k];
acc[m_idx] += a * b;
}
}
}

// Perform warp-level reduction
int const warpSize = 32;
int const warpId = tid / warpSize;
int const laneId = tid % warpSize;

// Register for warp-level reduction results
float warp_result[kNumTokens];

#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
warp_result[m_idx] = acc[m_idx];
}

// Perform warp-level reduction using optimized butterfly pattern
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float sum = warp_result[m];

// Butterfly reduction pattern
sum += __shfl_xor_sync(0xffffffff, sum, 16);
sum += __shfl_xor_sync(0xffffffff, sum, 8);
sum += __shfl_xor_sync(0xffffffff, sum, 4);
sum += __shfl_xor_sync(0xffffffff, sum, 2);
sum += __shfl_xor_sync(0xffffffff, sum, 1);

// Only the first thread in each warp stores to shared memory
if (laneId == 0) {
sm_reduction[m][warpId] = sum;
}
}

__syncthreads();

// Final reduction across warps (only first thread)
if (tid == 0) {
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float final_sum = 0.0f;

// Sum across the kNumWarps
#pragma unroll
for (int w = 0; w < kNumWarps; w++) {
final_sum += sm_reduction[m][w];
}

// Write final result
out[m * kNumExperts + n_idx] = final_sum;
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}

template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream) {
constexpr int VPT = 16 / sizeof(T);
constexpr int kBlockSize = 128;
cudaLaunchConfig_t config;
config.gridDim = kNumExperts;
config.blockDim = kBlockSize;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(
&config, router_gemm_kernel<T, kBlockSize, VPT, kNumTokens, kNumExperts, kHiddenDim>, output, mat_a, mat_b);
}

template void
invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void
invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template <int kBegin, int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller {
static void
unroll(int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) {
if (num_tokens == kBegin) {
invokeRouterGemm<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, stream);
} else {
LoopUnroller<kBegin + 1, kEnd, kNumExperts, kHiddenDim>::unroll(num_tokens, output, input, weights, stream);
}
}
};

template <int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller<kEnd, kEnd, kNumExperts, kHiddenDim> {
static void
unroll(int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) {
if (num_tokens == kEnd) {
invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream);
} else {
throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
}
}
};

void dsv3_router_gemm(
torch::Tensor& output, // [num_tokens, num_experts]
const torch::Tensor& mat_a, // [num_tokens, hidden_dim]
const torch::Tensor& mat_b // [num_experts, hidden_dim]
) {
TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2);

const int num_tokens = mat_a.size(0);
constexpr int num_experts = 256;
constexpr int hidden_dim = 7168;

TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim");
TORCH_CHECK(mat_a.size(1) == hidden_dim, "currently hidden_dim only supports 7168");
TORCH_CHECK(mat_b.size(0) == num_experts, "currently num_experts only supports 256");
TORCH_CHECK(
num_tokens >= 1 && num_tokens <= 16, "currently num_tokens must be less than or equal to 16 for router_gemm");
TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16");
TORCH_CHECK(mat_b.dtype() == torch::kBFloat16, "mat_b must be bf16");
TORCH_CHECK(output.dtype() == torch::kFloat32, "output must be float32");

auto const sm = getSMVersion();
TORCH_CHECK(sm >= 90, "required CUDA ARCH >= SM_90");

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

LoopUnroller<1, 16, num_experts, hidden_dim>::unroll(
num_tokens,
reinterpret_cast<float*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()),
stream);
}
1 change: 1 addition & 0 deletions sgl-kernel/include/sgl_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ void bmm_fp8(
at::Tensor workspace_buffer,
int64_t cublas_handle,
int64_t cuda_stream);
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b);

void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);

Expand Down
1 change: 1 addition & 0 deletions sgl-kernel/python/sgl_kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
bmm_fp8,
cutlass_scaled_fp4_mm,
dsv3_fused_a_gemm,
dsv3_router_gemm,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
int8_scaled_mm,
Expand Down
Loading
Loading