Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 132 additions & 120 deletions CMakeLists.txt

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ set(FLASHINFER_DISTRIBUTED ON)
# The following configurations can impact the binary
# size of the generated library
set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
set(FLASHINFER_GEN_PAGE_SIZES 1 16 32)
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
Expand Down
88 changes: 50 additions & 38 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "../utils.cuh"
#include "../vec_dtypes.cuh"
#include "cascade.cuh"
#include "logits_post_hook.cuh"
#include "state.cuh"

namespace flashinfer {
Expand All @@ -48,6 +49,7 @@ namespace {

/*!
* \brief Load k tile from smem and compute qk
* \tparam logits_post_hook The logits post hook used in the kernel
* \tparam pos_encoding_mode The positional encoding mode used in the kernel
* \tparam head_dim A template integer indicates the head dimension
* \tparam vec_size A template integer indicates the vector size
Expand All @@ -65,8 +67,8 @@ namespace {
* \param s A float indicates the thread-local result of qk
* \param st The self-attention state to be updated
*/
template <PosEncodingMode pos_encoding_mode, uint32_t vec_size, uint32_t bdx, uint32_t tile_size,
typename T>
template <LogitsPostHook logits_post_hook, PosEncodingMode pos_encoding_mode, uint32_t vec_size,
uint32_t bdx, uint32_t tile_size, typename T>
__device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage_idx,
const vec_t<float, vec_size>& q_vec,
const vec_t<float, vec_size>& freq, uint32_t kv_idx_base,
Expand Down Expand Up @@ -96,6 +98,7 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage
s[j] += math::shfl_xor_sync(s[j], offset);
}
s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -5e4;
s[j] = apply_logits_post_hook<logits_post_hook>(s[j]);
if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) {
s[j] += alibi_slope * float(int(kv_idx_base + tz * tile_size + j) - q_offset);
}
Expand Down Expand Up @@ -178,6 +181,7 @@ __device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, f

/*!
* \brief FlashAttention decoding cuda kernel with kv-cache for a single request
* \tparam logits_post_hook The logits post hook used in the kernel
* \tparam kv_layout The layout of k/v matrices (NHD or HND)
* \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not
* \tparam pos_encoding_mode The positional encoding mode
Expand All @@ -202,9 +206,10 @@ __device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, f
* of "theta" used in RoPE (Rotary Positional Embeddings)
* \param kv_chunk_size A integer indicates the kv-chunk size
*/
template <QKVLayout kv_layout, bool partition_kv, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx,
uint32_t bdy, uint32_t bdz, typename DTypeQ, typename DTypeKV, typename DTypeOut>
template <LogitsPostHook logits_post_hook, QKVLayout kv_layout, bool partition_kv,
PosEncodingMode pos_encoding_mode, uint32_t num_stages_smem, uint32_t tile_size_per_bdx,
uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz, typename DTypeQ,
typename DTypeKV, typename DTypeOut>
__global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k,
DTypeKV* __restrict__ v, DTypeOut* __restrict__ o,
DTypeOut* __restrict__ tmp,
Expand All @@ -213,7 +218,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
float rope_rcp_theta, uint32_t kv_chunk_size) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
sm_scale *= math::log2e;
sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f);

constexpr uint32_t head_dim = bdx * vec_size;
uint32_t kv_head_idx = blockIdx.y;
Expand Down Expand Up @@ -297,7 +302,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
// compute qk
cp_async::wait_group<2 * num_stages_smem - 1>();
block.sync();
compute_qk<pos_encoding_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
compute_qk<logits_post_hook, pos_encoding_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size,
seq_len - 1, alibi_slope, s, st_local);
Expand Down Expand Up @@ -356,16 +361,16 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
}
}

template <QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, uint32_t num_stages_smem,
uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz, typename DTypeQ,
typename DTypeKV, typename DTypeOut>
template <LogitsPostHook logits_post_hook, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz,
typename DTypeQ, typename DTypeKV, typename DTypeOut>
__global__ void BatchDecodeWithPaddedKVCacheKernel(
DTypeQ* __restrict__ q, DTypeKV* __restrict__ k, DTypeKV* __restrict__ v,
DTypeOut* __restrict__ o, float* __restrict__ lse,
tensor_info_t<kv_layout, bdy, bdx * vec_size> info, float sm_scale, float rope_rcp_scale,
float rope_rcp_theta) {
auto block = cg::this_thread_block();
sm_scale *= math::log2e;
sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f);

