Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,11 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel(
const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE;
// Find batch index within a block
__shared__ int batch_idx[BLOCK_Y_SIZE];
if (threadIdx.x == 0) {
batch_idx[threadIdx.y] = -1;
}
__syncthreads();

for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x));
iter++) {
int tid = iter * blockDim.x + threadIdx.x;
Expand All @@ -611,16 +616,18 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel(
}
}

#ifndef USE_ROCM
__syncwarp();
#endif
__syncthreads();

if (head_idx >= head_dim || token_idx >= num_tokens) {
// num_tokens may be an allocation upper bound when Python avoids a D2H sync.
// Only tokens covered by the exact device-side cu_seq_lens are valid to
// gather.
const int batch = batch_idx[threadIdx.y];
if (head_idx >= head_dim || token_idx >= num_tokens || batch < 0) {
return;
}
const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]];
const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks +
inbatch_seq_idx / cache_block_size];
const int inbatch_seq_idx = token_idx - cu_seq_lens[batch];
const int block_idx =
block_table[batch * num_blocks + inbatch_seq_idx / cache_block_size];
const int64_t src_block_offset = block_idx * block_stride;
const int64_t cache_inblock_offset =
(inbatch_seq_idx % cache_block_size) * head_dim + head_idx;
Expand Down
Loading