@@ -53,7 +53,7 @@ constexpr bool is_invalid_configuration(uint32_t num_frags_x, uint32_t num_frags
53
53
uint32_t num_warps_z) {
54
54
return ((num_frags_y < 4 ) || (num_frags_y == 4 && num_frags_z % 2 == 1 ) ||
55
55
(num_frags_y > 4 && num_frags_y % (2 * num_warps_x) != 0 ) ||
56
- (num_frags_x * (8 * num_frags_y + 2 * sizeof (DTypeQKAccum) * num_frags_z) >= 200 ));
56
+ (num_frags_x * (8 * num_frags_y + 2 * sizeof (DTypeQKAccum) * num_frags_z) >= 256 ));
57
57
}
58
58
59
59
/* !
@@ -207,30 +207,20 @@ template <bool produce_v, uint32_t num_warps_x, uint32_t num_warps_z, uint32_t n
207
207
__device__ __forceinline__ void page_produce_kv (
208
208
smem_t smem, uint32_t * smem_offset,
209
209
paged_kv_t <page_storage, kv_layout, DType, IdType>& paged_kv, const uint32_t kv_idx_base,
210
- const uint32_t packed_page_iter_base , const uint32_t kv_len, const IdType last_indptr ) {
210
+ const size_t * kv_offset , const uint32_t kv_len) {
211
211
constexpr SharedMemFillMode fill_mode =
212
212
produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill ;
213
213
constexpr uint32_t head_dim = num_frags_y * 16 ;
214
214
constexpr uint32_t num_warps = num_warps_x * num_warps_z;
215
215
constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b<DType>();
216
216
const uint32_t warp_idx = get_warp_idx<num_warps_x, num_warps_z>(), lane_idx = threadIdx .x ;
217
- const uint32_t kv_head_idx = blockIdx .z ;
218
217
uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8 ;
219
218
// NOTE(Zihao): num_frags_z * 4 / num_warps_x = num_warps_z * num_frags_z * 4 / num_warps
220
219
static_assert (num_frags_z * 4 % num_warps_x == 0 );
221
220
#pragma unroll
222
221
for (uint32_t i = 0 ; i < num_frags_z * 4 / num_warps_x; ++i) {
223
- uint32_t page_iter, entry_idx;
224
- paged_kv.page_size .divmod (
225
- packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps * i, page_iter,
226
- entry_idx);
227
- DType* gptr = produce_v
228
- ? paged_kv.protective_get_v_ptr (page_iter, kv_head_idx, entry_idx,
229
- (lane_idx % 8 ) * num_elems_per_128b<DType>(),
230
- last_indptr)
231
- : paged_kv.protective_get_k_ptr (page_iter, kv_head_idx, entry_idx,
232
- (lane_idx % 8 ) * num_elems_per_128b<DType>(),
233
- last_indptr);
222
+ DType* gptr = produce_v ? paged_kv.data + paged_kv.kv_offset_delta () + kv_offset[i]
223
+ : paged_kv.data + kv_offset[i];
234
224
#pragma unroll
235
225
for (uint32_t j = 0 ; j < num_frags_y / 4 ; ++j) {
236
226
smem.load_128b_async <fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
@@ -800,9 +790,21 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
800
790
const uint32_t lane_idx) {
801
791
// only necessary when blockDim.z > 1
802
792
if constexpr (num_warps_z > 1 ) {
803
- float2 * smem_md = (float2 *)smem_workspace;
804
- // o: [num_warps, warp_size, 8]
805
- // md: [num_warps, num_frags_x, 2, warp_size, 2 (m/d)]
793
+ float2 * smem_md = (float2 *)(smem_workspace + num_frags_x * num_frags_y * num_warps_x *
794
+ num_warps_z * warp_size * 8 );
795
+ // o: [num_warps, num_frags_x, num_frags_y, warp_size(32), 8]
796
+ // md: [num_warps, num_frags_x, 2, warp_size(32), 2 (m/d)]
797
+ #pragma unroll
798
+ for (uint32_t fx = 0 ; fx < num_frags_x; ++fx) {
799
+ #pragma unroll
800
+ for (uint32_t fy = 0 ; fy < num_frags_y; ++fy) {
801
+ vec_t <float , 8 >::memcpy (
802
+ smem_workspace +
803
+ (((warp_idx * num_frags_x + fx) * num_frags_y + fy) * warp_size + lane_idx) * 8 ,
804
+ o_frag[fx][fy]);
805
+ }
806
+ }
807
+
806
808
#pragma unroll
807
809
for (uint32_t fx = 0 ; fx < num_frags_x; ++fx) {
808
810
#pragma unroll
@@ -851,23 +853,22 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
851
853
}
852
854
}
853
855
854
- __syncthreads ();
855
-
856
856
// the following code saves shared memory usage.
857
857
#pragma unroll
858
858
for (uint32_t fx = 0 ; fx < num_frags_x; ++fx) {
859
859
#pragma unroll
860
860
for (uint32_t fy = 0 ; fy < num_frags_y; ++fy) {
861
861
vec_t <float , 8 > o_new;
862
862
o_new.fill (0 .f );
863
- vec_t <float , 8 >::memcpy (smem_workspace + (warp_idx * warp_size + lane_idx) * 8 ,
864
- o_frag[fx][fy]);
865
- __syncthreads ();
866
863
#pragma unroll
867
864
for (uint32_t i = 0 ; i < num_warps_z; ++i) {
868
865
vec_t <float , 8 > oi;
869
866
oi.load (smem_workspace +
870
- ((i * num_warps_x + get_warp_idx_x<num_warps_x, num_warps_z>()) * warp_size +
867
+ ((((i * num_warps_x + get_warp_idx_x<num_warps_x, num_warps_z>()) * num_frags_x +
868
+ fx) *
869
+ num_frags_y +
870
+ fy) *
871
+ warp_size +
871
872
lane_idx) *
872
873
8 );
873
874
#pragma unroll
@@ -876,7 +877,6 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
876
877
}
877
878
}
878
879
o_new.store (o_frag[fx][fy]);
879
- __syncthreads ();
880
880
}
881
881
}
882
882
}
@@ -1592,6 +1592,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
1592
1592
smem_t k_smem (smem + (num_warps_x * num_frags_x) * 16 * head_dim * sizeof (DTypeIn)),
1593
1593
v_smem (smem + (num_warps_x * num_frags_x + num_warps_z * num_frags_z) * 16 * head_dim *
1594
1594
sizeof (DTypeIn));
1595
+ size_t kv_offset[num_frags_z * 4 / num_warps_x];
1595
1596
1596
1597
uint32_t k_smem_offset_r = smem_t ::get_permuted_offset<channel_size_128b_in>(
1597
1598
get_warp_idx_z<num_warps_x, num_warps_z>() * num_frags_z * 16 + 8 * (lane_idx / 16 ) +
@@ -1605,13 +1606,22 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
1605
1606
const IdType last_indptr = paged_kv.indptr [paged_kv.batch_size ];
1606
1607
1607
1608
uint32_t packed_page_iter_base = paged_kv.indptr [request_idx] * paged_kv.page_size + chunk_start;
1609
+ for (uint32_t i = 0 ; i < num_frags_z * 4 / num_warps_x; ++i) {
1610
+ uint32_t page_iter, entry_idx;
1611
+ paged_kv.page_size .divmod (
1612
+ packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps_x * num_warps_z * i,
1613
+ page_iter, entry_idx);
1614
+ kv_offset[i] =
1615
+ page_iter < last_indptr
1616
+ ? paged_kv.get_k_elem_offset (__ldg (paged_kv.indices + page_iter), kv_head_idx,
1617
+ entry_idx, (lane_idx % 8 ) * num_elems_per_128b<DTypeIn>())
1618
+ : 0 ;
1619
+ }
1608
1620
page_produce_kv<false , num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
1609
- k_smem, &kv_smem_offset_w, paged_kv, chunk_start, packed_page_iter_base, chunk_end,
1610
- last_indptr);
1621
+ k_smem, &kv_smem_offset_w, paged_kv, chunk_start, kv_offset, chunk_end);
1611
1622
cp_async::commit_group ();
1612
1623
page_produce_kv<true , num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
1613
- v_smem, &kv_smem_offset_w, paged_kv, chunk_start, packed_page_iter_base, chunk_end,
1614
- last_indptr);
1624
+ v_smem, &kv_smem_offset_w, paged_kv, chunk_start, kv_offset, chunk_end);
1615
1625
cp_async::commit_group ();
1616
1626
1617
1627
const uint32_t num_iterations =
@@ -1631,8 +1641,20 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
1631
1641
: chunk_end - chunk_start) /
1632
1642
(16 * num_warps_z * num_frags_z);
1633
1643
1634
- #pragma unroll
1644
+ #pragma unroll 1
1635
1645
for (uint32_t iter = 0 ; iter < num_iterations; ++iter) {
1646
+ packed_page_iter_base += 16 * num_warps_z * num_frags_z;
1647
+ for (uint32_t i = 0 ; i < num_frags_z * 4 / num_warps_x; ++i) {
1648
+ uint32_t page_iter, entry_idx;
1649
+ paged_kv.page_size .divmod (
1650
+ packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps_x * num_warps_z * i,
1651
+ page_iter, entry_idx);
1652
+ kv_offset[i] = page_iter < last_indptr
1653
+ ? paged_kv.get_k_elem_offset (
1654
+ __ldg (paged_kv.indices + page_iter), kv_head_idx, entry_idx,
1655
+ (lane_idx % 8 ) * num_elems_per_128b<DTypeIn>())
1656
+ : 0 ;
1657
+ }
1636
1658
cp_async::wait_group<1 >();
1637
1659
block.sync ();
1638
1660
@@ -1677,11 +1699,9 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
1677
1699
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(s_frag, o_frag, m, d);
1678
1700
1679
1701
block.sync ();
1680
- packed_page_iter_base += 16 * num_warps_z * num_frags_z;
1681
1702
page_produce_kv<false , num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
1682
1703
k_smem, &kv_smem_offset_w, paged_kv,
1683
- chunk_start + (iter + 1 ) * 16 * num_warps_z * num_frags_z, packed_page_iter_base, chunk_end,
1684
- last_indptr);
1704
+ chunk_start + (iter + 1 ) * 16 * num_warps_z * num_frags_z, kv_offset, chunk_end);
1685
1705
cp_async::commit_group ();
1686
1706
cp_async::wait_group<1 >();
1687
1707
block.sync ();
@@ -1693,8 +1713,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
1693
1713
block.sync ();
1694
1714
page_produce_kv<true , num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
1695
1715
v_smem, &kv_smem_offset_w, paged_kv,
1696
- chunk_start + (iter + 1 ) * 16 * num_warps_z * num_frags_z, packed_page_iter_base, chunk_end,
1697
- last_indptr);
1716
+ chunk_start + (iter + 1 ) * 16 * num_warps_z * num_frags_z, kv_offset, chunk_end);
1698
1717
cp_async::commit_group ();
1699
1718
}
1700
1719
cp_async::wait_group<0 >();
@@ -1764,10 +1783,15 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
1764
1783
const uint_fastdiv group_size_fastdiv (group_size);
1765
1784
constexpr uint32_t num_frags_y = HEAD_DIM / 16 ;
1766
1785
WarpLayout warp_layout;
1767
- if (qo_len * group_size > 64 && HEAD_DIM < 256 ) {
1786
+ int64_t unpacked_qo_len = qo_len * group_size;
1787
+ if (unpacked_qo_len > 64 && HEAD_DIM < 256 ) {
1768
1788
warp_layout = WarpLayout::k4x1x2;
1769
1789
} else {
1770
- warp_layout = WarpLayout::k4x1x1;
1790
+ if (unpacked_qo_len > 16 ) {
1791
+ warp_layout = WarpLayout::k4x1x1;
1792
+ } else {
1793
+ warp_layout = WarpLayout::k1x4x1;
1794
+ }
1771
1795
}
1772
1796
1773
1797
DISPATCH_WARP_LAYOUT (warp_layout, WARP_LAYOUT, {
0 commit comments