Skip to content

Commit 7142d6b

Browse files
committed
upd
1 parent 01fdedd commit 7142d6b

File tree

3 files changed

+84
-62
lines changed

3 files changed

+84
-62
lines changed

include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh

Lines changed: 77 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -99,23 +99,29 @@ struct FP8SparseCollectiveMainloop {
9999
DTypeQ const* Q_ptr;
100100
LayoutT layout_Q;
101101
DTypeKV const* K_ptr;
102-
LayoutT layout_K;
102+
int64_t k_stride_n; // Stride between consecutive KV tokens
103+
int64_t k_page_stride; // Stride between pages
103104
DTypeKV const* V_ptr;
104-
LayoutT layout_V;
105+
int64_t v_stride_n; // Stride between consecutive KV tokens
106+
int64_t v_page_stride; // Stride between pages
105107
IdType const* kv_indices;
108+
uint32_t page_size; // Size of each page
106109
int window_left;
107110
AdditionalParams additional_params;
108111
};
109112

110113
// Device side kernel params
111114
struct Params {
112115
LayoutT layout_Q;
113-
LayoutT layout_K;
114-
LayoutT layout_V;
115116
TMA_Q tma_load_Q;
116117
DTypeKV* K_ptr;
118+
int64_t k_stride_n;
119+
int64_t k_page_stride;
117120
DTypeKV* V_ptr;
121+
int64_t v_stride_n;
122+
int64_t v_page_stride;
118123
IdType* kv_indices;
124+
uint_fastdiv page_size; // Size of each page (as fastdiv for efficient divmod)
119125
int window_left;
120126
AdditionalParams additional_params;
121127
using DTypeKV = typename Ktraits::DTypeKV;
@@ -125,15 +131,10 @@ struct FP8SparseCollectiveMainloop {
125131
Tensor mQ = make_tensor(make_gmem_ptr(args.Q_ptr), args.layout_Q);
126132
TMA_Q tma_load_Q =
127133
make_tma_copy(GmemTiledCopyQ{}, mQ, SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{});
128-
return {args.layout_Q,
129-
args.layout_K,
130-
args.layout_V,
131-
tma_load_Q,
132-
const_cast<DTypeKV*>(args.K_ptr),
133-
const_cast<DTypeKV*>(args.V_ptr),
134-
const_cast<IdType*>(args.kv_indices),
135-
args.window_left,
136-
args.additional_params};
134+
return {args.layout_Q, tma_load_Q, const_cast<DTypeKV*>(args.K_ptr),
135+
args.k_stride_n, args.k_page_stride, const_cast<DTypeKV*>(args.V_ptr),
136+
args.v_stride_n, args.v_page_stride, const_cast<IdType*>(args.kv_indices),
137+
args.page_size, args.window_left, args.additional_params};
137138
}
138139

139140
CUTLASS_DEVICE
@@ -208,43 +209,71 @@ struct FP8SparseCollectiveMainloop {
208209

209210
constexpr int HEAD_DIM = get<2>(TileShape_QKD{});
210211
constexpr int CTA_KV = get<1>(TileShape_QKD{});
211-
auto indexed_gather = BlockSparseIndexedGather<IdType>(mainloop_params.kv_indices + kv_indptr);
212+
IdType const* kv_indices_ptr = mainloop_params.kv_indices + kv_indptr;
212213

213-
Tensor mK = make_block_sparse_tensor( // (kv_len, D)
214-
make_gmem_ptr(mainloop_params.K_ptr + kv_head_idx * stride<2>(mainloop_params.layout_K)),
215-
make_shape(kv_len, HEAD_DIM), stride<0>(mainloop_params.layout_K), indexed_gather);
216-
Tensor mV = make_block_sparse_tensor( // (kv_len, D)
217-
make_gmem_ptr(mainloop_params.V_ptr + kv_head_idx * stride<2>(mainloop_params.layout_V)),
218-
make_shape(kv_len, HEAD_DIM), stride<0>(mainloop_params.layout_V), indexed_gather);
219-
220-
Tensor gK = local_tile(mK, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D, kv)
221-
Tensor gV = local_tile(mV, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D, kv)
222-
Tensor cKV = cute::make_identity_tensor(gK.shape());
214+
// Setup for manual K/V loading with page table
215+
DTypeKV* k_base_ptr = mainloop_params.K_ptr;
216+
DTypeKV* v_base_ptr = mainloop_params.V_ptr;
217+
int64_t k_stride_n = mainloop_params.k_stride_n;
218+
int64_t k_page_stride = mainloop_params.k_page_stride;
219+
int64_t v_stride_n = mainloop_params.v_stride_n;
220+
int64_t v_page_stride = mainloop_params.v_page_stride;
223221

224222
GmemTiledCopyKV gmem_tiled_copy_kv;
225223
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
226224

227-
Tensor tKgK = gmem_thr_copy_kv.partition_S(gK); // (CPY, CPY_KV, CPY_D, kv)
228-
Tensor tKsK = gmem_thr_copy_kv.partition_D(sK); // (CPY, CPY_KV, CPY_D, PIPE)
229-
Tensor tVgV = gmem_thr_copy_kv.partition_S(gV); // (CPY, CPY_KV, CPY_D, kv)
230-
Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); // (CPY, CPY_KV, CPY_D, PIPE)
225+
// Create coordinate tensors for partitioning
226+
Tensor cKV = cute::make_identity_tensor(make_shape(CTA_KV, HEAD_DIM));
231227
Tensor tKVcKV = gmem_thr_copy_kv.partition_D(cKV); // (CPY, CPY_KV, CPY_D)
232228
Tensor tKVcKVGroup = flatten_1(tKVcKV); // (CPY, (CPY_KV, CPY_D))
229+
Tensor tKsK = gmem_thr_copy_kv.partition_D(sK); // (CPY, CPY_KV, CPY_D, PIPE)
230+
Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); // (CPY, CPY_KV, CPY_D, PIPE)
233231

