Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions flashinfer/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,13 @@ def get_batch_indices_positions(
dtype = torch.int32

if batch_indices is None:
batch_indices = torch.empty((nnz,), device=device, dtype=dtype)
batch_indices = torch.full((nnz,), -1, device=device, dtype=dtype)
else:
check_shape_dtype_device(batch_indices, (nnz,), dtype, device, "batch_indices")
batch_indices.fill_(-1)

if positions is None:
positions = torch.empty((nnz,), device=device, dtype=dtype)
positions = torch.zeros((nnz,), device=device, dtype=dtype)
else:
check_shape_dtype_device(positions, (nnz,), dtype, device, "positions")

Expand Down
288 changes: 147 additions & 141 deletions include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -859,169 +859,175 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel(
const uint32_t idx = bx * bdy + ty;
const RoPEIdType pos = pos_ids[idx];

// Compute page location for this token
uint32_t page_iter, entry_idx;
paged_kv_like.page_size.divmod(
paged_kv_like.indptr[batch_indices[idx]] * paged_kv_like.page_size + positions[idx],
page_iter, entry_idx);

const int half_rope_dim = rope_dim / 2;
// Load cos/sin for RoPE processing blocks only
if ((tx * vec_size < rope_dim) && (by < k_rope_end)) {
int sin_offset = rope_dim / 2;
int vec_idx;
if constexpr (interleave) {
vec_idx = (tx * vec_size) / 2; // Force integer division
} else {
vec_idx = (tx * vec_size) % half_rope_dim;
// skip padding tokens with batch_indices < 0
if (batch_indices[idx] >= 0) {
Comment on lines +862 to +863
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main change is just this line. The following is just indent formatting.

// Compute page location for this token
uint32_t page_iter, entry_idx;
paged_kv_like.page_size.divmod(
paged_kv_like.indptr[batch_indices[idx]] * paged_kv_like.page_size + positions[idx],
page_iter, entry_idx);

const int half_rope_dim = rope_dim / 2;
// Load cos/sin for RoPE processing blocks only
if ((tx * vec_size < rope_dim) && (by < k_rope_end)) {
int sin_offset = rope_dim / 2;
int vec_idx;
if constexpr (interleave) {
vec_idx = (tx * vec_size) / 2; // Force integer division
} else {
vec_idx = (tx * vec_size) % half_rope_dim;
}
cos.load(cos_sin_cache + (pos * rope_dim) + vec_idx);
sin.load(cos_sin_cache + (pos * rope_dim) + (sin_offset + vec_idx));
}
cos.load(cos_sin_cache + (pos * rope_dim) + vec_idx);
sin.load(cos_sin_cache + (pos * rope_dim) + (sin_offset + vec_idx));
}

if (by < q_rope_end) {
// ============ Q RoPE processing ============
uint32_t q_head_idx = by / rope_chunks;
uint32_t rope_chunk_idx = by % rope_chunks;
uint32_t elem_offset = rope_chunk_idx * rope_chunk_size;
if (by < q_rope_end) {
// ============ Q RoPE processing ============
uint32_t q_head_idx = by / rope_chunks;
uint32_t rope_chunk_idx = by % rope_chunks;
uint32_t elem_offset = rope_chunk_idx * rope_chunk_size;

DType* q_rope_in_ptr =
q_rope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_in_stride_n,
q_rope_in_stride_h);
QuantType* q_rope_out_ptr =
q_rope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_out_stride_n,
q_rope_out_stride_h);
DType* q_rope_in_ptr =
q_rope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_in_stride_n,
q_rope_in_stride_h);
QuantType* q_rope_out_ptr =
q_rope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_out_stride_n,
q_rope_out_stride_h);

vec_t<float, vec_size> q_rope_vec;
if constexpr (interleave) {
q_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(
q_rope_in_ptr, cos, sin, rope_dim);
} else {
q_rope_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_rope_in_ptr, cos, sin, rope_dim);
}
vec_t<float, vec_size> q_rope_vec;
if constexpr (interleave) {
q_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(
q_rope_in_ptr, cos, sin, rope_dim);
} else {
q_rope_vec =
vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_rope_in_ptr, cos, sin, rope_dim);
}
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_rope_vec[i] = q_rope_vec[i] * quant_scale_q;
}
q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size);

} else if (by < k_rope_end) {
// ============ K RoPE processing & Cache Append ============
uint32_t k_head_idx = (by - q_rope_end) / rope_chunks;
uint32_t rope_chunk_idx = (by - q_rope_end) % rope_chunks;
uint32_t elem_offset = rope_chunk_idx * rope_chunk_size;

DType* k_rope_in_ptr;
if constexpr (IS_MLA) {
// MLA: 2D K
k_rope_in_ptr = k_rope_in + idx * k_rope_in_stride + elem_offset;
} else {
// GQA/MHA: 3D K
k_rope_in_ptr = k_rope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset,
k_rope_in_stride, k_rope_in_stride_h);
}
for (uint32_t i = 0; i < vec_size; ++i) {
q_rope_vec[i] = q_rope_vec[i] * quant_scale_q;
}
q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size);

} else if (by < k_rope_end) {
// ============ K RoPE processing & Cache Append ============
uint32_t k_head_idx = (by - q_rope_end) / rope_chunks;
uint32_t rope_chunk_idx = (by - q_rope_end) % rope_chunks;
uint32_t elem_offset = rope_chunk_idx * rope_chunk_size;

DType* k_rope_in_ptr;
if constexpr (IS_MLA) {
// MLA: 2D K
k_rope_in_ptr = k_rope_in + idx * k_rope_in_stride + elem_offset;
} else {
// GQA/MHA: 3D K
k_rope_in_ptr = k_rope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset,
k_rope_in_stride, k_rope_in_stride_h);
}

