Skip to content

Commit

Permalink
ci: further reduce binary size (#436)
Browse files Browse the repository at this point in the history
Move `partition_kv` flag from compile time to runtime.

This PR also fixes a minor issue that sparse module was broken in #428 .
  • Loading branch information
yzh119 authored Aug 10, 2024
1 parent 2c9d1c3 commit 9ca04e4
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 161 deletions.
1 change: 1 addition & 0 deletions .github/workflows/release_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
-e FLASHINFER_CI_TORCH_VERSION=${{ matrix.torch }} \
-e FLASHINFER_BUILD_VERSION=$version \
-e TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST" \
-e MAX_JOBS=224 \
--user $CI_UID:$CI_GID \
pytorch/manylinux-builder:cuda${{ matrix.cuda }} \
bash /app/scripts/run-ci-build-wheel.sh
Expand Down
59 changes: 22 additions & 37 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ __device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, f
/*!
* \brief FlashAttention decoding cuda kernel with kv-cache for a single request
* \tparam logits_post_hook The logits post hook used in the kernel
* \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not
* \tparam pos_encoding_mode The positional encoding mode
* \tparam vec_size A template integer indicates the vector size
* \tparam bdx A template integer indicates the block size in x dimension
Expand All @@ -208,7 +207,7 @@ __device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, f
* of "theta" used in RoPE (Rotary Positional Embeddings)
* \param kv_chunk_size A integer indicates the kv-chunk size
*/
template <LogitsPostHook logits_post_hook, bool partition_kv, PosEncodingMode pos_encoding_mode,
template <LogitsPostHook logits_post_hook, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx,
uint32_t bdy, uint32_t bdz, typename DTypeQ, typename DTypeKV, typename DTypeOut>
__global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k,
Expand Down Expand Up @@ -362,7 +361,6 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
/*!
* \brief FlashAttention decoding cuda kernel with paged kv-cache for multiple requests
* \tparam logits_post_hook The logits post hook used in the kernel
* \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not
* \tparam pos_encoding_mode The positional encoding mode
* \tparam vec_size A template integer indicates the vector size
* \tparam bdx A template integer indicates the block size in x dimension
Expand All @@ -385,16 +383,17 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
* \param rope_rcp_theta A floating number indicate the reciprocal
* of "theta" used in RoPE (Rotary Positional Embeddings)
*/
template <LogitsPostHook logits_post_hook, bool partition_kv, PosEncodingMode pos_encoding_mode,
template <LogitsPostHook logits_post_hook, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx,
uint32_t bdy, uint32_t bdz, PageStorage page_storage, typename DTypeQ, typename DTypeKV,
typename DTypeOut, typename IdType>
__global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
float* __restrict__ lse, bool* __restrict__ block_valid_mask, int32_t window_left,
float logits_soft_cap, float sm_scale, float rope_rcp_scale, float rope_rcp_theta) {
float* __restrict__ lse, bool* __restrict__ block_valid_mask, bool partition_kv,
int32_t window_left, float logits_soft_cap, float sm_scale, float rope_rcp_scale,
float rope_rcp_theta) {
auto block = cg::this_thread_block();
sm_scale *=
(logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap));
Expand Down Expand Up @@ -653,15 +652,13 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
const uint32_t smem_size =
2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) +
2U * bdy * bdz * sizeof(float);
auto kernel = SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx,
bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
if (seq_len <= 256 || tmp == nullptr) {
// no need to use partition-kv kernel
auto kernel =
SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/false, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx, bdy, bdz,
DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

dim3 nblks = dim3(1, num_kv_heads);
dim3 nthrs = dim3(bdx, bdy, bdz);
float* lse = nullptr;
Expand All @@ -680,13 +677,6 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
auto kernel =
SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/true, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx, bdy, bdz,
DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
Expand Down Expand Up @@ -751,25 +741,26 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
const uint32_t smem_size =
2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float));

auto kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, POS_ENCODING_MODE, num_stages_smem,
tile_size_per_bdx, vec_size, bdx, bdy, bdz, page_storage,
DTypeQ, DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
if (tmp_v == nullptr) {
// do not use partition-kv kernel
bool partition_kv = false;
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
auto kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/false,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, page_storage, DTypeQ, DTypeKV,
DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

void* args[] = {(void*)&q,
(void*)&q_offset,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&o,
(void*)&lse,
(void*)&block_valid_mask,
(void*)&partition_kv,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
Expand All @@ -778,29 +769,23 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
auto partition_kv_kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/true,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, page_storage, DTypeQ, DTypeKV,
DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
bool partition_kv = true;
void* args[] = {(void*)&q,
(void*)&q_offset,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&tmp_v,
(void*)&tmp_s,
(void*)&block_valid_mask,
(void*)&partition_kv,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)partition_kv_kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
tmp_v, tmp_s, kv_partition_info.chunk_indptr, o, lse,
kv_partition_info.batch_size_before_partition, num_qo_heads, HEAD_DIM, stream));
Expand Down
16 changes: 8 additions & 8 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,17 @@

namespace flashinfer {

template <LogitsPostHook logits_post_hook, bool partition_kv, PosEncodingMode pos_encoding_mode,
template <LogitsPostHook logits_post_hook, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx,
uint32_t bdy, uint32_t bdz, PageStorage page_storage, typename DTypeQ, typename DTypeKV,
typename DTypeOut, typename IdType>
__global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
float* __restrict__ lse, bool* __restrict__ block_valid_mask, int maybe_window_left,
float logits_soft_cap, float sm_scale, float rope_rcp_scale, float rope_rcp_theta);
float* __restrict__ lse, bool* __restrict__ block_valid_mask, bool partition_kv,
int maybe_window_left, float logits_soft_cap, float sm_scale, float rope_rcp_scale,
float rope_rcp_theta);

/*!
* \brief Compute the maximum number of pages per batch and the new batch size
Expand Down Expand Up @@ -156,18 +157,17 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float));

auto partition_kv_kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK,
/*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem,
auto kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, POS_ENCODING_MODE, num_stages_smem,
tile_size_per_bdx, vec_size, bdx, bdy, bdz, page_storage,
DTypeQ, DTypeKV, DTypeOut, IdType>;
int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
max_grid_size = num_blocks_per_sm * num_sm;
if (batch_size * num_kv_heads >= max_grid_size) {
split_kv = false;
Expand Down
Loading

0 comments on commit 9ca04e4

Please sign in to comment.