Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 70 additions & 64 deletions csrc/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -547,71 +547,77 @@ void rope_quantize_append_paged_kv_cache(

DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in.dtype(), c_type, [&] {
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(q_rope_out.dtype(), c_quant_type, [&] {
cudaError_t status;

if (is_mla) {
// MLA: Construct paged_kv_mla_t struct
auto ckv_strides = ckv_cache.strides();
auto kpe_strides = kpe_cache.strides();

paged_kv_mla_t<c_quant_type, int32_t> paged_kv_mla(
page_size, no_rope_dim, rope_dim, batch_size,
static_cast<c_quant_type*>(ckv_cache.data_ptr()), ckv_strides.data(),
static_cast<c_quant_type*>(kpe_cache.data_ptr()), kpe_strides.data(),
static_cast<int32_t*>(kv_indices.data_ptr()),
static_cast<int32_t*>(kv_indptr.data_ptr()),
nullptr // last_page_len not needed for this kernel
);

status = RopeQuantizeAppendPagedMLACache(
static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()),
static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()),
static_cast<c_quant_type*>(q_rope_out.data_ptr()),
static_cast<c_quant_type*>(q_nope_out.data_ptr()), paged_kv_mla,
static_cast<int32_t*>(batch_indices.data_ptr()),
static_cast<int32_t*>(positions.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int32_t*>(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim,
q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h,
q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h,
k_rope_in_stride, k_nope_in_stride, quant_scale_q, quant_scale_kv, interleave,
enable_pdl, stream);

} else {
// GQA/MHA: Construct paged_kv_t struct
auto k_strides = k_cache.strides();
auto v_strides = v_cache.strides();
uint32_t head_dim = rope_dim + no_rope_dim;

paged_kv_t<c_quant_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout,
static_cast<c_quant_type*>(k_cache.data_ptr()),
static_cast<c_quant_type*>(v_cache.data_ptr()), k_strides.data(),
static_cast<int32_t*>(kv_indices.data_ptr()),
static_cast<int32_t*>(kv_indptr.data_ptr()),
nullptr // last_page_len not needed for this kernel
);

status = RopeQuantizeAppendPagedKVCache(
static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()),
static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()),
static_cast<c_type*>(v_in.data_ptr()),
static_cast<c_quant_type*>(q_rope_out.data_ptr()),
static_cast<c_quant_type*>(q_nope_out.data_ptr()), paged_kv,
static_cast<int32_t*>(batch_indices.data_ptr()),
static_cast<int32_t*>(positions.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int32_t*>(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim,
no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n,
q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n,
q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride,
k_nope_in_stride_h, v_in_stride, v_in_stride_h, quant_scale_q, quant_scale_kv,
interleave, enable_pdl, stream);
}
return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids.dtype(), c_idtype, [&] {
cudaError_t status;

if (is_mla) {
// MLA: Construct paged_kv_mla_t struct
auto ckv_strides = ckv_cache.strides();
auto kpe_strides = kpe_cache.strides();

paged_kv_mla_t<c_quant_type, int32_t> paged_kv_mla(
page_size, no_rope_dim, rope_dim, batch_size,
static_cast<c_quant_type*>(ckv_cache.data_ptr()), ckv_strides.data(),
static_cast<c_quant_type*>(kpe_cache.data_ptr()), kpe_strides.data(),
static_cast<int32_t*>(kv_indices.data_ptr()),
static_cast<int32_t*>(kv_indptr.data_ptr()),
nullptr // last_page_len not needed for this kernel
);

status = RopeQuantizeAppendPagedMLACache(
static_cast<c_type*>(q_rope_in.data_ptr()),
static_cast<c_type*>(k_rope_in.data_ptr()),
static_cast<c_type*>(q_nope_in.data_ptr()),
static_cast<c_type*>(k_nope_in.data_ptr()),
static_cast<c_quant_type*>(q_rope_out.data_ptr()),
static_cast<c_quant_type*>(q_nope_out.data_ptr()), paged_kv_mla,
static_cast<int32_t*>(batch_indices.data_ptr()),
static_cast<int32_t*>(positions.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<c_idtype*>(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim,
q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h,
q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h,
k_rope_in_stride, k_nope_in_stride, quant_scale_q, quant_scale_kv, interleave,
enable_pdl, stream);

} else {
// GQA/MHA: Construct paged_kv_t struct
auto k_strides = k_cache.strides();
auto v_strides = v_cache.strides();
uint32_t head_dim = rope_dim + no_rope_dim;

paged_kv_t<c_quant_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout,
static_cast<c_quant_type*>(k_cache.data_ptr()),
static_cast<c_quant_type*>(v_cache.data_ptr()), k_strides.data(),
static_cast<int32_t*>(kv_indices.data_ptr()),
static_cast<int32_t*>(kv_indptr.data_ptr()),
nullptr // last_page_len not needed for this kernel
);

status = RopeQuantizeAppendPagedKVCache(
static_cast<c_type*>(q_rope_in.data_ptr()),
static_cast<c_type*>(k_rope_in.data_ptr()),
static_cast<c_type*>(q_nope_in.data_ptr()),
static_cast<c_type*>(k_nope_in.data_ptr()), static_cast<c_type*>(v_in.data_ptr()),
static_cast<c_quant_type*>(q_rope_out.data_ptr()),
static_cast<c_quant_type*>(q_nope_out.data_ptr()), paged_kv,
static_cast<int32_t*>(batch_indices.data_ptr()),
static_cast<int32_t*>(positions.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<c_idtype*>(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim,
no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n,
q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n,
q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride,
k_nope_in_stride_h, v_in_stride, v_in_stride_h, quant_scale_q, quant_scale_kv,
interleave, enable_pdl, stream);
}

TVM_FFI_ICHECK(status == cudaSuccess)
<< "RopeQuantizeAppendPagedKVCache failed with error code " << cudaGetErrorString(status);
return true;
TVM_FFI_ICHECK(status == cudaSuccess)
<< "RopeQuantizeAppendPagedKVCache failed with error code "
<< cudaGetErrorString(status);
return true;
});
});
});
}
5 changes: 5 additions & 0 deletions flashinfer/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,6 +1641,11 @@ def rope_quantize_fp8_append_paged_kv_cache(

kv_layout_code = TensorLayout[kv_layout].value

batch_indices = batch_indices.int()
positions = positions.int()
kv_indices = kv_indices.int()
kv_indptr = kv_indptr.int()

# Call custom op
_rope_quantize_fp8_append_paged_kv_cache(
q_rope,
Expand Down
53 changes: 27 additions & 26 deletions include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -803,13 +803,13 @@ __global__ void BatchQKApplyRotaryKernel(
* Templated on CacheT to support both GQA/MHA (paged_kv_t) and MLA (paged_kv_mla_t).
* Cache-only behaviors are selected with constexpr on the CacheT.
*/
template <bool interleave, uint32_t vec_size, uint32_t bdx, typename DType, typename IdType,
typename QuantType, typename CacheT>
template <bool interleave, uint32_t vec_size, uint32_t bdx, typename DType, typename RoPEIdType,
typename PagedKVIdType, typename QuantType, typename CacheT>
__global__ void RopeQuantizeAppendPagedKVCacheKernel(
DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in,
QuantType* q_rope_out, QuantType* q_nope_out, CacheT paged_kv_like,
IdType* __restrict__ batch_indices, IdType* __restrict__ positions,
float* __restrict__ cos_sin_cache, IdType* __restrict__ pos_ids,
PagedKVIdType* __restrict__ batch_indices, PagedKVIdType* __restrict__ positions,
float* __restrict__ cos_sin_cache, RoPEIdType* __restrict__ pos_ids,
const RopeQuantizeAppendPagedKVCacheParams params) {
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
Expand Down Expand Up @@ -852,12 +852,12 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel(
uint32_t k_nope_end = k_rope_end + num_kv_heads * no_rope_chunks;

// Deduce MLA vs GQA/MHA from CacheT
constexpr bool IS_MLA = std::is_same<CacheT, paged_kv_mla_t<QuantType, IdType>>::value;
constexpr bool IS_MLA = std::is_same<CacheT, paged_kv_mla_t<QuantType, PagedKVIdType>>::value;

vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) {
const uint32_t idx = bx * bdy + ty;
const IdType pos = pos_ids[idx];
const RoPEIdType pos = pos_ids[idx];

// Compute page location for this token
uint32_t page_iter, entry_idx;
Expand Down Expand Up @@ -1123,18 +1123,18 @@ cudaError_t RopeQuantize(
/*!
* \brief Host function to apply RoPE, quantize to FP8, and append K/V to paged cache (GQA/MHA)
*/
template <typename DType, typename IdType, typename QuantType>
template <typename DType, typename RoPEIdType, typename PagedKVIdType, typename QuantType>
cudaError_t RopeQuantizeAppendPagedKVCache(
DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in,
QuantType* q_rope_out, QuantType* q_nope_out, paged_kv_t<QuantType, IdType> paged_kv,
IdType* batch_indices, IdType* positions, float* cos_sin_cache, IdType* pos_ids, uint32_t nnz,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rope_dim, uint32_t no_rope_dim,
size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, size_t q_nope_in_stride_n,
size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, size_t q_rope_out_stride_h,
size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, size_t k_rope_in_stride,
size_t k_rope_in_stride_h, size_t k_nope_in_stride, size_t k_nope_in_stride_h,
size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q, float quant_scale_kv,
bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) {
QuantType* q_rope_out, QuantType* q_nope_out, paged_kv_t<QuantType, PagedKVIdType> paged_kv,
PagedKVIdType* batch_indices, PagedKVIdType* positions, float* cos_sin_cache,
RoPEIdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t rope_dim, uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h,
size_t q_nope_in_stride_n, size_t q_nope_in_stride_h, size_t q_rope_out_stride_n,
size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h,
size_t k_rope_in_stride, size_t k_rope_in_stride_h, size_t k_nope_in_stride,
size_t k_nope_in_stride_h, size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q,
float quant_scale_kv, bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) {
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
constexpr uint32_t vec_size = 32 / sizeof(DType);
uint32_t bdx = (rope_dim + vec_size - 1) / vec_size;
Expand Down Expand Up @@ -1163,9 +1163,9 @@ cudaError_t RopeQuantizeAppendPagedKVCache(
config.attrs = attribute;
config.numAttrs = 1;

auto kernel =
RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, /*bdx=*/1, DType, IdType,
QuantType, paged_kv_t<QuantType, IdType>>;
auto kernel = RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, /*bdx=*/1, DType,
RoPEIdType, PagedKVIdType, QuantType,
paged_kv_t<QuantType, PagedKVIdType>>;
RopeQuantizeAppendPagedKVCacheParams params;
params.nnz = nnz;
params.num_qo_heads = num_qo_heads;
Expand Down Expand Up @@ -1200,12 +1200,13 @@ cudaError_t RopeQuantizeAppendPagedKVCache(
/*!
* \brief Host function to apply RoPE, quantize to FP8, and append to MLA paged cache
*/
template <typename DType, typename IdType, typename QuantType>
template <typename DType, typename RoPEIdType, typename PagedKVIdType, typename QuantType>
cudaError_t RopeQuantizeAppendPagedMLACache(
DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, QuantType* q_rope_out,
QuantType* q_nope_out, paged_kv_mla_t<QuantType, IdType> paged_kv_mla, IdType* batch_indices,
IdType* positions, float* cos_sin_cache, IdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads,
uint32_t rope_dim, uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h,
QuantType* q_nope_out, paged_kv_mla_t<QuantType, PagedKVIdType> paged_kv_mla,
PagedKVIdType* batch_indices, PagedKVIdType* positions, float* cos_sin_cache,
RoPEIdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t rope_dim,
uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h,
size_t q_nope_in_stride_n, size_t q_nope_in_stride_h, size_t q_rope_out_stride_n,
size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h,
size_t k_rope_in_stride, size_t k_nope_in_stride, float quant_scale_q, float quant_scale_kv,
Expand Down Expand Up @@ -1237,9 +1238,9 @@ cudaError_t RopeQuantizeAppendPagedMLACache(
config.attrs = attribute;
config.numAttrs = 1;

auto kernel =
RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, /*bdx=*/1, DType, IdType,
QuantType, paged_kv_mla_t<QuantType, IdType>>;
auto kernel = RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, /*bdx=*/1, DType,
RoPEIdType, PagedKVIdType, QuantType,
paged_kv_mla_t<QuantType, PagedKVIdType>>;
DType* v_in_nullptr = nullptr;
uint32_t num_kv_heads_1 = 1;
size_t k_rope_in_stride_h_dup = k_rope_in_stride;
Expand Down
10 changes: 7 additions & 3 deletions tests/attention/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def test_generalized_rope_quantize(
@pytest.mark.parametrize("enable_pdl", [True, False])
@pytest.mark.parametrize("kv_layout", ["NHD", "HND"])
@pytest.mark.parametrize("page_size", [16, 32])
@pytest.mark.parametrize("rope_idtype", [torch.int32, torch.int64])
def test_generalized_rope_quantize_append_kv_cache(
attention_type,
num_qo_heads,
Expand All @@ -528,6 +529,7 @@ def test_generalized_rope_quantize_append_kv_cache(
enable_pdl,
kv_layout,
page_size,
rope_idtype,
):
device = "cuda:0"
# Fixed seed for reproducibility
Expand Down Expand Up @@ -589,7 +591,7 @@ def test_generalized_rope_quantize_append_kv_cache(
rope_ref = FlashInferRotaryEmbedding(
head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device
)
pos_ids = torch.arange(num_tokens, device=device, dtype=torch.int32)
pos_ids = torch.arange(num_tokens, device=device, dtype=rope_idtype)

# Build paged metadata
kv_append_length = torch.tensor(
Expand Down Expand Up @@ -833,6 +835,7 @@ def test_generalized_rope_quantize_append_kv_cache(
@pytest.mark.parametrize("enable_pdl", [True, False])
@pytest.mark.parametrize("kv_layout", ["NHD", "HND"])
@pytest.mark.parametrize("page_size", [16, 32])
@pytest.mark.parametrize("rope_idtype", [torch.int32, torch.int64])
def test_rope_quantize_fp8_append_paged_kv_cache_decode(
attention_type,
num_qo_heads,
Expand All @@ -846,6 +849,7 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode(
enable_pdl,
kv_layout,
page_size,
rope_idtype,
):
"""Test append to non-empty cache (decode/continuation scenario)."""
device = "cuda:0"
Expand Down Expand Up @@ -938,7 +942,7 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode(
head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device
)
pos_ids_existing = torch.arange(
num_existing_tokens, device=device, dtype=torch.int32
num_existing_tokens, device=device, dtype=rope_idtype
)

# Build metadata for existing tokens (single request for simplicity)
Expand Down Expand Up @@ -1131,7 +1135,7 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode(
num_existing_tokens,
num_existing_tokens + num_new_tokens,
device=device,
dtype=torch.int32,
dtype=rope_idtype,
)

# Build metadata for new tokens (continue appending to first request)
Expand Down