vec_t<float, vec_size> k_rope_vec;
if constexpr (interleave) {
k_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(
k_rope_in_ptr, cos, sin, rope_dim);
} else {
k_rope_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_rope_in_ptr, cos, sin, rope_dim);
}
vec_t<float, vec_size> k_rope_vec;
if constexpr (interleave) {
k_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(
k_rope_in_ptr, cos, sin, rope_dim);
} else {
k_rope_vec =
vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_rope_in_ptr, cos, sin, rope_dim);
}
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv;
}
for (uint32_t i = 0; i < vec_size; ++i) {
k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv;
}

if constexpr (IS_MLA) {
QuantType* kpe_ptr =
paged_kv_like.get_kpe_ptr(page_iter, entry_idx, elem_offset + tx * vec_size);
k_rope_vec.cast_store(kpe_ptr);
} else {
QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, tx * vec_size);
k_rope_vec.cast_store(k_ptr);
}
if constexpr (IS_MLA) {
QuantType* kpe_ptr =
paged_kv_like.get_kpe_ptr(page_iter, entry_idx, elem_offset + tx * vec_size);
k_rope_vec.cast_store(kpe_ptr);
} else {
QuantType* k_ptr =
paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, tx * vec_size);
k_rope_vec.cast_store(k_ptr);
}

} else if (by < k_nope_end) {
// ============ K Non-RoPE processing & Cache Append ============
uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks;
uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks;
uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;
} else if (by < k_nope_end) {
// ============ K Non-RoPE processing & Cache Append ============
uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks;
uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks;
uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;

DType* k_nope_in_ptr;
if constexpr (IS_MLA) {
k_nope_in_ptr = k_nope_in + idx * k_nope_in_stride + elem_offset;
} else {
k_nope_in_ptr = k_nope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset,
k_nope_in_stride, k_nope_in_stride_h);
}
DType* k_nope_in_ptr;
if constexpr (IS_MLA) {
k_nope_in_ptr = k_nope_in + idx * k_nope_in_stride + elem_offset;
} else {
k_nope_in_ptr = k_nope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset,
k_nope_in_stride, k_nope_in_stride_h);
}

vec_t<float, vec_size> k_nope_vec;
k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size);
vec_t<float, vec_size> k_nope_vec;
k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
}
for (uint32_t i = 0; i < vec_size; ++i) {
k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
}