constexpr uint32_t head_dim = bdx * vec_size;
uint32_t kv_head_idx = blockIdx.y;
Expand Down Expand Up @@ -438,7 +443,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(
// compute qk
cp_async::wait_group<2 * num_stages_smem - 1>();
block.sync();
compute_qk<pos_encoding_mode, vec_size, bdx, bdy>(
compute_qk<logits_post_hook, pos_encoding_mode, vec_size, bdx, bdy>(
k_smem + (stage_idx * bdz + tz) * bdy * head_dim, stage_idx, q_vec, freq,
consumer_kv_idx_base, iter * bdy * bdz, seq_len, seq_len - 1, alibi_slope, s, st_local);
block.sync();
Expand Down Expand Up @@ -489,6 +494,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(

/*!
* \brief FlashAttention decoding cuda kernel with paged kv-cache for multiple requests
* \tparam logits_post_hook The logits post hook used in the kernel
* \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not
* \tparam pos_encoding_mode The positional encoding mode
* \tparam vec_size A template integer indicates the vector size
Expand All @@ -512,10 +518,10 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(
* \param rope_rcp_theta A floating number indicate the reciprocal
* of "theta" used in RoPE (Rotary Positional Embeddings)
*/
template <bool partition_kv, PosEncodingMode pos_encoding_mode, uint32_t num_stages_smem,
uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz,
PageStorage page_storage, QKVLayout kv_layout, typename DTypeQ, typename DTypeKV,
typename DTypeOut, typename IdType>
template <LogitsPostHook logits_post_hook, bool partition_kv, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx,
uint32_t bdy, uint32_t bdz, PageStorage page_storage, QKVLayout kv_layout,
typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
__global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
Expand All @@ -524,7 +530,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
bool* __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale,
float rope_rcp_theta) {
auto block = cg::this_thread_block();
sm_scale *= math::log2e;
sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f);

constexpr uint32_t head_dim = bdx * vec_size;
const uint32_t batch_idx = blockIdx.x;
Expand Down Expand Up @@ -649,7 +655,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
// compute qk
cp_async::wait_group<2 * num_stages_smem - 1>();
block.sync();
compute_qk<pos_encoding_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
compute_qk<logits_post_hook, pos_encoding_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
freq,
(paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) +
Expand Down Expand Up @@ -760,8 +766,8 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo
* \param stream The cuda stream to launch the kernel
* \return status Indicates whether CUDA calls are successful
*/
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV,
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV,
typename DTypeOut>
cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, uint32_t num_kv_heads,
Expand All @@ -786,9 +792,9 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
if (seq_len <= 256 || tmp == nullptr) {
// no need to use partition-kv kernel
auto kernel =
SingleDecodeWithKVCacheKernel<KV_LAYOUT, /*partition_kv=*/false, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx, bdy, bdz,
DTypeQ, DTypeKV, DTypeOut>;
SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, KV_LAYOUT, /*partition_kv=*/false,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

Expand All @@ -807,9 +813,10 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
auto kernel = SingleDecodeWithKVCacheKernel<KV_LAYOUT, /*partition_kv=*/true, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx,
bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
auto kernel =
SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, KV_LAYOUT, /*partition_kv=*/true,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

Expand Down Expand Up @@ -848,8 +855,9 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
return cudaSuccess;
}

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage,
LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout, PosEncodingMode POS_ENCODING_MODE,
typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheDispatched(
DTypeQ* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
Expand Down Expand Up @@ -877,9 +885,10 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
auto kernel =
BatchDecodeWithPagedKVCacheKernel</*partition_kv=*/false, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx, bdy,
bdz, page_storage, kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>;
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/false,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, page_storage, kv_layout, DTypeQ,
DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
void* args[] = {(void*)&q,
Expand All @@ -898,9 +907,10 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
} else {
// use partition-kv kernel
auto partition_kv_kernel =
BatchDecodeWithPagedKVCacheKernel</*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem,
tile_size_per_bdx, vec_size, bdx, bdy, bdz, page_storage,
kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>;
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/true,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, page_storage, kv_layout, DTypeQ,
DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
void* args[] = {(void*)&q,
Expand Down Expand Up @@ -946,8 +956,9 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
* \param stream The cuda stream to launch the kernel
* \return status Indicates whether CUDA calls are successful
*/
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut>
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV,
typename DTypeOut>
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads,
Expand All @@ -970,8 +981,9 @@ cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeK

dim3 nblks(batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
auto kernel = BatchDecodeWithPaddedKVCacheKernel<KV_LAYOUT, POS_ENCODING_MODE, num_stages_smem,
vec_size, bdx, bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
auto kernel = BatchDecodeWithPaddedKVCacheKernel<LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE,
num_stages_smem, vec_size, bdx, bdy, bdz, DTypeQ,
DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
tensor_info_t<KV_LAYOUT, GROUP_SIZE, HEAD_DIM> info(1, padded_kv_len, num_kv_heads);
Expand Down
Loading