Skip to content

Commit e56ddad

Browse files
authored
perf: accelerate gqa performance (#356)
Changes: 1. Prefetch page indices (we have already done such optimization on decode kernels, but not on append/prefill kernels which was used in GQA). 2. Unlock 1x4 warp layout in #322, we didn't enable this because the binary size is too large, we should further reduce some unnecessary template arguments. 3. Optimize `threadblock_sync_mdo_states` for efficient merging attention states of multiple warps in a threadblock. Our previous implementation assumes small shared memory size and interleaves shared memory reads/writes with computations, which is not as efficient as a bulk shared memory access. After this PR, the GQA kernel execution time (on H100) for setting `batch_size=128, seq_len=1024, num_qo_heads=32, num_kv_heads=4, head_dim=128` was improved from 133us to 103us.
1 parent 2e64a65 commit e56ddad

File tree

5 files changed

+84
-52
lines changed

5 files changed

+84
-52
lines changed

Diff for: include/flashinfer/attention/handler.cuh

+6-1
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,12 @@ cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_siz
560560
if (avg_packed_qo_len > 64 && head_dim < 256) {
561561
warp_layout = WarpLayout::k4x1x2; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 2)
562562
} else {
563-
warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1)
563+
if (avg_packed_qo_len > 16) {
564+
warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1)
565+
} else {
566+
// avg_packed_qo_len <= 16
567+
warp_layout = WarpLayout::k1x4x1; // (num_warps_x = 1, num_warps_z = 4, num_frags_x = 1)
568+
}
564569
}
565570
const uint32_t qo_chunk_size = get_num_rows_per_cta(warp_layout);
566571

Diff for: include/flashinfer/attention/prefill.cuh

+60-36
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ constexpr bool is_invalid_configuration(uint32_t num_frags_x, uint32_t num_frags
5353
uint32_t num_warps_z) {
5454
return ((num_frags_y < 4) || (num_frags_y == 4 && num_frags_z % 2 == 1) ||
5555
(num_frags_y > 4 && num_frags_y % (2 * num_warps_x) != 0) ||
56-
(num_frags_x * (8 * num_frags_y + 2 * sizeof(DTypeQKAccum) * num_frags_z) >= 200));
56+
(num_frags_x * (8 * num_frags_y + 2 * sizeof(DTypeQKAccum) * num_frags_z) >= 256));
5757
}
5858

