Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 85 additions & 50 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
project(sgl-kernel LANGUAGES CXX CUDA)

# we only want to download 3rd, but not build them.
# FetchContent_MakeAvailable will build it.
cmake_policy(SET CMP0169 OLD)

include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)

set(BUILD_FA3, OFF)

find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)

enable_language(CUDA)
Expand All @@ -22,6 +24,8 @@ elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "11.8")
endif()

find_package(Torch REQUIRED)
# clean Torch Flag
clear_cuda_arches(CMAKE_FLAG)

include(FetchContent)

Expand Down Expand Up @@ -53,8 +57,8 @@ FetchContent_Populate(repo-flashinfer)
FetchContent_Declare(
repo-flash-attention
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
GIT_TAG sgl-kernel
GIT_SHALLOW OFF
GIT_TAG sgl-kernel
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flash-attention)

Expand Down Expand Up @@ -92,14 +96,13 @@ set(SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_90,code=sm_90"
"-std=c++17"
"-DFLASHINFER_ENABLE_F16"
"-DCUTE_USE_PACKED_TUPLE=1"
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
"-DCUTLASS_VERSIONS_GENERATED"
"-DCUTE_USE_PACKED_TUPLE=1"
"-DCUTLASS_TEST_LEVEL=0"
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
"--expt-relaxed-constexpr"
"--use_fast_math"
"-Xcompiler=-Wconversion"
"-Xcompiler=-fno-strict-aliasing"
)
Expand All @@ -122,6 +125,7 @@ else()
endif()

if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
set(BUILD_FA3 ON)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_90a,code=sm_90a"
)
Expand Down Expand Up @@ -152,30 +156,6 @@ string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")

# set flash-attention sources file
# BF16 source files
file(GLOB FA3_BF16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
file(GLOB FA3_BF16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})

# FP16 source files
file(GLOB FA3_FP16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
file(GLOB FA3_FP16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})

# FP8 source files
file(GLOB FA3_FP8_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
file(GLOB FA3_FP8_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})

set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})

set(SOURCES
"csrc/allreduce/trt_reduce_internal.cu"
"csrc/allreduce/trt_reduce_kernel.cu"
Expand All @@ -202,39 +182,94 @@ set(SOURCES
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/speculative/packbit.cu"
"csrc/torch_extension.cc"
"csrc/common_extension.cc"
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
"${FA3_GEN_SRCS}"
)

# Support abi3 for build
# set flash-attention sources file
# BF16 source files
if (BUILD_FA3)
set(SGL_FLASH_KERNEL_CUDA_FLAGS
"-DNDEBUG"
"-DOPERATOR_NAMESPACE=sgl-kernel"
"-O3"
"-Xcompiler"
"-fPIC"
"-gencode=arch=compute_90a,code=sm_90a"
"-std=c++17"
"-DCUTE_USE_PACKED_TUPLE=1"
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
"-DCUTLASS_VERSIONS_GENERATED"
"-DCUTLASS_TEST_LEVEL=0"
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
"--expt-relaxed-constexpr"
"--expt-extended-lambda"
"--use_fast_math"
"-Xcompiler=-Wconversion"
"-Xcompiler=-fno-strict-aliasing"
)

file(GLOB FA3_BF16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
file(GLOB FA3_BF16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})

# FP16 source files
file(GLOB FA3_FP16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
file(GLOB FA3_FP16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})

# FP8 source files
file(GLOB FA3_FP8_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
file(GLOB FA3_FP8_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})

set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})

set(FLASH_SOURCES
"csrc/flash_extension.cc"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
"${FA3_GEN_SRCS}"
)

Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})

target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
target_include_directories(flash_ops PRIVATE ${TORCH_INCLUDE_DIRS})
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)

install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")

target_compile_definitions(flash_ops PRIVATE
FLASHATTENTION_DISABLE_SM8x
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_VARLEN_ONLY
)
endif()

Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})

target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)

target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})

target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)

install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")

# Add some flash-attention custom flag for inference
target_compile_definitions(common_ops PRIVATE
FLASHATTENTION_DISABLE_SM8x
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_VARLEN_ONLY
)

# JIT Logic
# DeepGEMM

Expand Down
21 changes: 21 additions & 0 deletions sgl-kernel/cmake/utils.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Adapt from: https://github.com/neuralmagic/vllm-flash-attention/blob/main/cmake/utils.cmake
#
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
# `CUDA_ARCH_FLAGS`.
#
# Example:
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
# clear_cuda_arches(CUDA_ARCH_FLAGS)
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
# CMAKE_CUDA_FLAGS="-Wall"
#
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
${CMAKE_CUDA_FLAGS})

# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
# and passed back via the `CUDA_ARCHITECTURES` property.
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
${CMAKE_CUDA_FLAGS})
endmacro()
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.

#include "sgl_kernel_ops.h"

TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/*
* From csrc/allreduce
*/
Expand Down Expand Up @@ -202,45 +202,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);

/*
* From flash-attention
*/
m.def(
"fwd(Tensor! q,"
" Tensor k,"
" Tensor v,"
" Tensor? k_new,"
" Tensor? v_new,"
" Tensor? q_v,"
" Tensor!? out,"
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k,"
" Tensor? cu_seqlens_k_new,"
" Tensor? seqused_q,"
" Tensor? seqused_k,"
" int? max_seqlen_q,"
" int? max_seqlen_k,"
" Tensor? page_table,"
" Tensor? kv_batch_idx,"
" Tensor? leftpad_k,"
" Tensor? rotary_cos,"
" Tensor? rotary_sin,"
" Tensor? seqlens_rotary,"
" Tensor? q_descale,"
" Tensor? k_descale,"
" Tensor? v_descale,"
" float softmax_scale,"
" bool is_causal,"
" int window_size_left,"
" int window_size_right,"
" float softcap,"
" bool is_rotary_interleaved,"
" Tensor? scheduler_metadata,"
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin) -> Tensor[]");
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
}

REGISTER_EXTENSION(common_ops)
62 changes: 62 additions & 0 deletions sgl-kernel/csrc/flash_extension.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/all.h>
#include <torch/library.h>

#include "sgl_flash_kernel_ops.h"

TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/*
* From flash-attention
*/
m.def(
"fwd(Tensor! q,"
" Tensor k,"
" Tensor v,"
" Tensor? k_new,"
" Tensor? v_new,"
" Tensor? q_v,"
" Tensor!? out,"
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k,"
" Tensor? cu_seqlens_k_new,"
" Tensor? seqused_q,"
" Tensor? seqused_k,"
" int? max_seqlen_q,"
" int? max_seqlen_k,"
" Tensor? page_table,"
" Tensor? kv_batch_idx,"
" Tensor? leftpad_k,"
" Tensor? rotary_cos,"
" Tensor? rotary_sin,"
" Tensor? seqlens_rotary,"
" Tensor? q_descale,"
" Tensor? k_descale,"
" Tensor? v_descale,"
" float softmax_scale,"
" bool is_causal,"
" int window_size_left,"
" int window_size_right,"
" float softcap,"
" bool is_rotary_interleaved,"
" Tensor? scheduler_metadata,"
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin) -> Tensor[]");
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
}

REGISTER_EXTENSION(flash_ops)
Loading
Loading