From c499c9836cb9c54b2ecbad0281edee1b077138d9 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Mon, 16 Mar 2026 09:10:39 -0700 Subject: [PATCH 01/11] [Kernel] Add gpt-oss Router GEMM kernel Signed-off-by: Xin Yang --- CMakeLists.txt | 1 + benchmarks/kernels/benchmark_router_gemm.py | 68 +++ csrc/moe/moe_ops.h | 4 + csrc/moe/tinygemm2.cu | 137 ++++++ csrc/moe/tinygemm2.cuh | 447 ++++++++++++++++++++ csrc/moe/torch_bindings.cpp | 6 + tests/kernels/moe/test_router_gemm.py | 33 ++ vllm/_custom_ops.py | 13 + vllm/model_executor/models/gpt_oss.py | 76 +++- 9 files changed, 783 insertions(+), 2 deletions(-) create mode 100644 benchmarks/kernels/benchmark_router_gemm.py create mode 100644 csrc/moe/tinygemm2.cu create mode 100644 csrc/moe/tinygemm2.cuh create mode 100644 tests/kernels/moe/test_router_gemm.py diff --git a/CMakeLists.txt b/CMakeLists.txt index bbadfdc5e9e3..48467c268b6c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -999,6 +999,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu" "csrc/moe/grouped_topk_kernels.cu" + "csrc/moe/tinygemm2.cu" "csrc/moe/router_gemm.cu") endif() diff --git a/benchmarks/kernels/benchmark_router_gemm.py b/benchmarks/kernels/benchmark_router_gemm.py new file mode 100644 index 000000000000..64bde1e14fc8 --- /dev/null +++ b/benchmarks/kernels/benchmark_router_gemm.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse + +import torch +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.triton_utils import triton + +num_tokens_range = [2**x for x in range(14)] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens"], + x_vals=num_tokens_range, + x_log=False, + line_arg="impl", + line_vals=["torch-32", "tinygemm2-32", "torch-128", "tinygemm2-128"], + line_names=(["torch-32", "tinygemm2-32", "torch-128", "tinygemm2-128"]), + styles=([("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]), + ylabel="TFLOPs", + plot_name="router gemm throughput", + args={}, + ) +) +def benchmark(num_tokens, impl): + # M: num_tokens, K: hidden_dim, N: num_experts + M, K = num_tokens, 2880 + + if impl == "torch-32" or impl == "tinygemm2-32": + N = 32 + elif impl == "torch-128" or impl == "tinygemm2-128": + N = 128 + else: + raise ValueError(f"Unknown impl: {impl}") + + mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous() + mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous() + bias = torch.randn(N, dtype=torch.bfloat16, device="cuda").contiguous() + + quantiles = [0.5, 0.2, 0.8] + + if impl == "torch-32" or impl == "torch-128": + + def runner(): + F.linear(mat_a, mat_b, bias) + elif impl == "tinygemm2-32" or impl == "tinygemm2-128": + + def runner(): + ops.tinygemm2(mat_a, mat_b, bias) + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles) + + def tflops(t_ms): + flops = 2 * M * K * N + return flops / (t_ms * 1e-3) / 1e12 + + return tflops(ms), tflops(max_ms), tflops(min_ms) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + args = parser.parse_args() + + benchmark.run(print_data=True) diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index d8d962887dab..9d47fe12f68b 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -70,4 +70,8 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input, // Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168 void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); + +// gpt-oss optimized router GEMM kernel for SM90+ +void tinygemm2(torch::Tensor& output, torch::Tensor input, torch::Tensor weight, + torch::Tensor bias); #endif diff --git a/csrc/moe/tinygemm2.cu b/csrc/moe/tinygemm2.cu new file mode 100644 index 000000000000..617e8bc87e69 --- /dev/null +++ b/csrc/moe/tinygemm2.cu @@ -0,0 +1,137 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "tinygemm2.cuh" + +void launch_tinygemm2(__nv_bfloat16* gA, __nv_bfloat16* gB, __nv_bfloat16* gC, + __nv_bfloat16* bias, int batch_size, int output_features, + int input_features, cudaStream_t stream) { + static int const WARP_TILE_M = 16; + static int const TILE_M = WARP_TILE_M; + static int const TILE_N = 8; + static int const TILE_K = 64; + static int const STAGES = 16; + static int const STAGE_UNROLL = 4; + static bool const PROFILE = false; + + CUtensorMap weight_map{}; + CUtensorMap activation_map{}; + + constexpr uint32_t rank = 2; + uint64_t size[rank] = {(uint64_t)input_features, (uint64_t)output_features}; + uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)}; + uint32_t box_size[rank] = {TILE_K, TILE_M}; + uint32_t elem_stride[rank] = {1, 1}; + + CUresult res = cuTensorMapEncodeTiled( + &weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank, + gB, size, stride, box_size, elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + assert(res == 0); + + size[1] = batch_size; + box_size[1] = TILE_N; + + res = cuTensorMapEncodeTiled( + &activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + rank, gA, size, stride, box_size, elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + assert(res == 0); + + int smem_size = STAGES * STAGE_UNROLL * + (TILE_M * TILE_K * sizeof(__nv_bfloat16) + + TILE_N * TILE_K * sizeof(__nv_bfloat16)); + + gpuErrChk(cudaFuncSetAttribute( + tinygemm_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + int tiles_m = (output_features + TILE_M - 1) / TILE_M; + int tiles_n = (batch_size + TILE_N - 1) / TILE_N; + + dim3 grid(tiles_m, tiles_n); + dim3 block(384); + + cudaLaunchConfig_t config; + cudaLaunchAttribute attrs[1]; + config.gridDim = grid; + config.blockDim = block; + config.dynamicSmemBytes = smem_size; + config.stream = stream; + config.attrs = attrs; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = 1; + config.numAttrs = 1; + + cudaLaunchKernelEx(&config, + &tinygemm_kernel, + gC, gA, gB, bias, output_features, batch_size, + input_features, weight_map, activation_map, nullptr); +} + +void tinygemm2_cuda_forward(torch::Tensor& output, torch::Tensor input, + torch::Tensor weight, torch::Tensor bias) { + auto const batch_size = input.size(0); + auto const input_dim = input.size(1); + auto const output_dim = weight.size(0); + + auto stream = at::cuda::getCurrentCUDAStream(); + + if (input.scalar_type() == at::ScalarType::BFloat16) { + launch_tinygemm2((__nv_bfloat16*)input.data_ptr(), + (__nv_bfloat16*)weight.data_ptr(), + (__nv_bfloat16*)output.mutable_data_ptr(), + (__nv_bfloat16*)bias.data_ptr(), batch_size, output_dim, + input_dim, stream); + } else { + assert(false); + } +} + +void tinygemm2(torch::Tensor& output, torch::Tensor input, torch::Tensor weight, + torch::Tensor bias) { + TORCH_CHECK(input.dim() == 2, "input must be 2D"); + TORCH_CHECK(weight.dim() == 2, "weight must be 2D"); + TORCH_CHECK(bias.dim() == 1, "bias must be 1D"); + TORCH_CHECK(input.sizes()[1] == weight.sizes()[1], + "input.size(1) must match weight.size(1)"); + TORCH_CHECK(weight.sizes()[0] == bias.sizes()[0], + "weight.size(0) must match bias.size(0)"); + TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16, + "input tensor must be bfloat16"); + TORCH_CHECK(weight.scalar_type() == at::ScalarType::BFloat16, + "weight tensor must be bfloat16"); + TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16, + "bias tensor must be bfloat16"); + tinygemm2_cuda_forward(output, input, weight, bias); +} diff --git a/csrc/moe/tinygemm2.cuh b/csrc/moe/tinygemm2.cuh new file mode 100644 index 000000000000..202430bc61a2 --- /dev/null +++ b/csrc/moe/tinygemm2.cuh @@ -0,0 +1,447 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cuda_bf16.h" +#include +#include +#include + +#include "cuda_pipeline.h" +#include +#include +#include +#include + +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; +namespace ptx = cuda::ptx; + +#define gpuErrChk(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + } + +inline void gpuAssert(cudaError_t code, char const* file, int line, + bool abort = true) { + if (code != cudaSuccess) { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, + line); + if (abort) { + throw std::runtime_error(cudaGetErrorString(code)); + } + } +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +__device__ uint64_t gclock64() { + unsigned long long int rv; + asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(rv)); + return rv; +} + +__device__ void ldmatrix(__nv_bfloat16 rv[2], uint32_t smem_ptr) { + int dst; + asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(smem_ptr)); + int* rvi = reinterpret_cast(&rv[0]); + rvi[0] = dst; +} + +__device__ void ldmatrix2(__nv_bfloat16 rv[4], uint32_t smem_ptr) { + int x, y; + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(x), "=r"(y) + : "r"(smem_ptr)); + + int* rvi = reinterpret_cast(&rv[0]); + rvi[0] = x; + rvi[1] = y; +} + +__device__ void ldmatrix4(__nv_bfloat16 rv[8], uint32_t smem_ptr) { + int x, y, z, w; + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(smem_ptr)); + int* rvi = reinterpret_cast(&rv[0]); + rvi[0] = x; + rvi[1] = y; + rvi[2] = z; + rvi[3] = w; +} + +__device__ void HMMA_1688(float d[4], __nv_bfloat16 a[4], __nv_bfloat16 b[2], + float c[4]) { + uint32_t const* A = reinterpret_cast(&a[0]); + uint32_t const* B = reinterpret_cast(&b[0]); + float const* C = reinterpret_cast(&c[0]); + float* D = reinterpret_cast(&d[0]); + + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), + "f"(C[3])); +} + +__device__ void HMMA_16816(float d[4], __nv_bfloat16 a[8], __nv_bfloat16 b[4], + float c[4]) { + uint32_t const* A = reinterpret_cast(&a[0]); + uint32_t const* B = reinterpret_cast(&b[0]); + float const* C = reinterpret_cast(&c[0]); + float* D = reinterpret_cast(&d[0]); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); +} + +__device__ void bar_wait(uint32_t bar_ptr, int phase) { + asm volatile( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" ::"r"(bar_ptr), + "r"(phase)); +} + +__device__ bool bar_try_wait(uint32_t bar_ptr, int phase) { + uint32_t success; + #ifdef INTERNAL + asm volatile(".pragma \"set knob DontInsertYield\";\n" : : : "memory"); + #endif + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(success) + : "r"(bar_ptr), "r"(phase)); + return success; +} + +__device__ uint32_t elect_one_sync() { + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +} +#endif + +struct Profile { + uint64_t start; + uint64_t weight_load_start; + uint64_t act_load_start; + uint64_t compute_start; + uint64_t complete; +}; + +template +__global__ __launch_bounds__(384, 1) void tinygemm_kernel( + __nv_bfloat16* output, __nv_bfloat16* weights, __nv_bfloat16* activations, + __nv_bfloat16* bias, int M, int N, int K, + const __grid_constant__ CUtensorMap weight_map, + const __grid_constant__ CUtensorMap activation_map, + Profile* profile = nullptr) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + + if (PROFILE && threadIdx.x == 0 && blockIdx.y == 0) + profile[blockIdx.x].start = gclock64(); + + extern __shared__ __align__(128) char smem[]; + + __nv_bfloat16* sh_weights = (__nv_bfloat16*)&smem[0]; + __nv_bfloat16* sh_activations = + (__nv_bfloat16*)&smem[STAGES * STAGE_UNROLL * TILE_M * TILE_K * + sizeof(__nv_bfloat16)]; + + #pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ barrier bar_wt_ready[STAGES]; + __shared__ barrier bar_act_ready[STAGES]; + __shared__ barrier bar_data_consumed[STAGES]; + + __shared__ float4 reduction_buffer[128]; + + __shared__ nv_bfloat16 sh_bias[TILE_M]; + + if (threadIdx.x == 0) { + for (int i = 0; i < STAGES; i++) { + init(&bar_wt_ready[i], 1); + init(&bar_act_ready[i], 1); + init(&bar_data_consumed[i], 32); + } + ptx::fence_proxy_async(ptx::space_shared); + asm volatile("prefetch.tensormap [%0];" + : + : "l"(reinterpret_cast(&weight_map)) + : "memory"); + asm volatile("prefetch.tensormap [%0];" + : + : "l"(reinterpret_cast(&activation_map)) + : "memory"); + } + __syncthreads(); + + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + + int phase = 0; + + int mib = blockIdx.x * TILE_M; + int ni = blockIdx.y * TILE_N; + + float accum[4]; + for (int i = 0; i < 4; i++) accum[i] = 0.f; + + int const K_LOOPS_DMA = + (K + 4 * TILE_K * STAGE_UNROLL - 1) / (4 * (TILE_K * STAGE_UNROLL)); + int const K_LOOPS_COMPUTE = K_LOOPS_DMA; + + // Data loading thread + if (warp_id >= 4 && elect_one_sync()) { + int stage = warp_id % 4; + + bool weight_warp = warp_id < 8; + if (!weight_warp) { + cudaGridDependencySynchronize(); + cudaTriggerProgrammaticLaunchCompletion(); + } + + for (int ki = 0; ki < K_LOOPS_DMA; ki++) { + int k = (ki * 4 + (warp_id % 4)) * TILE_K * STAGE_UNROLL; + + uint64_t desc_ptr_wt = reinterpret_cast(&weight_map); + uint64_t desc_ptr_act = reinterpret_cast(&activation_map); + + uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]); + uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]); + int bytes_wt = TILE_M * TILE_K * sizeof(__nv_bfloat16); + int bytes_act = TILE_N * TILE_K * sizeof(__nv_bfloat16); + + bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1); + + if (weight_warp) + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + : + : "r"(bar_ptr_wt), "r"(STAGE_UNROLL * bytes_wt)); + if (!weight_warp) + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + : + : "r"(bar_ptr_act), "r"(STAGE_UNROLL * bytes_act)); + + if (PROFILE && blockIdx.y == 0 && ki == 0 && weight_warp) + profile[blockIdx.x].weight_load_start = gclock64(); + if (PROFILE && blockIdx.y == 0 && ki == 0 && !weight_warp) + profile[blockIdx.x].act_load_start = gclock64(); + + for (int i = 0; i < STAGE_UNROLL; i++) { + uint32_t smem_ptr_wt = __cvta_generic_to_shared( + &sh_weights[(stage * STAGE_UNROLL + i) * TILE_M * TILE_K]); + uint32_t crd0 = k + i * TILE_K; + uint32_t crd1 = mib; + if (weight_warp) + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_" + "tx::bytes [%0], [%1, {%3,%4}], " + "[%2];" + : + : "r"(smem_ptr_wt), "l"(desc_ptr_wt), "r"(bar_ptr_wt), "r"(crd0), + "r"(crd1) + : "memory"); + + uint32_t smem_ptr_act = __cvta_generic_to_shared( + &sh_activations[(stage * STAGE_UNROLL + i) * TILE_N * TILE_K]); + crd0 = k + i * TILE_K; + crd1 = ni; + if (!weight_warp) + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_" + "tx::bytes [%0], [%1, {%3,%4}], " + "[%2];" + : + : "r"(smem_ptr_act), "l"(desc_ptr_act), "r"(bar_ptr_act), + "r"(crd0), "r"(crd1) + : "memory"); + } + + stage += 4; + if (stage >= STAGES) { + stage = warp_id % 4; + phase ^= 1; + } + } + // Wait for pending loads to be consumed before exiting, to avoid race + for (int i = 0; i < (STAGES / 4) - 1; i++) { + bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1); + stage += 4; + if (stage >= STAGES) { + stage = warp_id % 4; + phase ^= 1; + } + } + } + // Compute threads + else if (warp_id < 4) { + // Sneak the bias load into the compute warps since they're just waiting for + // stuff anyway + if (threadIdx.x < TILE_M) sh_bias[threadIdx.x] = bias[mib + threadIdx.x]; + + int stage = warp_id; + + int phase = 0; + int lane_id_div8 = lane_id / 8; + int lane_id_mod8 = lane_id % 8; + + int lane_row_offset_wt = (lane_id_div8 % 2) ? 8 : 0; + int lane_col_offset_wt = (lane_id_div8 / 2) ? 1 : 0; + + int row_wt = lane_id_mod8 + lane_row_offset_wt; + int row_act = lane_id_mod8; + + int row_offset_wt = (reinterpret_cast(sh_weights) / 128) % 8; + int row_offset_act = row_offset_wt; + + uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]); + uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]); + + bool weight_ready = bar_try_wait(bar_ptr_wt, phase); + bool act_ready = bar_try_wait(bar_ptr_act, phase); + + #pragma unroll 2 + for (int ki = 0; ki < K_LOOPS_COMPUTE; ki++) { + int next_stage = stage + 4; + int next_phase = phase; + if (next_stage >= STAGES) { + next_stage = warp_id; + next_phase ^= 1; + } + + while (!weight_ready || !act_ready) { + weight_ready = bar_try_wait(bar_ptr_wt, phase); + act_ready = bar_try_wait(bar_ptr_act, phase); + } + + if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0 && ki == 0) + profile[blockIdx.x].compute_start = gclock64(); + + if (ki + 1 < K_LOOPS_COMPUTE) { + weight_ready = bar_try_wait( + __cvta_generic_to_shared(&bar_wt_ready[next_stage]), next_phase); + act_ready = bar_try_wait( + __cvta_generic_to_shared(&bar_act_ready[next_stage]), next_phase); + } + + #pragma unroll + for (int su = 0; su < STAGE_UNROLL; su++) { + __nv_bfloat16* ptr_weights = + &sh_weights[(stage * STAGE_UNROLL + su) * TILE_M * TILE_K]; + __nv_bfloat16* ptr_act = + &sh_activations[(stage * STAGE_UNROLL + su) * TILE_N * TILE_K]; + + #pragma unroll + for (int kii = 0; kii < TILE_K / 16; kii++) { + __nv_bfloat16 a[8]; + __nv_bfloat16 b[4]; + + int col = 2 * kii + lane_col_offset_wt; + int col_sw = ((row_wt + row_offset_wt) % 8) ^ col; + + ldmatrix4(a, __cvta_generic_to_shared( + &ptr_weights[row_wt * TILE_K + col_sw * 8])); + + col = 2 * kii + lane_id_div8; + col_sw = ((row_act + row_offset_act) % 8) ^ col; + + ldmatrix2(b, __cvta_generic_to_shared( + &ptr_act[row_act * TILE_K + 8 * col_sw])); + + HMMA_16816(accum, a, b, accum); + } + } + + uint32_t bar_c = __cvta_generic_to_shared(&bar_data_consumed[stage]); + asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0];" : : "r"(bar_c)); + + stage = next_stage; + phase = next_phase; + } + + float4 accum4; + accum4.x = accum[0]; + accum4.y = accum[1]; + accum4.z = accum[2]; + accum4.w = accum[3]; + reduction_buffer[threadIdx.x] = accum4; + + __syncthreads(); + + if (warp_id == 0) { + int mi = mib + warp_id * WARP_TILE_M; + int tm = mi + lane_id / 4; + int tn = ni + 2 * (lane_id % 4); + + float4 accum1 = reduction_buffer[32 + threadIdx.x]; + float4 accum2 = reduction_buffer[64 + threadIdx.x]; + float4 accum3 = reduction_buffer[96 + threadIdx.x]; + + accum[0] = accum[0] + accum1.x + accum2.x + accum3.x; + accum[1] = accum[1] + accum1.y + accum2.y + accum3.y; + accum[2] = accum[2] + accum1.z + accum2.z + accum3.z; + accum[3] = accum[3] + accum1.w + accum2.w + accum3.w; + + float bias_lo = __bfloat162float(sh_bias[tm - mib]); + float bias_hi = __bfloat162float(sh_bias[tm + 8 - mib]); + + if (tn < N && tm < M) + output[tn * M + tm] = __float2bfloat16(accum[0] + bias_lo); + if (tn + 1 < N && tm < M) + output[(tn + 1) * M + tm] = __float2bfloat16(accum[1] + bias_lo); + if (tn < N && tm + 8 < M) + output[tn * M + tm + 8] = __float2bfloat16(accum[2] + bias_hi); + if (tn + 1 < N && tm + 8 < M) + output[(tn + 1) * M + tm + 8] = __float2bfloat16(accum[3] + bias_hi); + + if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0) + profile[blockIdx.x].complete = gclock64(); + } + } +#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 7b627a6f8760..be3ff2d5c493 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -132,6 +132,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // DeepSeek V3 optimized router GEMM for SM90+ m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); // conditionally compiled so impl registration is in source file + + // gpt-oss optimized router GEMM kernel for SM90+ + m.def( + "tinygemm2(Tensor! output, Tensor input, Tensor weights, Tensor bias) -> " + "()"); + m.impl("tinygemm2", torch::kCUDA, &tinygemm2); #endif } diff --git a/tests/kernels/moe/test_router_gemm.py b/tests/kernels/moe/test_router_gemm.py new file mode 100644 index 000000000000..5dcdb5a62d44 --- /dev/null +++ b/tests/kernels/moe/test_router_gemm.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for optimized router GEMM kernel + +Run `pytest tests/kernels/moe/test_router_gemm.py`. +""" + +import pytest +import torch + +import vllm._custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_random_seed + + +@pytest.mark.skipif( + not (current_platform.is_cuda() and current_platform.has_device_capability(90)), + reason="This test is skipped on non-CUDA platform.", +) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("input_dim", [360, 720, 1440, 2880]) +@pytest.mark.parametrize("output_dim", [32, 64, 128]) +def test_tinygemm2(batch_size, input_dim, output_dim): + set_random_seed(0) + x = torch.randn(batch_size, input_dim, device="cuda", dtype=torch.bfloat16) + weight = torch.randn(output_dim, input_dim, device="cuda", dtype=torch.bfloat16) + bias = torch.randn(output_dim, device="cuda", dtype=torch.bfloat16) + + output = ops.tinygemm2(x, weight, bias) + output_ref = torch.nn.functional.linear(x, weight, bias) + + assert output.shape == (batch_size, output_dim) + assert torch.allclose(output, output_ref, rtol=1e-2, atol=1e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a01f44e1649d..7307e5f828a5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2362,6 +2362,19 @@ def dsv3_router_gemm( return output +def tinygemm2( + hidden_states: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + output = torch.empty( + hidden_states.shape[0], + weight.shape[0], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + torch.ops._moe_C.tinygemm2(output, hidden_states, weight, bias) + return output + + def topk_softmax( topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index c3111489c0ca..3a54da6f9de1 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -8,6 +8,7 @@ from torch import nn from transformers import GptOssConfig +import vllm._custom_ops as ops from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import ( @@ -45,6 +46,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backend import AttentionType from .interfaces import ( @@ -155,6 +157,77 @@ def forward( return output +def tinygemm2_impl( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + """ + Dynamically run min-latency gemm if num_tokens <= 128. + This must be wrapped in a custom op because our torch.compile integration + does not support runtime dispatching on num_tokens. + """ + if x.shape[0] <= 128: + return ops.tinygemm2(x, weight, bias) + else: + return torch.nn.functional.linear(x, weight, bias) + + +def tinygemm2_fake( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + return x.new_empty((x.shape[0], weight.shape[0])) + + +direct_register_custom_op( + op_name="tinygemm2", + op_func=tinygemm2_impl, + fake_impl=tinygemm2_fake, +) + + +class GptOssRouter(ReplicatedLinear): + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + return_bias: bool = True, + ): + assert quant_config is None + super().__init__( + input_size, + output_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.router", + return_bias=return_bias, + ) + + assert hasattr(self, "weight") + assert hasattr(self, "bias") + + # Check if tinygemm2 kernel can be used. + # This kernel supports PDL and is optimized for low batch size. + self.use_tinygemm = ( + self.weight.dtype == torch.bfloat16 + and current_platform.is_cuda() + and ( + current_platform.is_device_capability(90) + or current_platform.is_device_capability_family(100) + ) + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + if self.use_tinygemm: + return torch.ops.vllm.tinygemm2(x, self.weight, self.bias) + else: + return super().forward(x) + + class MLPBlock(torch.nn.Module): def __init__( self, @@ -175,7 +248,7 @@ def __init__( self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 - self.router = ReplicatedLinear( + self.router = GptOssRouter( config.hidden_size, config.num_local_experts, bias=True, @@ -273,7 +346,6 @@ def __init__( self.config = vllm_config.model_config.hf_config self.quant_config = vllm_config.quant_config self.parallel_config = vllm_config.parallel_config - self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, From 95ffaafad88bfce1b661c6fc5c60a36d354314eb Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Mon, 16 Mar 2026 13:51:39 -0700 Subject: [PATCH 02/11] Review changes Signed-off-by: Xin Yang --- csrc/moe/tinygemm2.cu | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/csrc/moe/tinygemm2.cu b/csrc/moe/tinygemm2.cu index 617e8bc87e69..89b13ac11a43 100644 --- a/csrc/moe/tinygemm2.cu +++ b/csrc/moe/tinygemm2.cu @@ -52,7 +52,9 @@ void launch_tinygemm2(__nv_bfloat16* gA, __nv_bfloat16* gB, __nv_bfloat16* gC, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - assert(res == 0); + TORCH_CHECK(res == CUDA_SUCCESS, + "cuTensorMapEncodeTiled failed for weight_map, error code=", + static_cast(res)); size[1] = batch_size; box_size[1] = TILE_N; @@ -64,7 +66,9 @@ void launch_tinygemm2(__nv_bfloat16* gA, __nv_bfloat16* gB, __nv_bfloat16* gC, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - assert(res == 0); + TORCH_CHECK(res == CUDA_SUCCESS, + "cuTensorMapEncodeTiled failed for activation_map, error code=", + static_cast(res)); int smem_size = STAGES * STAGE_UNROLL * (TILE_M * TILE_K * sizeof(__nv_bfloat16) + @@ -114,7 +118,7 @@ void tinygemm2_cuda_forward(torch::Tensor& output, torch::Tensor input, (__nv_bfloat16*)bias.data_ptr(), batch_size, output_dim, input_dim, stream); } else { - assert(false); + throw std::invalid_argument("Unsupported dtype, only supports bfloat16"); } } From 4d1bd579ea62515858fc2cab9dc61c80bedbbbe5 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Mon, 16 Mar 2026 13:57:28 -0700 Subject: [PATCH 03/11] Review changes Signed-off-by: Xin Yang --- CMakeLists.txt | 2 +- benchmarks/kernels/benchmark_router_gemm.py | 24 +++++++--- .../{tinygemm2.cu => gpt_oss_router_gemm.cu} | 45 ++++++++++--------- ...{tinygemm2.cuh => gpt_oss_router_gemm.cuh} | 2 +- csrc/moe/moe_ops.h | 4 +- csrc/moe/torch_bindings.cpp | 6 +-- tests/kernels/moe/test_router_gemm.py | 4 +- vllm/_custom_ops.py | 4 +- vllm/model_executor/models/gpt_oss.py | 20 ++++----- 9 files changed, 63 insertions(+), 48 deletions(-) rename csrc/moe/{tinygemm2.cu => gpt_oss_router_gemm.cu} (75%) rename csrc/moe/{tinygemm2.cuh => gpt_oss_router_gemm.cuh} (99%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 48467c268b6c..693070b5f476 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -999,7 +999,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu" "csrc/moe/grouped_topk_kernels.cu" - "csrc/moe/tinygemm2.cu" + "csrc/moe/gpt_oss_router_gemm.cu" "csrc/moe/router_gemm.cu") endif() diff --git a/benchmarks/kernels/benchmark_router_gemm.py b/benchmarks/kernels/benchmark_router_gemm.py index 64bde1e14fc8..137eca369fb7 100644 --- a/benchmarks/kernels/benchmark_router_gemm.py +++ b/benchmarks/kernels/benchmark_router_gemm.py @@ -18,8 +18,20 @@ x_vals=num_tokens_range, x_log=False, line_arg="impl", - line_vals=["torch-32", "tinygemm2-32", "torch-128", "tinygemm2-128"], - line_names=(["torch-32", "tinygemm2-32", "torch-128", "tinygemm2-128"]), + line_vals=[ + "torch-32", + "gpt_oss_router_gemm-32", + "torch-128", + "gpt_oss_router_gemm-128", + ], + line_names=( + [ + "torch-32", + "gpt_oss_router_gemm-32", + "torch-128", + "gpt_oss_router_gemm-128", + ] + ), styles=([("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]), ylabel="TFLOPs", plot_name="router gemm throughput", @@ -30,9 +42,9 @@ def benchmark(num_tokens, impl): # M: num_tokens, K: hidden_dim, N: num_experts M, K = num_tokens, 2880 - if impl == "torch-32" or impl == "tinygemm2-32": + if impl == "torch-32" or impl == "gpt_oss_router_gemm-32": N = 32 - elif impl == "torch-128" or impl == "tinygemm2-128": + elif impl == "torch-128" or impl == "gpt_oss_router_gemm-128": N = 128 else: raise ValueError(f"Unknown impl: {impl}") @@ -47,10 +59,10 @@ def benchmark(num_tokens, impl): def runner(): F.linear(mat_a, mat_b, bias) - elif impl == "tinygemm2-32" or impl == "tinygemm2-128": + elif impl == "gpt_oss_router_gemm-32" or impl == "gpt_oss_router_gemm-128": def runner(): - ops.tinygemm2(mat_a, mat_b, bias) + ops.gpt_oss_router_gemm(mat_a, mat_b, bias) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles) diff --git a/csrc/moe/tinygemm2.cu b/csrc/moe/gpt_oss_router_gemm.cu similarity index 75% rename from csrc/moe/tinygemm2.cu rename to csrc/moe/gpt_oss_router_gemm.cu index 89b13ac11a43..0294cd36aa8f 100644 --- a/csrc/moe/tinygemm2.cu +++ b/csrc/moe/gpt_oss_router_gemm.cu @@ -23,11 +23,12 @@ #include #include #include -#include "tinygemm2.cuh" +#include "gpt_oss_router_gemm.cuh" -void launch_tinygemm2(__nv_bfloat16* gA, __nv_bfloat16* gB, __nv_bfloat16* gC, - __nv_bfloat16* bias, int batch_size, int output_features, - int input_features, cudaStream_t stream) { +void launch_gpt_oss_router_gemm(__nv_bfloat16* gA, __nv_bfloat16* gB, + __nv_bfloat16* gC, __nv_bfloat16* bias, + int batch_size, int output_features, + int input_features, cudaStream_t stream) { static int const WARP_TILE_M = 16; static int const TILE_M = WARP_TILE_M; static int const TILE_N = 8; @@ -75,8 +76,8 @@ void launch_tinygemm2(__nv_bfloat16* gA, __nv_bfloat16* gB, __nv_bfloat16* gC, TILE_N * TILE_K * sizeof(__nv_bfloat16)); gpuErrChk(cudaFuncSetAttribute( - tinygemm_kernel, + gpt_oss_router_gemm_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); int tiles_m = (output_features + TILE_M - 1) / TILE_M; @@ -96,15 +97,17 @@ void launch_tinygemm2(__nv_bfloat16* gA, __nv_bfloat16* gB, __nv_bfloat16* gC, attrs[0].val.programmaticStreamSerializationAllowed = 1; config.numAttrs = 1; - cudaLaunchKernelEx(&config, - &tinygemm_kernel, - gC, gA, gB, bias, output_features, batch_size, - input_features, weight_map, activation_map, nullptr); + cudaLaunchKernelEx( + &config, + &gpt_oss_router_gemm_kernel, + gC, gA, gB, bias, output_features, batch_size, input_features, weight_map, + activation_map, nullptr); } -void tinygemm2_cuda_forward(torch::Tensor& output, torch::Tensor input, - torch::Tensor weight, torch::Tensor bias) { +void gpt_oss_router_gemm_cuda_forward(torch::Tensor& output, + torch::Tensor input, torch::Tensor weight, + torch::Tensor bias) { auto const batch_size = input.size(0); auto const input_dim = input.size(1); auto const output_dim = weight.size(0); @@ -112,18 +115,18 @@ void tinygemm2_cuda_forward(torch::Tensor& output, torch::Tensor input, auto stream = at::cuda::getCurrentCUDAStream(); if (input.scalar_type() == at::ScalarType::BFloat16) { - launch_tinygemm2((__nv_bfloat16*)input.data_ptr(), - (__nv_bfloat16*)weight.data_ptr(), - (__nv_bfloat16*)output.mutable_data_ptr(), - (__nv_bfloat16*)bias.data_ptr(), batch_size, output_dim, - input_dim, stream); + launch_gpt_oss_router_gemm((__nv_bfloat16*)input.data_ptr(), + (__nv_bfloat16*)weight.data_ptr(), + (__nv_bfloat16*)output.mutable_data_ptr(), + (__nv_bfloat16*)bias.data_ptr(), batch_size, + output_dim, input_dim, stream); } else { throw std::invalid_argument("Unsupported dtype, only supports bfloat16"); } } -void tinygemm2(torch::Tensor& output, torch::Tensor input, torch::Tensor weight, - torch::Tensor bias) { +void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input, + torch::Tensor weight, torch::Tensor bias) { TORCH_CHECK(input.dim() == 2, "input must be 2D"); TORCH_CHECK(weight.dim() == 2, "weight must be 2D"); TORCH_CHECK(bias.dim() == 1, "bias must be 1D"); @@ -137,5 +140,5 @@ void tinygemm2(torch::Tensor& output, torch::Tensor input, torch::Tensor weight, "weight tensor must be bfloat16"); TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16, "bias tensor must be bfloat16"); - tinygemm2_cuda_forward(output, input, weight, bias); + gpt_oss_router_gemm_cuda_forward(output, input, weight, bias); } diff --git a/csrc/moe/tinygemm2.cuh b/csrc/moe/gpt_oss_router_gemm.cuh similarity index 99% rename from csrc/moe/tinygemm2.cuh rename to csrc/moe/gpt_oss_router_gemm.cuh index 202430bc61a2..5cc653f19cfb 100644 --- a/csrc/moe/tinygemm2.cuh +++ b/csrc/moe/gpt_oss_router_gemm.cuh @@ -175,7 +175,7 @@ struct Profile { template -__global__ __launch_bounds__(384, 1) void tinygemm_kernel( +__global__ __launch_bounds__(384, 1) void gpt_oss_router_gemm_kernel( __nv_bfloat16* output, __nv_bfloat16* weights, __nv_bfloat16* activations, __nv_bfloat16* bias, int M, int N, int K, const __grid_constant__ CUtensorMap weight_map, diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 9d47fe12f68b..de931dc76467 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -72,6 +72,6 @@ void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); // gpt-oss optimized router GEMM kernel for SM90+ -void tinygemm2(torch::Tensor& output, torch::Tensor input, torch::Tensor weight, - torch::Tensor bias); +void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input, + torch::Tensor weight, torch::Tensor bias); #endif diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index be3ff2d5c493..4cd74366ea4d 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -135,9 +135,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // gpt-oss optimized router GEMM kernel for SM90+ m.def( - "tinygemm2(Tensor! output, Tensor input, Tensor weights, Tensor bias) -> " - "()"); - m.impl("tinygemm2", torch::kCUDA, &tinygemm2); + "gpt_oss_router_gemm(Tensor! output, Tensor input, Tensor weights, " + "Tensor bias) -> ()"); + m.impl("gpt_oss_router_gemm", torch::kCUDA, &gpt_oss_router_gemm); #endif } diff --git a/tests/kernels/moe/test_router_gemm.py b/tests/kernels/moe/test_router_gemm.py index 5dcdb5a62d44..28ed0b7e8138 100644 --- a/tests/kernels/moe/test_router_gemm.py +++ b/tests/kernels/moe/test_router_gemm.py @@ -20,13 +20,13 @@ @pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) @pytest.mark.parametrize("input_dim", [360, 720, 1440, 2880]) @pytest.mark.parametrize("output_dim", [32, 64, 128]) -def test_tinygemm2(batch_size, input_dim, output_dim): +def test_gpt_oss_router_gemm(batch_size, input_dim, output_dim): set_random_seed(0) x = torch.randn(batch_size, input_dim, device="cuda", dtype=torch.bfloat16) weight = torch.randn(output_dim, input_dim, device="cuda", dtype=torch.bfloat16) bias = torch.randn(output_dim, device="cuda", dtype=torch.bfloat16) - output = ops.tinygemm2(x, weight, bias) + output = ops.gpt_oss_router_gemm(x, weight, bias) output_ref = torch.nn.functional.linear(x, weight, bias) assert output.shape == (batch_size, output_dim) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7307e5f828a5..a45caac7c9e2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2362,7 +2362,7 @@ def dsv3_router_gemm( return output -def tinygemm2( +def gpt_oss_router_gemm( hidden_states: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor ) -> torch.Tensor: output = torch.empty( @@ -2371,7 +2371,7 @@ def tinygemm2( device=hidden_states.device, dtype=hidden_states.dtype, ) - torch.ops._moe_C.tinygemm2(output, hidden_states, weight, bias) + torch.ops._moe_C.gpt_oss_router_gemm(output, hidden_states, weight, bias) return output diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 3a54da6f9de1..f579dd996138 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -157,7 +157,7 @@ def forward( return output -def tinygemm2_impl( +def gpt_oss_router_gemm_impl( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor ) -> torch.Tensor: """ @@ -166,21 +166,21 @@ def tinygemm2_impl( does not support runtime dispatching on num_tokens. """ if x.shape[0] <= 128: - return ops.tinygemm2(x, weight, bias) + return ops.gpt_oss_router_gemm(x, weight, bias) else: return torch.nn.functional.linear(x, weight, bias) -def tinygemm2_fake( +def gpt_oss_router_gemm_fake( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor ) -> torch.Tensor: return x.new_empty((x.shape[0], weight.shape[0])) direct_register_custom_op( - op_name="tinygemm2", - op_func=tinygemm2_impl, - fake_impl=tinygemm2_fake, + op_name="gpt_oss_router_gemm", + op_func=gpt_oss_router_gemm_impl, + fake_impl=gpt_oss_router_gemm_fake, ) @@ -207,9 +207,9 @@ def __init__( assert hasattr(self, "weight") assert hasattr(self, "bias") - # Check if tinygemm2 kernel can be used. + # Check if gpt_oss_router_gemm kernel can be used. # This kernel supports PDL and is optimized for low batch size. - self.use_tinygemm = ( + self.use_gpt_oss_router_gemm = ( self.weight.dtype == torch.bfloat16 and current_platform.is_cuda() and ( @@ -222,8 +222,8 @@ def forward( self, x: torch.Tensor, ) -> torch.Tensor: - if self.use_tinygemm: - return torch.ops.vllm.tinygemm2(x, self.weight, self.bias) + if self.use_gpt_oss_router_gemm: + return torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias) else: return super().forward(x) From b11bf8e8715527be777a4e7df86d7a59f6840d7b Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Mon, 16 Mar 2026 14:06:27 -0700 Subject: [PATCH 04/11] Update test Signed-off-by: Xin Yang --- tests/kernels/moe/test_router_gemm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/kernels/moe/test_router_gemm.py b/tests/kernels/moe/test_router_gemm.py index 28ed0b7e8138..a6e3514101b2 100644 --- a/tests/kernels/moe/test_router_gemm.py +++ b/tests/kernels/moe/test_router_gemm.py @@ -28,6 +28,4 @@ def test_gpt_oss_router_gemm(batch_size, input_dim, output_dim): output = ops.gpt_oss_router_gemm(x, weight, bias) output_ref = torch.nn.functional.linear(x, weight, bias) - - assert output.shape == (batch_size, output_dim) - assert torch.allclose(output, output_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2) From 758426b54019dd4b1d5cc9719f99ef8b5177e121 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Mon, 16 Mar 2026 15:53:50 -0700 Subject: [PATCH 05/11] Add into GateLinear Signed-off-by: Xin Yang --- .../layers/fused_moe/router/gate_linear.py | 58 +++++++++++-- vllm/model_executor/models/gpt_oss.py | 82 +------------------ 2 files changed, 55 insertions(+), 85 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index 77d8e756026d..e8ed8a5249d1 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -3,9 +3,11 @@ import torch from torch.nn.parameter import Parameter +import vllm._custom_ops as ops from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op @PluggableLayer.register("gate_linear") @@ -13,8 +15,9 @@ class GateLinear(ReplicatedLinear): """MoE gate linear layer with three-tier GEMM dispatch: 1. DSV3 specialized kernel (SM90+, batch<=16, supported dims) - 2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) - 3. F.linear via ReplicatedLinear (ultimate fallback) + 2. gpt-oss specialized kernel (SM90+, batch<=128, supported dims) + 3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) + 4. F.linear via ReplicatedLinear (ultimate fallback) The ``out_dtype`` attribute is mutable and can be set after init (e.g. when the required dtype depends on the expert quantization @@ -25,6 +28,10 @@ class GateLinear(ReplicatedLinear): DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] DSV3_SUPPORTED_HIDDEN_SIZES = [7168] + # Dimensions supported by the gpt-oss specialized kernel + GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128] + GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880] + def __init__( self, input_size: int, @@ -65,6 +72,15 @@ def __init__( and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES ) + # gpt-oss specialized kernel eligibility (SM90+, exact dims) + self.allow_gpt_oss_router_gemm = ( + self.weight.dtype == torch.bfloat16 + and current_platform.is_cuda() + and is_hopper_or_blackwell + and output_size in self.GPT_OSS_SUPPORTED_NUM_EXPERTS + and input_size in self.GPT_OSS_SUPPORTED_HIDDEN_SIZES + ) + # cuBLAS bf16→fp32 eligibility self.allow_cublas_router_gemm = ( self.allow_specialized_router_gemm @@ -92,8 +108,6 @@ def set_out_dtype(self, out_dtype: torch.dtype) -> None: def forward( self, x: torch.Tensor ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: - import vllm._custom_ops as ops - # Tier 1: DSV3 specialized kernel if self.allow_dsv3_router_gemm and x.shape[0] <= 16: output = ops.dsv3_router_gemm( @@ -103,15 +117,47 @@ def forward( ) return output, None - # Tier 2: cuBLAS bf16→fp32 + # Tier 2: gpt-oss specialized kernel + if self.allow_gpt_oss_router_gemm: + output = torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias) + return output, None + + # Tier 3: cuBLAS bf16→fp32 if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16: output = ops.router_gemm_bf16_fp32(x, self.weight) return output, None - # Tier 3: F.linear (ReplicatedLinear) + # Tier 4: F.linear (ReplicatedLinear) if self.out_dtype is not None and x.dtype != self.weight.dtype: x = x.to(self.weight.dtype) output, output_bias = super().forward(x) if self.out_dtype is not None and output.dtype != self.out_dtype: output = output.to(self.out_dtype) return output, output_bias + + +def gpt_oss_router_gemm_impl( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + """ + Dynamically run min-latency gemm if num_tokens <= 128. + This must be wrapped in a custom op because our torch.compile integration + does not support runtime dispatching on num_tokens. + """ + if x.shape[0] <= 128: + return ops.gpt_oss_router_gemm(x, weight, bias) + else: + return torch.nn.functional.linear(x, weight, bias) + + +def gpt_oss_router_gemm_fake( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + return x.new_empty((x.shape[0], weight.shape[0])) + + +direct_register_custom_op( + op_name="gpt_oss_router_gemm", + op_func=gpt_oss_router_gemm_impl, + fake_impl=gpt_oss_router_gemm_fake, +) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index f579dd996138..482056250a1e 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -8,7 +8,6 @@ from torch import nn from transformers import GptOssConfig -import vllm._custom_ops as ops from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import ( @@ -21,12 +20,11 @@ tensor_model_parallel_all_gather, ) from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -46,7 +44,6 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.math_utils import cdiv -from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backend import AttentionType from .interfaces import ( @@ -157,77 +154,6 @@ def forward( return output -def gpt_oss_router_gemm_impl( - x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -) -> torch.Tensor: - """ - Dynamically run min-latency gemm if num_tokens <= 128. - This must be wrapped in a custom op because our torch.compile integration - does not support runtime dispatching on num_tokens. - """ - if x.shape[0] <= 128: - return ops.gpt_oss_router_gemm(x, weight, bias) - else: - return torch.nn.functional.linear(x, weight, bias) - - -def gpt_oss_router_gemm_fake( - x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -) -> torch.Tensor: - return x.new_empty((x.shape[0], weight.shape[0])) - - -direct_register_custom_op( - op_name="gpt_oss_router_gemm", - op_func=gpt_oss_router_gemm_impl, - fake_impl=gpt_oss_router_gemm_fake, -) - - -class GptOssRouter(ReplicatedLinear): - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - return_bias: bool = True, - ): - assert quant_config is None - super().__init__( - input_size, - output_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.router", - return_bias=return_bias, - ) - - assert hasattr(self, "weight") - assert hasattr(self, "bias") - - # Check if gpt_oss_router_gemm kernel can be used. - # This kernel supports PDL and is optimized for low batch size. - self.use_gpt_oss_router_gemm = ( - self.weight.dtype == torch.bfloat16 - and current_platform.is_cuda() - and ( - current_platform.is_device_capability(90) - or current_platform.is_device_capability_family(100) - ) - ) - - def forward( - self, - x: torch.Tensor, - ) -> torch.Tensor: - if self.use_gpt_oss_router_gemm: - return torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias) - else: - return super().forward(x) - - class MLPBlock(torch.nn.Module): def __init__( self, @@ -248,13 +174,11 @@ def __init__( self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 - self.router = GptOssRouter( + self.router = GateLinear( config.hidden_size, config.num_local_experts, bias=True, - quant_config=None, prefix=f"{prefix}.router", - return_bias=False, ) assert config.intermediate_size % self.world_size == 0 self.experts = FusedMoE( @@ -282,7 +206,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self, x[:, : self.hidden_size], self.router.weight, self.router.bias ) else: - g = self.router(x) + g, _ = self.router(x) x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size] if self.is_sequence_parallel: From edc97b9c27009e0d3659afaa3c4f3e1db580b2d3 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Mon, 16 Mar 2026 17:06:08 -0700 Subject: [PATCH 06/11] Update test Signed-off-by: Xin Yang --- tests/kernels/moe/test_router_gemm.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_router_gemm.py b/tests/kernels/moe/test_router_gemm.py index a6e3514101b2..c4780ae45069 100644 --- a/tests/kernels/moe/test_router_gemm.py +++ b/tests/kernels/moe/test_router_gemm.py @@ -14,8 +14,14 @@ @pytest.mark.skipif( - not (current_platform.is_cuda() and current_platform.has_device_capability(90)), - reason="This test is skipped on non-CUDA platform.", + not ( + current_platform.is_cuda() + and ( + current_platform.is_device_capability(90) + or current_platform.is_device_capability_family(100) + ) + ), + reason="This test only runs on CUDA Hopper or Blackwell platform.", ) @pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) @pytest.mark.parametrize("input_dim", [360, 720, 1440, 2880]) From 991c20b0a6748bfa7081ab772577fe24505674ca Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Mon, 16 Mar 2026 17:06:29 -0700 Subject: [PATCH 07/11] Update benchmark script Signed-off-by: Xin Yang --- benchmarks/kernels/benchmark_router_gemm.py | 174 +++++++++++++------- 1 file changed, 114 insertions(+), 60 deletions(-) diff --git a/benchmarks/kernels/benchmark_router_gemm.py b/benchmarks/kernels/benchmark_router_gemm.py index 137eca369fb7..cc63f8904c27 100644 --- a/benchmarks/kernels/benchmark_router_gemm.py +++ b/benchmarks/kernels/benchmark_router_gemm.py @@ -1,80 +1,134 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse - import torch import torch.nn.functional as F from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.transformers_utils.config import get_config from vllm.triton_utils import triton +from vllm.utils.argparse_utils import FlexibleArgumentParser -num_tokens_range = [2**x for x in range(14)] - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["num_tokens"], - x_vals=num_tokens_range, - x_log=False, - line_arg="impl", - line_vals=[ - "torch-32", - "gpt_oss_router_gemm-32", - "torch-128", - "gpt_oss_router_gemm-128", - ], - line_names=( - [ - "torch-32", - "gpt_oss_router_gemm-32", - "torch-128", - "gpt_oss_router_gemm-128", - ] - ), - styles=([("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]), - ylabel="TFLOPs", - plot_name="router gemm throughput", - args={}, - ) -) -def benchmark(num_tokens, impl): - # M: num_tokens, K: hidden_dim, N: num_experts - M, K = num_tokens, 2880 - - if impl == "torch-32" or impl == "gpt_oss_router_gemm-32": - N = 32 - elif impl == "torch-128" or impl == "gpt_oss_router_gemm-128": - N = 128 - else: - raise ValueError(f"Unknown impl: {impl}") - - mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous() - mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous() - bias = torch.randn(N, dtype=torch.bfloat16, device="cuda").contiguous() +# Dimensions supported by the DSV3 specialized kernel +DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] +DSV3_SUPPORTED_HIDDEN_SIZES = [7168] - quantiles = [0.5, 0.2, 0.8] +# Dimensions supported by the gpt-oss specialized kernel +GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128] +GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880] - if impl == "torch-32" or impl == "torch-128": - def runner(): - F.linear(mat_a, mat_b, bias) - elif impl == "gpt_oss_router_gemm-32" or impl == "gpt_oss_router_gemm-128": +def get_batch_size_range(max_batch_size): + return [2**x for x in range(14) if 2**x <= max_batch_size] - def runner(): - ops.gpt_oss_router_gemm(mat_a, mat_b, bias) - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles) - - def tflops(t_ms): - flops = 2 * M * K * N - return flops / (t_ms * 1e-3) / 1e12 - - return tflops(ms), tflops(max_ms), tflops(min_ms) +def get_model_params(config): + if config.architectures[0] in ( + "DeepseekV2ForCausalLM", + "DeepseekV3ForCausalLM", + "DeepseekV32ForCausalLM", + ): + num_experts = config.n_routed_experts + hidden_size = config.hidden_size + elif config.architectures[0] in ("GptOssForCausalLM",): + num_experts = config.num_local_experts + hidden_size = config.hidden_size + else: + raise ValueError(f"Unsupported architecture: {config.architectures}") + return num_experts, hidden_size + + +def get_benchmark(model, max_batch_size, trust_remote_code): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=get_batch_size_range(max_batch_size), + x_log=False, + line_arg="provider", + line_vals=[ + "torch", + "vllm", + ], + line_names=["PyTorch", "vLLM"], + styles=([("blue", "-"), ("red", "-")]), + ylabel="TFLOPs", + plot_name=f"{model} router gemm throughput", + args={}, + ) + ) + def benchmark(batch_size, provider): + config = get_config(model=model, trust_remote_code=trust_remote_code) + num_experts, hidden_size = get_model_params(config) + + mat_a = torch.randn( + (batch_size, hidden_size), dtype=torch.bfloat16, device="cuda" + ).contiguous() + mat_b = torch.randn( + (num_experts, hidden_size), dtype=torch.bfloat16, device="cuda" + ).contiguous() + bias = torch.randn( + num_experts, dtype=torch.bfloat16, device="cuda" + ).contiguous() + + is_hopper_or_blackwell = current_platform.is_device_capability( + 90 + ) or current_platform.is_device_capability_family(100) + allow_dsv3_router_gemm = ( + is_hopper_or_blackwell + and num_experts in DSV3_SUPPORTED_NUM_EXPERTS + and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES + ) + allow_gpt_oss_router_gemm = ( + is_hopper_or_blackwell + and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS + and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES + ) + + has_bias = False + if allow_gpt_oss_router_gemm: + has_bias = True + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch": + + def runner(): + if has_bias: + F.linear(mat_a, mat_b, bias) + else: + F.linear(mat_a, mat_b) + elif provider == "vllm": + + def runner(): + if allow_dsv3_router_gemm: + ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16) + elif allow_gpt_oss_router_gemm: + ops.gpt_oss_router_gemm(mat_a, mat_b, bias) + else: + raise ValueError("Unsupported router gemm") + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + runner, quantiles=quantiles + ) + + def tflops(t_ms): + flops = 2 * batch_size * hidden_size * num_experts + return flops / (t_ms * 1e-3) / 1e12 + + return tflops(ms), tflops(max_ms), tflops(min_ms) + + return benchmark if __name__ == "__main__": - parser = argparse.ArgumentParser() + parser = FlexibleArgumentParser() + parser.add_argument("--model", type=str, default="openai/gpt-oss-20b") + parser.add_argument("--max-batch-size", default=16, type=int) + parser.add_argument("--trust-remote-code", action="store_true") args = parser.parse_args() + # Get the benchmark function + benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code) + # Run performance benchmark benchmark.run(print_data=True) From 5db217e60cf1e2fb53971ecf9bcf2f9292328873 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Mon, 16 Mar 2026 18:13:22 -0700 Subject: [PATCH 08/11] Update Signed-off-by: Xin Yang --- tests/kernels/moe/test_router_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_router_gemm.py b/tests/kernels/moe/test_router_gemm.py index c4780ae45069..906e47708f29 100644 --- a/tests/kernels/moe/test_router_gemm.py +++ b/tests/kernels/moe/test_router_gemm.py @@ -21,7 +21,7 @@ or current_platform.is_device_capability_family(100) ) ), - reason="This test only runs on CUDA Hopper or Blackwell platform.", + reason="This test only runs on Hopper or Blackwell GPUs.", ) @pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) @pytest.mark.parametrize("input_dim", [360, 720, 1440, 2880]) From a42846a45a8202ad89733e2e110d9979cf6a6da5 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Tue, 17 Mar 2026 15:37:51 -0700 Subject: [PATCH 09/11] Add GateLinearWithLoRA Signed-off-by: Xin Yang --- vllm/lora/layers/__init__.py | 2 ++ vllm/lora/layers/gate_linear.py | 30 ++++++++++++++++++++++++++++++ vllm/lora/utils.py | 2 ++ 3 files changed, 34 insertions(+) create mode 100644 vllm/lora/layers/gate_linear.py diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py index 1f3fdea2cdaf..235f40b73852 100644 --- a/vllm/lora/layers/__init__.py +++ b/vllm/lora/layers/__init__.py @@ -13,6 +13,7 @@ QKVParallelLinearWithShardedLoRA, ) from vllm.lora.layers.fused_moe import FusedMoE3DWithLoRA, FusedMoEWithLoRA +from vllm.lora.layers.gate_linear import GateLinearWithLoRA from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA from vllm.lora.layers.row_parallel_linear import ( @@ -38,6 +39,7 @@ "RowParallelLinearWithLoRA", "RowParallelLinearWithShardedLoRA", "ReplicatedLinearWithLoRA", + "GateLinearWithLoRA", "LoRAMapping", "LoRAMappingType", "FusedMoEWithLoRA", diff --git a/vllm/lora/layers/gate_linear.py b/vllm/lora/layers/gate_linear.py new file mode 100644 index 000000000000..9bcaaa5b8e20 --- /dev/null +++ b/vllm/lora/layers/gate_linear.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.model_executor.custom_op import maybe_get_oot_by_class +from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear + +from .replicated_linear import ReplicatedLinearWithLoRA + + +class GateLinearWithLoRA(ReplicatedLinearWithLoRA): + def __init__(self, base_layer: GateLinear) -> None: + super().__init__( + base_layer, + ) + + # GateLinearWithLoRA should always be replaced, regardless of the fully + # sharded LoRAs setting, because it is, by definition, copied per GPU. + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: PretrainedConfig | None = None, + ) -> bool: + return type(source_layer) is maybe_get_oot_by_class(GateLinear) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 2349ace70846..75ed9674af56 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -21,6 +21,7 @@ ColumnParallelLinearWithShardedLoRA, FusedMoE3DWithLoRA, FusedMoEWithLoRA, + GateLinearWithLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearVariableSliceWithLoRA, MergedColumnParallelLinearWithLoRA, @@ -81,6 +82,7 @@ def get_lora_id(): MergedQKVParallelLinearWithLoRA, RowParallelLinearWithLoRA, ReplicatedLinearWithLoRA, + GateLinearWithLoRA, LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, From ac52c1dbc5a75d05a987b1f1d89e6beee1506cc1 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Tue, 17 Mar 2026 16:48:03 -0700 Subject: [PATCH 10/11] Use flashinfer tinygemm Signed-off-by: Xin Yang --- CMakeLists.txt | 1 - benchmarks/kernels/benchmark_router_gemm.py | 134 ------ csrc/moe/gpt_oss_router_gemm.cu | 144 ------ csrc/moe/gpt_oss_router_gemm.cuh | 447 ------------------ csrc/moe/moe_ops.h | 4 - csrc/moe/torch_bindings.cpp | 6 - tests/kernels/moe/test_router_gemm.py | 37 -- vllm/_custom_ops.py | 13 - .../layers/fused_moe/router/gate_linear.py | 92 ++-- 9 files changed, 49 insertions(+), 829 deletions(-) delete mode 100644 benchmarks/kernels/benchmark_router_gemm.py delete mode 100644 csrc/moe/gpt_oss_router_gemm.cu delete mode 100644 csrc/moe/gpt_oss_router_gemm.cuh delete mode 100644 tests/kernels/moe/test_router_gemm.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 693070b5f476..bbadfdc5e9e3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -999,7 +999,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu" "csrc/moe/grouped_topk_kernels.cu" - "csrc/moe/gpt_oss_router_gemm.cu" "csrc/moe/router_gemm.cu") endif() diff --git a/benchmarks/kernels/benchmark_router_gemm.py b/benchmarks/kernels/benchmark_router_gemm.py deleted file mode 100644 index cc63f8904c27..000000000000 --- a/benchmarks/kernels/benchmark_router_gemm.py +++ /dev/null @@ -1,134 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch -import torch.nn.functional as F - -from vllm import _custom_ops as ops -from vllm.platforms import current_platform -from vllm.transformers_utils.config import get_config -from vllm.triton_utils import triton -from vllm.utils.argparse_utils import FlexibleArgumentParser - -# Dimensions supported by the DSV3 specialized kernel -DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] -DSV3_SUPPORTED_HIDDEN_SIZES = [7168] - -# Dimensions supported by the gpt-oss specialized kernel -GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128] -GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880] - - -def get_batch_size_range(max_batch_size): - return [2**x for x in range(14) if 2**x <= max_batch_size] - - -def get_model_params(config): - if config.architectures[0] in ( - "DeepseekV2ForCausalLM", - "DeepseekV3ForCausalLM", - "DeepseekV32ForCausalLM", - ): - num_experts = config.n_routed_experts - hidden_size = config.hidden_size - elif config.architectures[0] in ("GptOssForCausalLM",): - num_experts = config.num_local_experts - hidden_size = config.hidden_size - else: - raise ValueError(f"Unsupported architecture: {config.architectures}") - return num_experts, hidden_size - - -def get_benchmark(model, max_batch_size, trust_remote_code): - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size"], - x_vals=get_batch_size_range(max_batch_size), - x_log=False, - line_arg="provider", - line_vals=[ - "torch", - "vllm", - ], - line_names=["PyTorch", "vLLM"], - styles=([("blue", "-"), ("red", "-")]), - ylabel="TFLOPs", - plot_name=f"{model} router gemm throughput", - args={}, - ) - ) - def benchmark(batch_size, provider): - config = get_config(model=model, trust_remote_code=trust_remote_code) - num_experts, hidden_size = get_model_params(config) - - mat_a = torch.randn( - (batch_size, hidden_size), dtype=torch.bfloat16, device="cuda" - ).contiguous() - mat_b = torch.randn( - (num_experts, hidden_size), dtype=torch.bfloat16, device="cuda" - ).contiguous() - bias = torch.randn( - num_experts, dtype=torch.bfloat16, device="cuda" - ).contiguous() - - is_hopper_or_blackwell = current_platform.is_device_capability( - 90 - ) or current_platform.is_device_capability_family(100) - allow_dsv3_router_gemm = ( - is_hopper_or_blackwell - and num_experts in DSV3_SUPPORTED_NUM_EXPERTS - and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES - ) - allow_gpt_oss_router_gemm = ( - is_hopper_or_blackwell - and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS - and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES - ) - - has_bias = False - if allow_gpt_oss_router_gemm: - has_bias = True - - quantiles = [0.5, 0.2, 0.8] - - if provider == "torch": - - def runner(): - if has_bias: - F.linear(mat_a, mat_b, bias) - else: - F.linear(mat_a, mat_b) - elif provider == "vllm": - - def runner(): - if allow_dsv3_router_gemm: - ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16) - elif allow_gpt_oss_router_gemm: - ops.gpt_oss_router_gemm(mat_a, mat_b, bias) - else: - raise ValueError("Unsupported router gemm") - - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - runner, quantiles=quantiles - ) - - def tflops(t_ms): - flops = 2 * batch_size * hidden_size * num_experts - return flops / (t_ms * 1e-3) / 1e12 - - return tflops(ms), tflops(max_ms), tflops(min_ms) - - return benchmark - - -if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser.add_argument("--model", type=str, default="openai/gpt-oss-20b") - parser.add_argument("--max-batch-size", default=16, type=int) - parser.add_argument("--trust-remote-code", action="store_true") - args = parser.parse_args() - - # Get the benchmark function - benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code) - # Run performance benchmark - benchmark.run(print_data=True) diff --git a/csrc/moe/gpt_oss_router_gemm.cu b/csrc/moe/gpt_oss_router_gemm.cu deleted file mode 100644 index 0294cd36aa8f..000000000000 --- a/csrc/moe/gpt_oss_router_gemm.cu +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Adapted from - * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu - * Copyright (c) 2025, The vLLM team. - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. - * All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include "gpt_oss_router_gemm.cuh" - -void launch_gpt_oss_router_gemm(__nv_bfloat16* gA, __nv_bfloat16* gB, - __nv_bfloat16* gC, __nv_bfloat16* bias, - int batch_size, int output_features, - int input_features, cudaStream_t stream) { - static int const WARP_TILE_M = 16; - static int const TILE_M = WARP_TILE_M; - static int const TILE_N = 8; - static int const TILE_K = 64; - static int const STAGES = 16; - static int const STAGE_UNROLL = 4; - static bool const PROFILE = false; - - CUtensorMap weight_map{}; - CUtensorMap activation_map{}; - - constexpr uint32_t rank = 2; - uint64_t size[rank] = {(uint64_t)input_features, (uint64_t)output_features}; - uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)}; - uint32_t box_size[rank] = {TILE_K, TILE_M}; - uint32_t elem_stride[rank] = {1, 1}; - - CUresult res = cuTensorMapEncodeTiled( - &weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank, - gB, size, stride, box_size, elem_stride, - CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, - CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, - CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - TORCH_CHECK(res == CUDA_SUCCESS, - "cuTensorMapEncodeTiled failed for weight_map, error code=", - static_cast(res)); - - size[1] = batch_size; - box_size[1] = TILE_N; - - res = cuTensorMapEncodeTiled( - &activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, - rank, gA, size, stride, box_size, elem_stride, - CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, - CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, - CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - TORCH_CHECK(res == CUDA_SUCCESS, - "cuTensorMapEncodeTiled failed for activation_map, error code=", - static_cast(res)); - - int smem_size = STAGES * STAGE_UNROLL * - (TILE_M * TILE_K * sizeof(__nv_bfloat16) + - TILE_N * TILE_K * sizeof(__nv_bfloat16)); - - gpuErrChk(cudaFuncSetAttribute( - gpt_oss_router_gemm_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - int tiles_m = (output_features + TILE_M - 1) / TILE_M; - int tiles_n = (batch_size + TILE_N - 1) / TILE_N; - - dim3 grid(tiles_m, tiles_n); - dim3 block(384); - - cudaLaunchConfig_t config; - cudaLaunchAttribute attrs[1]; - config.gridDim = grid; - config.blockDim = block; - config.dynamicSmemBytes = smem_size; - config.stream = stream; - config.attrs = attrs; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = 1; - config.numAttrs = 1; - - cudaLaunchKernelEx( - &config, - &gpt_oss_router_gemm_kernel, - gC, gA, gB, bias, output_features, batch_size, input_features, weight_map, - activation_map, nullptr); -} - -void gpt_oss_router_gemm_cuda_forward(torch::Tensor& output, - torch::Tensor input, torch::Tensor weight, - torch::Tensor bias) { - auto const batch_size = input.size(0); - auto const input_dim = input.size(1); - auto const output_dim = weight.size(0); - - auto stream = at::cuda::getCurrentCUDAStream(); - - if (input.scalar_type() == at::ScalarType::BFloat16) { - launch_gpt_oss_router_gemm((__nv_bfloat16*)input.data_ptr(), - (__nv_bfloat16*)weight.data_ptr(), - (__nv_bfloat16*)output.mutable_data_ptr(), - (__nv_bfloat16*)bias.data_ptr(), batch_size, - output_dim, input_dim, stream); - } else { - throw std::invalid_argument("Unsupported dtype, only supports bfloat16"); - } -} - -void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input, - torch::Tensor weight, torch::Tensor bias) { - TORCH_CHECK(input.dim() == 2, "input must be 2D"); - TORCH_CHECK(weight.dim() == 2, "weight must be 2D"); - TORCH_CHECK(bias.dim() == 1, "bias must be 1D"); - TORCH_CHECK(input.sizes()[1] == weight.sizes()[1], - "input.size(1) must match weight.size(1)"); - TORCH_CHECK(weight.sizes()[0] == bias.sizes()[0], - "weight.size(0) must match bias.size(0)"); - TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16, - "input tensor must be bfloat16"); - TORCH_CHECK(weight.scalar_type() == at::ScalarType::BFloat16, - "weight tensor must be bfloat16"); - TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16, - "bias tensor must be bfloat16"); - gpt_oss_router_gemm_cuda_forward(output, input, weight, bias); -} diff --git a/csrc/moe/gpt_oss_router_gemm.cuh b/csrc/moe/gpt_oss_router_gemm.cuh deleted file mode 100644 index 5cc653f19cfb..000000000000 --- a/csrc/moe/gpt_oss_router_gemm.cuh +++ /dev/null @@ -1,447 +0,0 @@ -/* - * Adapted from - * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh - * Copyright (c) 2025, The vLLM team. - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. - * All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cuda_bf16.h" -#include -#include -#include - -#include "cuda_pipeline.h" -#include -#include -#include -#include - -using barrier = cuda::barrier; -namespace cde = cuda::device::experimental; -namespace ptx = cuda::ptx; - -#define gpuErrChk(ans) \ - { \ - gpuAssert((ans), __FILE__, __LINE__); \ - } - -inline void gpuAssert(cudaError_t code, char const* file, int line, - bool abort = true) { - if (code != cudaSuccess) { - fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, - line); - if (abort) { - throw std::runtime_error(cudaGetErrorString(code)); - } - } -} - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) -__device__ uint64_t gclock64() { - unsigned long long int rv; - asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(rv)); - return rv; -} - -__device__ void ldmatrix(__nv_bfloat16 rv[2], uint32_t smem_ptr) { - int dst; - asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" - : "=r"(dst) - : "r"(smem_ptr)); - int* rvi = reinterpret_cast(&rv[0]); - rvi[0] = dst; -} - -__device__ void ldmatrix2(__nv_bfloat16 rv[4], uint32_t smem_ptr) { - int x, y; - asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" - : "=r"(x), "=r"(y) - : "r"(smem_ptr)); - - int* rvi = reinterpret_cast(&rv[0]); - rvi[0] = x; - rvi[1] = y; -} - -__device__ void ldmatrix4(__nv_bfloat16 rv[8], uint32_t smem_ptr) { - int x, y, z, w; - asm volatile( - "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" - : "=r"(x), "=r"(y), "=r"(z), "=r"(w) - : "r"(smem_ptr)); - int* rvi = reinterpret_cast(&rv[0]); - rvi[0] = x; - rvi[1] = y; - rvi[2] = z; - rvi[3] = w; -} - -__device__ void HMMA_1688(float d[4], __nv_bfloat16 a[4], __nv_bfloat16 b[2], - float c[4]) { - uint32_t const* A = reinterpret_cast(&a[0]); - uint32_t const* B = reinterpret_cast(&b[0]); - float const* C = reinterpret_cast(&c[0]); - float* D = reinterpret_cast(&d[0]); - - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), - "f"(C[3])); -} - -__device__ void HMMA_16816(float d[4], __nv_bfloat16 a[8], __nv_bfloat16 b[4], - float c[4]) { - uint32_t const* A = reinterpret_cast(&a[0]); - uint32_t const* B = reinterpret_cast(&b[0]); - float const* C = reinterpret_cast(&c[0]); - float* D = reinterpret_cast(&d[0]); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); -} - -__device__ void bar_wait(uint32_t bar_ptr, int phase) { - asm volatile( - "{\n" - ".reg .pred P1;\n" - "LAB_WAIT:\n" - "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" - "@P1 bra.uni DONE;\n" - "bra.uni LAB_WAIT;\n" - "DONE:\n" - "}\n" ::"r"(bar_ptr), - "r"(phase)); -} - -__device__ bool bar_try_wait(uint32_t bar_ptr, int phase) { - uint32_t success; - #ifdef INTERNAL - asm volatile(".pragma \"set knob DontInsertYield\";\n" : : : "memory"); - #endif - asm volatile( - "{\n\t" - ".reg .pred P1; \n\t" - "mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" - "selp.b32 %0, 1, 0, P1; \n\t" - "}" - : "=r"(success) - : "r"(bar_ptr), "r"(phase)); - return success; -} - -__device__ uint32_t elect_one_sync() { - uint32_t pred = 0; - uint32_t laneid = 0; - asm volatile( - "{\n" - ".reg .b32 %%rx;\n" - ".reg .pred %%px;\n" - " elect.sync %%rx|%%px, %2;\n" - "@%%px mov.s32 %1, 1;\n" - " mov.s32 %0, %%rx;\n" - "}\n" - : "+r"(laneid), "+r"(pred) - : "r"(0xFFFFFFFF)); - return pred; -} -#endif - -struct Profile { - uint64_t start; - uint64_t weight_load_start; - uint64_t act_load_start; - uint64_t compute_start; - uint64_t complete; -}; - -template -__global__ __launch_bounds__(384, 1) void gpt_oss_router_gemm_kernel( - __nv_bfloat16* output, __nv_bfloat16* weights, __nv_bfloat16* activations, - __nv_bfloat16* bias, int M, int N, int K, - const __grid_constant__ CUtensorMap weight_map, - const __grid_constant__ CUtensorMap activation_map, - Profile* profile = nullptr) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - - if (PROFILE && threadIdx.x == 0 && blockIdx.y == 0) - profile[blockIdx.x].start = gclock64(); - - extern __shared__ __align__(128) char smem[]; - - __nv_bfloat16* sh_weights = (__nv_bfloat16*)&smem[0]; - __nv_bfloat16* sh_activations = - (__nv_bfloat16*)&smem[STAGES * STAGE_UNROLL * TILE_M * TILE_K * - sizeof(__nv_bfloat16)]; - - #pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ barrier bar_wt_ready[STAGES]; - __shared__ barrier bar_act_ready[STAGES]; - __shared__ barrier bar_data_consumed[STAGES]; - - __shared__ float4 reduction_buffer[128]; - - __shared__ nv_bfloat16 sh_bias[TILE_M]; - - if (threadIdx.x == 0) { - for (int i = 0; i < STAGES; i++) { - init(&bar_wt_ready[i], 1); - init(&bar_act_ready[i], 1); - init(&bar_data_consumed[i], 32); - } - ptx::fence_proxy_async(ptx::space_shared); - asm volatile("prefetch.tensormap [%0];" - : - : "l"(reinterpret_cast(&weight_map)) - : "memory"); - asm volatile("prefetch.tensormap [%0];" - : - : "l"(reinterpret_cast(&activation_map)) - : "memory"); - } - __syncthreads(); - - int warp_id = threadIdx.x / 32; - int lane_id = threadIdx.x % 32; - - int phase = 0; - - int mib = blockIdx.x * TILE_M; - int ni = blockIdx.y * TILE_N; - - float accum[4]; - for (int i = 0; i < 4; i++) accum[i] = 0.f; - - int const K_LOOPS_DMA = - (K + 4 * TILE_K * STAGE_UNROLL - 1) / (4 * (TILE_K * STAGE_UNROLL)); - int const K_LOOPS_COMPUTE = K_LOOPS_DMA; - - // Data loading thread - if (warp_id >= 4 && elect_one_sync()) { - int stage = warp_id % 4; - - bool weight_warp = warp_id < 8; - if (!weight_warp) { - cudaGridDependencySynchronize(); - cudaTriggerProgrammaticLaunchCompletion(); - } - - for (int ki = 0; ki < K_LOOPS_DMA; ki++) { - int k = (ki * 4 + (warp_id % 4)) * TILE_K * STAGE_UNROLL; - - uint64_t desc_ptr_wt = reinterpret_cast(&weight_map); - uint64_t desc_ptr_act = reinterpret_cast(&activation_map); - - uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]); - uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]); - int bytes_wt = TILE_M * TILE_K * sizeof(__nv_bfloat16); - int bytes_act = TILE_N * TILE_K * sizeof(__nv_bfloat16); - - bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1); - - if (weight_warp) - asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" - : - : "r"(bar_ptr_wt), "r"(STAGE_UNROLL * bytes_wt)); - if (!weight_warp) - asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" - : - : "r"(bar_ptr_act), "r"(STAGE_UNROLL * bytes_act)); - - if (PROFILE && blockIdx.y == 0 && ki == 0 && weight_warp) - profile[blockIdx.x].weight_load_start = gclock64(); - if (PROFILE && blockIdx.y == 0 && ki == 0 && !weight_warp) - profile[blockIdx.x].act_load_start = gclock64(); - - for (int i = 0; i < STAGE_UNROLL; i++) { - uint32_t smem_ptr_wt = __cvta_generic_to_shared( - &sh_weights[(stage * STAGE_UNROLL + i) * TILE_M * TILE_K]); - uint32_t crd0 = k + i * TILE_K; - uint32_t crd1 = mib; - if (weight_warp) - asm volatile( - "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_" - "tx::bytes [%0], [%1, {%3,%4}], " - "[%2];" - : - : "r"(smem_ptr_wt), "l"(desc_ptr_wt), "r"(bar_ptr_wt), "r"(crd0), - "r"(crd1) - : "memory"); - - uint32_t smem_ptr_act = __cvta_generic_to_shared( - &sh_activations[(stage * STAGE_UNROLL + i) * TILE_N * TILE_K]); - crd0 = k + i * TILE_K; - crd1 = ni; - if (!weight_warp) - asm volatile( - "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_" - "tx::bytes [%0], [%1, {%3,%4}], " - "[%2];" - : - : "r"(smem_ptr_act), "l"(desc_ptr_act), "r"(bar_ptr_act), - "r"(crd0), "r"(crd1) - : "memory"); - } - - stage += 4; - if (stage >= STAGES) { - stage = warp_id % 4; - phase ^= 1; - } - } - // Wait for pending loads to be consumed before exiting, to avoid race - for (int i = 0; i < (STAGES / 4) - 1; i++) { - bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1); - stage += 4; - if (stage >= STAGES) { - stage = warp_id % 4; - phase ^= 1; - } - } - } - // Compute threads - else if (warp_id < 4) { - // Sneak the bias load into the compute warps since they're just waiting for - // stuff anyway - if (threadIdx.x < TILE_M) sh_bias[threadIdx.x] = bias[mib + threadIdx.x]; - - int stage = warp_id; - - int phase = 0; - int lane_id_div8 = lane_id / 8; - int lane_id_mod8 = lane_id % 8; - - int lane_row_offset_wt = (lane_id_div8 % 2) ? 8 : 0; - int lane_col_offset_wt = (lane_id_div8 / 2) ? 1 : 0; - - int row_wt = lane_id_mod8 + lane_row_offset_wt; - int row_act = lane_id_mod8; - - int row_offset_wt = (reinterpret_cast(sh_weights) / 128) % 8; - int row_offset_act = row_offset_wt; - - uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]); - uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]); - - bool weight_ready = bar_try_wait(bar_ptr_wt, phase); - bool act_ready = bar_try_wait(bar_ptr_act, phase); - - #pragma unroll 2 - for (int ki = 0; ki < K_LOOPS_COMPUTE; ki++) { - int next_stage = stage + 4; - int next_phase = phase; - if (next_stage >= STAGES) { - next_stage = warp_id; - next_phase ^= 1; - } - - while (!weight_ready || !act_ready) { - weight_ready = bar_try_wait(bar_ptr_wt, phase); - act_ready = bar_try_wait(bar_ptr_act, phase); - } - - if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0 && ki == 0) - profile[blockIdx.x].compute_start = gclock64(); - - if (ki + 1 < K_LOOPS_COMPUTE) { - weight_ready = bar_try_wait( - __cvta_generic_to_shared(&bar_wt_ready[next_stage]), next_phase); - act_ready = bar_try_wait( - __cvta_generic_to_shared(&bar_act_ready[next_stage]), next_phase); - } - - #pragma unroll - for (int su = 0; su < STAGE_UNROLL; su++) { - __nv_bfloat16* ptr_weights = - &sh_weights[(stage * STAGE_UNROLL + su) * TILE_M * TILE_K]; - __nv_bfloat16* ptr_act = - &sh_activations[(stage * STAGE_UNROLL + su) * TILE_N * TILE_K]; - - #pragma unroll - for (int kii = 0; kii < TILE_K / 16; kii++) { - __nv_bfloat16 a[8]; - __nv_bfloat16 b[4]; - - int col = 2 * kii + lane_col_offset_wt; - int col_sw = ((row_wt + row_offset_wt) % 8) ^ col; - - ldmatrix4(a, __cvta_generic_to_shared( - &ptr_weights[row_wt * TILE_K + col_sw * 8])); - - col = 2 * kii + lane_id_div8; - col_sw = ((row_act + row_offset_act) % 8) ^ col; - - ldmatrix2(b, __cvta_generic_to_shared( - &ptr_act[row_act * TILE_K + 8 * col_sw])); - - HMMA_16816(accum, a, b, accum); - } - } - - uint32_t bar_c = __cvta_generic_to_shared(&bar_data_consumed[stage]); - asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0];" : : "r"(bar_c)); - - stage = next_stage; - phase = next_phase; - } - - float4 accum4; - accum4.x = accum[0]; - accum4.y = accum[1]; - accum4.z = accum[2]; - accum4.w = accum[3]; - reduction_buffer[threadIdx.x] = accum4; - - __syncthreads(); - - if (warp_id == 0) { - int mi = mib + warp_id * WARP_TILE_M; - int tm = mi + lane_id / 4; - int tn = ni + 2 * (lane_id % 4); - - float4 accum1 = reduction_buffer[32 + threadIdx.x]; - float4 accum2 = reduction_buffer[64 + threadIdx.x]; - float4 accum3 = reduction_buffer[96 + threadIdx.x]; - - accum[0] = accum[0] + accum1.x + accum2.x + accum3.x; - accum[1] = accum[1] + accum1.y + accum2.y + accum3.y; - accum[2] = accum[2] + accum1.z + accum2.z + accum3.z; - accum[3] = accum[3] + accum1.w + accum2.w + accum3.w; - - float bias_lo = __bfloat162float(sh_bias[tm - mib]); - float bias_hi = __bfloat162float(sh_bias[tm + 8 - mib]); - - if (tn < N && tm < M) - output[tn * M + tm] = __float2bfloat16(accum[0] + bias_lo); - if (tn + 1 < N && tm < M) - output[(tn + 1) * M + tm] = __float2bfloat16(accum[1] + bias_lo); - if (tn < N && tm + 8 < M) - output[tn * M + tm + 8] = __float2bfloat16(accum[2] + bias_hi); - if (tn + 1 < N && tm + 8 < M) - output[(tn + 1) * M + tm + 8] = __float2bfloat16(accum[3] + bias_hi); - - if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0) - profile[blockIdx.x].complete = gclock64(); - } - } -#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) -} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index de931dc76467..d8d962887dab 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -70,8 +70,4 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input, // Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168 void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); - -// gpt-oss optimized router GEMM kernel for SM90+ -void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input, - torch::Tensor weight, torch::Tensor bias); #endif diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 4cd74366ea4d..7b627a6f8760 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -132,12 +132,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // DeepSeek V3 optimized router GEMM for SM90+ m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); // conditionally compiled so impl registration is in source file - - // gpt-oss optimized router GEMM kernel for SM90+ - m.def( - "gpt_oss_router_gemm(Tensor! output, Tensor input, Tensor weights, " - "Tensor bias) -> ()"); - m.impl("gpt_oss_router_gemm", torch::kCUDA, &gpt_oss_router_gemm); #endif } diff --git a/tests/kernels/moe/test_router_gemm.py b/tests/kernels/moe/test_router_gemm.py deleted file mode 100644 index 906e47708f29..000000000000 --- a/tests/kernels/moe/test_router_gemm.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for optimized router GEMM kernel - -Run `pytest tests/kernels/moe/test_router_gemm.py`. -""" - -import pytest -import torch - -import vllm._custom_ops as ops -from vllm.platforms import current_platform -from vllm.utils.torch_utils import set_random_seed - - -@pytest.mark.skipif( - not ( - current_platform.is_cuda() - and ( - current_platform.is_device_capability(90) - or current_platform.is_device_capability_family(100) - ) - ), - reason="This test only runs on Hopper or Blackwell GPUs.", -) -@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) -@pytest.mark.parametrize("input_dim", [360, 720, 1440, 2880]) -@pytest.mark.parametrize("output_dim", [32, 64, 128]) -def test_gpt_oss_router_gemm(batch_size, input_dim, output_dim): - set_random_seed(0) - x = torch.randn(batch_size, input_dim, device="cuda", dtype=torch.bfloat16) - weight = torch.randn(output_dim, input_dim, device="cuda", dtype=torch.bfloat16) - bias = torch.randn(output_dim, device="cuda", dtype=torch.bfloat16) - - output = ops.gpt_oss_router_gemm(x, weight, bias) - output_ref = torch.nn.functional.linear(x, weight, bias) - torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a45caac7c9e2..a01f44e1649d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2362,19 +2362,6 @@ def dsv3_router_gemm( return output -def gpt_oss_router_gemm( - hidden_states: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -) -> torch.Tensor: - output = torch.empty( - hidden_states.shape[0], - weight.shape[0], - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - torch.ops._moe_C.gpt_oss_router_gemm(output, hidden_states, weight, bias) - return output - - def topk_softmax( topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index e8ed8a5249d1..a4a833e26b29 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -7,16 +7,51 @@ from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer from vllm.utils.torch_utils import direct_register_custom_op +if has_flashinfer(): + + def flashinfer_tinygemm_impl( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + ) -> torch.Tensor: + from flashinfer.gemm.routergemm import tinygemm_bf16 + + if x.shape[0] <= 128: + output = torch.empty( + x.shape[0], + weight.shape[0], + dtype=torch.bfloat16, + device=x.device, + ) + tinygemm_bf16(x, weight, output, bias) + return output + else: + return torch.nn.functional.linear(x, weight, bias) + + def flashinfer_tinygemm_fake( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + ) -> torch.Tensor: + return x.new_empty((x.shape[0], weight.shape[0])) + + direct_register_custom_op( + op_name="flashinfer_tinygemm", + op_func=flashinfer_tinygemm_impl, + fake_impl=flashinfer_tinygemm_fake, + ) + @PluggableLayer.register("gate_linear") class GateLinear(ReplicatedLinear): - """MoE gate linear layer with three-tier GEMM dispatch: + """MoE gate linear layer with four-tier GEMM dispatch: 1. DSV3 specialized kernel (SM90+, batch<=16, supported dims) - 2. gpt-oss specialized kernel (SM90+, batch<=128, supported dims) - 3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) + 2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) + 3. Flashinfer tinygemm_bf16 kernel (SM90+, batch<=128, supported dims) 4. F.linear via ReplicatedLinear (ultimate fallback) The ``out_dtype`` attribute is mutable and can be set after init @@ -28,10 +63,6 @@ class GateLinear(ReplicatedLinear): DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] DSV3_SUPPORTED_HIDDEN_SIZES = [7168] - # Dimensions supported by the gpt-oss specialized kernel - GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128] - GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880] - def __init__( self, input_size: int, @@ -73,12 +104,14 @@ def __init__( ) # gpt-oss specialized kernel eligibility (SM90+, exact dims) - self.allow_gpt_oss_router_gemm = ( + self.allow_flashinfer_tinygemm = ( self.weight.dtype == torch.bfloat16 + and (self.out_dtype is None or self.out_dtype == torch.bfloat16) and current_platform.is_cuda() and is_hopper_or_blackwell - and output_size in self.GPT_OSS_SUPPORTED_NUM_EXPERTS - and input_size in self.GPT_OSS_SUPPORTED_HIDDEN_SIZES + and has_flashinfer() + and input_size % 64 == 0 + and output_size % 16 == 0 ) # cuBLAS bf16→fp32 eligibility @@ -117,16 +150,16 @@ def forward( ) return output, None - # Tier 2: gpt-oss specialized kernel - if self.allow_gpt_oss_router_gemm: - output = torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias) - return output, None - - # Tier 3: cuBLAS bf16→fp32 + # Tier 2: cuBLAS bf16→fp32 if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16: output = ops.router_gemm_bf16_fp32(x, self.weight) return output, None + # Tier 3: Flashinfer tinygemm_bf16 + if self.allow_flashinfer_tinygemm and x.dtype == torch.bfloat16: + output = torch.ops.vllm.flashinfer_tinygemm(x, self.weight, self.bias) + return output, None + # Tier 4: F.linear (ReplicatedLinear) if self.out_dtype is not None and x.dtype != self.weight.dtype: x = x.to(self.weight.dtype) @@ -134,30 +167,3 @@ def forward( if self.out_dtype is not None and output.dtype != self.out_dtype: output = output.to(self.out_dtype) return output, output_bias - - -def gpt_oss_router_gemm_impl( - x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -) -> torch.Tensor: - """ - Dynamically run min-latency gemm if num_tokens <= 128. - This must be wrapped in a custom op because our torch.compile integration - does not support runtime dispatching on num_tokens. - """ - if x.shape[0] <= 128: - return ops.gpt_oss_router_gemm(x, weight, bias) - else: - return torch.nn.functional.linear(x, weight, bias) - - -def gpt_oss_router_gemm_fake( - x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -) -> torch.Tensor: - return x.new_empty((x.shape[0], weight.shape[0])) - - -direct_register_custom_op( - op_name="gpt_oss_router_gemm", - op_func=gpt_oss_router_gemm_impl, - fake_impl=gpt_oss_router_gemm_fake, -) From 4ff3f4eaee724071e1b7e0358b2b61821ec3a19e Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Tue, 17 Mar 2026 18:40:27 -0700 Subject: [PATCH 11/11] Revert "Use flashinfer tinygemm" This reverts commit ac52c1dbc5a75d05a987b1f1d89e6beee1506cc1. Signed-off-by: Xin Yang --- CMakeLists.txt | 1 + benchmarks/kernels/benchmark_router_gemm.py | 134 ++++++ csrc/moe/gpt_oss_router_gemm.cu | 144 ++++++ csrc/moe/gpt_oss_router_gemm.cuh | 447 ++++++++++++++++++ csrc/moe/moe_ops.h | 4 + csrc/moe/torch_bindings.cpp | 6 + tests/kernels/moe/test_router_gemm.py | 37 ++ vllm/_custom_ops.py | 13 + .../layers/fused_moe/router/gate_linear.py | 92 ++-- 9 files changed, 829 insertions(+), 49 deletions(-) create mode 100644 benchmarks/kernels/benchmark_router_gemm.py create mode 100644 csrc/moe/gpt_oss_router_gemm.cu create mode 100644 csrc/moe/gpt_oss_router_gemm.cuh create mode 100644 tests/kernels/moe/test_router_gemm.py diff --git a/CMakeLists.txt b/CMakeLists.txt index bbadfdc5e9e3..693070b5f476 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -999,6 +999,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu" "csrc/moe/grouped_topk_kernels.cu" + "csrc/moe/gpt_oss_router_gemm.cu" "csrc/moe/router_gemm.cu") endif() diff --git a/benchmarks/kernels/benchmark_router_gemm.py b/benchmarks/kernels/benchmark_router_gemm.py new file mode 100644 index 000000000000..cc63f8904c27 --- /dev/null +++ b/benchmarks/kernels/benchmark_router_gemm.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.transformers_utils.config import get_config +from vllm.triton_utils import triton +from vllm.utils.argparse_utils import FlexibleArgumentParser + +# Dimensions supported by the DSV3 specialized kernel +DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] +DSV3_SUPPORTED_HIDDEN_SIZES = [7168] + +# Dimensions supported by the gpt-oss specialized kernel +GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128] +GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880] + + +def get_batch_size_range(max_batch_size): + return [2**x for x in range(14) if 2**x <= max_batch_size] + + +def get_model_params(config): + if config.architectures[0] in ( + "DeepseekV2ForCausalLM", + "DeepseekV3ForCausalLM", + "DeepseekV32ForCausalLM", + ): + num_experts = config.n_routed_experts + hidden_size = config.hidden_size + elif config.architectures[0] in ("GptOssForCausalLM",): + num_experts = config.num_local_experts + hidden_size = config.hidden_size + else: + raise ValueError(f"Unsupported architecture: {config.architectures}") + return num_experts, hidden_size + + +def get_benchmark(model, max_batch_size, trust_remote_code): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=get_batch_size_range(max_batch_size), + x_log=False, + line_arg="provider", + line_vals=[ + "torch", + "vllm", + ], + line_names=["PyTorch", "vLLM"], + styles=([("blue", "-"), ("red", "-")]), + ylabel="TFLOPs", + plot_name=f"{model} router gemm throughput", + args={}, + ) + ) + def benchmark(batch_size, provider): + config = get_config(model=model, trust_remote_code=trust_remote_code) + num_experts, hidden_size = get_model_params(config) + + mat_a = torch.randn( + (batch_size, hidden_size), dtype=torch.bfloat16, device="cuda" + ).contiguous() + mat_b = torch.randn( + (num_experts, hidden_size), dtype=torch.bfloat16, device="cuda" + ).contiguous() + bias = torch.randn( + num_experts, dtype=torch.bfloat16, device="cuda" + ).contiguous() + + is_hopper_or_blackwell = current_platform.is_device_capability( + 90 + ) or current_platform.is_device_capability_family(100) + allow_dsv3_router_gemm = ( + is_hopper_or_blackwell + and num_experts in DSV3_SUPPORTED_NUM_EXPERTS + and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES + ) + allow_gpt_oss_router_gemm = ( + is_hopper_or_blackwell + and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS + and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES + ) + + has_bias = False + if allow_gpt_oss_router_gemm: + has_bias = True + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch": + + def runner(): + if has_bias: + F.linear(mat_a, mat_b, bias) + else: + F.linear(mat_a, mat_b) + elif provider == "vllm": + + def runner(): + if allow_dsv3_router_gemm: + ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16) + elif allow_gpt_oss_router_gemm: + ops.gpt_oss_router_gemm(mat_a, mat_b, bias) + else: + raise ValueError("Unsupported router gemm") + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + runner, quantiles=quantiles + ) + + def tflops(t_ms): + flops = 2 * batch_size * hidden_size * num_experts + return flops / (t_ms * 1e-3) / 1e12 + + return tflops(ms), tflops(max_ms), tflops(min_ms) + + return benchmark + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--model", type=str, default="openai/gpt-oss-20b") + parser.add_argument("--max-batch-size", default=16, type=int) + parser.add_argument("--trust-remote-code", action="store_true") + args = parser.parse_args() + + # Get the benchmark function + benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code) + # Run performance benchmark + benchmark.run(print_data=True) diff --git a/csrc/moe/gpt_oss_router_gemm.cu b/csrc/moe/gpt_oss_router_gemm.cu new file mode 100644 index 000000000000..0294cd36aa8f --- /dev/null +++ b/csrc/moe/gpt_oss_router_gemm.cu @@ -0,0 +1,144 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "gpt_oss_router_gemm.cuh" + +void launch_gpt_oss_router_gemm(__nv_bfloat16* gA, __nv_bfloat16* gB, + __nv_bfloat16* gC, __nv_bfloat16* bias, + int batch_size, int output_features, + int input_features, cudaStream_t stream) { + static int const WARP_TILE_M = 16; + static int const TILE_M = WARP_TILE_M; + static int const TILE_N = 8; + static int const TILE_K = 64; + static int const STAGES = 16; + static int const STAGE_UNROLL = 4; + static bool const PROFILE = false; + + CUtensorMap weight_map{}; + CUtensorMap activation_map{}; + + constexpr uint32_t rank = 2; + uint64_t size[rank] = {(uint64_t)input_features, (uint64_t)output_features}; + uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)}; + uint32_t box_size[rank] = {TILE_K, TILE_M}; + uint32_t elem_stride[rank] = {1, 1}; + + CUresult res = cuTensorMapEncodeTiled( + &weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank, + gB, size, stride, box_size, elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + TORCH_CHECK(res == CUDA_SUCCESS, + "cuTensorMapEncodeTiled failed for weight_map, error code=", + static_cast(res)); + + size[1] = batch_size; + box_size[1] = TILE_N; + + res = cuTensorMapEncodeTiled( + &activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + rank, gA, size, stride, box_size, elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + TORCH_CHECK(res == CUDA_SUCCESS, + "cuTensorMapEncodeTiled failed for activation_map, error code=", + static_cast(res)); + + int smem_size = STAGES * STAGE_UNROLL * + (TILE_M * TILE_K * sizeof(__nv_bfloat16) + + TILE_N * TILE_K * sizeof(__nv_bfloat16)); + + gpuErrChk(cudaFuncSetAttribute( + gpt_oss_router_gemm_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + int tiles_m = (output_features + TILE_M - 1) / TILE_M; + int tiles_n = (batch_size + TILE_N - 1) / TILE_N; + + dim3 grid(tiles_m, tiles_n); + dim3 block(384); + + cudaLaunchConfig_t config; + cudaLaunchAttribute attrs[1]; + config.gridDim = grid; + config.blockDim = block; + config.dynamicSmemBytes = smem_size; + config.stream = stream; + config.attrs = attrs; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = 1; + config.numAttrs = 1; + + cudaLaunchKernelEx( + &config, + &gpt_oss_router_gemm_kernel, + gC, gA, gB, bias, output_features, batch_size, input_features, weight_map, + activation_map, nullptr); +} + +void gpt_oss_router_gemm_cuda_forward(torch::Tensor& output, + torch::Tensor input, torch::Tensor weight, + torch::Tensor bias) { + auto const batch_size = input.size(0); + auto const input_dim = input.size(1); + auto const output_dim = weight.size(0); + + auto stream = at::cuda::getCurrentCUDAStream(); + + if (input.scalar_type() == at::ScalarType::BFloat16) { + launch_gpt_oss_router_gemm((__nv_bfloat16*)input.data_ptr(), + (__nv_bfloat16*)weight.data_ptr(), + (__nv_bfloat16*)output.mutable_data_ptr(), + (__nv_bfloat16*)bias.data_ptr(), batch_size, + output_dim, input_dim, stream); + } else { + throw std::invalid_argument("Unsupported dtype, only supports bfloat16"); + } +} + +void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input, + torch::Tensor weight, torch::Tensor bias) { + TORCH_CHECK(input.dim() == 2, "input must be 2D"); + TORCH_CHECK(weight.dim() == 2, "weight must be 2D"); + TORCH_CHECK(bias.dim() == 1, "bias must be 1D"); + TORCH_CHECK(input.sizes()[1] == weight.sizes()[1], + "input.size(1) must match weight.size(1)"); + TORCH_CHECK(weight.sizes()[0] == bias.sizes()[0], + "weight.size(0) must match bias.size(0)"); + TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16, + "input tensor must be bfloat16"); + TORCH_CHECK(weight.scalar_type() == at::ScalarType::BFloat16, + "weight tensor must be bfloat16"); + TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16, + "bias tensor must be bfloat16"); + gpt_oss_router_gemm_cuda_forward(output, input, weight, bias); +} diff --git a/csrc/moe/gpt_oss_router_gemm.cuh b/csrc/moe/gpt_oss_router_gemm.cuh new file mode 100644 index 000000000000..5cc653f19cfb --- /dev/null +++ b/csrc/moe/gpt_oss_router_gemm.cuh @@ -0,0 +1,447 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cuda_bf16.h" +#include +#include +#include + +#include "cuda_pipeline.h" +#include +#include +#include +#include + +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; +namespace ptx = cuda::ptx; + +#define gpuErrChk(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + } + +inline void gpuAssert(cudaError_t code, char const* file, int line, + bool abort = true) { + if (code != cudaSuccess) { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, + line); + if (abort) { + throw std::runtime_error(cudaGetErrorString(code)); + } + } +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +__device__ uint64_t gclock64() { + unsigned long long int rv; + asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(rv)); + return rv; +} + +__device__ void ldmatrix(__nv_bfloat16 rv[2], uint32_t smem_ptr) { + int dst; + asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(smem_ptr)); + int* rvi = reinterpret_cast(&rv[0]); + rvi[0] = dst; +} + +__device__ void ldmatrix2(__nv_bfloat16 rv[4], uint32_t smem_ptr) { + int x, y; + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(x), "=r"(y) + : "r"(smem_ptr)); + + int* rvi = reinterpret_cast(&rv[0]); + rvi[0] = x; + rvi[1] = y; +} + +__device__ void ldmatrix4(__nv_bfloat16 rv[8], uint32_t smem_ptr) { + int x, y, z, w; + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(smem_ptr)); + int* rvi = reinterpret_cast(&rv[0]); + rvi[0] = x; + rvi[1] = y; + rvi[2] = z; + rvi[3] = w; +} + +__device__ void HMMA_1688(float d[4], __nv_bfloat16 a[4], __nv_bfloat16 b[2], + float c[4]) { + uint32_t const* A = reinterpret_cast(&a[0]); + uint32_t const* B = reinterpret_cast(&b[0]); + float const* C = reinterpret_cast(&c[0]); + float* D = reinterpret_cast(&d[0]); + + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), + "f"(C[3])); +} + +__device__ void HMMA_16816(float d[4], __nv_bfloat16 a[8], __nv_bfloat16 b[4], + float c[4]) { + uint32_t const* A = reinterpret_cast(&a[0]); + uint32_t const* B = reinterpret_cast(&b[0]); + float const* C = reinterpret_cast(&c[0]); + float* D = reinterpret_cast(&d[0]); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); +} + +__device__ void bar_wait(uint32_t bar_ptr, int phase) { + asm volatile( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" ::"r"(bar_ptr), + "r"(phase)); +} + +__device__ bool bar_try_wait(uint32_t bar_ptr, int phase) { + uint32_t success; + #ifdef INTERNAL + asm volatile(".pragma \"set knob DontInsertYield\";\n" : : : "memory"); + #endif + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(success) + : "r"(bar_ptr), "r"(phase)); + return success; +} + +__device__ uint32_t elect_one_sync() { + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +} +#endif + +struct Profile { + uint64_t start; + uint64_t weight_load_start; + uint64_t act_load_start; + uint64_t compute_start; + uint64_t complete; +}; + +template +__global__ __launch_bounds__(384, 1) void gpt_oss_router_gemm_kernel( + __nv_bfloat16* output, __nv_bfloat16* weights, __nv_bfloat16* activations, + __nv_bfloat16* bias, int M, int N, int K, + const __grid_constant__ CUtensorMap weight_map, + const __grid_constant__ CUtensorMap activation_map, + Profile* profile = nullptr) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + + if (PROFILE && threadIdx.x == 0 && blockIdx.y == 0) + profile[blockIdx.x].start = gclock64(); + + extern __shared__ __align__(128) char smem[]; + + __nv_bfloat16* sh_weights = (__nv_bfloat16*)&smem[0]; + __nv_bfloat16* sh_activations = + (__nv_bfloat16*)&smem[STAGES * STAGE_UNROLL * TILE_M * TILE_K * + sizeof(__nv_bfloat16)]; + + #pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ barrier bar_wt_ready[STAGES]; + __shared__ barrier bar_act_ready[STAGES]; + __shared__ barrier bar_data_consumed[STAGES]; + + __shared__ float4 reduction_buffer[128]; + + __shared__ nv_bfloat16 sh_bias[TILE_M]; + + if (threadIdx.x == 0) { + for (int i = 0; i < STAGES; i++) { + init(&bar_wt_ready[i], 1); + init(&bar_act_ready[i], 1); + init(&bar_data_consumed[i], 32); + } + ptx::fence_proxy_async(ptx::space_shared); + asm volatile("prefetch.tensormap [%0];" + : + : "l"(reinterpret_cast(&weight_map)) + : "memory"); + asm volatile("prefetch.tensormap [%0];" + : + : "l"(reinterpret_cast(&activation_map)) + : "memory"); + } + __syncthreads(); + + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + + int phase = 0; + + int mib = blockIdx.x * TILE_M; + int ni = blockIdx.y * TILE_N; + + float accum[4]; + for (int i = 0; i < 4; i++) accum[i] = 0.f; + + int const K_LOOPS_DMA = + (K + 4 * TILE_K * STAGE_UNROLL - 1) / (4 * (TILE_K * STAGE_UNROLL)); + int const K_LOOPS_COMPUTE = K_LOOPS_DMA; + + // Data loading thread + if (warp_id >= 4 && elect_one_sync()) { + int stage = warp_id % 4; + + bool weight_warp = warp_id < 8; + if (!weight_warp) { + cudaGridDependencySynchronize(); + cudaTriggerProgrammaticLaunchCompletion(); + } + + for (int ki = 0; ki < K_LOOPS_DMA; ki++) { + int k = (ki * 4 + (warp_id % 4)) * TILE_K * STAGE_UNROLL; + + uint64_t desc_ptr_wt = reinterpret_cast(&weight_map); + uint64_t desc_ptr_act = reinterpret_cast(&activation_map); + + uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]); + uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]); + int bytes_wt = TILE_M * TILE_K * sizeof(__nv_bfloat16); + int bytes_act = TILE_N * TILE_K * sizeof(__nv_bfloat16); + + bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1); + + if (weight_warp) + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + : + : "r"(bar_ptr_wt), "r"(STAGE_UNROLL * bytes_wt)); + if (!weight_warp) + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + : + : "r"(bar_ptr_act), "r"(STAGE_UNROLL * bytes_act)); + + if (PROFILE && blockIdx.y == 0 && ki == 0 && weight_warp) + profile[blockIdx.x].weight_load_start = gclock64(); + if (PROFILE && blockIdx.y == 0 && ki == 0 && !weight_warp) + profile[blockIdx.x].act_load_start = gclock64(); + + for (int i = 0; i < STAGE_UNROLL; i++) { + uint32_t smem_ptr_wt = __cvta_generic_to_shared( + &sh_weights[(stage * STAGE_UNROLL + i) * TILE_M * TILE_K]); + uint32_t crd0 = k + i * TILE_K; + uint32_t crd1 = mib; + if (weight_warp) + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_" + "tx::bytes [%0], [%1, {%3,%4}], " + "[%2];" + : + : "r"(smem_ptr_wt), "l"(desc_ptr_wt), "r"(bar_ptr_wt), "r"(crd0), + "r"(crd1) + : "memory"); + + uint32_t smem_ptr_act = __cvta_generic_to_shared( + &sh_activations[(stage * STAGE_UNROLL + i) * TILE_N * TILE_K]); + crd0 = k + i * TILE_K; + crd1 = ni; + if (!weight_warp) + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_" + "tx::bytes [%0], [%1, {%3,%4}], " + "[%2];" + : + : "r"(smem_ptr_act), "l"(desc_ptr_act), "r"(bar_ptr_act), + "r"(crd0), "r"(crd1) + : "memory"); + } + + stage += 4; + if (stage >= STAGES) { + stage = warp_id % 4; + phase ^= 1; + } + } + // Wait for pending loads to be consumed before exiting, to avoid race + for (int i = 0; i < (STAGES / 4) - 1; i++) { + bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1); + stage += 4; + if (stage >= STAGES) { + stage = warp_id % 4; + phase ^= 1; + } + } + } + // Compute threads + else if (warp_id < 4) { + // Sneak the bias load into the compute warps since they're just waiting for + // stuff anyway + if (threadIdx.x < TILE_M) sh_bias[threadIdx.x] = bias[mib + threadIdx.x]; + + int stage = warp_id; + + int phase = 0; + int lane_id_div8 = lane_id / 8; + int lane_id_mod8 = lane_id % 8; + + int lane_row_offset_wt = (lane_id_div8 % 2) ? 8 : 0; + int lane_col_offset_wt = (lane_id_div8 / 2) ? 1 : 0; + + int row_wt = lane_id_mod8 + lane_row_offset_wt; + int row_act = lane_id_mod8; + + int row_offset_wt = (reinterpret_cast(sh_weights) / 128) % 8; + int row_offset_act = row_offset_wt; + + uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]); + uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]); + + bool weight_ready = bar_try_wait(bar_ptr_wt, phase); + bool act_ready = bar_try_wait(bar_ptr_act, phase); + + #pragma unroll 2 + for (int ki = 0; ki < K_LOOPS_COMPUTE; ki++) { + int next_stage = stage + 4; + int next_phase = phase; + if (next_stage >= STAGES) { + next_stage = warp_id; + next_phase ^= 1; + } + + while (!weight_ready || !act_ready) { + weight_ready = bar_try_wait(bar_ptr_wt, phase); + act_ready = bar_try_wait(bar_ptr_act, phase); + } + + if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0 && ki == 0) + profile[blockIdx.x].compute_start = gclock64(); + + if (ki + 1 < K_LOOPS_COMPUTE) { + weight_ready = bar_try_wait( + __cvta_generic_to_shared(&bar_wt_ready[next_stage]), next_phase); + act_ready = bar_try_wait( + __cvta_generic_to_shared(&bar_act_ready[next_stage]), next_phase); + } + + #pragma unroll + for (int su = 0; su < STAGE_UNROLL; su++) { + __nv_bfloat16* ptr_weights = + &sh_weights[(stage * STAGE_UNROLL + su) * TILE_M * TILE_K]; + __nv_bfloat16* ptr_act = + &sh_activations[(stage * STAGE_UNROLL + su) * TILE_N * TILE_K]; + + #pragma unroll + for (int kii = 0; kii < TILE_K / 16; kii++) { + __nv_bfloat16 a[8]; + __nv_bfloat16 b[4]; + + int col = 2 * kii + lane_col_offset_wt; + int col_sw = ((row_wt + row_offset_wt) % 8) ^ col; + + ldmatrix4(a, __cvta_generic_to_shared( + &ptr_weights[row_wt * TILE_K + col_sw * 8])); + + col = 2 * kii + lane_id_div8; + col_sw = ((row_act + row_offset_act) % 8) ^ col; + + ldmatrix2(b, __cvta_generic_to_shared( + &ptr_act[row_act * TILE_K + 8 * col_sw])); + + HMMA_16816(accum, a, b, accum); + } + } + + uint32_t bar_c = __cvta_generic_to_shared(&bar_data_consumed[stage]); + asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0];" : : "r"(bar_c)); + + stage = next_stage; + phase = next_phase; + } + + float4 accum4; + accum4.x = accum[0]; + accum4.y = accum[1]; + accum4.z = accum[2]; + accum4.w = accum[3]; + reduction_buffer[threadIdx.x] = accum4; + + __syncthreads(); + + if (warp_id == 0) { + int mi = mib + warp_id * WARP_TILE_M; + int tm = mi + lane_id / 4; + int tn = ni + 2 * (lane_id % 4); + + float4 accum1 = reduction_buffer[32 + threadIdx.x]; + float4 accum2 = reduction_buffer[64 + threadIdx.x]; + float4 accum3 = reduction_buffer[96 + threadIdx.x]; + + accum[0] = accum[0] + accum1.x + accum2.x + accum3.x; + accum[1] = accum[1] + accum1.y + accum2.y + accum3.y; + accum[2] = accum[2] + accum1.z + accum2.z + accum3.z; + accum[3] = accum[3] + accum1.w + accum2.w + accum3.w; + + float bias_lo = __bfloat162float(sh_bias[tm - mib]); + float bias_hi = __bfloat162float(sh_bias[tm + 8 - mib]); + + if (tn < N && tm < M) + output[tn * M + tm] = __float2bfloat16(accum[0] + bias_lo); + if (tn + 1 < N && tm < M) + output[(tn + 1) * M + tm] = __float2bfloat16(accum[1] + bias_lo); + if (tn < N && tm + 8 < M) + output[tn * M + tm + 8] = __float2bfloat16(accum[2] + bias_hi); + if (tn + 1 < N && tm + 8 < M) + output[(tn + 1) * M + tm + 8] = __float2bfloat16(accum[3] + bias_hi); + + if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0) + profile[blockIdx.x].complete = gclock64(); + } + } +#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index d8d962887dab..de931dc76467 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -70,4 +70,8 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input, // Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168 void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); + +// gpt-oss optimized router GEMM kernel for SM90+ +void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input, + torch::Tensor weight, torch::Tensor bias); #endif diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 7b627a6f8760..4cd74366ea4d 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -132,6 +132,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // DeepSeek V3 optimized router GEMM for SM90+ m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); // conditionally compiled so impl registration is in source file + + // gpt-oss optimized router GEMM kernel for SM90+ + m.def( + "gpt_oss_router_gemm(Tensor! output, Tensor input, Tensor weights, " + "Tensor bias) -> ()"); + m.impl("gpt_oss_router_gemm", torch::kCUDA, &gpt_oss_router_gemm); #endif } diff --git a/tests/kernels/moe/test_router_gemm.py b/tests/kernels/moe/test_router_gemm.py new file mode 100644 index 000000000000..906e47708f29 --- /dev/null +++ b/tests/kernels/moe/test_router_gemm.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for optimized router GEMM kernel + +Run `pytest tests/kernels/moe/test_router_gemm.py`. +""" + +import pytest +import torch + +import vllm._custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_random_seed + + +@pytest.mark.skipif( + not ( + current_platform.is_cuda() + and ( + current_platform.is_device_capability(90) + or current_platform.is_device_capability_family(100) + ) + ), + reason="This test only runs on Hopper or Blackwell GPUs.", +) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("input_dim", [360, 720, 1440, 2880]) +@pytest.mark.parametrize("output_dim", [32, 64, 128]) +def test_gpt_oss_router_gemm(batch_size, input_dim, output_dim): + set_random_seed(0) + x = torch.randn(batch_size, input_dim, device="cuda", dtype=torch.bfloat16) + weight = torch.randn(output_dim, input_dim, device="cuda", dtype=torch.bfloat16) + bias = torch.randn(output_dim, device="cuda", dtype=torch.bfloat16) + + output = ops.gpt_oss_router_gemm(x, weight, bias) + output_ref = torch.nn.functional.linear(x, weight, bias) + torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a01f44e1649d..a45caac7c9e2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2362,6 +2362,19 @@ def dsv3_router_gemm( return output +def gpt_oss_router_gemm( + hidden_states: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + output = torch.empty( + hidden_states.shape[0], + weight.shape[0], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + torch.ops._moe_C.gpt_oss_router_gemm(output, hidden_states, weight, bias) + return output + + def topk_softmax( topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index a4a833e26b29..e8ed8a5249d1 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -7,51 +7,16 @@ from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.platforms import current_platform -from vllm.utils.flashinfer import has_flashinfer from vllm.utils.torch_utils import direct_register_custom_op -if has_flashinfer(): - - def flashinfer_tinygemm_impl( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - ) -> torch.Tensor: - from flashinfer.gemm.routergemm import tinygemm_bf16 - - if x.shape[0] <= 128: - output = torch.empty( - x.shape[0], - weight.shape[0], - dtype=torch.bfloat16, - device=x.device, - ) - tinygemm_bf16(x, weight, output, bias) - return output - else: - return torch.nn.functional.linear(x, weight, bias) - - def flashinfer_tinygemm_fake( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - ) -> torch.Tensor: - return x.new_empty((x.shape[0], weight.shape[0])) - - direct_register_custom_op( - op_name="flashinfer_tinygemm", - op_func=flashinfer_tinygemm_impl, - fake_impl=flashinfer_tinygemm_fake, - ) - @PluggableLayer.register("gate_linear") class GateLinear(ReplicatedLinear): - """MoE gate linear layer with four-tier GEMM dispatch: + """MoE gate linear layer with three-tier GEMM dispatch: 1. DSV3 specialized kernel (SM90+, batch<=16, supported dims) - 2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) - 3. Flashinfer tinygemm_bf16 kernel (SM90+, batch<=128, supported dims) + 2. gpt-oss specialized kernel (SM90+, batch<=128, supported dims) + 3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) 4. F.linear via ReplicatedLinear (ultimate fallback) The ``out_dtype`` attribute is mutable and can be set after init @@ -63,6 +28,10 @@ class GateLinear(ReplicatedLinear): DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] DSV3_SUPPORTED_HIDDEN_SIZES = [7168] + # Dimensions supported by the gpt-oss specialized kernel + GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128] + GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880] + def __init__( self, input_size: int, @@ -104,14 +73,12 @@ def __init__( ) # gpt-oss specialized kernel eligibility (SM90+, exact dims) - self.allow_flashinfer_tinygemm = ( + self.allow_gpt_oss_router_gemm = ( self.weight.dtype == torch.bfloat16 - and (self.out_dtype is None or self.out_dtype == torch.bfloat16) and current_platform.is_cuda() and is_hopper_or_blackwell - and has_flashinfer() - and input_size % 64 == 0 - and output_size % 16 == 0 + and output_size in self.GPT_OSS_SUPPORTED_NUM_EXPERTS + and input_size in self.GPT_OSS_SUPPORTED_HIDDEN_SIZES ) # cuBLAS bf16→fp32 eligibility @@ -150,14 +117,14 @@ def forward( ) return output, None - # Tier 2: cuBLAS bf16→fp32 - if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16: - output = ops.router_gemm_bf16_fp32(x, self.weight) + # Tier 2: gpt-oss specialized kernel + if self.allow_gpt_oss_router_gemm: + output = torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias) return output, None - # Tier 3: Flashinfer tinygemm_bf16 - if self.allow_flashinfer_tinygemm and x.dtype == torch.bfloat16: - output = torch.ops.vllm.flashinfer_tinygemm(x, self.weight, self.bias) + # Tier 3: cuBLAS bf16→fp32 + if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16: + output = ops.router_gemm_bf16_fp32(x, self.weight) return output, None # Tier 4: F.linear (ReplicatedLinear) @@ -167,3 +134,30 @@ def forward( if self.out_dtype is not None and output.dtype != self.out_dtype: output = output.to(self.out_dtype) return output, output_bias + + +def gpt_oss_router_gemm_impl( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + """ + Dynamically run min-latency gemm if num_tokens <= 128. + This must be wrapped in a custom op because our torch.compile integration + does not support runtime dispatching on num_tokens. + """ + if x.shape[0] <= 128: + return ops.gpt_oss_router_gemm(x, weight, bias) + else: + return torch.nn.functional.linear(x, weight, bias) + + +def gpt_oss_router_gemm_fake( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + return x.new_empty((x.shape[0], weight.shape[0])) + + +direct_register_custom_op( + op_name="gpt_oss_router_gemm", + op_func=gpt_oss_router_gemm_impl, + fake_impl=gpt_oss_router_gemm_fake, +)