From 1f2db59cfb6b3aa01fc905242fa344b01f72988f Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Sun, 21 Dec 2025 17:45:05 -0800 Subject: [PATCH 1/3] fix pos_ids uint64 accuracy issue --- csrc/rope.cu | 134 +++++++++++++++++---------------- include/flashinfer/pos_enc.cuh | 53 ++++++------- 2 files changed, 97 insertions(+), 90 deletions(-) diff --git a/csrc/rope.cu b/csrc/rope.cu index 14008bcaea..6699675411 100644 --- a/csrc/rope.cu +++ b/csrc/rope.cu @@ -547,71 +547,77 @@ void rope_quantize_append_paged_kv_cache( DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in.dtype(), c_type, [&] { return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(q_rope_out.dtype(), c_quant_type, [&] { - cudaError_t status; - - if (is_mla) { - // MLA: Construct paged_kv_mla_t struct - auto ckv_strides = ckv_cache.strides(); - auto kpe_strides = kpe_cache.strides(); - - paged_kv_mla_t paged_kv_mla( - page_size, no_rope_dim, rope_dim, batch_size, - static_cast(ckv_cache.data_ptr()), ckv_strides.data(), - static_cast(kpe_cache.data_ptr()), kpe_strides.data(), - static_cast(kv_indices.data_ptr()), - static_cast(kv_indptr.data_ptr()), - nullptr // last_page_len not needed for this kernel - ); - - status = RopeQuantizeAppendPagedMLACache( - static_cast(q_rope_in.data_ptr()), static_cast(k_rope_in.data_ptr()), - static_cast(q_nope_in.data_ptr()), static_cast(k_nope_in.data_ptr()), - static_cast(q_rope_out.data_ptr()), - static_cast(q_nope_out.data_ptr()), paged_kv_mla, - static_cast(batch_indices.data_ptr()), - static_cast(positions.data_ptr()), - static_cast(cos_sin_cache.data_ptr()), - static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim, - q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, - q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, - k_rope_in_stride, k_nope_in_stride, quant_scale_q, quant_scale_kv, interleave, - enable_pdl, stream); - - } else { - // GQA/MHA: Construct paged_kv_t struct - auto k_strides = k_cache.strides(); - auto v_strides = v_cache.strides(); - uint32_t head_dim = rope_dim + no_rope_dim; - - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, kv_layout, - static_cast(k_cache.data_ptr()), - static_cast(v_cache.data_ptr()), k_strides.data(), - static_cast(kv_indices.data_ptr()), - static_cast(kv_indptr.data_ptr()), - nullptr // last_page_len not needed for this kernel - ); - - status = RopeQuantizeAppendPagedKVCache( - static_cast(q_rope_in.data_ptr()), static_cast(k_rope_in.data_ptr()), - static_cast(q_nope_in.data_ptr()), static_cast(k_nope_in.data_ptr()), - static_cast(v_in.data_ptr()), - static_cast(q_rope_out.data_ptr()), - static_cast(q_nope_out.data_ptr()), paged_kv, - static_cast(batch_indices.data_ptr()), - static_cast(positions.data_ptr()), - static_cast(cos_sin_cache.data_ptr()), - static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim, - no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, - q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, - q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride, - k_nope_in_stride_h, v_in_stride, v_in_stride_h, quant_scale_q, quant_scale_kv, - interleave, enable_pdl, stream); - } + return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids.dtype(), c_idtype, [&] { + cudaError_t status; + + if (is_mla) { + // MLA: Construct paged_kv_mla_t struct + auto ckv_strides = ckv_cache.strides(); + auto kpe_strides = kpe_cache.strides(); + + paged_kv_mla_t paged_kv_mla( + page_size, no_rope_dim, rope_dim, batch_size, + static_cast(ckv_cache.data_ptr()), ckv_strides.data(), + static_cast(kpe_cache.data_ptr()), kpe_strides.data(), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), + nullptr // last_page_len not needed for this kernel + ); + + status = RopeQuantizeAppendPagedMLACache( + static_cast(q_rope_in.data_ptr()), + static_cast(k_rope_in.data_ptr()), + static_cast(q_nope_in.data_ptr()), + static_cast(k_nope_in.data_ptr()), + static_cast(q_rope_out.data_ptr()), + static_cast(q_nope_out.data_ptr()), paged_kv_mla, + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim, + q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, + q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, + k_rope_in_stride, k_nope_in_stride, quant_scale_q, quant_scale_kv, interleave, + enable_pdl, stream); + + } else { + // GQA/MHA: Construct paged_kv_t struct + auto k_strides = k_cache.strides(); + auto v_strides = v_cache.strides(); + uint32_t head_dim = rope_dim + no_rope_dim; + + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, kv_layout, + static_cast(k_cache.data_ptr()), + static_cast(v_cache.data_ptr()), k_strides.data(), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), + nullptr // last_page_len not needed for this kernel + ); + + status = RopeQuantizeAppendPagedKVCache( + static_cast(q_rope_in.data_ptr()), + static_cast(k_rope_in.data_ptr()), + static_cast(q_nope_in.data_ptr()), + static_cast(k_nope_in.data_ptr()), static_cast(v_in.data_ptr()), + static_cast(q_rope_out.data_ptr()), + static_cast(q_nope_out.data_ptr()), paged_kv, + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim, + no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, + q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, + q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride, + k_nope_in_stride_h, v_in_stride, v_in_stride_h, quant_scale_q, quant_scale_kv, + interleave, enable_pdl, stream); + } - TVM_FFI_ICHECK(status == cudaSuccess) - << "RopeQuantizeAppendPagedKVCache failed with error code " << cudaGetErrorString(status); - return true; + TVM_FFI_ICHECK(status == cudaSuccess) + << "RopeQuantizeAppendPagedKVCache failed with error code " + << cudaGetErrorString(status); + return true; + }); }); }); } diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 60e813e921..4fdd75e0a3 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -803,13 +803,13 @@ __global__ void BatchQKApplyRotaryKernel( * Templated on CacheT to support both GQA/MHA (paged_kv_t) and MLA (paged_kv_mla_t). * Cache-only behaviors are selected with constexpr on the CacheT. */ -template +template __global__ void RopeQuantizeAppendPagedKVCacheKernel( DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in, QuantType* q_rope_out, QuantType* q_nope_out, CacheT paged_kv_like, - IdType* __restrict__ batch_indices, IdType* __restrict__ positions, - float* __restrict__ cos_sin_cache, IdType* __restrict__ pos_ids, + PagedKVIdType* __restrict__ batch_indices, PagedKVIdType* __restrict__ positions, + float* __restrict__ cos_sin_cache, RoPEIdType* __restrict__ pos_ids, const RopeQuantizeAppendPagedKVCacheParams params) { #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); @@ -852,12 +852,12 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( uint32_t k_nope_end = k_rope_end + num_kv_heads * no_rope_chunks; // Deduce MLA vs GQA/MHA from CacheT - constexpr bool IS_MLA = std::is_same>::value; + constexpr bool IS_MLA = std::is_same>::value; vec_t cos, sin; if (bx * bdy + ty < nnz) { const uint32_t idx = bx * bdy + ty; - const IdType pos = pos_ids[idx]; + const RoPEIdType pos = pos_ids[idx]; // Compute page location for this token uint32_t page_iter, entry_idx; @@ -1123,18 +1123,18 @@ cudaError_t RopeQuantize( /*! * \brief Host function to apply RoPE, quantize to FP8, and append K/V to paged cache (GQA/MHA) */ -template +template cudaError_t RopeQuantizeAppendPagedKVCache( DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in, - QuantType* q_rope_out, QuantType* q_nope_out, paged_kv_t paged_kv, - IdType* batch_indices, IdType* positions, float* cos_sin_cache, IdType* pos_ids, uint32_t nnz, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rope_dim, uint32_t no_rope_dim, - size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, size_t q_nope_in_stride_n, - size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, size_t q_rope_out_stride_h, - size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, size_t k_rope_in_stride, - size_t k_rope_in_stride_h, size_t k_nope_in_stride, size_t k_nope_in_stride_h, - size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q, float quant_scale_kv, - bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { + QuantType* q_rope_out, QuantType* q_nope_out, paged_kv_t paged_kv, + PagedKVIdType* batch_indices, PagedKVIdType* positions, float* cos_sin_cache, + RoPEIdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t rope_dim, uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, + size_t q_nope_in_stride_n, size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, + size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, + size_t k_rope_in_stride, size_t k_rope_in_stride_h, size_t k_nope_in_stride, + size_t k_nope_in_stride_h, size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q, + float quant_scale_kv, bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { constexpr uint32_t vec_size = 32 / sizeof(DType); uint32_t bdx = (rope_dim + vec_size - 1) / vec_size; @@ -1163,9 +1163,9 @@ cudaError_t RopeQuantizeAppendPagedKVCache( config.attrs = attribute; config.numAttrs = 1; - auto kernel = - RopeQuantizeAppendPagedKVCacheKernel>; + auto kernel = RopeQuantizeAppendPagedKVCacheKernel>; RopeQuantizeAppendPagedKVCacheParams params; params.nnz = nnz; params.num_qo_heads = num_qo_heads; @@ -1200,12 +1200,13 @@ cudaError_t RopeQuantizeAppendPagedKVCache( /*! * \brief Host function to apply RoPE, quantize to FP8, and append to MLA paged cache */ -template +template cudaError_t RopeQuantizeAppendPagedMLACache( DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, QuantType* q_rope_out, - QuantType* q_nope_out, paged_kv_mla_t paged_kv_mla, IdType* batch_indices, - IdType* positions, float* cos_sin_cache, IdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, - uint32_t rope_dim, uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, + QuantType* q_nope_out, paged_kv_mla_t paged_kv_mla, + PagedKVIdType* batch_indices, PagedKVIdType* positions, float* cos_sin_cache, + RoPEIdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t rope_dim, + uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, size_t q_nope_in_stride_n, size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, size_t k_rope_in_stride, size_t k_nope_in_stride, float quant_scale_q, float quant_scale_kv, @@ -1237,9 +1238,9 @@ cudaError_t RopeQuantizeAppendPagedMLACache( config.attrs = attribute; config.numAttrs = 1; - auto kernel = - RopeQuantizeAppendPagedKVCacheKernel>; + auto kernel = RopeQuantizeAppendPagedKVCacheKernel>; DType* v_in_nullptr = nullptr; uint32_t num_kv_heads_1 = 1; size_t k_rope_in_stride_h_dup = k_rope_in_stride; From 0fbf196944daa045d0c0c70c8604bfbbd93d83e0 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 22 Dec 2025 02:06:46 -0500 Subject: [PATCH 2/3] update ut --- csrc/rope.cu | 38 ++++++++++++++++++------- tests/attention/test_rope.py | 54 +++++++++++++++++++++--------------- 2 files changed, 59 insertions(+), 33 deletions(-) diff --git a/csrc/rope.cu b/csrc/rope.cu index 6699675411..f68fb16fcf 100644 --- a/csrc/rope.cu +++ b/csrc/rope.cu @@ -450,6 +450,24 @@ void rope_quantize_append_paged_kv_cache( CHECK_INPUT(batch_indices); CHECK_INPUT(positions); + // Validate that all index tensors have the same dtype as pos_ids + if (kv_indices.dtype() != pos_ids.dtype()) { + TVM_FFI_LOG_AND_THROW(TypeError) << "kv_indices dtype (" << kv_indices.dtype() + << ") must match pos_ids dtype (" << pos_ids.dtype() << ")"; + } + if (kv_indptr.dtype() != pos_ids.dtype()) { + TVM_FFI_LOG_AND_THROW(TypeError) << "kv_indptr dtype (" << kv_indptr.dtype() + << ") must match pos_ids dtype (" << pos_ids.dtype() << ")"; + } + if (batch_indices.dtype() != pos_ids.dtype()) { + TVM_FFI_LOG_AND_THROW(TypeError) << "batch_indices dtype (" << batch_indices.dtype() + << ") must match pos_ids dtype (" << pos_ids.dtype() << ")"; + } + if (positions.dtype() != pos_ids.dtype()) { + TVM_FFI_LOG_AND_THROW(TypeError) << "positions dtype (" << positions.dtype() + << ") must match pos_ids dtype (" << pos_ids.dtype() << ")"; + } + // Extract dimensions uint32_t rope_dim = q_rope_in.size(-1); uint32_t no_rope_dim = q_nope_in.size(-1); @@ -555,12 +573,12 @@ void rope_quantize_append_paged_kv_cache( auto ckv_strides = ckv_cache.strides(); auto kpe_strides = kpe_cache.strides(); - paged_kv_mla_t paged_kv_mla( + paged_kv_mla_t paged_kv_mla( page_size, no_rope_dim, rope_dim, batch_size, static_cast(ckv_cache.data_ptr()), ckv_strides.data(), static_cast(kpe_cache.data_ptr()), kpe_strides.data(), - static_cast(kv_indices.data_ptr()), - static_cast(kv_indptr.data_ptr()), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), nullptr // last_page_len not needed for this kernel ); @@ -571,8 +589,8 @@ void rope_quantize_append_paged_kv_cache( static_cast(k_nope_in.data_ptr()), static_cast(q_rope_out.data_ptr()), static_cast(q_nope_out.data_ptr()), paged_kv_mla, - static_cast(batch_indices.data_ptr()), - static_cast(positions.data_ptr()), + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), static_cast(cos_sin_cache.data_ptr()), static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, @@ -586,12 +604,12 @@ void rope_quantize_append_paged_kv_cache( auto v_strides = v_cache.strides(); uint32_t head_dim = rope_dim + no_rope_dim; - paged_kv_t paged_kv( + paged_kv_t paged_kv( num_kv_heads, page_size, head_dim, batch_size, kv_layout, static_cast(k_cache.data_ptr()), static_cast(v_cache.data_ptr()), k_strides.data(), - static_cast(kv_indices.data_ptr()), - static_cast(kv_indptr.data_ptr()), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), nullptr // last_page_len not needed for this kernel ); @@ -602,8 +620,8 @@ void rope_quantize_append_paged_kv_cache( static_cast(k_nope_in.data_ptr()), static_cast(v_in.data_ptr()), static_cast(q_rope_out.data_ptr()), static_cast(q_nope_out.data_ptr()), paged_kv, - static_cast(batch_indices.data_ptr()), - static_cast(positions.data_ptr()), + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), static_cast(cos_sin_cache.data_ptr()), static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim, no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, diff --git a/tests/attention/test_rope.py b/tests/attention/test_rope.py index 651f43d5e9..ab5c8f2227 100644 --- a/tests/attention/test_rope.py +++ b/tests/attention/test_rope.py @@ -516,6 +516,7 @@ def test_generalized_rope_quantize( @pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) @pytest.mark.parametrize("page_size", [16, 32]) +@pytest.mark.parametrize("idtype", [torch.int32, torch.int64]) def test_generalized_rope_quantize_append_kv_cache( attention_type, num_qo_heads, @@ -528,6 +529,7 @@ def test_generalized_rope_quantize_append_kv_cache( enable_pdl, kv_layout, page_size, + idtype, ): device = "cuda:0" # Fixed seed for reproducibility @@ -589,36 +591,36 @@ def test_generalized_rope_quantize_append_kv_cache( rope_ref = FlashInferRotaryEmbedding( head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device ) - pos_ids = torch.arange(num_tokens, device=device, dtype=torch.int32) + pos_ids = torch.arange(num_tokens, device=device, dtype=idtype) # Build paged metadata kv_append_length = torch.tensor( - [num_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device + [num_tokens] + [0] * (batch_size - 1), dtype=idtype, device=device ) kv_append_indptr = torch.cat( [ - torch.zeros(1, dtype=torch.int32, device=device), - torch.cumsum(kv_append_length, dim=0), + torch.zeros(1, dtype=idtype, device=device), + torch.cumsum(kv_append_length, dim=0).to(idtype), ] ) num_pages_per_req = torch.tensor( [(num_tokens + page_size - 1) // page_size] + [0] * (batch_size - 1), - dtype=torch.int32, + dtype=idtype, device=device, ) kv_page_indptr = torch.cat( [ - torch.zeros(1, dtype=torch.int32, device=device), - torch.cumsum(num_pages_per_req, dim=0), + torch.zeros(1, dtype=idtype, device=device), + torch.cumsum(num_pages_per_req, dim=0).to(idtype), ] ) kv_page_indices = torch.arange( - kv_page_indptr[-1].item(), dtype=torch.int32, device=device + kv_page_indptr[-1].item(), dtype=idtype, device=device ) kv_last_page_len = torch.tensor( [num_tokens % page_size if num_tokens % page_size != 0 else page_size] + [0] * (batch_size - 1), - dtype=torch.int32, + dtype=idtype, device=device, ) # Allocate caches sized by required pages @@ -629,6 +631,9 @@ def test_generalized_rope_quantize_append_kv_cache( batch_indices, positions = flashinfer.get_batch_indices_positions( kv_append_indptr, seq_lens, num_tokens ) + # Convert to idtype to match other index tensors + batch_indices = batch_indices.to(idtype) + positions = positions.to(idtype) # Fused call + cache allocation if attention_type == "mla": @@ -833,6 +838,7 @@ def test_generalized_rope_quantize_append_kv_cache( @pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) @pytest.mark.parametrize("page_size", [16, 32]) +@pytest.mark.parametrize("idtype", [torch.int32, torch.int64]) def test_rope_quantize_fp8_append_paged_kv_cache_decode( attention_type, num_qo_heads, @@ -846,6 +852,7 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( enable_pdl, kv_layout, page_size, + idtype, ): """Test append to non-empty cache (decode/continuation scenario).""" device = "cuda:0" @@ -937,28 +944,26 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( rope_ref = FlashInferRotaryEmbedding( head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device ) - pos_ids_existing = torch.arange( - num_existing_tokens, device=device, dtype=torch.int32 - ) + pos_ids_existing = torch.arange(num_existing_tokens, device=device, dtype=idtype) # Build metadata for existing tokens (single request for simplicity) kv_append_length_existing = torch.tensor( - [num_existing_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device + [num_existing_tokens] + [0] * (batch_size - 1), dtype=idtype, device=device ) kv_append_indptr_existing = torch.cat( [ - torch.zeros(1, dtype=torch.int32, device=device), - torch.cumsum(kv_append_length_existing, dim=0), + torch.zeros(1, dtype=idtype, device=device), + torch.cumsum(kv_append_length_existing, dim=0).to(idtype), ] ) num_pages_existing = (num_existing_tokens + page_size - 1) // page_size kv_page_indptr_existing = torch.tensor( [0, num_pages_existing] + [num_pages_existing] * (batch_size - 1), - dtype=torch.int32, + dtype=idtype, device=device, ) kv_page_indices_existing = torch.arange( - num_pages_existing, dtype=torch.int32, device=device + num_pages_existing, dtype=idtype, device=device ) kv_last_page_len_existing = torch.tensor( [ @@ -967,7 +972,7 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( else page_size ] + [0] * (batch_size - 1), - dtype=torch.int32, + dtype=idtype, device=device, ) seq_lens_existing = flashinfer.get_seq_lens( @@ -976,6 +981,9 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( batch_indices_existing, positions_existing = flashinfer.get_batch_indices_positions( kv_append_indptr_existing, seq_lens_existing, num_existing_tokens ) + # Convert to idtype to match other index tensors + batch_indices_existing = batch_indices_existing.to(idtype) + positions_existing = positions_existing.to(idtype) # Allocate cache sized for existing + new tokens total_tokens = num_existing_tokens + num_new_tokens @@ -1131,26 +1139,26 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( num_existing_tokens, num_existing_tokens + num_new_tokens, device=device, - dtype=torch.int32, + dtype=idtype, ) # Build metadata for new tokens (continue appending to first request) num_pages_new_needed = (total_tokens + page_size - 1) // page_size kv_page_indptr_new = torch.tensor( [0, num_pages_new_needed] + [num_pages_new_needed] * (batch_size - 1), - dtype=torch.int32, + dtype=idtype, device=device, ) kv_page_indices_new = torch.arange( - num_pages_new_needed, dtype=torch.int32, device=device + num_pages_new_needed, dtype=idtype, device=device ) # For continuation, positions start at num_existing_tokens - batch_indices_new = torch.zeros(num_new_tokens, device=device, dtype=torch.int32) + batch_indices_new = torch.zeros(num_new_tokens, device=device, dtype=idtype) positions_new = torch.arange( num_existing_tokens, num_existing_tokens + num_new_tokens, device=device, - dtype=torch.int32, + dtype=idtype, ) # Snapshot existing cache for later comparison From 0903058ce72d91f0696a60b6e9fa379d99b2f553 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Mon, 22 Dec 2025 20:26:41 -0800 Subject: [PATCH 3/3] only int64 support apply on RoPE argument --- csrc/rope.cu | 38 +++++++---------------- flashinfer/rope.py | 5 ++++ tests/attention/test_rope.py | 58 +++++++++++++++++------------------- 3 files changed, 42 insertions(+), 59 deletions(-) diff --git a/csrc/rope.cu b/csrc/rope.cu index f68fb16fcf..6699675411 100644 --- a/csrc/rope.cu +++ b/csrc/rope.cu @@ -450,24 +450,6 @@ void rope_quantize_append_paged_kv_cache( CHECK_INPUT(batch_indices); CHECK_INPUT(positions); - // Validate that all index tensors have the same dtype as pos_ids - if (kv_indices.dtype() != pos_ids.dtype()) { - TVM_FFI_LOG_AND_THROW(TypeError) << "kv_indices dtype (" << kv_indices.dtype() - << ") must match pos_ids dtype (" << pos_ids.dtype() << ")"; - } - if (kv_indptr.dtype() != pos_ids.dtype()) { - TVM_FFI_LOG_AND_THROW(TypeError) << "kv_indptr dtype (" << kv_indptr.dtype() - << ") must match pos_ids dtype (" << pos_ids.dtype() << ")"; - } - if (batch_indices.dtype() != pos_ids.dtype()) { - TVM_FFI_LOG_AND_THROW(TypeError) << "batch_indices dtype (" << batch_indices.dtype() - << ") must match pos_ids dtype (" << pos_ids.dtype() << ")"; - } - if (positions.dtype() != pos_ids.dtype()) { - TVM_FFI_LOG_AND_THROW(TypeError) << "positions dtype (" << positions.dtype() - << ") must match pos_ids dtype (" << pos_ids.dtype() << ")"; - } - // Extract dimensions uint32_t rope_dim = q_rope_in.size(-1); uint32_t no_rope_dim = q_nope_in.size(-1); @@ -573,12 +555,12 @@ void rope_quantize_append_paged_kv_cache( auto ckv_strides = ckv_cache.strides(); auto kpe_strides = kpe_cache.strides(); - paged_kv_mla_t paged_kv_mla( + paged_kv_mla_t paged_kv_mla( page_size, no_rope_dim, rope_dim, batch_size, static_cast(ckv_cache.data_ptr()), ckv_strides.data(), static_cast(kpe_cache.data_ptr()), kpe_strides.data(), - static_cast(kv_indices.data_ptr()), - static_cast(kv_indptr.data_ptr()), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), nullptr // last_page_len not needed for this kernel ); @@ -589,8 +571,8 @@ void rope_quantize_append_paged_kv_cache( static_cast(k_nope_in.data_ptr()), static_cast(q_rope_out.data_ptr()), static_cast(q_nope_out.data_ptr()), paged_kv_mla, - static_cast(batch_indices.data_ptr()), - static_cast(positions.data_ptr()), + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), static_cast(cos_sin_cache.data_ptr()), static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, @@ -604,12 +586,12 @@ void rope_quantize_append_paged_kv_cache( auto v_strides = v_cache.strides(); uint32_t head_dim = rope_dim + no_rope_dim; - paged_kv_t paged_kv( + paged_kv_t paged_kv( num_kv_heads, page_size, head_dim, batch_size, kv_layout, static_cast(k_cache.data_ptr()), static_cast(v_cache.data_ptr()), k_strides.data(), - static_cast(kv_indices.data_ptr()), - static_cast(kv_indptr.data_ptr()), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), nullptr // last_page_len not needed for this kernel ); @@ -620,8 +602,8 @@ void rope_quantize_append_paged_kv_cache( static_cast(k_nope_in.data_ptr()), static_cast(v_in.data_ptr()), static_cast(q_rope_out.data_ptr()), static_cast(q_nope_out.data_ptr()), paged_kv, - static_cast(batch_indices.data_ptr()), - static_cast(positions.data_ptr()), + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), static_cast(cos_sin_cache.data_ptr()), static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim, no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 1d069e3189..d39d2e07e6 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -1641,6 +1641,11 @@ def rope_quantize_fp8_append_paged_kv_cache( kv_layout_code = TensorLayout[kv_layout].value + batch_indices = batch_indices.int() + positions = positions.int() + kv_indices = kv_indices.int() + kv_indptr = kv_indptr.int() + # Call custom op _rope_quantize_fp8_append_paged_kv_cache( q_rope, diff --git a/tests/attention/test_rope.py b/tests/attention/test_rope.py index ab5c8f2227..75570223ae 100644 --- a/tests/attention/test_rope.py +++ b/tests/attention/test_rope.py @@ -516,7 +516,7 @@ def test_generalized_rope_quantize( @pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) @pytest.mark.parametrize("page_size", [16, 32]) -@pytest.mark.parametrize("idtype", [torch.int32, torch.int64]) +@pytest.mark.parametrize("rope_idtype", [torch.int32, torch.int64]) def test_generalized_rope_quantize_append_kv_cache( attention_type, num_qo_heads, @@ -529,7 +529,7 @@ def test_generalized_rope_quantize_append_kv_cache( enable_pdl, kv_layout, page_size, - idtype, + rope_idtype, ): device = "cuda:0" # Fixed seed for reproducibility @@ -591,36 +591,36 @@ def test_generalized_rope_quantize_append_kv_cache( rope_ref = FlashInferRotaryEmbedding( head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device ) - pos_ids = torch.arange(num_tokens, device=device, dtype=idtype) + pos_ids = torch.arange(num_tokens, device=device, dtype=rope_idtype) # Build paged metadata kv_append_length = torch.tensor( - [num_tokens] + [0] * (batch_size - 1), dtype=idtype, device=device + [num_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device ) kv_append_indptr = torch.cat( [ - torch.zeros(1, dtype=idtype, device=device), - torch.cumsum(kv_append_length, dim=0).to(idtype), + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(kv_append_length, dim=0), ] ) num_pages_per_req = torch.tensor( [(num_tokens + page_size - 1) // page_size] + [0] * (batch_size - 1), - dtype=idtype, + dtype=torch.int32, device=device, ) kv_page_indptr = torch.cat( [ - torch.zeros(1, dtype=idtype, device=device), - torch.cumsum(num_pages_per_req, dim=0).to(idtype), + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(num_pages_per_req, dim=0), ] ) kv_page_indices = torch.arange( - kv_page_indptr[-1].item(), dtype=idtype, device=device + kv_page_indptr[-1].item(), dtype=torch.int32, device=device ) kv_last_page_len = torch.tensor( [num_tokens % page_size if num_tokens % page_size != 0 else page_size] + [0] * (batch_size - 1), - dtype=idtype, + dtype=torch.int32, device=device, ) # Allocate caches sized by required pages @@ -631,9 +631,6 @@ def test_generalized_rope_quantize_append_kv_cache( batch_indices, positions = flashinfer.get_batch_indices_positions( kv_append_indptr, seq_lens, num_tokens ) - # Convert to idtype to match other index tensors - batch_indices = batch_indices.to(idtype) - positions = positions.to(idtype) # Fused call + cache allocation if attention_type == "mla": @@ -838,7 +835,7 @@ def test_generalized_rope_quantize_append_kv_cache( @pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) @pytest.mark.parametrize("page_size", [16, 32]) -@pytest.mark.parametrize("idtype", [torch.int32, torch.int64]) +@pytest.mark.parametrize("rope_idtype", [torch.int32, torch.int64]) def test_rope_quantize_fp8_append_paged_kv_cache_decode( attention_type, num_qo_heads, @@ -852,7 +849,7 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( enable_pdl, kv_layout, page_size, - idtype, + rope_idtype, ): """Test append to non-empty cache (decode/continuation scenario).""" device = "cuda:0" @@ -944,26 +941,28 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( rope_ref = FlashInferRotaryEmbedding( head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device ) - pos_ids_existing = torch.arange(num_existing_tokens, device=device, dtype=idtype) + pos_ids_existing = torch.arange( + num_existing_tokens, device=device, dtype=rope_idtype + ) # Build metadata for existing tokens (single request for simplicity) kv_append_length_existing = torch.tensor( - [num_existing_tokens] + [0] * (batch_size - 1), dtype=idtype, device=device + [num_existing_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device ) kv_append_indptr_existing = torch.cat( [ - torch.zeros(1, dtype=idtype, device=device), - torch.cumsum(kv_append_length_existing, dim=0).to(idtype), + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(kv_append_length_existing, dim=0), ] ) num_pages_existing = (num_existing_tokens + page_size - 1) // page_size kv_page_indptr_existing = torch.tensor( [0, num_pages_existing] + [num_pages_existing] * (batch_size - 1), - dtype=idtype, + dtype=torch.int32, device=device, ) kv_page_indices_existing = torch.arange( - num_pages_existing, dtype=idtype, device=device + num_pages_existing, dtype=torch.int32, device=device ) kv_last_page_len_existing = torch.tensor( [ @@ -972,7 +971,7 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( else page_size ] + [0] * (batch_size - 1), - dtype=idtype, + dtype=torch.int32, device=device, ) seq_lens_existing = flashinfer.get_seq_lens( @@ -981,9 +980,6 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( batch_indices_existing, positions_existing = flashinfer.get_batch_indices_positions( kv_append_indptr_existing, seq_lens_existing, num_existing_tokens ) - # Convert to idtype to match other index tensors - batch_indices_existing = batch_indices_existing.to(idtype) - positions_existing = positions_existing.to(idtype) # Allocate cache sized for existing + new tokens total_tokens = num_existing_tokens + num_new_tokens @@ -1139,26 +1135,26 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( num_existing_tokens, num_existing_tokens + num_new_tokens, device=device, - dtype=idtype, + dtype=rope_idtype, ) # Build metadata for new tokens (continue appending to first request) num_pages_new_needed = (total_tokens + page_size - 1) // page_size kv_page_indptr_new = torch.tensor( [0, num_pages_new_needed] + [num_pages_new_needed] * (batch_size - 1), - dtype=idtype, + dtype=torch.int32, device=device, ) kv_page_indices_new = torch.arange( - num_pages_new_needed, dtype=idtype, device=device + num_pages_new_needed, dtype=torch.int32, device=device ) # For continuation, positions start at num_existing_tokens - batch_indices_new = torch.zeros(num_new_tokens, device=device, dtype=idtype) + batch_indices_new = torch.zeros(num_new_tokens, device=device, dtype=torch.int32) positions_new = torch.arange( num_existing_tokens, num_existing_tokens + num_new_tokens, device=device, - dtype=idtype, + dtype=torch.int32, ) # Snapshot existing cache for later comparison