diff --git a/CMakeLists.txt b/CMakeLists.txt index bbadfdc5e9e3..693070b5f476 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -999,6 +999,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu" "csrc/moe/grouped_topk_kernels.cu" + "csrc/moe/gpt_oss_router_gemm.cu" "csrc/moe/router_gemm.cu") endif() diff --git a/benchmarks/kernels/benchmark_router_gemm.py b/benchmarks/kernels/benchmark_router_gemm.py new file mode 100644 index 000000000000..cc63f8904c27 --- /dev/null +++ b/benchmarks/kernels/benchmark_router_gemm.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.transformers_utils.config import get_config +from vllm.triton_utils import triton +from vllm.utils.argparse_utils import FlexibleArgumentParser + +# Dimensions supported by the DSV3 specialized kernel +DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] +DSV3_SUPPORTED_HIDDEN_SIZES = [7168] + +# Dimensions supported by the gpt-oss specialized kernel +GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128] +GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880] + + +def get_batch_size_range(max_batch_size): + return [2**x for x in range(14) if 2**x <= max_batch_size] + + +def get_model_params(config): + if config.architectures[0] in ( + "DeepseekV2ForCausalLM", + "DeepseekV3ForCausalLM", + "DeepseekV32ForCausalLM", + ): + num_experts = config.n_routed_experts + hidden_size = config.hidden_size + elif config.architectures[0] in ("GptOssForCausalLM",): + num_experts = config.num_local_experts + hidden_size = config.hidden_size + else: + raise ValueError(f"Unsupported architecture: {config.architectures}") + return num_experts, hidden_size + + +def get_benchmark(model, max_batch_size, trust_remote_code): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=get_batch_size_range(max_batch_size), + x_log=False, + line_arg="provider", + line_vals=[ + "torch", + "vllm", + ], + line_names=["PyTorch", "vLLM"], + styles=([("blue", "-"), ("red", "-")]), + ylabel="TFLOPs", + plot_name=f"{model} router gemm throughput", + args={}, + ) + ) + def benchmark(batch_size, provider): + config = get_config(model=model, trust_remote_code=trust_remote_code) + num_experts, hidden_size = get_model_params(config) + + mat_a = torch.randn( + (batch_size, hidden_size), dtype=torch.bfloat16, device="cuda" + ).contiguous() + mat_b = torch.randn( + (num_experts, hidden_size), dtype=torch.bfloat16, device="cuda" + ).contiguous() + bias = torch.randn( + num_experts, dtype=torch.bfloat16, device="cuda" + ).contiguous() + + is_hopper_or_blackwell = current_platform.is_device_capability( + 90 + ) or current_platform.is_device_capability_family(100) + allow_dsv3_router_gemm = ( + is_hopper_or_blackwell + and num_experts in DSV3_SUPPORTED_NUM_EXPERTS + and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES + ) + allow_gpt_oss_router_gemm = ( + is_hopper_or_blackwell + and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS + and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES + ) + + has_bias = False + if allow_gpt_oss_router_gemm: + has_bias = True + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch": + + def runner(): + if has_bias: + F.linear(mat_a, mat_b, bias) + else: + F.linear(mat_a, mat_b) + elif provider == "vllm": + + def runner(): + if allow_dsv3_router_gemm: + ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16) + elif allow_gpt_oss_router_gemm: + ops.gpt_oss_router_gemm(mat_a, mat_b, bias) + else: + raise ValueError("Unsupported router gemm") + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + runner, quantiles=quantiles + ) + + def tflops(t_ms): + flops = 2 * batch_size * hidden_size * num_experts + return flops / (t_ms * 1e-3) / 1e12 + + return tflops(ms), tflops(max_ms), tflops(min_ms) + + return benchmark + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--model", type=str, default="openai/gpt-oss-20b") + parser.add_argument("--max-batch-size", default=16, type=int) + parser.add_argument("--trust-remote-code", action="store_true") + args = parser.parse_args() + + # Get the benchmark function + benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code) + # Run performance benchmark + benchmark.run(print_data=True) diff --git a/csrc/moe/gpt_oss_router_gemm.cu b/csrc/moe/gpt_oss_router_gemm.cu new file mode 100644 index 000000000000..0294cd36aa8f --- /dev/null +++ b/csrc/moe/gpt_oss_router_gemm.cu @@ -0,0 +1,144 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "gpt_oss_router_gemm.cuh" + +void launch_gpt_oss_router_gemm(__nv_bfloat16* gA, __nv_bfloat16* gB, + __nv_bfloat16* gC, __nv_bfloat16* bias, + int batch_size, int output_features, + int input_features, cudaStream_t stream) { + static int const WARP_TILE_M = 16; + static int const TILE_M = WARP_TILE_M; + static int const TILE_N = 8; + static int const TILE_K = 64; + static int const STAGES = 16; + static int const STAGE_UNROLL = 4; + static bool const PROFILE = false; + + CUtensorMap weight_map{}; + CUtensorMap activation_map{}; + + constexpr uint32_t rank = 2; + uint64_t size[rank] = {(uint64_t)input_features, (uint64_t)output_features}; + uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)}; + uint32_t box_size[rank] = {TILE_K, TILE_M}; + uint32_t elem_stride[rank] = {1, 1}; + + CUresult res = cuTensorMapEncodeTiled( + &weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank, + gB, size, stride, box_size, elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + TORCH_CHECK(res == CUDA_SUCCESS, + "cuTensorMapEncodeTiled failed for weight_map, error code=", + static_cast(res)); + + size[1] = batch_size; + box_size[1] = TILE_N; + + res = cuTensorMapEncodeTiled( + &activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + rank, gA, size, stride, box_size, elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + TORCH_CHECK(res == CUDA_SUCCESS, + "cuTensorMapEncodeTiled failed for activation_map, error code=", + static_cast(res)); + + int smem_size = STAGES * STAGE_UNROLL * + (TILE_M * TILE_K * sizeof(__nv_bfloat16) + + TILE_N * TILE_K * sizeof(__nv_bfloat16)); + + gpuErrChk(cudaFuncSetAttribute( + gpt_oss_router_gemm_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + int tiles_m = (output_features + TILE_M - 1) / TILE_M; + int tiles_n = (batch_size + TILE_N - 1) / TILE_N; + + dim3 grid(tiles_m, tiles_n); + dim3 block(384); + + cudaLaunchConfig_t config; + cudaLaunchAttribute attrs[1]; + config.gridDim = grid; + config.blockDim = block; + config.dynamicSmemBytes = smem_size; + config.stream = stream; + config.attrs = attrs; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = 1; + config.numAttrs = 1; + + cudaLaunchKernelEx( + &config, + &gpt_oss_router_gemm_kernel, + gC, gA, gB, bias, output_features, batch_size, input_features, weight_map, + activation_map, nullptr); +} + +void gpt_oss_router_gemm_cuda_forward(torch::Tensor& output, + torch::Tensor input, torch::Tensor weight, + torch::Tensor bias) { + auto const batch_size = input.size(0); + auto const input_dim = input.size(1); + auto const output_dim = weight.size(0); + + auto stream = at::cuda::getCurrentCUDAStream(); + + if (input.scalar_type() == at::ScalarType::BFloat16) { + launch_gpt_oss_router_gemm((__nv_bfloat16*)input.data_ptr(), + (__nv_bfloat16*)weight.data_ptr(), + (__nv_bfloat16*)output.mutable_data_ptr(), + (__nv_bfloat16*)bias.data_ptr(), batch_size, + output_dim, input_dim, stream); + } else { + throw std::invalid_argument("Unsupported dtype, only supports bfloat16"); + } +} + +void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input, + torch::Tensor weight, torch::Tensor bias) { + TORCH_CHECK(input.dim() == 2, "input must be 2D"); + TORCH_CHECK(weight.dim() == 2, "weight must be 2D"); + TORCH_CHECK(bias.dim() == 1, "bias must be 1D"); + TORCH_CHECK(input.sizes()[1] == weight.sizes()[1], + "input.size(1) must match weight.size(1)"); + TORCH_CHECK(weight.sizes()[0] == bias.sizes()[0], + "weight.size(0) must match bias.size(0)"); + TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16, + "input tensor must be bfloat16"); + TORCH_CHECK(weight.scalar_type() == at::ScalarType::BFloat16, + "weight tensor must be bfloat16"); + TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16, + "bias tensor must be bfloat16"); + gpt_oss_router_gemm_cuda_forward(output, input, weight, bias); +} diff --git a/csrc/moe/gpt_oss_router_gemm.cuh b/csrc/moe/gpt_oss_router_gemm.cuh new file mode 100644 index 000000000000..5cc653f19cfb --- /dev/null +++ b/csrc/moe/gpt_oss_router_gemm.cuh @@ -0,0 +1,447 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cuda_bf16.h" +#include +#include +#include + +#include "cuda_pipeline.h" +#include +#include +#include +#include + +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; +namespace ptx = cuda::ptx; + +#define gpuErrChk(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + } + +inline void gpuAssert(cudaError_t code, char const* file, int line, + bool abort = true) { + if (code != cudaSuccess) { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, + line); + if (abort) { + throw std::runtime_error(cudaGetErrorString(code)); + } + } +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +__device__ uint64_t gclock64() { + unsigned long long int rv; + asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(rv)); + return rv; +} + +__device__ void ldmatrix(__nv_bfloat16 rv[2], uint32_t smem_ptr) { + int dst; + asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(smem_ptr)); + int* rvi = reinterpret_cast(&rv[0]); + rvi[0] = dst; +} + +__device__ void ldmatrix2(__nv_bfloat16 rv[4], uint32_t smem_ptr) { + int x, y; + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(x), "=r"(y) + : "r"(smem_ptr)); + + int* rvi = reinterpret_cast(&rv[0]); + rvi[0] = x; + rvi[1] = y; +} + +__device__ void ldmatrix4(__nv_bfloat16 rv[8], uint32_t smem_ptr) { + int x, y, z, w; + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(smem_ptr)); + int* rvi = reinterpret_cast(&rv[0]); + rvi[0] = x; + rvi[1] = y; + rvi[2] = z; + rvi[3] = w; +} + +__device__ void HMMA_1688(float d[4], __nv_bfloat16 a[4], __nv_bfloat16 b[2], + float c[4]) { + uint32_t const* A = reinterpret_cast(&a[0]); + uint32_t const* B = reinterpret_cast(&b[0]); + float const* C = reinterpret_cast(&c[0]); + float* D = reinterpret_cast(&d[0]); + + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), + "f"(C[3])); +} + +__device__ void HMMA_16816(float d[4], __nv_bfloat16 a[8], __nv_bfloat16 b[4], + float c[4]) { + uint32_t const* A = reinterpret_cast(&a[0]); + uint32_t const* B = reinterpret_cast(&b[0]); + float const* C = reinterpret_cast(&c[0]); + float* D = reinterpret_cast(&d[0]); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); +} + +__device__ void bar_wait(uint32_t bar_ptr, int phase) { + asm volatile( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" ::"r"(bar_ptr), + "r"(phase)); +} + +__device__ bool bar_try_wait(uint32_t bar_ptr, int phase) { + uint32_t success; + #ifdef INTERNAL + asm volatile(".pragma \"set knob DontInsertYield\";\n" : : : "memory"); + #endif + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(success) + : "r"(bar_ptr), "r"(phase)); + return success; +} + +__device__ uint32_t elect_one_sync() { + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +} +#endif + +struct Profile { + uint64_t start; + uint64_t weight_load_start; + uint64_t act_load_start; + uint64_t compute_start; + uint64_t complete; +}; + +template +__global__ __launch_bounds__(384, 1) void gpt_oss_router_gemm_kernel( + __nv_bfloat16* output, __nv_bfloat16* weights, __nv_bfloat16* activations, + __nv_bfloat16* bias, int M, int N, int K, + const __grid_constant__ CUtensorMap weight_map, + const __grid_constant__ CUtensorMap activation_map, + Profile* profile = nullptr) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + + if (PROFILE && threadIdx.x == 0 && blockIdx.y == 0) + profile[blockIdx.x].start = gclock64(); + + extern __shared__ __align__(128) char smem[]; + + __nv_bfloat16* sh_weights = (__nv_bfloat16*)&smem[0]; + __nv_bfloat16* sh_activations = + (__nv_bfloat16*)&smem[STAGES * STAGE_UNROLL * TILE_M * TILE_K * + sizeof(__nv_bfloat16)]; + + #pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ barrier bar_wt_ready[STAGES]; + __shared__ barrier bar_act_ready[STAGES]; + __shared__ barrier bar_data_consumed[STAGES]; + + __shared__ float4 reduction_buffer[128]; + + __shared__ nv_bfloat16 sh_bias[TILE_M]; + + if (threadIdx.x == 0) { + for (int i = 0; i < STAGES; i++) { + init(&bar_wt_ready[i], 1); + init(&bar_act_ready[i], 1); + init(&bar_data_consumed[i], 32); + } + ptx::fence_proxy_async(ptx::space_shared); + asm volatile("prefetch.tensormap [%0];" + : + : "l"(reinterpret_cast(&weight_map)) + : "memory"); + asm volatile("prefetch.tensormap [%0];" + : + : "l"(reinterpret_cast(&activation_map)) + : "memory"); + } + __syncthreads(); + + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + + int phase = 0; + + int mib = blockIdx.x * TILE_M; + int ni = blockIdx.y * TILE_N; + + float accum[4]; + for (int i = 0; i < 4; i++) accum[i] = 0.f; + + int const K_LOOPS_DMA = + (K + 4 * TILE_K * STAGE_UNROLL - 1) / (4 * (TILE_K * STAGE_UNROLL)); + int const K_LOOPS_COMPUTE = K_LOOPS_DMA; + + // Data loading thread + if (warp_id >= 4 && elect_one_sync()) { + int stage = warp_id % 4; + + bool weight_warp = warp_id < 8; + if (!weight_warp) { + cudaGridDependencySynchronize(); + cudaTriggerProgrammaticLaunchCompletion(); + } + + for (int ki = 0; ki < K_LOOPS_DMA; ki++) { + int k = (ki * 4 + (warp_id % 4)) * TILE_K * STAGE_UNROLL; + + uint64_t desc_ptr_wt = reinterpret_cast(&weight_map); + uint64_t desc_ptr_act = reinterpret_cast(&activation_map); + + uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]); + uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]); + int bytes_wt = TILE_M * TILE_K * sizeof(__nv_bfloat16); + int bytes_act = TILE_N * TILE_K * sizeof(__nv_bfloat16); + + bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1); + + if (weight_warp) + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + : + : "r"(bar_ptr_wt), "r"(STAGE_UNROLL * bytes_wt)); + if (!weight_warp) + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + : + : "r"(bar_ptr_act), "r"(STAGE_UNROLL * bytes_act)); + + if (PROFILE && blockIdx.y == 0 && ki == 0 && weight_warp) + profile[blockIdx.x].weight_load_start = gclock64(); + if (PROFILE && blockIdx.y == 0 && ki == 0 && !weight_warp) + profile[blockIdx.x].act_load_start = gclock64(); + + for (int i = 0; i < STAGE_UNROLL; i++) { + uint32_t smem_ptr_wt = __cvta_generic_to_shared( + &sh_weights[(stage * STAGE_UNROLL + i) * TILE_M * TILE_K]); + uint32_t crd0 = k + i * TILE_K; + uint32_t crd1 = mib; + if (weight_warp) + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_" + "tx::bytes [%0], [%1, {%3,%4}], " + "[%2];" + : + : "r"(smem_ptr_wt), "l"(desc_ptr_wt), "r"(bar_ptr_wt), "r"(crd0), + "r"(crd1) + : "memory"); + + uint32_t smem_ptr_act = __cvta_generic_to_shared( + &sh_activations[(stage * STAGE_UNROLL + i) * TILE_N * TILE_K]); + crd0 = k + i * TILE_K; + crd1 = ni; + if (!weight_warp) + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_" + "tx::bytes [%0], [%1, {%3,%4}], " + "[%2];" + : + : "r"(smem_ptr_act), "l"(desc_ptr_act), "r"(bar_ptr_act), + "r"(crd0), "r"(crd1) + : "memory"); + } + + stage += 4; + if (stage >= STAGES) { + stage = warp_id % 4; + phase ^= 1; + } + } + // Wait for pending loads to be consumed before exiting, to avoid race + for (int i = 0; i < (STAGES / 4) - 1; i++) { + bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1); + stage += 4; + if (stage >= STAGES) { + stage = warp_id % 4; + phase ^= 1; + } + } + } + // Compute threads + else if (warp_id < 4) { + // Sneak the bias load into the compute warps since they're just waiting for + // stuff anyway + if (threadIdx.x < TILE_M) sh_bias[threadIdx.x] = bias[mib + threadIdx.x]; + + int stage = warp_id; + + int phase = 0; + int lane_id_div8 = lane_id / 8; + int lane_id_mod8 = lane_id % 8; + + int lane_row_offset_wt = (lane_id_div8 % 2) ? 8 : 0; + int lane_col_offset_wt = (lane_id_div8 / 2) ? 1 : 0; + + int row_wt = lane_id_mod8 + lane_row_offset_wt; + int row_act = lane_id_mod8; + + int row_offset_wt = (reinterpret_cast(sh_weights) / 128) % 8; + int row_offset_act = row_offset_wt; + + uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]); + uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]); + + bool weight_ready = bar_try_wait(bar_ptr_wt, phase); + bool act_ready = bar_try_wait(bar_ptr_act, phase); + + #pragma unroll 2 + for (int ki = 0; ki < K_LOOPS_COMPUTE; ki++) { + int next_stage = stage + 4; + int next_phase = phase; + if (next_stage >= STAGES) { + next_stage = warp_id; + next_phase ^= 1; + } + + while (!weight_ready || !act_ready) { + weight_ready = bar_try_wait(bar_ptr_wt, phase); + act_ready = bar_try_wait(bar_ptr_act, phase); + } + + if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0 && ki == 0) + profile[blockIdx.x].compute_start = gclock64(); + + if (ki + 1 < K_LOOPS_COMPUTE) { + weight_ready = bar_try_wait( + __cvta_generic_to_shared(&bar_wt_ready[next_stage]), next_phase); + act_ready = bar_try_wait( + __cvta_generic_to_shared(&bar_act_ready[next_stage]), next_phase); + } + + #pragma unroll + for (int su = 0; su < STAGE_UNROLL; su++) { + __nv_bfloat16* ptr_weights = + &sh_weights[(stage * STAGE_UNROLL + su) * TILE_M * TILE_K]; + __nv_bfloat16* ptr_act = + &sh_activations[(stage * STAGE_UNROLL + su) * TILE_N * TILE_K]; + + #pragma unroll + for (int kii = 0; kii < TILE_K / 16; kii++) { + __nv_bfloat16 a[8]; + __nv_bfloat16 b[4]; + + int col = 2 * kii + lane_col_offset_wt; + int col_sw = ((row_wt + row_offset_wt) % 8) ^ col; + + ldmatrix4(a, __cvta_generic_to_shared( + &ptr_weights[row_wt * TILE_K + col_sw * 8])); + + col = 2 * kii + lane_id_div8; + col_sw = ((row_act + row_offset_act) % 8) ^ col; + + ldmatrix2(b, __cvta_generic_to_shared( + &ptr_act[row_act * TILE_K + 8 * col_sw])); + + HMMA_16816(accum, a, b, accum); + } + } + + uint32_t bar_c = __cvta_generic_to_shared(&bar_data_consumed[stage]); + asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0];" : : "r"(bar_c)); + + stage = next_stage; + phase = next_phase; + } + + float4 accum4; + accum4.x = accum[0]; + accum4.y = accum[1]; + accum4.z = accum[2]; + accum4.w = accum[3]; + reduction_buffer[threadIdx.x] = accum4; + + __syncthreads(); + + if (warp_id == 0) { + int mi = mib + warp_id * WARP_TILE_M; + int tm = mi + lane_id / 4; + int tn = ni + 2 * (lane_id % 4); + + float4 accum1 = reduction_buffer[32 + threadIdx.x]; + float4 accum2 = reduction_buffer[64 + threadIdx.x]; + float4 accum3 = reduction_buffer[96 + threadIdx.x]; + + accum[0] = accum[0] + accum1.x + accum2.x + accum3.x; + accum[1] = accum[1] + accum1.y + accum2.y + accum3.y; + accum[2] = accum[2] + accum1.z + accum2.z + accum3.z; + accum[3] = accum[3] + accum1.w + accum2.w + accum3.w; + + float bias_lo = __bfloat162float(sh_bias[tm - mib]); + float bias_hi = __bfloat162float(sh_bias[tm + 8 - mib]); + + if (tn < N && tm < M) + output[tn * M + tm] = __float2bfloat16(accum[0] + bias_lo); + if (tn + 1 < N && tm < M) + output[(tn + 1) * M + tm] = __float2bfloat16(accum[1] + bias_lo); + if (tn < N && tm + 8 < M) + output[tn * M + tm + 8] = __float2bfloat16(accum[2] + bias_hi); + if (tn + 1 < N && tm + 8 < M) + output[(tn + 1) * M + tm + 8] = __float2bfloat16(accum[3] + bias_hi); + + if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0) + profile[blockIdx.x].complete = gclock64(); + } + } +#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index d8d962887dab..de931dc76467 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -70,4 +70,8 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input, // Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168 void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); + +// gpt-oss optimized router GEMM kernel for SM90+ +void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input, + torch::Tensor weight, torch::Tensor bias); #endif diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 7b627a6f8760..4cd74366ea4d 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -132,6 +132,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // DeepSeek V3 optimized router GEMM for SM90+ m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); // conditionally compiled so impl registration is in source file + + // gpt-oss optimized router GEMM kernel for SM90+ + m.def( + "gpt_oss_router_gemm(Tensor! output, Tensor input, Tensor weights, " + "Tensor bias) -> ()"); + m.impl("gpt_oss_router_gemm", torch::kCUDA, &gpt_oss_router_gemm); #endif } diff --git a/tests/kernels/moe/test_router_gemm.py b/tests/kernels/moe/test_router_gemm.py new file mode 100644 index 000000000000..906e47708f29 --- /dev/null +++ b/tests/kernels/moe/test_router_gemm.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for optimized router GEMM kernel + +Run `pytest tests/kernels/moe/test_router_gemm.py`. +""" + +import pytest +import torch + +import vllm._custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_random_seed + + +@pytest.mark.skipif( + not ( + current_platform.is_cuda() + and ( + current_platform.is_device_capability(90) + or current_platform.is_device_capability_family(100) + ) + ), + reason="This test only runs on Hopper or Blackwell GPUs.", +) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("input_dim", [360, 720, 1440, 2880]) +@pytest.mark.parametrize("output_dim", [32, 64, 128]) +def test_gpt_oss_router_gemm(batch_size, input_dim, output_dim): + set_random_seed(0) + x = torch.randn(batch_size, input_dim, device="cuda", dtype=torch.bfloat16) + weight = torch.randn(output_dim, input_dim, device="cuda", dtype=torch.bfloat16) + bias = torch.randn(output_dim, device="cuda", dtype=torch.bfloat16) + + output = ops.gpt_oss_router_gemm(x, weight, bias) + output_ref = torch.nn.functional.linear(x, weight, bias) + torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a01f44e1649d..a45caac7c9e2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2362,6 +2362,19 @@ def dsv3_router_gemm( return output +def gpt_oss_router_gemm( + hidden_states: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + output = torch.empty( + hidden_states.shape[0], + weight.shape[0], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + torch.ops._moe_C.gpt_oss_router_gemm(output, hidden_states, weight, bias) + return output + + def topk_softmax( topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py index 1f3fdea2cdaf..235f40b73852 100644 --- a/vllm/lora/layers/__init__.py +++ b/vllm/lora/layers/__init__.py @@ -13,6 +13,7 @@ QKVParallelLinearWithShardedLoRA, ) from vllm.lora.layers.fused_moe import FusedMoE3DWithLoRA, FusedMoEWithLoRA +from vllm.lora.layers.gate_linear import GateLinearWithLoRA from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA from vllm.lora.layers.row_parallel_linear import ( @@ -38,6 +39,7 @@ "RowParallelLinearWithLoRA", "RowParallelLinearWithShardedLoRA", "ReplicatedLinearWithLoRA", + "GateLinearWithLoRA", "LoRAMapping", "LoRAMappingType", "FusedMoEWithLoRA", diff --git a/vllm/lora/layers/gate_linear.py b/vllm/lora/layers/gate_linear.py new file mode 100644 index 000000000000..9bcaaa5b8e20 --- /dev/null +++ b/vllm/lora/layers/gate_linear.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.model_executor.custom_op import maybe_get_oot_by_class +from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear + +from .replicated_linear import ReplicatedLinearWithLoRA + + +class GateLinearWithLoRA(ReplicatedLinearWithLoRA): + def __init__(self, base_layer: GateLinear) -> None: + super().__init__( + base_layer, + ) + + # GateLinearWithLoRA should always be replaced, regardless of the fully + # sharded LoRAs setting, because it is, by definition, copied per GPU. + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: PretrainedConfig | None = None, + ) -> bool: + return type(source_layer) is maybe_get_oot_by_class(GateLinear) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 2349ace70846..75ed9674af56 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -21,6 +21,7 @@ ColumnParallelLinearWithShardedLoRA, FusedMoE3DWithLoRA, FusedMoEWithLoRA, + GateLinearWithLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearVariableSliceWithLoRA, MergedColumnParallelLinearWithLoRA, @@ -81,6 +82,7 @@ def get_lora_id(): MergedQKVParallelLinearWithLoRA, RowParallelLinearWithLoRA, ReplicatedLinearWithLoRA, + GateLinearWithLoRA, LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index 77d8e756026d..e8ed8a5249d1 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -3,9 +3,11 @@ import torch from torch.nn.parameter import Parameter +import vllm._custom_ops as ops from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op @PluggableLayer.register("gate_linear") @@ -13,8 +15,9 @@ class GateLinear(ReplicatedLinear): """MoE gate linear layer with three-tier GEMM dispatch: 1. DSV3 specialized kernel (SM90+, batch<=16, supported dims) - 2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) - 3. F.linear via ReplicatedLinear (ultimate fallback) + 2. gpt-oss specialized kernel (SM90+, batch<=128, supported dims) + 3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) + 4. F.linear via ReplicatedLinear (ultimate fallback) The ``out_dtype`` attribute is mutable and can be set after init (e.g. when the required dtype depends on the expert quantization @@ -25,6 +28,10 @@ class GateLinear(ReplicatedLinear): DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] DSV3_SUPPORTED_HIDDEN_SIZES = [7168] + # Dimensions supported by the gpt-oss specialized kernel + GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128] + GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880] + def __init__( self, input_size: int, @@ -65,6 +72,15 @@ def __init__( and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES ) + # gpt-oss specialized kernel eligibility (SM90+, exact dims) + self.allow_gpt_oss_router_gemm = ( + self.weight.dtype == torch.bfloat16 + and current_platform.is_cuda() + and is_hopper_or_blackwell + and output_size in self.GPT_OSS_SUPPORTED_NUM_EXPERTS + and input_size in self.GPT_OSS_SUPPORTED_HIDDEN_SIZES + ) + # cuBLAS bf16→fp32 eligibility self.allow_cublas_router_gemm = ( self.allow_specialized_router_gemm @@ -92,8 +108,6 @@ def set_out_dtype(self, out_dtype: torch.dtype) -> None: def forward( self, x: torch.Tensor ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: - import vllm._custom_ops as ops - # Tier 1: DSV3 specialized kernel if self.allow_dsv3_router_gemm and x.shape[0] <= 16: output = ops.dsv3_router_gemm( @@ -103,15 +117,47 @@ def forward( ) return output, None - # Tier 2: cuBLAS bf16→fp32 + # Tier 2: gpt-oss specialized kernel + if self.allow_gpt_oss_router_gemm: + output = torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias) + return output, None + + # Tier 3: cuBLAS bf16→fp32 if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16: output = ops.router_gemm_bf16_fp32(x, self.weight) return output, None - # Tier 3: F.linear (ReplicatedLinear) + # Tier 4: F.linear (ReplicatedLinear) if self.out_dtype is not None and x.dtype != self.weight.dtype: x = x.to(self.weight.dtype) output, output_bias = super().forward(x) if self.out_dtype is not None and output.dtype != self.out_dtype: output = output.to(self.out_dtype) return output, output_bias + + +def gpt_oss_router_gemm_impl( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + """ + Dynamically run min-latency gemm if num_tokens <= 128. + This must be wrapped in a custom op because our torch.compile integration + does not support runtime dispatching on num_tokens. + """ + if x.shape[0] <= 128: + return ops.gpt_oss_router_gemm(x, weight, bias) + else: + return torch.nn.functional.linear(x, weight, bias) + + +def gpt_oss_router_gemm_fake( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + return x.new_empty((x.shape[0], weight.shape[0])) + + +direct_register_custom_op( + op_name="gpt_oss_router_gemm", + op_func=gpt_oss_router_gemm_impl, + fake_impl=gpt_oss_router_gemm_fake, +) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index c3111489c0ca..482056250a1e 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -20,12 +20,11 @@ tensor_model_parallel_all_gather, ) from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -175,13 +174,11 @@ def __init__( self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 - self.router = ReplicatedLinear( + self.router = GateLinear( config.hidden_size, config.num_local_experts, bias=True, - quant_config=None, prefix=f"{prefix}.router", - return_bias=False, ) assert config.intermediate_size % self.world_size == 0 self.experts = FusedMoE( @@ -209,7 +206,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self, x[:, : self.hidden_size], self.router.weight, self.router.bias ) else: - g = self.router(x) + g, _ = self.router(x) x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size] if self.is_sequence_parallel: @@ -273,7 +270,6 @@ def __init__( self.config = vllm_config.model_config.hf_config self.quant_config = vllm_config.quant_config self.parallel_config = vllm_config.parallel_config - self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size,