Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 0 additions & 44 deletions cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,6 @@ struct Gmem_tile_contiguous_kv
template <typename Smem_tile>
inline __device__ void load(Smem_tile& smem_tile)
{
// TODO(perkzz): add remap_kv_row for sliding window attention.
uint32_t preds[LDGS];
#pragma unroll
for (int ii = 0; ii < LDGS; ++ii)
Expand Down Expand Up @@ -1091,42 +1090,6 @@ struct Gmem_tile_paged_kv
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////
// Remap the row to the one in cyclic kv cache.
inline __device__ void remap_kv_row(int& row)
{
// Sliding window attention + chunked context needs special handling.
if constexpr (SLIDING_WINDOW_ATTENTION)
{
// For chunked context (i.e. separate q and kv layout), the kv cache might be overwritten
// after last chunk is processed.
// To deal with this issue, the new tokens' kv will be appended to the kv cache first, and
// overwrite the kv cache after FMHA is done.
// The kv input layout is like: [cyclic kv cache] + [new tokens' kv].
// There are two possible cases:
// 1. The kv cache hasn't been overwritten while processing previous chunks, so we can take
// it normally, where we have full kv cache.
// 2. The kv cache has been overwritten while processing previous chunks. we need to mask
// out the tokens in the kv cache based on the sliding window size. It needs to track the
// last kv cache token's position in a circular way.

// Remap the kv row when kv cache has been overwritten in a circular way.
if (past_seqlen_ > sliding_window_size_)
{
// Map the kv row to the new tokens' kv.
if (row >= past_seqlen_)
{
row = sliding_window_size_ + (row - past_seqlen_);
}
else
{
// Map the kv row to the cyclic kv cache.
row = row % sliding_window_size_;
}
}
}
}

// Load data from memory.
template <typename Smem_tile>
inline __device__ void load(Smem_tile& smem_tile)
Expand All @@ -1144,13 +1107,6 @@ struct Gmem_tile_paged_kv
for (int ii = 0; ii < LDGS; ++ii)
{
int row_idx = row_ + ii * (int) ROWS_PER_LDG;

// Remap row_idx if sliding window attention is used.
// This will be removed later as the remapping will be handled by the kvCacheManger in TRTLLM.
#ifdef GENERATE_CUBIN
remap_kv_row(row_idx);
#endif

int paged_kv_block_idx = (row_idx >> paged_kv_log2_block_size_);
char const* local_kv_ptr = reinterpret_cast<char*>(paged_kv_block_pool_ptr_
+ params_kv_block_size_in_bytes_ * paged_kv_global_block_offsets_[paged_kv_block_idx]);
Expand Down
4 changes: 2 additions & 2 deletions cpp/kernels/fmha_v2/src/fmha/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ struct Mask<Traits, Cta_tile, 4> : public Mask<Traits, Cta_tile, 3>
inline __device__ bool is_valid(int row, int col) const
{
// Is it a valid position in the sequence, i.e. are we in the lower triangle?
return (row >= col) && (col >= max(0, row - sliding_window_size_));
return (row >= col) && (col >= max(0, row + 1 - sliding_window_size_));
}

// The sliding window size.
Expand Down Expand Up @@ -946,7 +946,7 @@ struct Mask_hopper<Traits, Cta_tile, 4> : public Mask_hopper<Traits, Cta_tile, 3
inline __device__ bool is_valid(int row, int col) const
{
// Is it a valid position in the sequence?
return col <= row && col >= max(0, row - sliding_window_size_);
return col <= row && col >= max(0, row + 1 - sliding_window_size_);
}

// The sliding window size for attention.
Expand Down
2 changes: 1 addition & 1 deletion cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ struct Compute
// The kv_left_mask_end is the start of the chunk.
kv_left_mask_end = div_up(is_chunked_attention
? ((tile_offset_end >> params.log2_chunked_attention_size) << params.log2_chunked_attention_size)
: (tile_offset_end - params.sliding_window_size),
: (tile_offset_end + 1 - params.sliding_window_size),
STEP_KV);
}

Expand Down
57 changes: 4 additions & 53 deletions cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ struct DMA
// The kv_offset_start.
int kv_offset_start = is_chunked_attention
? ((q_step_offset >> params.log2_chunked_attention_size) << params.log2_chunked_attention_size)
: max(0, q_step_offset - params.sliding_window_size);
: max(0, q_step_offset + 1 - params.sliding_window_size);
kv_idx_start = kv_offset_start / STEP_KV;
}

Expand Down Expand Up @@ -388,51 +388,6 @@ struct DMA
elect_one_, {-1, -1, -1, -1, -1, -1, -1, -1});
}

