Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 88 additions & 64 deletions csrc/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,24 @@ void rope_quantize_append_paged_kv_cache(
CHECK_INPUT(batch_indices);
CHECK_INPUT(positions);

// Validate that all index tensors have the same dtype as pos_ids
if (kv_indices.dtype() != pos_ids.dtype()) {
TVM_FFI_LOG_AND_THROW(TypeError) << "kv_indices dtype (" << kv_indices.dtype()
<< ") must match pos_ids dtype (" << pos_ids.dtype() << ")";
}
if (kv_indptr.dtype() != pos_ids.dtype()) {
TVM_FFI_LOG_AND_THROW(TypeError) << "kv_indptr dtype (" << kv_indptr.dtype()
<< ") must match pos_ids dtype (" << pos_ids.dtype() << ")";
}
if (batch_indices.dtype() != pos_ids.dtype()) {
TVM_FFI_LOG_AND_THROW(TypeError) << "batch_indices dtype (" << batch_indices.dtype()
<< ") must match pos_ids dtype (" << pos_ids.dtype() << ")";
}
if (positions.dtype() != pos_ids.dtype()) {
TVM_FFI_LOG_AND_THROW(TypeError) << "positions dtype (" << positions.dtype()
<< ") must match pos_ids dtype (" << pos_ids.dtype() << ")";
}

@elvischenv elvischenv Dec 22, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119 This won't work.
pos_ids is for RoPE part:

flashinfer/flashinfer/rope.py

Lines 1298 to 1314 in df82616

def rope_quantize_fp8(
q_rope: torch.Tensor,
k_rope: torch.Tensor,
q_nope: Optional[torch.Tensor],
k_nope: Optional[torch.Tensor],
cos_sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
is_neox: bool = True,
quantize_dtype: Optional[torch.dtype] = None,
quant_scale_q: float = 1.0,
quant_scale_kv: float = 1.0,
q_rope_out: Optional[torch.Tensor] = None,
k_rope_out: Optional[torch.Tensor] = None,
q_nope_out: Optional[torch.Tensor] = None,
k_nope_out: Optional[torch.Tensor] = None,
enable_pdl: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

batch_indices, positions, kv_indices, kv_indptr are for KV cache update part:

def append_paged_kv_cache(
append_key: torch.Tensor,
append_value: torch.Tensor,
batch_indices: torch.Tensor,
positions: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
kv_indices: torch.Tensor,
kv_indptr: torch.Tensor,
kv_last_page_len: torch.Tensor,
kv_layout: str = "NHD",
) -> None:

rope_quantize_fp8_append_paged_kv_cache is a merged version that has pos_ids, batch_indices, positions, kv_indices, kv_indptr:

flashinfer/flashinfer/rope.py

Lines 1438 to 1460 in df82616

def rope_quantize_fp8_append_paged_kv_cache(
q_rope: torch.Tensor,
k_rope: torch.Tensor,
q_nope: Optional[torch.Tensor],
k_nope: Optional[torch.Tensor],
v: Optional[torch.Tensor],
cos_sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
paged_kv_cache: Tuple[torch.Tensor, torch.Tensor],
kv_indices: torch.Tensor,
kv_indptr: torch.Tensor,
batch_indices: torch.Tensor,
positions: torch.Tensor,
is_neox: bool = True,
quantize_dtype: Optional[torch.dtype] = None,
quant_scale_q: float = 1.0,
quant_scale_kv: float = 1.0,
page_size: int = 16,
kv_layout: str = "NHD",
q_rope_out: Optional[torch.Tensor] = None,
q_nope_out: Optional[torch.Tensor] = None,
enable_pdl: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:

  • From my observation from SGLang, RoPE and KV update part can have different integer types.
  • That is, for RoPE part, we need a standalone IdType typename RoPEIdType for pos_ids. For KV update part, we need another standalone IdType typename PagedKVIdType for batch_indices, positions, kv_indices, kv_indptr.
  • So that we could support the combination like pos_ids in int64, and batch_indices, positions, kv_indices, kv_indptr in int32.

// Extract dimensions
uint32_t rope_dim = q_rope_in.size(-1);
uint32_t no_rope_dim = q_nope_in.size(-1);
Expand Down Expand Up @@ -547,71 +565,77 @@ void rope_quantize_append_paged_kv_cache(

DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in.dtype(), c_type, [&] {
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(q_rope_out.dtype(), c_quant_type, [&] {
cudaError_t status;

if (is_mla) {
// MLA: Construct paged_kv_mla_t struct
auto ckv_strides = ckv_cache.strides();
auto kpe_strides = kpe_cache.strides();

paged_kv_mla_t<c_quant_type, int32_t> paged_kv_mla(
page_size, no_rope_dim, rope_dim, batch_size,
static_cast<c_quant_type*>(ckv_cache.data_ptr()), ckv_strides.data(),
static_cast<c_quant_type*>(kpe_cache.data_ptr()), kpe_strides.data(),
static_cast<int32_t*>(kv_indices.data_ptr()),
static_cast<int32_t*>(kv_indptr.data_ptr()),
nullptr // last_page_len not needed for this kernel
);

status = RopeQuantizeAppendPagedMLACache(
static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()),
static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()),
static_cast<c_quant_type*>(q_rope_out.data_ptr()),
static_cast<c_quant_type*>(q_nope_out.data_ptr()), paged_kv_mla,
static_cast<int32_t*>(batch_indices.data_ptr()),
static_cast<int32_t*>(positions.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int32_t*>(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim,
q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h,
q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h,
k_rope_in_stride, k_nope_in_stride, quant_scale_q, quant_scale_kv, interleave,
enable_pdl, stream);

} else {
// GQA/MHA: Construct paged_kv_t struct
auto k_strides = k_cache.strides();
auto v_strides = v_cache.strides();
uint32_t head_dim = rope_dim + no_rope_dim;

paged_kv_t<c_quant_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout,
static_cast<c_quant_type*>(k_cache.data_ptr()),
static_cast<c_quant_type*>(v_cache.data_ptr()), k_strides.data(),
static_cast<int32_t*>(kv_indices.data_ptr()),
static_cast<int32_t*>(kv_indptr.data_ptr()),
nullptr // last_page_len not needed for this kernel
);

status = RopeQuantizeAppendPagedKVCache(
static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()),
static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()),
static_cast<c_type*>(v_in.data_ptr()),
static_cast<c_quant_type*>(q_rope_out.data_ptr()),
static_cast<c_quant_type*>(q_nope_out.data_ptr()), paged_kv,
static_cast<int32_t*>(batch_indices.data_ptr()),
static_cast<int32_t*>(positions.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int32_t*>(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim,
no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n,
q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n,
q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride,
k_nope_in_stride_h, v_in_stride, v_in_stride_h, quant_scale_q, quant_scale_kv,
interleave, enable_pdl, stream);
}
return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids.dtype(), c_idtype, [&] {
cudaError_t status;

if (is_mla) {
// MLA: Construct paged_kv_mla_t struct
auto ckv_strides = ckv_cache.strides();
auto kpe_strides = kpe_cache.strides();

paged_kv_mla_t<c_quant_type, c_idtype> paged_kv_mla(

@elvischenv elvischenv Dec 22, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • For RoPE part, I have added a DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids.dtype()... for dispatching the idtype for RoPE part integer type.
  • If we also want to dispatch a idtype for KV update part, we need another nest dispatcher like DISPATCH_DLPACK_IDTYPE_TO_CTYPE(kv_indptr.dtype()... for integer type in KV update code path.

page_size, no_rope_dim, rope_dim, batch_size,
static_cast<c_quant_type*>(ckv_cache.data_ptr()), ckv_strides.data(),
static_cast<c_quant_type*>(kpe_cache.data_ptr()), kpe_strides.data(),
static_cast<c_idtype*>(kv_indices.data_ptr()),
static_cast<c_idtype*>(kv_indptr.data_ptr()),
nullptr // last_page_len not needed for this kernel
);
Comment thread
elvischenv marked this conversation as resolved.

status = RopeQuantizeAppendPagedMLACache(
static_cast<c_type*>(q_rope_in.data_ptr()),
static_cast<c_type*>(k_rope_in.data_ptr()),
static_cast<c_type*>(q_nope_in.data_ptr()),
static_cast<c_type*>(k_nope_in.data_ptr()),
static_cast<c_quant_type*>(q_rope_out.data_ptr()),
static_cast<c_quant_type*>(q_nope_out.data_ptr()), paged_kv_mla,
static_cast<c_idtype*>(batch_indices.data_ptr()),
static_cast<c_idtype*>(positions.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<c_idtype*>(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim,
q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h,
q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h,
k_rope_in_stride, k_nope_in_stride, quant_scale_q, quant_scale_kv, interleave,
enable_pdl, stream);

} else {
// GQA/MHA: Construct paged_kv_t struct
auto k_strides = k_cache.strides();
auto v_strides = v_cache.strides();
uint32_t head_dim = rope_dim + no_rope_dim;

paged_kv_t<c_quant_type, c_idtype> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout,
static_cast<c_quant_type*>(k_cache.data_ptr()),
static_cast<c_quant_type*>(v_cache.data_ptr()), k_strides.data(),
static_cast<c_idtype*>(kv_indices.data_ptr()),
static_cast<c_idtype*>(kv_indptr.data_ptr()),
nullptr // last_page_len not needed for this kernel
);
Comment thread
elvischenv marked this conversation as resolved.

status = RopeQuantizeAppendPagedKVCache(
static_cast<c_type*>(q_rope_in.data_ptr()),
static_cast<c_type*>(k_rope_in.data_ptr()),
static_cast<c_type*>(q_nope_in.data_ptr()),
static_cast<c_type*>(k_nope_in.data_ptr()), static_cast<c_type*>(v_in.data_ptr()),
static_cast<c_quant_type*>(q_rope_out.data_ptr()),
static_cast<c_quant_type*>(q_nope_out.data_ptr()), paged_kv,
static_cast<c_idtype*>(batch_indices.data_ptr()),
static_cast<c_idtype*>(positions.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<c_idtype*>(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim,
no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n,
q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n,
q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride,
k_nope_in_stride_h, v_in_stride, v_in_stride_h, quant_scale_q, quant_scale_kv,
interleave, enable_pdl, stream);
}

TVM_FFI_ICHECK(status == cudaSuccess)
<< "RopeQuantizeAppendPagedKVCache failed with error code " << cudaGetErrorString(status);
return true;
TVM_FFI_ICHECK(status == cudaSuccess)
<< "RopeQuantizeAppendPagedKVCache failed with error code "
<< cudaGetErrorString(status);
return true;
});
});
});
}
53 changes: 27 additions & 26 deletions include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -803,13 +803,13 @@ __global__ void BatchQKApplyRotaryKernel(
* Templated on CacheT to support both GQA/MHA (paged_kv_t) and MLA (paged_kv_mla_t).
* Cache-only behaviors are selected with constexpr on the CacheT.
*/
template <bool interleave, uint32_t vec_size, uint32_t bdx, typename DType, typename IdType,
typename QuantType, typename CacheT>
template <bool interleave, uint32_t vec_size, uint32_t bdx, typename DType, typename RoPEIdType,
typename PagedKVIdType, typename QuantType, typename CacheT>
__global__ void RopeQuantizeAppendPagedKVCacheKernel(
DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in,
QuantType* q_rope_out, QuantType* q_nope_out, CacheT paged_kv_like,
IdType* __restrict__ batch_indices, IdType* __restrict__ positions,
float* __restrict__ cos_sin_cache, IdType* __restrict__ pos_ids,
PagedKVIdType* __restrict__ batch_indices, PagedKVIdType* __restrict__ positions,
float* __restrict__ cos_sin_cache, RoPEIdType* __restrict__ pos_ids,
const RopeQuantizeAppendPagedKVCacheParams params) {
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
Expand Down Expand Up @@ -852,12 +852,12 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel(
uint32_t k_nope_end = k_rope_end + num_kv_heads * no_rope_chunks;

// Deduce MLA vs GQA/MHA from CacheT
constexpr bool IS_MLA = std::is_same<CacheT, paged_kv_mla_t<QuantType, IdType>>::value;
constexpr bool IS_MLA = std::is_same<CacheT, paged_kv_mla_t<QuantType, PagedKVIdType>>::value;

vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) {
const uint32_t idx = bx * bdy + ty;
const IdType pos = pos_ids[idx];
const RoPEIdType pos = pos_ids[idx];

// Compute page location for this token
uint32_t page_iter, entry_idx;
Expand Down Expand Up @@ -1123,18 +1123,18 @@ cudaError_t RopeQuantize(
/*!
* \brief Host function to apply RoPE, quantize to FP8, and append K/V to paged cache (GQA/MHA)
*/
template <typename DType, typename IdType, typename QuantType>
template <typename DType, typename RoPEIdType, typename PagedKVIdType, typename QuantType>
cudaError_t RopeQuantizeAppendPagedKVCache(
DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in,
QuantType* q_rope_out, QuantType* q_nope_out, paged_kv_t<QuantType, IdType> paged_kv,
IdType* batch_indices, IdType* positions, float* cos_sin_cache, IdType* pos_ids, uint32_t nnz,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rope_dim, uint32_t no_rope_dim,
size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, size_t q_nope_in_stride_n,
size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, size_t q_rope_out_stride_h,
size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, size_t k_rope_in_stride,
size_t k_rope_in_stride_h, size_t k_nope_in_stride, size_t k_nope_in_stride_h,
size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q, float quant_scale_kv,
bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) {
QuantType* q_rope_out, QuantType* q_nope_out, paged_kv_t<QuantType, PagedKVIdType> paged_kv,
PagedKVIdType* batch_indices, PagedKVIdType* positions, float* cos_sin_cache,
RoPEIdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t rope_dim, uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h,
size_t q_nope_in_stride_n, size_t q_nope_in_stride_h, size_t q_rope_out_stride_n,
size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h,
size_t k_rope_in_stride, size_t k_rope_in_stride_h, size_t k_nope_in_stride,
size_t k_nope_in_stride_h, size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q,
float quant_scale_kv, bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) {
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
constexpr uint32_t vec_size = 32 / sizeof(DType);
uint32_t bdx = (rope_dim + vec_size - 1) / vec_size;
Expand Down Expand Up @@ -1163,9 +1163,9 @@ cudaError_t RopeQuantizeAppendPagedKVCache(
config.attrs = attribute;
config.numAttrs = 1;

auto kernel =
RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, /*bdx=*/1, DType, IdType,
QuantType, paged_kv_t<QuantType, IdType>>;
auto kernel = RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, /*bdx=*/1, DType,
RoPEIdType, PagedKVIdType, QuantType,
paged_kv_t<QuantType, PagedKVIdType>>;
RopeQuantizeAppendPagedKVCacheParams params;
params.nnz = nnz;
params.num_qo_heads = num_qo_heads;
Expand Down Expand Up @@ -1200,12 +1200,13 @@ cudaError_t RopeQuantizeAppendPagedKVCache(
/*!
* \brief Host function to apply RoPE, quantize to FP8, and append to MLA paged cache
*/
template <typename DType, typename IdType, typename QuantType>
template <typename DType, typename RoPEIdType, typename PagedKVIdType, typename QuantType>
cudaError_t RopeQuantizeAppendPagedMLACache(
DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, QuantType* q_rope_out,
QuantType* q_nope_out, paged_kv_mla_t<QuantType, IdType> paged_kv_mla, IdType* batch_indices,
IdType* positions, float* cos_sin_cache, IdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads,
uint32_t rope_dim, uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h,
QuantType* q_nope_out, paged_kv_mla_t<QuantType, PagedKVIdType> paged_kv_mla,
PagedKVIdType* batch_indices, PagedKVIdType* positions, float* cos_sin_cache,
RoPEIdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t rope_dim,
uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h,
size_t q_nope_in_stride_n, size_t q_nope_in_stride_h, size_t q_rope_out_stride_n,
size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h,
size_t k_rope_in_stride, size_t k_nope_in_stride, float quant_scale_q, float quant_scale_kv,
Expand Down Expand Up @@ -1237,9 +1238,9 @@ cudaError_t RopeQuantizeAppendPagedMLACache(
config.attrs = attribute;
config.numAttrs = 1;

auto kernel =
RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, /*bdx=*/1, DType, IdType,
QuantType, paged_kv_mla_t<QuantType, IdType>>;
auto kernel = RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, /*bdx=*/1, DType,
RoPEIdType, PagedKVIdType, QuantType,
paged_kv_mla_t<QuantType, PagedKVIdType>>;
DType* v_in_nullptr = nullptr;
uint32_t num_kv_heads_1 = 1;
size_t k_rope_in_stride_h_dup = k_rope_in_stride;
Expand Down
Loading