5959
/*!
@@ -207,30 +207,20 @@ template <bool produce_v, uint32_t num_warps_x, uint32_t num_warps_z, uint32_t n
207207
__device__ __forceinline__ void page_produce_kv(
208208
smem_t smem, uint32_t* smem_offset,
209209
paged_kv_t<page_storage, kv_layout, DType, IdType>& paged_kv, const uint32_t kv_idx_base,
210-
const uint32_t packed_page_iter_base, const uint32_t kv_len, const IdType last_indptr) {
210+
const size_t* kv_offset, const uint32_t kv_len) {
211211
constexpr SharedMemFillMode fill_mode =
212212
produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill;
213213
constexpr uint32_t head_dim = num_frags_y * 16;
214214
constexpr uint32_t num_warps = num_warps_x * num_warps_z;
215215
constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b<DType>();
216216
const uint32_t warp_idx = get_warp_idx<num_warps_x, num_warps_z>(), lane_idx = threadIdx.x;
217-
const uint32_t kv_head_idx = blockIdx.z;
218217
uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8;
219218
// NOTE(Zihao): num_frags_z * 4 / num_warps_x = num_warps_z * num_frags_z * 4 / num_warps
220219
static_assert(num_frags_z * 4 % num_warps_x == 0);
221220
#pragma unroll
222221
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) {
223-
uint32_t page_iter, entry_idx;
224-
paged_kv.page_size.divmod(
225-
packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps * i, page_iter,
226-
entry_idx);
227-
DType* gptr = produce_v
228-
? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx,
229-
(lane_idx % 8) * num_elems_per_128b<DType>(),
230-
last_indptr)
231-
: paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx,
232-
(lane_idx % 8) * num_elems_per_128b<DType>(),
233-
last_indptr);
222+
DType* gptr = produce_v ? paged_kv.data + paged_kv.kv_offset_delta() + kv_offset[i]
223+
: paged_kv.data + kv_offset[i];
234224
#pragma unroll
235225
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
236226
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
@@ -800,9 +790,21 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
800790
const uint32_t lane_idx) {
801791
// only necessary when blockDim.z > 1
802792
if constexpr (num_warps_z > 1) {
803-
float2* smem_md = (float2*)smem_workspace;
804-
// o: [num_warps, warp_size, 8]
805-
// md: [num_warps, num_frags_x, 2, warp_size, 2 (m/d)]
793+
float2* smem_md = (float2*)(smem_workspace + num_frags_x * num_frags_y * num_warps_x *
794+
num_warps_z * warp_size * 8);
795+
// o: [num_warps, num_frags_x, num_frags_y, warp_size(32), 8]
796+
// md: [num_warps, num_frags_x, 2, warp_size(32), 2 (m/d)]
797+
#pragma unroll
798+
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
799+
#pragma unroll
800+
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
801+
vec_t<float, 8>::memcpy(
802+
smem_workspace +
803+
(((warp_idx * num_frags_x + fx) * num_frags_y + fy) * warp_size + lane_idx) * 8,
804+
o_frag[fx][fy]);
805+
}
806+
}
807+
806808
#pragma unroll
807809
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
808810
#pragma unroll
@@ -851,23 +853,22 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
851853
}
852854
}
853855

854-
__syncthreads();
855-
856856
// the following code saves shared memory usage.
857857
#pragma unroll
858858
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
859859
#pragma unroll
860860
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
861861
vec_t<float, 8> o_new;
862862
o_new.fill(0.f);
863-
vec_t<float, 8>::memcpy(smem_workspace + (warp_idx * warp_size + lane_idx) * 8,
864-
o_frag[fx][fy]);
865-
__syncthreads();
866863
#pragma unroll
867864
for (uint32_t i = 0; i < num_warps_z; ++i) {
868865
vec_t<float, 8> oi;
869866
oi.load(smem_workspace +
870-
((i * num_warps_x + get_warp_idx_x<num_warps_x, num_warps_z>()) * warp_size +
867+
((((i * num_warps_x + get_warp_idx_x<num_warps_x, num_warps_z>()) * num_frags_x +
868+
fx) *
869+
num_frags_y +
870+
fy) *
871+
warp_size +
871872
lane_idx) *
872873
8);
873874
#pragma unroll
@@ -876,7 +877,6 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
876877
}
877878
}
878879
o_new.store(o_frag[fx][fy]);
879-
__syncthreads();
880880
}
881881
}
882882
}
@@ -1592,6 +1592,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
15921592
smem_t k_smem(smem + (num_warps_x * num_frags_x) * 16 * head_dim * sizeof(DTypeIn)),
15931593
v_smem(smem + (num_warps_x * num_frags_x + num_warps_z * num_frags_z) * 16 * head_dim *
15941594
sizeof(DTypeIn));
1595+
size_t kv_offset[num_frags_z * 4 / num_warps_x];
15951596

15961597
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<channel_size_128b_in>(
15971598
get_warp_idx_z<num_warps_x, num_warps_z>() * num_frags_z * 16 + 8 * (lane_idx / 16) +
@@ -1605,13 +1606,22 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
16051606
const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size];
16061607

16071608
uint32_t packed_page_iter_base = paged_kv.indptr[request_idx] * paged_kv.page_size + chunk_start;
1609+
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) {
1610+
uint32_t page_iter, entry_idx;
1611+
paged_kv.page_size.divmod(
1612+
packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps_x * num_warps_z * i,
1613+
page_iter, entry_idx);
1614+
kv_offset[i] =
1615+
page_iter < last_indptr
1616+
? paged_kv.get_k_elem_offset(__ldg(paged_kv.indices + page_iter), kv_head_idx,
1617+
entry_idx, (lane_idx % 8) * num_elems_per_128b<DTypeIn>())
1618+
: 0;
1619+
}
16081620
page_produce_kv<false, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
1609-
k_smem, &kv_smem_offset_w, paged_kv, chunk_start, packed_page_iter_base, chunk_end,
1610-
last_indptr);
1621+
k_smem, &kv_smem_offset_w, paged_kv, chunk_start, kv_offset, chunk_end);
16111622
cp_async::commit_group();
16121623
page_produce_kv<true, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
1613-
v_smem, &kv_smem_offset_w, paged_kv, chunk_start, packed_page_iter_base, chunk_end,
1614-
last_indptr);
1624+
v_smem, &kv_smem_offset_w, paged_kv, chunk_start, kv_offset, chunk_end);
16151625
cp_async::commit_group();
16161626

16171627
const uint32_t num_iterations =
@@ -1631,8 +1641,20 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
16311641
: chunk_end - chunk_start) /
16321642
(16 * num_warps_z * num_frags_z);
16331643

1634-
#pragma unroll
1644+
#pragma unroll 1
16351645
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
1646+
packed_page_iter_base += 16 * num_warps_z * num_frags_z;
1647+
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) {
1648+
uint32_t page_iter, entry_idx;
1649+
paged_kv.page_size.divmod(
1650+
packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps_x * num_warps_z * i,
1651+
page_iter, entry_idx);
1652+
kv_offset[i] = page_iter < last_indptr
1653+
? paged_kv.get_k_elem_offset(
1654+
__ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx,
1655+
(lane_idx % 8) * num_elems_per_128b<DTypeIn>())
1656+
: 0;
1657+
}
16361658
cp_async::wait_group<1>();
16371659
block.sync();
16381660

@@ -1677,11 +1699,9 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
16771699
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(s_frag, o_frag, m, d);
16781700

16791701
block.sync();
1680-
packed_page_iter_base += 16 * num_warps_z * num_frags_z;
16811702
page_produce_kv<false, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
16821703
k_smem, &kv_smem_offset_w, paged_kv,
1683-
chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, packed_page_iter_base, chunk_end,
1684-
last_indptr);
1704+
chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, kv_offset, chunk_end);
16851705
cp_async::commit_group();
16861706
cp_async::wait_group<1>();
16871707
block.sync();
@@ -1693,8 +1713,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
16931713
block.sync();
16941714
page_produce_kv<true, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
16951715
v_smem, &kv_smem_offset_w, paged_kv,
1696-
chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, packed_page_iter_base, chunk_end,
1697-
last_indptr);
1716+
chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, kv_offset, chunk_end);
16981717
cp_async::commit_group();
16991718
}
17001719
cp_async::wait_group<0>();
@@ -1764,10 +1783,15 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
17641783
const uint_fastdiv group_size_fastdiv(group_size);
17651784
constexpr uint32_t num_frags_y = HEAD_DIM / 16;
17661785
WarpLayout warp_layout;
1767-
if (qo_len * group_size > 64 && HEAD_DIM < 256) {
1786+
int64_t unpacked_qo_len = qo_len * group_size;
1787+
if (unpacked_qo_len > 64 && HEAD_DIM < 256) {
17681788
warp_layout = WarpLayout::k4x1x2;
17691789
} else {
1770-
warp_layout = WarpLayout::k4x1x1;
1790+
if (unpacked_qo_len > 16) {
1791+
warp_layout = WarpLayout::k4x1x1;
1792+
} else {
1793+
warp_layout = WarpLayout::k1x4x1;
1794+
}
17711795
}
17721796

17731797
DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, {

Diff for: include/flashinfer/attention/warp_layout.cuh

+16-13
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace flashinfer {
2626
enum class WarpLayout {
2727
k4x1x2 = 0U,
2828
k4x1x1 = 1U,
29-
// k1x4x1 = 2U,
29+
k1x4x1 = 2U,
3030
};
3131

3232
template <WarpLayout warp_layout>
@@ -44,10 +44,10 @@ constexpr uint32_t get_num_warps_x<WarpLayout::k4x1x1>() {
4444
return 4;
4545
}
4646

47-
// template <>
48-
// constexpr uint32_t get_num_warps_x<WarpLayout::k1x4x1>() {
49-
// return 1;
50-
// }
47+
template <>
48+
constexpr uint32_t get_num_warps_x<WarpLayout::k1x4x1>() {
49+
return 1;
50+
}
5151

5252
template <WarpLayout warp_layout>
5353
constexpr uint32_t get_num_warps_z() {
@@ -64,10 +64,10 @@ constexpr uint32_t get_num_warps_z<WarpLayout::k4x1x1>() {
6464
return 1;
6565
}
6666

67-
// template <>
68-
// constexpr uint32_t get_num_warps_z<WarpLayout::k1x4x1>() {
69-
// return 4;
70-
// }
67+
template <>
68+
constexpr uint32_t get_num_warps_z<WarpLayout::k1x4x1>() {
69+
return 4;
70+
}
7171

7272
template <WarpLayout warp_layout>
7373
constexpr uint32_t get_num_frags_x() {
@@ -84,10 +84,10 @@ constexpr uint32_t get_num_frags_x<WarpLayout::k4x1x1>() {
8484
return 1;
8585
}
8686

87-
// template <>
88-
// constexpr uint32_t get_num_frags_x<WarpLayout::k1x4x1>() {
89-
// return 1;
90-
// }
87+
template <>
88+
constexpr uint32_t get_num_frags_x<WarpLayout::k1x4x1>() {
89+
return 1;
90+
}
9191

9292
#define DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, ...) \
9393
if (warp_layout == WarpLayout::k4x1x2) { \
@@ -96,6 +96,9 @@ constexpr uint32_t get_num_frags_x<WarpLayout::k4x1x1>() {
9696
} else if (warp_layout == WarpLayout::k4x1x1) { \
9797
constexpr WarpLayout WARP_LAYOUT = WarpLayout::k4x1x1; \
9898
__VA_ARGS__ \
99+
} else if (warp_layout == WarpLayout::k1x4x1) { \
100+
constexpr WarpLayout WARP_LAYOUT = WarpLayout::k1x4x1; \
101+
__VA_ARGS__ \
99102
} else { \
100103
std::ostringstream err_msg; \
101104
err_msg << "Unsupported warp layout: " << int(warp_layout); \

Diff for: python/generate_batch_paged_prefill_inst.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def get_cu_file_str(
4040
dtype_out,
4141
idtype,
4242
):
43-
warp_layout_choice = [0, 1]
43+
warp_layout_choice = [0, 1, 2]
4444
insts = "\n".join(
4545
[
4646
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<page_storage, {warp_layout}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>(

Diff for: python/generate_batch_ragged_prefill_inst.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_cu_file_str(
3939
dtype_out,
4040
idtype,
4141
):
42-
warp_layout_choice = [0, 1]
42+
warp_layout_choice = [0, 1, 2]
4343
insts = "\n".join(
4444
[
4545
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{warp_layout}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>(

0 commit comments

Comments
 (0)