diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 6bea5abc3dfb..486b39f23ff7 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -599,6 +599,11 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; // Find batch index within a block __shared__ int batch_idx[BLOCK_Y_SIZE]; + if (threadIdx.x == 0) { + batch_idx[threadIdx.y] = -1; + } + __syncthreads(); + for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x)); iter++) { int tid = iter * blockDim.x + threadIdx.x; @@ -611,16 +616,18 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( } } -#ifndef USE_ROCM - __syncwarp(); -#endif + __syncthreads(); - if (head_idx >= head_dim || token_idx >= num_tokens) { + // num_tokens may be an allocation upper bound when Python avoids a D2H sync. + // Only tokens covered by the exact device-side cu_seq_lens are valid to + // gather. + const int batch = batch_idx[threadIdx.y]; + if (head_idx >= head_dim || token_idx >= num_tokens || batch < 0) { return; } - const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]]; - const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks + - inbatch_seq_idx / cache_block_size]; + const int inbatch_seq_idx = token_idx - cu_seq_lens[batch]; + const int block_idx = + block_table[batch * num_blocks + inbatch_seq_idx / cache_block_size]; const int64_t src_block_offset = block_idx * block_stride; const int64_t cache_inblock_offset = (inbatch_seq_idx % cache_block_size) * head_dim + head_idx;