From e85f927703e345de20cc38122b5975fbe25b421e Mon Sep 17 00:00:00 2001 From: Rain Jiang Date: Tue, 26 Aug 2025 02:48:19 +0000 Subject: [PATCH 1/3] support cuda 13.0 and cuda 12.8 --- sgl-kernel/CMakeLists.txt | 23 +++++++++------- .../moe/marlin_moe_wna16/generate_kernels.py | 27 +++++++++++++++++-- sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h | 1 + ...kernel_bf16_ku4.cu => kernel_bf16_ku4.cuh} | 1 + ...el_bf16_ku4b8.cu => kernel_bf16_ku4b8.cuh} | 1 + ...f16_ku8b128.cu => kernel_bf16_ku8b128.cuh} | 1 + ...kernel_fp16_ku4.cu => kernel_fp16_ku4.cuh} | 1 + ...el_fp16_ku4b8.cu => kernel_fp16_ku4b8.cuh} | 1 + ...p16_ku8b128.cu => kernel_fp16_ku8b128.cuh} | 1 + .../moe/marlin_moe_wna16/kernel_marlin.cuh | 10 +++++++ .../moe/marlin_moe_wna16/marlin_template.h | 2 ++ sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu | 1 + .../csrc/moe/moe_topk_softmax_kernels.cu | 16 ++++++++--- 13 files changed, 71 insertions(+), 15 deletions(-) rename sgl-kernel/csrc/moe/marlin_moe_wna16/{kernel_bf16_ku4.cu => kernel_bf16_ku4.cuh} (99%) rename sgl-kernel/csrc/moe/marlin_moe_wna16/{kernel_bf16_ku4b8.cu => kernel_bf16_ku4b8.cuh} (99%) rename sgl-kernel/csrc/moe/marlin_moe_wna16/{kernel_bf16_ku8b128.cu => kernel_bf16_ku8b128.cuh} (99%) rename sgl-kernel/csrc/moe/marlin_moe_wna16/{kernel_fp16_ku4.cu => kernel_fp16_ku4.cuh} (99%) rename sgl-kernel/csrc/moe/marlin_moe_wna16/{kernel_fp16_ku4b8.cu => kernel_fp16_ku4b8.cuh} (99%) rename sgl-kernel/csrc/moe/marlin_moe_wna16/{kernel_fp16_ku8b128.cu => kernel_fp16_ku8b128.cuh} (99%) create mode 100644 sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 09ec8b00fe3..40d696cd976 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -45,7 +45,7 @@ include(FetchContent) FetchContent_Declare( repo-cutlass GIT_REPOSITORY https://github.com/NVIDIA/cutlass - GIT_TAG 664c4f7b3ed1959414905025728eef5568209479 + GIT_TAG a49a78ffefc86a87160dfe0ccc3a3a2d1622c918 GIT_SHALLOW OFF ) FetchContent_Populate(repo-cutlass) @@ -57,6 +57,9 @@ if("${CUDA_VERSION}" VERSION_EQUAL "12.8") elseif("${CUDA_VERSION}" VERSION_EQUAL "12.9") set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM") set(DeepGEMM_TAG "blackwell") +elseif("${CUDA_VERSION}" VERSION_EQUAL "13.0") + set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM") + set(DeepGEMM_TAG "e38c2e31033dc6880d92eff4977c40f2eb6cff4a") else() set(DeepGEMM_REPO "https://github.com/deepseek-ai/DeepGEMM") set(DeepGEMM_TAG "391755ada0ffefa9a6a52b6f14dcaf22d1a463e0") @@ -83,7 +86,7 @@ FetchContent_Populate(repo-triton) FetchContent_Declare( repo-flashinfer GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git - GIT_TAG 9220fb3443b5a5d274f00ca5552f798e225239b7 + GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3 GIT_SHALLOW OFF ) FetchContent_Populate(repo-flashinfer) @@ -162,6 +165,7 @@ set(SGL_KERNEL_CUDA_FLAGS ) option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF) +option(SGL_KERNEL_ENABLE_SM101A "Enable SM101A" OFF) option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF) option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON) option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON) @@ -175,12 +179,17 @@ if (ENABLE_BELOW_SM90) ) endif() +if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" AND SGL_KERNEL_ENABLE_SM101A) + list(APPEND SGL_KERNEL_CUDA_FLAGS + "-gencode=arch=compute_101,code=sm_101" + "-gencode=arch=compute_101a,code=sm_101a" + ) +endif() + if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) list(APPEND SGL_KERNEL_CUDA_FLAGS "-gencode=arch=compute_100,code=sm_100" "-gencode=arch=compute_100a,code=sm_100a" - "-gencode=arch=compute_101,code=sm_101" - "-gencode=arch=compute_101a,code=sm_101a" "-gencode=arch=compute_120,code=sm_120" "-gencode=arch=compute_120a,code=sm_120a" ) @@ -266,12 +275,6 @@ set(SOURCES "csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu" "csrc/moe/marlin_moe_wna16/ops.cu" - "csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu" - "csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu" - "csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu" - "csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu" - "csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu" - "csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu" "csrc/moe/moe_align_kernel.cu" "csrc/moe/moe_fused_gate.cu" "csrc/moe/moe_topk_softmax_kernels.cu" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py b/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py index 833d074ea30..b3ed863a3a1 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -9,6 +9,7 @@ FILE_HEAD = """ // auto generated by generate.py // clang-format off +#pragma once #include "kernel.h" #include "marlin_template.h" @@ -33,6 +34,17 @@ "( MARLIN_KERNEL_PARAMS );" ) +KERNEL_FILE_TEMPLATE = ( + "// auto generated by generate.py\n" + "// clang-format off\n" + "#pragma once\n\n" + "{% for kernel_file in kernel_files %}" + '#include "{{ kernel_file }}"\n' + "{% endfor %}" +) + +KERNEL_FILE_NAME = "kernel_marlin.cuh" + # int8 with zero point case (sglang::kU8) is also supported, # we don't add it to reduce wheel size. SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"] @@ -48,11 +60,12 @@ def remove_old_kernels(): - for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cuh"): subprocess.call(["rm", "-f", filename]) def generate_new_kernels(): + kernel_files = set() for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): has_zp = "B" not in scalar_type all_template_str_list = [] @@ -95,10 +108,20 @@ def generate_new_kernels(): file_content = FILE_HEAD + "\n\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" - filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu" + filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cuh" with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: f.write(file_content) + kernel_files.add(filename) + + kernel_files = list(kernel_files) + kernel_files.sort() + + file_content = jinja2.Template(KERNEL_FILE_TEMPLATE).render( + kernel_files=kernel_files + ) + with open(os.path.join(os.path.dirname(__file__), KERNEL_FILE_NAME), "w") as f: + f.write(file_content) if __name__ == "__main__": diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h index 88d157507a0..afa7c377b17 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h @@ -1,3 +1,4 @@ +#pragma once #ifndef MARLIN_NAMESPACE_NAME #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh index 1e3d923aee0..7e83bed8f2f 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh @@ -1,5 +1,6 @@ // auto generated by generate.py // clang-format off +#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh index 513ddc2ed1e..60e2dea3199 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh @@ -1,5 +1,6 @@ // auto generated by generate.py // clang-format off +#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh index eebe9d3daa1..7eb6b18de6f 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh @@ -1,5 +1,6 @@ // auto generated by generate.py // clang-format off +#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh index 9adc6623a5e..ec41e018b41 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh @@ -1,5 +1,6 @@ // auto generated by generate.py // clang-format off +#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh index 66ca7e36a2b..7df28701b04 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh @@ -1,5 +1,6 @@ // auto generated by generate.py // clang-format off +#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh index 21fdf0c1a21..1150844e235 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh @@ -1,5 +1,6 @@ // auto generated by generate.py // clang-format off +#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh new file mode 100644 index 00000000000..bb828dc5b3d --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh @@ -0,0 +1,10 @@ +// auto generated by generate.py +// clang-format off +#pragma once + +#include "kernel_bf16_ku4.cuh" +#include "kernel_bf16_ku4b8.cuh" +#include "kernel_bf16_ku8b128.cuh" +#include "kernel_fp16_ku4.cuh" +#include "kernel_fp16_ku4b8.cuh" +#include "kernel_fp16_ku8b128.cuh" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h b/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h index 71c91839dcc..ade562af64d 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -18,6 +18,8 @@ /* * Adapted from https://github.com/IST-DASLab/marlin */ +#pragma once + #ifndef MARLIN_NAMESPACE_NAME #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #endif diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu index f430390d148..b249f64156d 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu @@ -24,6 +24,7 @@ #endif #include "kernel.h" +#include "kernel_marlin.cuh" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert( \ diff --git a/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu b/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu index 050e8d52be9..c9bc8a628de 100644 --- a/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu +++ b/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu @@ -23,6 +23,7 @@ limitations under the License. #ifndef USE_ROCM #include #include +#include #else #include #include @@ -33,6 +34,16 @@ limitations under the License. #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) +// Define reduction operators based on CUDA version +// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum +#if CUDA_VERSION >= 12090 +using MaxReduceOp = cuda::maximum<>; +using MinReduceOp = cuda::minimum<>; +#else +using MaxReduceOp = cub::Max; +using MinReduceOp = cub::Min; +#endif + /// Aligned array type template < typename T, @@ -72,7 +83,6 @@ __launch_bounds__(TPB) __global__ const int thread_row_offset = blockIdx.x * num_cols; - cub::Sum sum; float threadData(-FLT_MAX); // Don't touch finished rows. @@ -85,7 +95,7 @@ __launch_bounds__(TPB) __global__ threadData = max(convert_to_float(input[idx]), threadData); } - const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp()); if (threadIdx.x == 0) { float_max = maxElem; @@ -99,7 +109,7 @@ __launch_bounds__(TPB) __global__ threadData += exp((convert_to_float(input[idx]) - float_max)); } - const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + const auto Z = BlockReduce(tmpStorage).Sum(threadData); if (threadIdx.x == 0) { normalizing_factor = 1.f / Z; From fc62038c04de07e8325f10c6f300a855d0ad5f10 Mon Sep 17 00:00:00 2001 From: Rain Jiang Date: Tue, 26 Aug 2025 17:42:16 +0000 Subject: [PATCH 2/3] update the nvcc flags and the arch 110 121 support on cuda 130 --- sgl-kernel/CMakeLists.txt | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 40d696cd976..5dced93bf0e 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -165,7 +165,6 @@ set(SGL_KERNEL_CUDA_FLAGS ) option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF) -option(SGL_KERNEL_ENABLE_SM101A "Enable SM101A" OFF) option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF) option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON) option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON) @@ -179,13 +178,6 @@ if (ENABLE_BELOW_SM90) ) endif() -if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" AND SGL_KERNEL_ENABLE_SM101A) - list(APPEND SGL_KERNEL_CUDA_FLAGS - "-gencode=arch=compute_101,code=sm_101" - "-gencode=arch=compute_101a,code=sm_101a" - ) -endif() - if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) list(APPEND SGL_KERNEL_CUDA_FLAGS "-gencode=arch=compute_100,code=sm_100" @@ -193,6 +185,23 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) "-gencode=arch=compute_120,code=sm_120" "-gencode=arch=compute_120a,code=sm_120a" ) + + # refer sm_121, sm_110 and sm_101 description https://github.com/pytorch/pytorch/pull/156176 + if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0") + list(APPEND SGL_KERNEL_CUDA_FLAGS + "-gencode=arch=compute_110,code=sm_110" + "-gencode=arch=compute_110a,code=sm_110a" + "-gencode=arch=compute_121,code=sm_121" + "-gencode=arch=compute_121a,code=sm_121a" + "--compress-mode=size" + ) + else() + list(APPEND SGL_KERNEL_CUDA_FLAGS + "-gencode=arch=compute_101,code=sm_101" + "-gencode=arch=compute_101a,code=sm_101a" + ) + endif() + else() list(APPEND SGL_KERNEL_CUDA_FLAGS "-use_fast_math" From b59d071a8446f57433ea3d9328f9ba6743d5fe48 Mon Sep 17 00:00:00 2001 From: Rain Jiang Date: Wed, 27 Aug 2025 06:00:22 +0000 Subject: [PATCH 3/3] rollback deep_gemm and add B300 suppor --- sgl-kernel/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 5dced93bf0e..9752914356f 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -59,7 +59,7 @@ elseif("${CUDA_VERSION}" VERSION_EQUAL "12.9") set(DeepGEMM_TAG "blackwell") elseif("${CUDA_VERSION}" VERSION_EQUAL "13.0") set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM") - set(DeepGEMM_TAG "e38c2e31033dc6880d92eff4977c40f2eb6cff4a") + set(DeepGEMM_TAG "blackwell") else() set(DeepGEMM_REPO "https://github.com/deepseek-ai/DeepGEMM") set(DeepGEMM_TAG "391755ada0ffefa9a6a52b6f14dcaf22d1a463e0") @@ -182,6 +182,8 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) list(APPEND SGL_KERNEL_CUDA_FLAGS "-gencode=arch=compute_100,code=sm_100" "-gencode=arch=compute_100a,code=sm_100a" + "-gencode=arch=compute_103,code=sm_103" + "-gencode=arch=compute_103a,code=sm_103a" "-gencode=arch=compute_120,code=sm_120" "-gencode=arch=compute_120a,code=sm_120a" )