diff --git a/csrc/rope.cu b/csrc/rope.cu index 14008bcaea..6699675411 100644 --- a/csrc/rope.cu +++ b/csrc/rope.cu @@ -547,71 +547,77 @@ void rope_quantize_append_paged_kv_cache( DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in.dtype(), c_type, [&] { return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(q_rope_out.dtype(), c_quant_type, [&] { - cudaError_t status; - - if (is_mla) { - // MLA: Construct paged_kv_mla_t struct - auto ckv_strides = ckv_cache.strides(); - auto kpe_strides = kpe_cache.strides(); - - paged_kv_mla_t paged_kv_mla( - page_size, no_rope_dim, rope_dim, batch_size, - static_cast(ckv_cache.data_ptr()), ckv_strides.data(), - static_cast(kpe_cache.data_ptr()), kpe_strides.data(), - static_cast(kv_indices.data_ptr()), - static_cast(kv_indptr.data_ptr()), - nullptr // last_page_len not needed for this kernel - ); - - status = RopeQuantizeAppendPagedMLACache( - static_cast(q_rope_in.data_ptr()), static_cast(k_rope_in.data_ptr()), - static_cast(q_nope_in.data_ptr()), static_cast(k_nope_in.data_ptr()), - static_cast(q_rope_out.data_ptr()), - static_cast(q_nope_out.data_ptr()), paged_kv_mla, - static_cast(batch_indices.data_ptr()), - static_cast(positions.data_ptr()), - static_cast(cos_sin_cache.data_ptr()), - static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim, - q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, - q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, - k_rope_in_stride, k_nope_in_stride, quant_scale_q, quant_scale_kv, interleave, - enable_pdl, stream); - - } else { - // GQA/MHA: Construct paged_kv_t struct - auto k_strides = k_cache.strides(); - auto v_strides = v_cache.strides(); - uint32_t head_dim = rope_dim + no_rope_dim; - - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, kv_layout, - static_cast(k_cache.data_ptr()), - static_cast(v_cache.data_ptr()), k_strides.data(), - static_cast(kv_indices.data_ptr()), - static_cast(kv_indptr.data_ptr()), - nullptr // last_page_len not needed for this kernel - ); - - status = RopeQuantizeAppendPagedKVCache( - static_cast(q_rope_in.data_ptr()), static_cast(k_rope_in.data_ptr()), - static_cast(q_nope_in.data_ptr()), static_cast(k_nope_in.data_ptr()), - static_cast(v_in.data_ptr()), - static_cast(q_rope_out.data_ptr()), - static_cast(q_nope_out.data_ptr()), paged_kv, - static_cast(batch_indices.data_ptr()), - static_cast(positions.data_ptr()), - static_cast(cos_sin_cache.data_ptr()), - static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim, - no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, - q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, - q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride, - k_nope_in_stride_h, v_in_stride, v_in_stride_h, quant_scale_q, quant_scale_kv, - interleave, enable_pdl, stream); - } + return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids.dtype(), c_idtype, [&] { + cudaError_t status; + + if (is_mla) { + // MLA: Construct paged_kv_mla_t struct + auto ckv_strides = ckv_cache.strides(); + auto kpe_strides = kpe_cache.strides(); + + paged_kv_mla_t paged_kv_mla( + page_size, no_rope_dim, rope_dim, batch_size, + static_cast(ckv_cache.data_ptr()), ckv_strides.data(), + static_cast(kpe_cache.data_ptr()), kpe_strides.data(), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), + nullptr // last_page_len not needed for this kernel + ); + + status = RopeQuantizeAppendPagedMLACache( + static_cast(q_rope_in.data_ptr()), + static_cast(k_rope_in.data_ptr()), + static_cast(q_nope_in.data_ptr()), + static_cast(k_nope_in.data_ptr()), + static_cast(q_rope_out.data_ptr()), + static_cast(q_nope_out.data_ptr()), paged_kv_mla, + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim, + q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, + q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, + k_rope_in_stride, k_nope_in_stride, quant_scale_q, quant_scale_kv, interleave, + enable_pdl, stream); + + } else { + // GQA/MHA: Construct paged_kv_t struct + auto k_strides = k_cache.strides(); + auto v_strides = v_cache.strides(); + uint32_t head_dim = rope_dim + no_rope_dim; + + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, kv_layout, + static_cast(k_cache.data_ptr()), + static_cast(v_cache.data_ptr()), k_strides.data(), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), + nullptr // last_page_len not needed for this kernel + ); + + status = RopeQuantizeAppendPagedKVCache( + static_cast(q_rope_in.data_ptr()), + static_cast(k_rope_in.data_ptr()), + static_cast(q_nope_in.data_ptr()), + static_cast(k_nope_in.data_ptr()), static_cast(v_in.data_ptr()), + static_cast(q_rope_out.data_ptr()), + static_cast(q_nope_out.data_ptr()), paged_kv, + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim, + no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, + q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, + q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride, + k_nope_in_stride_h, v_in_stride, v_in_stride_h, quant_scale_q, quant_scale_kv, + interleave, enable_pdl, stream); + } - TVM_FFI_ICHECK(status == cudaSuccess) - << "RopeQuantizeAppendPagedKVCache failed with error code " << cudaGetErrorString(status); - return true; + TVM_FFI_ICHECK(status == cudaSuccess) + << "RopeQuantizeAppendPagedKVCache failed with error code " + << cudaGetErrorString(status); + return true; + }); }); }); } diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 1d069e3189..d39d2e07e6 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -1641,6 +1641,11 @@ def rope_quantize_fp8_append_paged_kv_cache( kv_layout_code = TensorLayout[kv_layout].value + batch_indices = batch_indices.int() + positions = positions.int() + kv_indices = kv_indices.int() + kv_indptr = kv_indptr.int() + # Call custom op _rope_quantize_fp8_append_paged_kv_cache( q_rope, diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 60e813e921..4fdd75e0a3 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -803,13 +803,13 @@ __global__ void BatchQKApplyRotaryKernel( * Templated on CacheT to support both GQA/MHA (paged_kv_t) and MLA (paged_kv_mla_t). * Cache-only behaviors are selected with constexpr on the CacheT. */ -template +template __global__ void RopeQuantizeAppendPagedKVCacheKernel( DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in, QuantType* q_rope_out, QuantType* q_nope_out, CacheT paged_kv_like, - IdType* __restrict__ batch_indices, IdType* __restrict__ positions, - float* __restrict__ cos_sin_cache, IdType* __restrict__ pos_ids, + PagedKVIdType* __restrict__ batch_indices, PagedKVIdType* __restrict__ positions, + float* __restrict__ cos_sin_cache, RoPEIdType* __restrict__ pos_ids, const RopeQuantizeAppendPagedKVCacheParams params) { #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); @@ -852,12 +852,12 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( uint32_t k_nope_end = k_rope_end + num_kv_heads * no_rope_chunks; // Deduce MLA vs GQA/MHA from CacheT - constexpr bool IS_MLA = std::is_same>::value; + constexpr bool IS_MLA = std::is_same>::value; vec_t cos, sin; if (bx * bdy + ty < nnz) { const uint32_t idx = bx * bdy + ty; - const IdType pos = pos_ids[idx]; + const RoPEIdType pos = pos_ids[idx]; // Compute page location for this token uint32_t page_iter, entry_idx; @@ -1123,18 +1123,18 @@ cudaError_t RopeQuantize( /*! * \brief Host function to apply RoPE, quantize to FP8, and append K/V to paged cache (GQA/MHA) */ -template +template cudaError_t RopeQuantizeAppendPagedKVCache( DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in, - QuantType* q_rope_out, QuantType* q_nope_out, paged_kv_t paged_kv, - IdType* batch_indices, IdType* positions, float* cos_sin_cache, IdType* pos_ids, uint32_t nnz, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rope_dim, uint32_t no_rope_dim, - size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, size_t q_nope_in_stride_n, - size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, size_t q_rope_out_stride_h, - size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, size_t k_rope_in_stride, - size_t k_rope_in_stride_h, size_t k_nope_in_stride, size_t k_nope_in_stride_h, - size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q, float quant_scale_kv, - bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { + QuantType* q_rope_out, QuantType* q_nope_out, paged_kv_t paged_kv, + PagedKVIdType* batch_indices, PagedKVIdType* positions, float* cos_sin_cache, + RoPEIdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t rope_dim, uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, + size_t q_nope_in_stride_n, size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, + size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, + size_t k_rope_in_stride, size_t k_rope_in_stride_h, size_t k_nope_in_stride, + size_t k_nope_in_stride_h, size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q, + float quant_scale_kv, bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { constexpr uint32_t vec_size = 32 / sizeof(DType); uint32_t bdx = (rope_dim + vec_size - 1) / vec_size; @@ -1163,9 +1163,9 @@ cudaError_t RopeQuantizeAppendPagedKVCache( config.attrs = attribute; config.numAttrs = 1; - auto kernel = - RopeQuantizeAppendPagedKVCacheKernel>; + auto kernel = RopeQuantizeAppendPagedKVCacheKernel>; RopeQuantizeAppendPagedKVCacheParams params; params.nnz = nnz; params.num_qo_heads = num_qo_heads; @@ -1200,12 +1200,13 @@ cudaError_t RopeQuantizeAppendPagedKVCache( /*! * \brief Host function to apply RoPE, quantize to FP8, and append to MLA paged cache */ -template +template cudaError_t RopeQuantizeAppendPagedMLACache( DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, QuantType* q_rope_out, - QuantType* q_nope_out, paged_kv_mla_t paged_kv_mla, IdType* batch_indices, - IdType* positions, float* cos_sin_cache, IdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, - uint32_t rope_dim, uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, + QuantType* q_nope_out, paged_kv_mla_t paged_kv_mla, + PagedKVIdType* batch_indices, PagedKVIdType* positions, float* cos_sin_cache, + RoPEIdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t rope_dim, + uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, size_t q_nope_in_stride_n, size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, size_t k_rope_in_stride, size_t k_nope_in_stride, float quant_scale_q, float quant_scale_kv, @@ -1237,9 +1238,9 @@ cudaError_t RopeQuantizeAppendPagedMLACache( config.attrs = attribute; config.numAttrs = 1; - auto kernel = - RopeQuantizeAppendPagedKVCacheKernel>; + auto kernel = RopeQuantizeAppendPagedKVCacheKernel>; DType* v_in_nullptr = nullptr; uint32_t num_kv_heads_1 = 1; size_t k_rope_in_stride_h_dup = k_rope_in_stride; diff --git a/tests/attention/test_rope.py b/tests/attention/test_rope.py index 651f43d5e9..75570223ae 100644 --- a/tests/attention/test_rope.py +++ b/tests/attention/test_rope.py @@ -516,6 +516,7 @@ def test_generalized_rope_quantize( @pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) @pytest.mark.parametrize("page_size", [16, 32]) +@pytest.mark.parametrize("rope_idtype", [torch.int32, torch.int64]) def test_generalized_rope_quantize_append_kv_cache( attention_type, num_qo_heads, @@ -528,6 +529,7 @@ def test_generalized_rope_quantize_append_kv_cache( enable_pdl, kv_layout, page_size, + rope_idtype, ): device = "cuda:0" # Fixed seed for reproducibility @@ -589,7 +591,7 @@ def test_generalized_rope_quantize_append_kv_cache( rope_ref = FlashInferRotaryEmbedding( head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device ) - pos_ids = torch.arange(num_tokens, device=device, dtype=torch.int32) + pos_ids = torch.arange(num_tokens, device=device, dtype=rope_idtype) # Build paged metadata kv_append_length = torch.tensor( @@ -833,6 +835,7 @@ def test_generalized_rope_quantize_append_kv_cache( @pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) @pytest.mark.parametrize("page_size", [16, 32]) +@pytest.mark.parametrize("rope_idtype", [torch.int32, torch.int64]) def test_rope_quantize_fp8_append_paged_kv_cache_decode( attention_type, num_qo_heads, @@ -846,6 +849,7 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( enable_pdl, kv_layout, page_size, + rope_idtype, ): """Test append to non-empty cache (decode/continuation scenario).""" device = "cuda:0" @@ -938,7 +942,7 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device ) pos_ids_existing = torch.arange( - num_existing_tokens, device=device, dtype=torch.int32 + num_existing_tokens, device=device, dtype=rope_idtype ) # Build metadata for existing tokens (single request for simplicity) @@ -1131,7 +1135,7 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( num_existing_tokens, num_existing_tokens + num_new_tokens, device=device, - dtype=torch.int32, + dtype=rope_idtype, ) # Build metadata for new tokens (continue appending to first request)