Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ set(SGL_KERNEL_CUDA_FLAGS
"-O3"
"-Xcompiler"
"-fPIC"
"-gencode=arch=compute_75,code=sm_75"
"-gencode=arch=compute_80,code=sm_80"
# "-gencode=arch=compute_75,code=sm_75"
# "-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_89,code=sm_89"
"-gencode=arch=compute_90,code=sm_90"
# "-gencode=arch=compute_90,code=sm_90"
"-std=c++17"
"-DFLASHINFER_ENABLE_F16"
"-DCUTE_USE_PACKED_TUPLE=1"
Expand Down Expand Up @@ -130,10 +130,10 @@ else()
endif()

if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
set(SGL_KERNEL_ENABLE_FA3 ON)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_90a,code=sm_90a"
)
# set(SGL_KERNEL_ENABLE_FA3 ON)
# list(APPEND SGL_KERNEL_CUDA_FLAGS
# "-gencode=arch=compute_90a,code=sm_90a"
# )
endif()

if (SGL_KERNEL_ENABLE_BF16)
Expand Down Expand Up @@ -164,6 +164,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE
set(SOURCES
"csrc/allreduce/custom_all_reduce.cu"
"csrc/attention/cascade.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/attention/cutlass_mla_kernel.cu"
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/elementwise/activation.cu"
Expand Down
191 changes: 191 additions & 0 deletions sgl-kernel/csrc/attention/merge_attn_states.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
// Adapted from https://github.com/vllm-project/vllm/pull/16173
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <algorithm>
#include <optional>

#include "pytorch_extension_utils.h"

// Helper functions to convert between different data types
// (float, half, bfloat16) for the merge attention states kernel.
inline __device__ float to_float(float u) {
return u;
}
inline __device__ float to_float(half u) {
return __half2float(u);
}
inline __device__ float to_float(__nv_bfloat16 u) {
return __bfloat162float(u);
}
inline __device__ void from_float(float& d, float s) {
d = s;
}
inline __device__ void from_float(half& d, float s) {
d = __float2half(s);
}
inline __device__ void from_float(__nv_bfloat16& d, float s) {
d = __float2bfloat16(s);
}

// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
// Reference blog: https://zhuanlan.zhihu.com/p/1892966682634473987
template <typename scalar_t, const uint NUM_THREADS>
__global__ void merge_attn_states_kernel(
scalar_t* output,
float* output_lse,
const scalar_t* prefix_output,
const float* prefix_lse,
const scalar_t* suffix_output,
const float* suffix_lse,
const uint num_tokens,
const uint num_heads,
const uint head_size) {
using pack_128b_t = uint4;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;

const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x;
const uint token_head_threads = num_tokens * num_heads * threads_per_head;

if (global_idx >= token_head_threads) return;

// global_idx -> token_idx + head_idx + pack_idx
const uint token_head_idx = global_idx / threads_per_head;
const uint pack_idx = global_idx % threads_per_head;

const uint token_idx = token_head_idx / num_heads;
const uint head_idx = token_head_idx % num_heads;

const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
const uint head_offset = token_idx * num_heads * head_size + head_idx * head_size;
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
scalar_t* output_head_ptr = output + head_offset;

float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;

const float max_lse = fmaxf(p_lse, s_lse);
p_lse = p_lse - max_lse;
s_lse = s_lse - max_lse;
const float p_se = expf(p_lse);
const float s_se = expf(s_lse);
const float out_se = p_se + s_se;
const float p_scale = p_se / out_se;
const float s_scale = s_se / out_se;

if (pack_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(prefix_head_ptr)[pack_offset / pack_size];
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(suffix_head_ptr)[pack_offset / pack_size];
pack_128b_t o_out_pack;

#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
// Always use float for FMA to keep high precision.
// half(uint16_t), bfloat16, float -> float.
const float p_out_f = to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
const float s_out_f = to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
// float -> half(uint16_t), bfloat16, float.
from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
}

// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] = o_out_pack;
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
float out_lse = logf(out_se) + max_lse;
output_lse[head_idx * num_tokens + token_idx] = out_lse;
}
}

// The following macro is used to dispatch the conversion function based on
// the output data type. The FN is a macro that calls a function with
// template<typename scalar_t>.
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
{ \
if (scalar_dtype == at::ScalarType::Float) { \
fn(float); \
} else if (scalar_dtype == at::ScalarType::Half) { \
fn(half); \
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
fn(__nv_bfloat16); \
} else { \
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
} \
}

