diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f2c43fea284..2c34a62aad20 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -317,6 +317,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" "csrc/quantization/fp8/per_token_group_quant.cu") diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_entry.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_entry.cu new file mode 100644 index 000000000000..ac067d5da281 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_entry.cu @@ -0,0 +1,26 @@ +#include + +#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 +void cutlass_fp4_group_mm_sm100( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets); +#endif + +void cutlass_fp4_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) { +#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 + cutlass_fp4_group_mm_sm100(output, a, b, a_blockscale, b_blockscales, alphas, + problem_sizes, expert_offsets, sf_offsets); +#else + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_fp4_group_mm kernel, vLLM must " + "be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA " + "12.8 or above."); +#endif +} \ No newline at end of file diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu index 2c8df6144bf4..9fa26668e588 100644 --- a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -367,7 +367,7 @@ constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; CHECK_CONTIGUOUS(x, m); \ CHECK_TYPE(x, st, m) -void cutlass_fp4_group_mm( +void cutlass_fp4_group_mm_sm100( torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, const torch::Tensor& alphas, const torch::Tensor& problem_sizes,