// Calculate the start tile idx.
inline __device__ int remap_kv_tile_idx(
int kv_tile_idx, int num_kv_cache_tiles, int past_kv_length, int sliding_window_size)
{

// The remapped kv tile idx.
int remapped_kv_tile_idx = kv_tile_idx;
// This will be removed later as the remapping will be handled by the kvCacheManger in TRTLLM.
#ifdef GENERATE_CUBIN
// Sliding window attention + chunked context needs special handling.
if constexpr (SLIDING_OR_CHUNKED_ATTENTION)
{
// For chunked context (i.e. separate q and kv layout), the kv cache might be
// overwritten after last chunk is processed.
// To deal with this issue, the new tokens' kv will be appended to the kv cache first,
// and overwrite the kv cache after FMHA is done.
// The kv input layout is like: [cyclic kv cache] + [new tokens' kv].
// There are two possible cases:
// 1. The kv cache hasn't been overwritten while processing previous chunks, so we can
// take it normally, where we have full kv cache.
// 2. The kv cache has been overwritten while processing previous chunks. we need to
// mask out the tokens in the kv cache based on the sliding window size. It needs
// to track the last kv cache token's position in a circular way.

// Remap the kv tile index when kv cache has been overwritten in a circular way.
if (past_kv_length > sliding_window_size)
{
// Map the kv tile index to the new tokens' kv.
if (kv_tile_idx * STEP_KV >= past_kv_length)
{
remapped_kv_tile_idx
= num_kv_cache_tiles + int((kv_tile_idx * STEP_KV - past_kv_length) / STEP_KV);
}
else
{
// Map the kv tile index to the cyclic kv cache.
remapped_kv_tile_idx = kv_tile_idx % num_kv_cache_tiles;
}
}
}
#endif
// Return the remapped kv tile idx.
return remapped_kv_tile_idx;
}

// Support contiguous Q + contiguous/paged KV separate cache.
inline __device__ void run_separate_q_and_kv(
bert::Fused_multihead_attention_params_v2 const& params, Shared* shared)
Expand Down Expand Up @@ -560,24 +515,20 @@ struct DMA
// Iterate over the kv tiles for this q step.
for (int kv_step_idx = kv_idx_start; kv_step_idx < kv_idx_end; kv_step_idx++)
{
// Remap the kv tile idx if sliding window attention is enabled.
// Sliding_window_size should be multiple of STEP_KV.
int remapped_kv_step_idx = remap_kv_tile_idx(kv_step_idx, params.sliding_window_size / STEP_KV,
past_kv_length, params.sliding_window_size);
// The barrier id.
int bar_id;
// Load paged kv input.
if constexpr (PAGED_KV_INPUT)
{
bar_id = load_paged_kv(bidh_kv, remapped_kv_step_idx * STEP_KV, num_valid_kv_blocks,
bar_id = load_paged_kv(bidh_kv, kv_step_idx * STEP_KV, num_valid_kv_blocks,
params.paged_kv_cache.mTokensPerBlockLog2, params.blocks_per_tma_load,
params.blocks_per_tma_load_log2, params.paged_kv_cache.mMaxBlocksPerSeq,
paged_block_offsets, desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch);
}
else
{
bar_id = load_kv(bidh_kv, remapped_kv_step_idx * STEP_KV, desc_k, desc_v, shared, cbw_k,
cbw_v, cbw_v_scratch);
bar_id = load_kv(
bidh_kv, kv_step_idx * STEP_KV, desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch);
}

// Opportunistically hide headinfo in the shadow of UTMALDGs of the QKV tensor
Expand Down
2 changes: 1 addition & 1 deletion cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ struct Softmax_base
else
{
// The sliding window start is the max of 0 and row - sliding_window_size.
return max(0, row - sliding_window_size_);
return max(0, row + 1 - sliding_window_size_);
}
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1578,7 +1578,7 @@ int main(int argc, char** argv)
}
else
{
valid = valid && (si >= std::max(int(so - sliding_window_size), 0));
valid = valid && (si >= std::max(int(so + 1 - sliding_window_size), 0));
}
}
if (is_mtp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ inline __device__ void device_flash_attention_nl(Params const& params)

int const kv_loop_end = ((valid_seqlen + Cta_tile_p::N - 1) / Cta_tile_p::N) * Cta_tile_p::N;
int const kv_loop_start = mask_sliding_window
? (max(0, q_sequence_start - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
? (max(0, q_sequence_start + 1 - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
: 0;
int const sliding_window_mask_end = mask_sliding_window
? (max(0, q_sequence_start + Cta_tile_p::M - 1 - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
? (max(0, q_sequence_start + Cta_tile_p::M - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
: 0;

static_assert(Cta_tile_p::M >= Cta_tile_p::N, "");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ inline __device__ void device_flash_attention_nl_tiled(Params const& params)

int const kv_loop_end = ((valid_seqlen + Cta_tile_p::N - 1) / Cta_tile_p::N) * Cta_tile_p::N;
int const kv_loop_start = mask_sliding_window
? (max(0, q_sequence_start - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
? (max(0, q_sequence_start + 1 - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
: 0;
int const sliding_window_mask_end = mask_sliding_window
? (max(0, q_sequence_start + Cta_tile_p::M - 1 - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
? (max(0, q_sequence_start + Cta_tile_p::M - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
: 0;

// Move K and V tiles.
Expand Down
4 changes: 4 additions & 0 deletions cpp/kernels/xqa/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ static_assert(CACHE_ELEM_ENUM != 0);
#define OPTIMIZE_FOR_LATENCY 1
#endif

#ifndef IS_SPEC_DEC_TREE
#define IS_SPEC_DEC_TREE 1 // by default SPEC_DEC expect tree-based draft token structure
#endif

#define DBG_BATCH_SIZE 2
#define DBG_SEQ_LEN 256 * 4 + 3
#define DBG_NB_CTAS_PER_SEQ 8
Expand Down
1 change: 0 additions & 1 deletion cpp/kernels/xqa/mha.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1592,7 +1592,6 @@ CUBIN_EXPORT __global__
#endif

uint32_t const cacheSeqLen = getCacheSeqLen<usePagedKVCache>(cacheList, idxReq);
static_assert(!(allowSlidingWindow && useSpecDec), "Sliding window is not yet supported in spec-dec mode");
#if SLIDING_WINDOW
bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0;
Expand Down
Loading