if constexpr (IS_MLA) {
QuantType* ckv_ptr =
paged_kv_like.get_ckv_ptr(page_iter, entry_idx, elem_offset + tx * vec_size);
k_nope_vec.cast_store(ckv_ptr);
} else {
QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx,
rope_dim + elem_offset + tx * vec_size);
k_nope_vec.cast_store(k_ptr);
}
if constexpr (IS_MLA) {
QuantType* ckv_ptr =
paged_kv_like.get_ckv_ptr(page_iter, entry_idx, elem_offset + tx * vec_size);
k_nope_vec.cast_store(ckv_ptr);
} else {
QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx,
rope_dim + elem_offset + tx * vec_size);
k_nope_vec.cast_store(k_ptr);
}

} else if (by < k_nope_end + (IS_MLA ? 0u : num_kv_heads)) {
// ============ V processing & Cache Append (GQA/MHA only) ============
if constexpr (!IS_MLA) {
uint32_t kv_head_idx = by - k_nope_end;
DType* v_in_ptr =
v_in + get_elem_offset_impl(idx, kv_head_idx, 0, v_in_stride, v_in_stride_h);
// Cover the full head dimension (rope_dim + no_rope_dim) in chunks of rope_chunk_size
uint32_t head_dim_total = rope_dim + no_rope_dim;
uint32_t v_chunks = (head_dim_total + rope_chunk_size - 1) / rope_chunk_size;
} else if (by < k_nope_end + (IS_MLA ? 0u : num_kv_heads)) {
// ============ V processing & Cache Append (GQA/MHA only) ============
if constexpr (!IS_MLA) {
uint32_t kv_head_idx = by - k_nope_end;
DType* v_in_ptr =
v_in + get_elem_offset_impl(idx, kv_head_idx, 0, v_in_stride, v_in_stride_h);
// Cover the full head dimension (rope_dim + no_rope_dim) in chunks of rope_chunk_size
uint32_t head_dim_total = rope_dim + no_rope_dim;
uint32_t v_chunks = (head_dim_total + rope_chunk_size - 1) / rope_chunk_size;
#pragma unroll 1
for (uint32_t j = 0; j < v_chunks; ++j) {
uint32_t v_elem_offset = j * rope_chunk_size;
if (v_elem_offset + tx * vec_size < head_dim_total) {
vec_t<float, vec_size> v_vec;
v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size);
for (uint32_t j = 0; j < v_chunks; ++j) {
uint32_t v_elem_offset = j * rope_chunk_size;
if (v_elem_offset + tx * vec_size < head_dim_total) {
vec_t<float, vec_size> v_vec;
v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
v_vec[i] = v_vec[i] * quant_scale_kv;
for (uint32_t i = 0; i < vec_size; ++i) {
v_vec[i] = v_vec[i] * quant_scale_kv;
}
QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx,
v_elem_offset + tx * vec_size);
v_vec.cast_store(v_ptr);
}
QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx,
v_elem_offset + tx * vec_size);
v_vec.cast_store(v_ptr);
}
}
}

} else {
// ============ Q Non-RoPE processing ============
// MLA has no V section, so Q-nope starts immediately after K-nope.
// GQA/MHA has a V section of length num_kv_heads blocks.
uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads);
uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks;
uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks;
uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;

DType* q_nope_in_ptr =
q_nope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_in_stride_n,
q_nope_in_stride_h);
QuantType* q_nope_out_ptr =
q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n,
q_nope_out_stride_h);

vec_t<float, vec_size> q_nope_vec;
q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
} else {
// ============ Q Non-RoPE processing ============
// MLA has no V section, so Q-nope starts immediately after K-nope.
// GQA/MHA has a V section of length num_kv_heads blocks.
uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads);
uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks;
uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks;
uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;

DType* q_nope_in_ptr =
q_nope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_in_stride_n,
q_nope_in_stride_h);
QuantType* q_nope_out_ptr =
q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n,
q_nope_out_stride_h);

vec_t<float, vec_size> q_nope_vec;
q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
for (uint32_t i = 0; i < vec_size; ++i) {
q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
}
q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
}
q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
}
}
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
Expand Down
Loading
Loading