@@ -199,7 +199,7 @@ struct DMA
199199 // The kv_offset_start.
200200 int kv_offset_start = is_chunked_attention
201201 ? ((q_step_offset >> params.log2_chunked_attention_size ) << params.log2_chunked_attention_size )
202- : max (0 , q_step_offset - params.sliding_window_size );
202+ : max (0 , q_step_offset + 1 - params.sliding_window_size );
203203 kv_idx_start = kv_offset_start / STEP_KV;
204204 }
205205
@@ -388,51 +388,6 @@ struct DMA
388388 elect_one_, {-1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 });
389389 }
390390
391- // Calculate the start tile idx.
392- inline __device__ int remap_kv_tile_idx (
393- int kv_tile_idx, int num_kv_cache_tiles, int past_kv_length, int sliding_window_size)
394- {
395-
396- // The remapped kv tile idx.
397- int remapped_kv_tile_idx = kv_tile_idx;
398- // This will be removed later as the remapping will be handled by the kvCacheManger in TRTLLM.
399- #ifdef GENERATE_CUBIN
400- // Sliding window attention + chunked context needs special handling.
401- if constexpr (SLIDING_OR_CHUNKED_ATTENTION)
402- {
403- // For chunked context (i.e. separate q and kv layout), the kv cache might be
404- // overwritten after last chunk is processed.
405- // To deal with this issue, the new tokens' kv will be appended to the kv cache first,
406- // and overwrite the kv cache after FMHA is done.
407- // The kv input layout is like: [cyclic kv cache] + [new tokens' kv].
408- // There are two possible cases:
409- // 1. The kv cache hasn't been overwritten while processing previous chunks, so we can
410- // take it normally, where we have full kv cache.
411- // 2. The kv cache has been overwritten while processing previous chunks. we need to
412- // mask out the tokens in the kv cache based on the sliding window size. It needs
413- // to track the last kv cache token's position in a circular way.
414-
415- // Remap the kv tile index when kv cache has been overwritten in a circular way.
416- if (past_kv_length > sliding_window_size)
417- {
418- // Map the kv tile index to the new tokens' kv.
419- if (kv_tile_idx * STEP_KV >= past_kv_length)
420- {
421- remapped_kv_tile_idx
422- = num_kv_cache_tiles + int ((kv_tile_idx * STEP_KV - past_kv_length) / STEP_KV);
423- }
424- else
425- {
426- // Map the kv tile index to the cyclic kv cache.
427- remapped_kv_tile_idx = kv_tile_idx % num_kv_cache_tiles;
428- }
429- }
430- }
431- #endif
432- // Return the remapped kv tile idx.
433- return remapped_kv_tile_idx;
434- }
435-
436391 // Support contiguous Q + contiguous/paged KV separate cache.
437392 inline __device__ void run_separate_q_and_kv (
438393 bert::Fused_multihead_attention_params_v2 const & params, Shared* shared)
@@ -560,24 +515,20 @@ struct DMA
560515 // Iterate over the kv tiles for this q step.
561516 for (int kv_step_idx = kv_idx_start; kv_step_idx < kv_idx_end; kv_step_idx++)
562517 {
563- // Remap the kv tile idx if sliding window attention is enabled.
564- // Sliding_window_size should be multiple of STEP_KV.
565- int remapped_kv_step_idx = remap_kv_tile_idx (kv_step_idx, params.sliding_window_size / STEP_KV,
566- past_kv_length, params.sliding_window_size );
567518 // The barrier id.
568519 int bar_id;
569520 // Load paged kv input.
570521 if constexpr (PAGED_KV_INPUT)
571522 {
572- bar_id = load_paged_kv (bidh_kv, remapped_kv_step_idx * STEP_KV, num_valid_kv_blocks,
523+ bar_id = load_paged_kv (bidh_kv, kv_step_idx * STEP_KV, num_valid_kv_blocks,
573524 params.paged_kv_cache .mTokensPerBlockLog2 , params.blocks_per_tma_load ,
574525 params.blocks_per_tma_load_log2 , params.paged_kv_cache .mMaxBlocksPerSeq ,
575526 paged_block_offsets, desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch);
576527 }
577528 else
578529 {
579- bar_id = load_kv (bidh_kv, remapped_kv_step_idx * STEP_KV, desc_k, desc_v, shared, cbw_k,
580- cbw_v, cbw_v_scratch);
530+ bar_id = load_kv (
531+ bidh_kv, kv_step_idx * STEP_KV, desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch);
581532 }
582533
583534 // Opportunistically hide headinfo in the shadow of UTMALDGs of the QKV tensor
0 commit comments