@@ -99,23 +99,29 @@ struct FP8SparseCollectiveMainloop {
9999 DTypeQ const * Q_ptr;
100100 LayoutT layout_Q;
101101 DTypeKV const * K_ptr;
102- LayoutT layout_K;
102+ int64_t k_stride_n; // Stride between consecutive KV tokens
103+ int64_t k_page_stride; // Stride between pages
103104 DTypeKV const * V_ptr;
104- LayoutT layout_V;
105+ int64_t v_stride_n; // Stride between consecutive KV tokens
106+ int64_t v_page_stride; // Stride between pages
105107 IdType const * kv_indices;
108+ uint32_t page_size; // Size of each page
106109 int window_left;
107110 AdditionalParams additional_params;
108111 };
109112
110113 // Device side kernel params
111114 struct Params {
112115 LayoutT layout_Q;
113- LayoutT layout_K;
114- LayoutT layout_V;
115116 TMA_Q tma_load_Q;
116117 DTypeKV* K_ptr;
118+ int64_t k_stride_n;
119+ int64_t k_page_stride;
117120 DTypeKV* V_ptr;
121+ int64_t v_stride_n;
122+ int64_t v_page_stride;
118123 IdType* kv_indices;
124+ uint_fastdiv page_size; // Size of each page (as fastdiv for efficient divmod)
119125 int window_left;
120126 AdditionalParams additional_params;
121127 using DTypeKV = typename Ktraits::DTypeKV;
@@ -125,15 +131,10 @@ struct FP8SparseCollectiveMainloop {
125131 Tensor mQ = make_tensor (make_gmem_ptr (args.Q_ptr ), args.layout_Q );
126132 TMA_Q tma_load_Q =
127133 make_tma_copy (GmemTiledCopyQ{}, mQ , SmemLayoutQ{}, select<0 , 2 >(TileShape_QKD{}), _1{});
128- return {args.layout_Q ,
129- args.layout_K ,
130- args.layout_V ,
131- tma_load_Q,
132- const_cast <DTypeKV*>(args.K_ptr ),
133- const_cast <DTypeKV*>(args.V_ptr ),
134- const_cast <IdType*>(args.kv_indices ),
135- args.window_left ,
136- args.additional_params };
134+ return {args.layout_Q , tma_load_Q, const_cast <DTypeKV*>(args.K_ptr ),
135+ args.k_stride_n , args.k_page_stride , const_cast <DTypeKV*>(args.V_ptr ),
136+ args.v_stride_n , args.v_page_stride , const_cast <IdType*>(args.kv_indices ),
137+ args.page_size , args.window_left , args.additional_params };
137138 }
138139
139140 CUTLASS_DEVICE
@@ -208,43 +209,71 @@ struct FP8SparseCollectiveMainloop {
208209
209210 constexpr int HEAD_DIM = get<2 >(TileShape_QKD{});
210211 constexpr int CTA_KV = get<1 >(TileShape_QKD{});
211- auto indexed_gather = BlockSparseIndexedGather<IdType>( mainloop_params.kv_indices + kv_indptr) ;
212+ IdType const * kv_indices_ptr = mainloop_params.kv_indices + kv_indptr;
212213
213- Tensor mK = make_block_sparse_tensor ( // (kv_len, D)
214- make_gmem_ptr (mainloop_params.K_ptr + kv_head_idx * stride<2 >(mainloop_params.layout_K )),
215- make_shape (kv_len, HEAD_DIM), stride<0 >(mainloop_params.layout_K ), indexed_gather);
216- Tensor mV = make_block_sparse_tensor ( // (kv_len, D)
217- make_gmem_ptr (mainloop_params.V_ptr + kv_head_idx * stride<2 >(mainloop_params.layout_V )),
218- make_shape (kv_len, HEAD_DIM), stride<0 >(mainloop_params.layout_V ), indexed_gather);
219-
220- Tensor gK = local_tile (mK , select<1 , 2 >(TileShape_QKD{}), make_coord (_, _0{})); // (KV, D, kv)
221- Tensor gV = local_tile (mV , select<1 , 2 >(TileShape_QKD{}), make_coord (_, _0{})); // (KV, D, kv)
222- Tensor cKV = cute::make_identity_tensor (gK .shape ());
214+ // Setup for manual K/V loading with page table
215+ DTypeKV* k_base_ptr = mainloop_params.K_ptr ;
216+ DTypeKV* v_base_ptr = mainloop_params.V_ptr ;
217+ int64_t k_stride_n = mainloop_params.k_stride_n ;
218+ int64_t k_page_stride = mainloop_params.k_page_stride ;
219+ int64_t v_stride_n = mainloop_params.v_stride_n ;
220+ int64_t v_page_stride = mainloop_params.v_page_stride ;
223221
224222 GmemTiledCopyKV gmem_tiled_copy_kv;
225223 auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice (thread_idx);
226224
227- Tensor tKgK = gmem_thr_copy_kv.partition_S (gK ); // (CPY, CPY_KV, CPY_D, kv)
228- Tensor tKsK = gmem_thr_copy_kv.partition_D (sK ); // (CPY, CPY_KV, CPY_D, PIPE)
229- Tensor tVgV = gmem_thr_copy_kv.partition_S (gV ); // (CPY, CPY_KV, CPY_D, kv)
230- Tensor tVsV = gmem_thr_copy_kv.partition_D (sV ); // (CPY, CPY_KV, CPY_D, PIPE)
225+ // Create coordinate tensors for partitioning
226+ Tensor cKV = cute::make_identity_tensor (make_shape (CTA_KV, HEAD_DIM));
231227 Tensor tKVcKV = gmem_thr_copy_kv.partition_D (cKV); // (CPY, CPY_KV, CPY_D)
232228 Tensor tKVcKVGroup = flatten_1 (tKVcKV); // (CPY, (CPY_KV, CPY_D))
229+ Tensor tKsK = gmem_thr_copy_kv.partition_D (sK ); // (CPY, CPY_KV, CPY_D, PIPE)
230+ Tensor tVsV = gmem_thr_copy_kv.partition_D (sV ); // (CPY, CPY_KV, CPY_D, PIPE)
233231
234- int valid_last_kv_tile_size = std::min<int >(kv_len - kv_tile_idx * CTA_KV, CTA_KV);
235- auto predicate_fn = [&](auto coords) {
236- auto s_coords = tKVcKVGroup (_0{}, coords);
237- return elem_less (get<0 >(s_coords), valid_last_kv_tile_size);
232+ // Lambda to load K/V tile with manual offset calculation
233+ auto load_kv_tile = [&](DTypeKV* base_ptr, int64_t stride_n, int64_t page_stride, auto & tXsX,
234+ int tile_idx, int pipe_idx, bool use_predicate) {
235+ using VecType = typename GmemTiledCopyKV::ValType;
236+ constexpr int VecSize = sizeof (VecType) / sizeof (DTypeKV);
237+
238+ int kv_base_idx = tile_idx * CTA_KV;
239+ int valid_tile_size = use_predicate ? std::min<int >(kv_len - kv_base_idx, CTA_KV) : CTA_KV;
240+
241+ // Flatten the destination tensor for this pipe stage
242+ Tensor tXsXiGroup = flatten_1 (tXsX (_, _, _, pipe_idx)); // (CPY, (CPY_KV, CPY_D))
243+
244+ // Iterate over flattened elements this thread is responsible for
245+ CUTE_UNROLL
246+ for (int i = 0 ; i < size (tXsXiGroup); ++i) {
247+ auto coord = tKVcKVGroup (_0{}, i);
248+ int kv_offset = get<0 >(coord);
249+ int d_idx = get<1 >(coord);
250+ int kv_idx = kv_base_idx + kv_offset;
251+
252+ bool guard = kv_idx < kv_len && kv_offset < valid_tile_size;
253+
254+ // Compute page and offset within page
255+ uint32_t page_iter, entry_idx;
256+ mainloop_params.page_size .divmod (kv_idx, page_iter, entry_idx);
257+ IdType page_idx = kv_indices_ptr[page_iter];
258+
259+ // Compute address: base_ptr + page_idx * page_stride + entry_idx * stride_n + d_idx
260+ int64_t offset = page_idx * page_stride + entry_idx * stride_n + d_idx;
261+ VecType const * src_ptr = reinterpret_cast <VecType const *>(base_ptr + offset);
262+ VecType* dst_ptr = reinterpret_cast <VecType*>(&tXsXiGroup (0 , i));
263+
264+ cutlass::arch::cp_async_zfill<sizeof (VecType), cutlass::arch::CacheOperation::Global>(
265+ dst_ptr, src_ptr, guard);
266+ }
238267 };
239268
269+ int valid_last_kv_tile_size = std::min<int >(kv_len - kv_tile_idx * CTA_KV, CTA_KV);
270+
240271 // load last k-tile
241272 // all threads are issuing as TMA is disabled
242273 {
243274 pipeline_k.producer_acquire (smem_pipe_write);
244- Tensor tKgKiGroup = flatten_1 (tKgK (_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D))
245- Tensor tKsKiGroup =
246- flatten_1 (tKsK (_, _, _, smem_pipe_write.index ())); // (CPY, (CPY_KV, CPY_D))
247- copy_if (gmem_tiled_copy_kv, predicate_fn, tKgKiGroup, tKsKiGroup);
275+ load_kv_tile (k_base_ptr, k_stride_n, k_page_stride, tKsK, kv_tile_idx,
276+ smem_pipe_write.index (), true );
248277 pipeline_k.producer_commit (smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
249278 }
250279
@@ -266,10 +295,8 @@ struct FP8SparseCollectiveMainloop {
266295 if (kv_tile_idx == swa_begin_kv_tile_idx) {
267296 // first tile is the last tile
268297 pipeline_v.producer_acquire (smem_pipe_write);
269- Tensor tVgViGroup = flatten_1 (tVgV (_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D))
270- Tensor tVsViGroup =
271- flatten_1 (tVsV (_, _, _, smem_pipe_write.index ())); // (CPY, (CPY_KV, CPY_D))
272- copy_if (gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup);
298+ load_kv_tile (v_base_ptr, v_stride_n, v_page_stride, tVsV, kv_tile_idx,
299+ smem_pipe_write.index (), true );
273300 pipeline_v.producer_commit (smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
274301
275302 // Transpose V
@@ -283,10 +310,8 @@ struct FP8SparseCollectiveMainloop {
283310 } else {
284311 // load second last k-tile and last v-tile
285312 pipeline_v.producer_acquire (smem_pipe_write);
286- Tensor tVgViGroup = flatten_1 (tVgV (_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D))
287- Tensor tVsViGroup =
288- flatten_1 (tVsV (_, _, _, smem_pipe_write.index ())); // (CPY, (CPY_KV, CPY_D))
289- copy_if (gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup);
313+ load_kv_tile (v_base_ptr, v_stride_n, v_page_stride, tVsV, kv_tile_idx,
314+ smem_pipe_write.index (), true );
290315 pipeline_v.producer_commit (smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
291316
292317 // Transpose V
@@ -299,9 +324,8 @@ struct FP8SparseCollectiveMainloop {
299324 ++smem_pipe_write; // update state, as K is loaded 1 step faster
300325
301326 pipeline_k.producer_acquire (smem_pipe_write);
302- Tensor tKgKi = tKgK (_, _, _, kv_tile_idx - 1 ); // (CPY, CPY_KV, CPY_D)
303- Tensor tKsKi = tKsK (_, _, _, smem_pipe_write.index ()); // (CPY, CPY_KV, CPY_D)
304- copy (gmem_tiled_copy_kv, tKgKi, tKsKi);
327+ load_kv_tile (k_base_ptr, k_stride_n, k_page_stride, tKsK, kv_tile_idx - 1 ,
328+ smem_pipe_write.index (), false );
305329 pipeline_k.producer_commit (smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
306330
307331 --kv_tile_idx;
@@ -310,9 +334,8 @@ struct FP8SparseCollectiveMainloop {
310334#pragma unroll 2
311335 for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) {
312336 pipeline_v.producer_acquire (smem_pipe_write);
313- Tensor tVgVi = tVgV (_, _, _, kv_tile_idx); // (CPY, CPY_KV, CPY_D)
314- Tensor tVsVi = tVsV (_, _, _, smem_pipe_write.index ()); // (CPY, CPY_KV, CPY_D)
315- copy (gmem_tiled_copy_kv, tVgVi, tVsVi);
337+ load_kv_tile (v_base_ptr, v_stride_n, v_page_stride, tVsV, kv_tile_idx,
338+ smem_pipe_write.index (), false );
316339 pipeline_v.producer_commit (smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
317340
318341 // Transpose V
@@ -325,19 +348,17 @@ struct FP8SparseCollectiveMainloop {
325348 ++smem_pipe_write; // update state, as K is loaded 1 step faster
326349
327350 pipeline_k.producer_acquire (smem_pipe_write);
328- Tensor tKgKi = tKgK (_, _, _, kv_tile_idx - 1 ); // (CPY, CPY_KV, CPY_D)
329- Tensor tKsKi = tKsK (_, _, _, smem_pipe_write.index ()); // (CPY, CPY_KV, CPY_D)
330- copy (gmem_tiled_copy_kv, tKgKi, tKsKi);
351+ load_kv_tile (k_base_ptr, k_stride_n, k_page_stride, tKsK, kv_tile_idx - 1 ,
352+ smem_pipe_write.index (), false );
331353 pipeline_k.producer_commit (smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
332354 }
333355 scheduler.prefetch_next_work (scheduler_params, work_tile_info);
334356
335357 // load first v tile
336358 {
337359 pipeline_v.producer_acquire (smem_pipe_write);
338- Tensor tVgVi = tVgV (_, _, _, 0 ); // (CPY, (CPY_KV, CPY_D))
339- Tensor tVsVi = tVsV (_, _, _, smem_pipe_write.index ()); // (CPY, (CPY_KV, CPY_D))
340- copy (gmem_tiled_copy_kv, tVgVi, tVsVi);
360+ load_kv_tile (v_base_ptr, v_stride_n, v_page_stride, tVsV, 0 , smem_pipe_write.index (),
361+ false );
341362 pipeline_v.producer_commit (smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);
342363
343364 // Transpose V
0 commit comments