234-
int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV);
235-
auto predicate_fn = [&](auto coords) {
236-
auto s_coords = tKVcKVGroup(_0{}, coords);
237-
return elem_less(get<0>(s_coords), valid_last_kv_tile_size);
232+
// Lambda to load K/V tile with manual offset calculation
233+
auto load_kv_tile = [&](DTypeKV* base_ptr, int64_t stride_n, int64_t page_stride, auto& tXsX,
234+
int tile_idx, int pipe_idx, bool use_predicate) {
235+
using VecType = typename GmemTiledCopyKV::ValType;
236+
constexpr int VecSize = sizeof(VecType) / sizeof(DTypeKV);
237+
238+
int kv_base_idx = tile_idx * CTA_KV;
239+
int valid_tile_size = use_predicate ? std::min<int>(kv_len - kv_base_idx, CTA_KV) : CTA_KV;
240+
241+
// Flatten the destination tensor for this pipe stage
242+
Tensor tXsXiGroup = flatten_1(tXsX(_, _, _, pipe_idx)); // (CPY, (CPY_KV, CPY_D))
243+
244+
// Iterate over flattened elements this thread is responsible for
245+
CUTE_UNROLL
246+
for (int i = 0; i < size(tXsXiGroup); ++i) {
247+
auto coord = tKVcKVGroup(_0{}, i);
248+
int kv_offset = get<0>(coord);
249+
int d_idx = get<1>(coord);
250+
int kv_idx = kv_base_idx + kv_offset;
251+
252+
bool guard = kv_idx < kv_len && kv_offset < valid_tile_size;
253+
254+
// Compute page and offset within page
255+
uint32_t page_iter, entry_idx;
256+
mainloop_params.page_size.divmod(kv_idx, page_iter, entry_idx);
257+
IdType page_idx = kv_indices_ptr[page_iter];
258+
259+
// Compute address: base_ptr + page_idx * page_stride + entry_idx * stride_n + d_idx
260+
int64_t offset = page_idx * page_stride + entry_idx * stride_n + d_idx;
261+
VecType const* src_ptr = reinterpret_cast<VecType const*>(base_ptr + offset);
262+
VecType* dst_ptr = reinterpret_cast<VecType*>(&tXsXiGroup(0, i));
263+
264+
cutlass::arch::cp_async_zfill<sizeof(VecType), cutlass::arch::CacheOperation::Global>(
265+
dst_ptr, src_ptr, guard);
266+
}
238267
};
239268

269+
int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV);
270+
240271
// load last k-tile
241272
// all threads are issuing as TMA is disabled
242273
{
243274
pipeline_k.producer_acquire(smem_pipe_write);
244-
Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D))
245-
Tensor tKsKiGroup =
246-
flatten_1(tKsK(_, _, _, smem_pipe_write.index())); // (CPY, (CPY_KV, CPY_D))
247-
copy_if(gmem_tiled_copy_kv, predicate_fn, tKgKiGroup, tKsKiGroup);
275+
load_kv_tile(k_base_ptr, k_stride_n, k_page_stride, tKsK, kv_tile_idx,
276+
smem_pipe_write.index(), true);
248277
pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
249278
}
250279

@@ -266,10 +295,8 @@ struct FP8SparseCollectiveMainloop {
266295
if (kv_tile_idx == swa_begin_kv_tile_idx) {
267296
// first tile is the last tile
268297
pipeline_v.producer_acquire(smem_pipe_write);
269-
Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D))
270-
Tensor tVsViGroup =
271-
flatten_1(tVsV(_, _, _, smem_pipe_write.index())); // (CPY, (CPY_KV, CPY_D))
272-
copy_if(gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup);
298+
load_kv_tile(v_base_ptr, v_stride_n, v_page_stride, tVsV, kv_tile_idx,
299+
smem_pipe_write.index(), true);
273300
pipeline_v.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
274301