#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
{ \
merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), \
output_lse_ptr, \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), \
num_tokens, \
num_heads, \
head_size); \
}

/*@brief Merges the attention states from prefix and suffix
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
*
* @param output [n,h,d] The output tensor to store the merged attention states.
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_lse [h,d] The log-sum-exp values for the prefix attention
* states.
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [h,d] The log-sum-exp values for the suffix attention
* states.
*/
template <typename scalar_t>
void merge_attn_states_launcher(
at::Tensor& output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
std::optional<at::Tensor> output_lse, // [NUM_HEADS, NUM_TOKENS]
const at::Tensor& prefix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
const at::Tensor& prefix_lse, // [NUM_HEADS, NUM_TOKENS]
const at::Tensor& suffix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
const at::Tensor& suffix_lse // [NUM_HEADS, NUM_TOKENS]
) {
constexpr uint NUM_THREADS = 128;
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
const uint head_size = output.size(2);
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size);
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();
}
// process one pack elements per thread. float -> 4, half/bf16 -> 8
const uint threads_per_head = head_size / pack_size;
const uint total_threads = num_tokens * num_heads * threads_per_head;

dim3 block(NUM_THREADS);
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);

LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
}

#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse); }

void merge_attn_states(
at::Tensor& output,
std::optional<at::Tensor> output_lse,
const at::Tensor& prefix_output,
const at::Tensor& prefix_lse,
const at::Tensor& suffix_output,
const at::Tensor& suffix_lse) {
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
}
4 changes: 4 additions & 0 deletions sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
m.impl("merge_state", torch::kCUDA, &merge_state);
m.def(
"merge_attn_states(Tensor! output, Tensor!? output_lse, Tensor prefix_output, Tensor prefix_lse, Tensor "
"suffix_output, Tensor suffix_lse) -> ()");
m.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
m.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor workspace) -> ()");
Expand Down
7 changes: 7 additions & 0 deletions sgl-kernel/include/sgl_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ void lightning_attention_decode(
torch::Tensor new_kv);
void merge_state(
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
void merge_attn_states(
at::Tensor& output,
std::optional<at::Tensor> output_lse,
const at::Tensor& prefix_output,
const at::Tensor& prefix_lse,
const at::Tensor& suffix_output,
const at::Tensor& suffix_lse);
void cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope_and_q_pe,
Expand Down
1 change: 1 addition & 0 deletions sgl-kernel/python/sgl_kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
cutlass_mla_decode,
cutlass_mla_get_workspace_size,
lightning_attention_decode,
merge_attn_states,
merge_state,
)
from sgl_kernel.elementwise import (
Expand Down
41 changes: 37 additions & 4 deletions sgl-kernel/python/sgl_kernel/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Optional, Tuple

import torch

Expand All @@ -10,16 +10,49 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):


def merge_state(
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
v_a: torch.Tensor,
s_a: torch.Tensor,
v_b: torch.Tensor,
s_b: torch.Tensor,
v_merged: Optional[torch.Tensor] = None,
s_merged: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
s_a = s_a.to(torch.float32)
s_b = s_b.to(torch.float32)
v_merged = torch.empty_like(v_a)
s_merged = torch.empty_like(s_a)
# Avoid creating new tensors if they are already provided
if v_merged is None:
v_merged = torch.empty_like(v_a)
if s_merged is None:
s_merged = torch.empty_like(s_a)
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
return v_merged, s_merged


def merge_attn_states(
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output: Optional[torch.Tensor] = None,
output_lse: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Compare with merge_state function:
# prefix_output: v_a, prefix_lse: s_a
# suffix_output: v_b, suffix_lse: s_b
# output: v_merged, output_lse: s_merged
# TODO(DefTruth): Currently, the custom merge_attn_states kernel
# does not support the FP8 data type and non - CUDA devices.
# It may be necessary to fall back to using the Triton kernel.
if output is None:
output = torch.empty_like(prefix_output)
if output_lse is None:
output_lse = torch.empty_like(prefix_lse)
torch.ops.sgl_kernel.merge_attn_states.default(
output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse
)
return output, output_lse


def cutlass_mla_decode(
q_nope_and_q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
Expand Down
Loading
Loading