275302
// Transpose V
@@ -283,10 +310,8 @@ struct FP8SparseCollectiveMainloop {
283310
} else {
284311
// load second last k-tile and last v-tile
285312
pipeline_v.producer_acquire(smem_pipe_write);
286-
Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D))
287-
Tensor tVsViGroup =
288-
flatten_1(tVsV(_, _, _, smem_pipe_write.index())); // (CPY, (CPY_KV, CPY_D))
289-
copy_if(gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup);
313+
load_kv_tile(v_base_ptr, v_stride_n, v_page_stride, tVsV, kv_tile_idx,
314+
smem_pipe_write.index(), true);
290315
pipeline_v.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
291316

292317
// Transpose V
@@ -299,9 +324,8 @@ struct FP8SparseCollectiveMainloop {
299324
++smem_pipe_write; // update state, as K is loaded 1 step faster
300325

301326
pipeline_k.producer_acquire(smem_pipe_write);
302-
Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D)
303-
Tensor tKsKi = tKsK(_, _, _, smem_pipe_write.index()); // (CPY, CPY_KV, CPY_D)
304-
copy(gmem_tiled_copy_kv, tKgKi, tKsKi);
327+
load_kv_tile(k_base_ptr, k_stride_n, k_page_stride, tKsK, kv_tile_idx - 1,
328+
smem_pipe_write.index(), false);
305329
pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
306330

307331
--kv_tile_idx;
@@ -310,9 +334,8 @@ struct FP8SparseCollectiveMainloop {
310334
#pragma unroll 2
311335
for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) {
312336
pipeline_v.producer_acquire(smem_pipe_write);
313-
Tensor tVgVi = tVgV(_, _, _, kv_tile_idx); // (CPY, CPY_KV, CPY_D)
314-
Tensor tVsVi = tVsV(_, _, _, smem_pipe_write.index()); // (CPY, CPY_KV, CPY_D)
315-
copy(gmem_tiled_copy_kv, tVgVi, tVsVi);
337+
load_kv_tile(v_base_ptr, v_stride_n, v_page_stride, tVsV, kv_tile_idx,
338+
smem_pipe_write.index(), false);
316339
pipeline_v.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
317340

318341
// Transpose V
@@ -325,19 +348,17 @@ struct FP8SparseCollectiveMainloop {
325348
++smem_pipe_write; // update state, as K is loaded 1 step faster
326349

327350
pipeline_k.producer_acquire(smem_pipe_write);
328-
Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D)
329-
Tensor tKsKi = tKsK(_, _, _, smem_pipe_write.index()); // (CPY, CPY_KV, CPY_D)
330-
copy(gmem_tiled_copy_kv, tKgKi, tKsKi);
351+
load_kv_tile(k_base_ptr, k_stride_n, k_page_stride, tKsK, kv_tile_idx - 1,
352+
smem_pipe_write.index(), false);
331353
pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
332354
}
333355
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
334356

335357
// load first v tile
336358
{
337359
pipeline_v.producer_acquire(smem_pipe_write);
338-
Tensor tVgVi = tVgV(_, _, _, 0); // (CPY, (CPY_KV, CPY_D))
339-
Tensor tVsVi = tVsV(_, _, _, smem_pipe_write.index()); // (CPY, (CPY_KV, CPY_D))
340-
copy(gmem_tiled_copy_kv, tVgVi, tVsVi);
360+
load_kv_tile(v_base_ptr, v_stride_n, v_page_stride, tVsV, 0, smem_pipe_write.index(),
361+
false);
341362
pipeline_v.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
342363

343364
// Transpose V

include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,14 @@ cudaError_t BatchFP8PrefillWithPagedKVCacheKernelTraitsDispatched(Params& params
334334
params.q_stride_n,
335335
params.q_stride_h), // layout_Q
336336
params.k_ptr,
337-
// NOTE(Zihao): nnz was useless here, we can just pass 0
338-
get_gmem_layout(/*nnz=*/0, params.num_kv_heads, KernelTraits::HEAD_DIM, params.k_stride_n,
339-
params.k_stride_h), // layout_K
337+
params.k_stride_n, // k_stride_n
338+
params.k_page_stride, // k_page_stride
340339
params.v_ptr,
341-
get_gmem_layout(/*nnz=*/0, params.num_kv_heads, KernelTraits::HEAD_DIM, params.v_stride_n,
342-
params.v_stride_h), // layout_V
343-
params.kv_indices, params.window_left, params.additional_params});
340+
params.v_stride_n, // v_stride_n
341+
params.v_page_stride, // v_page_stride
342+
params.kv_indices,
343+
static_cast<uint32_t>(params.page_size), // page_size
344+
params.window_left, params.additional_params});
344345
typename CollectiveEpilogue::Params epilogue_params =
345346
CollectiveEpilogue::to_underlying_arguments({
346347
params.o_ptr,
File renamed without changes.

0 commit comments

Comments
 (0)