diff --git a/csrc/flashinfer_rope_binding.cu b/csrc/flashinfer_rope_binding.cu index 94809da735..a5b6b57fb0 100644 --- a/csrc/flashinfer_rope_binding.cu +++ b/csrc/flashinfer_rope_binding.cu @@ -53,6 +53,15 @@ void rope_quantize_append_paged_kv_cache( TensorView positions, int64_t kv_layout_code, int64_t page_size, double quant_scale_q, double quant_scale_kv, bool interleave, bool enable_pdl); +void rope_append_paged_kv_cache(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, + TensorView k_nope_in, TensorView v_in, TensorView q_rope_out, + TensorView q_nope_out, TensorView cos_sin_cache, TensorView pos_ids, + TensorView k_cache, TensorView v_cache, TensorView kv_indices, + TensorView kv_indptr, TensorView kv_last_page_len, + TensorView batch_indices, TensorView positions, + int64_t kv_layout_code, int64_t page_size, double kv_scale, + bool interleave, bool enable_pdl); + TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope, apply_rope); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope, apply_llama31_rope); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids, apply_rope_pos_ids); @@ -61,3 +70,4 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids_cos_sin_cache, apply_rope_pos_i TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize, rope_quantize); TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize_append_paged_kv_cache, rope_quantize_append_paged_kv_cache); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_append_paged_kv_cache, rope_append_paged_kv_cache); diff --git a/csrc/rope.cu b/csrc/rope.cu index 6699675411..1c05d76f76 100644 --- a/csrc/rope.cu +++ b/csrc/rope.cu @@ -621,3 +621,141 @@ void rope_quantize_append_paged_kv_cache( }); }); } + +/*! + * TVM FFI binding for fused RoPE + paged KV cache append kernel (GQA/MHA only). + */ +void rope_append_paged_kv_cache(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, + TensorView k_nope_in, TensorView v_in, TensorView q_rope_out, + TensorView q_nope_out, TensorView cos_sin_cache, TensorView pos_ids, + TensorView k_cache, TensorView v_cache, TensorView kv_indices, + TensorView kv_indptr, TensorView kv_last_page_len, + TensorView batch_indices, TensorView positions, + int64_t kv_layout_code, int64_t page_size, double kv_scale, + bool interleave, bool enable_pdl) { + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_rope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_nope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_out); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_out); + CHECK_INPUT(cos_sin_cache); + CHECK_INPUT(pos_ids); + CHECK_CUDA(k_cache); + CHECK_CUDA(v_cache); + CHECK_INPUT(kv_indices); + CHECK_INPUT(kv_indptr); + CHECK_INPUT(kv_last_page_len); + CHECK_INPUT(batch_indices); + CHECK_INPUT(positions); + + CHECK_DIM(3, q_rope_in); + CHECK_DIM(3, k_rope_in); + CHECK_DIM(3, q_nope_in); + CHECK_DIM(3, k_nope_in); + CHECK_DIM(3, v_in); + CHECK_DIM(3, q_rope_out); + CHECK_DIM(3, q_nope_out); + CHECK_DIM(4, k_cache); + CHECK_DIM(4, v_cache); + CHECK_DIM(1, kv_last_page_len); + + uint32_t rope_dim = q_rope_in.size(-1); + uint32_t no_rope_dim = q_nope_in.size(-1); + uint32_t head_dim = rope_dim + no_rope_dim; + uint32_t nnz = q_rope_in.size(0); + uint32_t num_qo_heads = q_rope_in.size(1); + uint32_t num_kv_heads = k_rope_in.size(1); + uint32_t batch_size = kv_indptr.size(0) - 1; + QKVLayout kv_layout = QKVLayout(kv_layout_code); + + TVM_FFI_ICHECK_EQ(k_rope_in.size(-1), rope_dim); + TVM_FFI_ICHECK_EQ(k_nope_in.size(-1), no_rope_dim); + TVM_FFI_ICHECK_EQ(k_nope_in.size(1), num_kv_heads); + TVM_FFI_ICHECK_EQ(v_in.size(0), nnz); + TVM_FFI_ICHECK_EQ(v_in.size(1), num_kv_heads); + TVM_FFI_ICHECK_EQ(v_in.size(2), head_dim); + TVM_FFI_ICHECK_EQ(q_rope_out.size(0), nnz); + TVM_FFI_ICHECK_EQ(q_rope_out.size(1), num_qo_heads); + TVM_FFI_ICHECK_EQ(q_rope_out.size(2), rope_dim); + TVM_FFI_ICHECK_EQ(q_nope_out.size(0), nnz); + TVM_FFI_ICHECK_EQ(q_nope_out.size(1), num_qo_heads); + TVM_FFI_ICHECK_EQ(q_nope_out.size(2), no_rope_dim); + TVM_FFI_ICHECK_EQ(k_cache.size(0), v_cache.size(0)); + TVM_FFI_ICHECK_EQ(k_cache.size(1), v_cache.size(1)); + TVM_FFI_ICHECK_EQ(k_cache.size(2), v_cache.size(2)); + TVM_FFI_ICHECK_EQ(k_cache.size(3), v_cache.size(3)); + TVM_FFI_ICHECK_EQ(kv_last_page_len.size(0), batch_size); + + TVM_FFI_ICHECK(q_rope_in.dtype() == dl_float16 || q_rope_in.dtype() == dl_bfloat16) + << "Input dtype must be float16 or bfloat16"; + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_rope_in.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), q_nope_in.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_nope_in.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), v_in.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), q_rope_out.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), q_nope_out.dtype()); + TVM_FFI_ICHECK_EQ(k_cache.dtype(), v_cache.dtype()); + TVM_FFI_ICHECK(k_cache.dtype() == dl_float16 || k_cache.dtype() == dl_bfloat16 || + k_cache.dtype() == dl_float8_e4m3fn || k_cache.dtype() == dl_float8_e5m2) + << "Cache dtype must be float16, bfloat16, float8_e4m3fn, or float8_e5m2"; + + const uint32_t q_rope_in_stride_n = q_rope_in.stride(0); + const uint32_t q_rope_in_stride_h = q_rope_in.stride(1); + const uint32_t q_nope_in_stride_n = q_nope_in.stride(0); + const uint32_t q_nope_in_stride_h = q_nope_in.stride(1); + const uint32_t q_rope_out_stride_n = q_rope_out.stride(0); + const uint32_t q_rope_out_stride_h = q_rope_out.stride(1); + const uint32_t q_nope_out_stride_n = q_nope_out.stride(0); + const uint32_t q_nope_out_stride_h = q_nope_out.stride(1); + const uint32_t k_rope_in_stride = k_rope_in.stride(0); + const uint32_t k_rope_in_stride_h = k_rope_in.stride(1); + const uint32_t k_nope_in_stride = k_nope_in.stride(0); + const uint32_t k_nope_in_stride_h = k_nope_in.stride(1); + const uint32_t v_in_stride = v_in.stride(0); + const uint32_t v_in_stride_h = v_in.stride(1); + + ffi::CUDADeviceGuard device_guard(q_rope_in.device().device_id); + const cudaStream_t stream = get_stream(q_rope_in.device()); + + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in.dtype(), c_type, [&] { + return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids.dtype(), c_idtype, [&] { + auto launch = [&](auto cache_dtype_tag) -> bool { + using c_cache_type = decltype(cache_dtype_tag); + auto k_strides = k_cache.strides(); + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, kv_layout, + static_cast(k_cache.data_ptr()), + static_cast(v_cache.data_ptr()), k_strides.data(), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), + static_cast(kv_last_page_len.data_ptr())); + cudaError_t status = RopeAppendPagedKVCache( + static_cast(q_rope_in.data_ptr()), static_cast(k_rope_in.data_ptr()), + static_cast(q_nope_in.data_ptr()), static_cast(k_nope_in.data_ptr()), + static_cast(v_in.data_ptr()), static_cast(q_rope_out.data_ptr()), + static_cast(q_nope_out.data_ptr()), paged_kv, + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim, + no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, + q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, + q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride, + k_nope_in_stride_h, v_in_stride, v_in_stride_h, kv_scale, interleave, enable_pdl, + stream); + TVM_FFI_ICHECK(status == cudaSuccess) + << "RopeAppendPagedKVCache failed with error code " << cudaGetErrorString(status); + return true; + }; + + if (k_cache.dtype() == dl_float16 || k_cache.dtype() == dl_bfloat16) { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(k_cache.dtype(), c_cache_type, + [&] { return launch(c_cache_type{}); }); + } + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(k_cache.dtype(), c_cache_type, + [&] { return launch(c_cache_type{}); }); + }); + }); +} diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 8fa98adb62..0851e93b16 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -143,6 +143,7 @@ from .rope import apply_rope_inplace as apply_rope_inplace from .rope import apply_rope_pos_ids as apply_rope_pos_ids from .rope import apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace +from .rope import rope_append_paged_kv_cache as rope_append_paged_kv_cache from .rope import apply_rope_with_cos_sin_cache as apply_rope_with_cos_sin_cache from .rope import ( apply_rope_with_cos_sin_cache_inplace as apply_rope_with_cos_sin_cache_inplace, diff --git a/flashinfer/rope.py b/flashinfer/rope.py index d39d2e07e6..58c3342155 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -326,6 +326,85 @@ def _fake_rope_quantize_fp8_append_paged_kv_cache( pass +@register_custom_op( + "flashinfer::rope_append_paged_kv_cache", + mutates_args=("q_rope_out", "q_nope_out", "k_cache", "v_cache"), +) +def _rope_append_paged_kv_cache( + q_rope_in: torch.Tensor, + k_rope_in: torch.Tensor, + q_nope_in: torch.Tensor, + k_nope_in: torch.Tensor, + v_in: torch.Tensor, + q_rope_out: torch.Tensor, + q_nope_out: torch.Tensor, + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + kv_last_page_len: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, + kv_layout_code: int, + page_size: int, + kv_scale: float, + interleave: bool, + enable_pdl: bool, +) -> None: + get_rope_module().rope_append_paged_kv_cache( + q_rope_in, + k_rope_in, + q_nope_in, + k_nope_in, + v_in, + q_rope_out, + q_nope_out, + cos_sin_cache, + pos_ids, + k_cache, + v_cache, + kv_indices, + kv_indptr, + kv_last_page_len, + batch_indices, + positions, + kv_layout_code, + page_size, + kv_scale, + interleave, + enable_pdl, + ) + + +@register_fake_op("flashinfer::rope_append_paged_kv_cache") +def _fake_rope_append_paged_kv_cache( + q_rope_in: torch.Tensor, + k_rope_in: torch.Tensor, + q_nope_in: torch.Tensor, + k_nope_in: torch.Tensor, + v_in: torch.Tensor, + q_rope_out: torch.Tensor, + q_nope_out: torch.Tensor, + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + kv_last_page_len: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, + kv_layout_code: int, + page_size: int, + kv_scale: float, + interleave: bool, + enable_pdl: bool, +) -> None: + pass + + @register_custom_op( "flashinfer::apply_rope_pos_ids_cos_sin_cache", mutates_args=("q_rope", "k_rope") ) @@ -1674,3 +1753,98 @@ def rope_quantize_fp8_append_paged_kv_cache( ) return q_rope_out, q_nope_out + + +@flashinfer_api +def rope_append_paged_kv_cache( + q_rope: torch.Tensor, + k_rope: torch.Tensor, + q_nope: Optional[torch.Tensor], + k_nope: Optional[torch.Tensor], + v: torch.Tensor, + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + paged_kv_cache: Tuple[torch.Tensor, torch.Tensor], + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + kv_last_page_len: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, + is_neox: bool = True, + kv_scale: float = 1.0, + page_size: int = 16, + kv_layout: str = "NHD", + q_rope_out: Optional[torch.Tensor] = None, + q_nope_out: Optional[torch.Tensor] = None, + enable_pdl: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Apply RoPE to Q/K and append K/V to paged KV cache. + + This primitive keeps query outputs in the input dtype and only uses + ``kv_scale`` for cache-side casting when the cache dtype is FP8. + """ + if cos_sin_cache.dtype != torch.float32: + raise ValueError("cos_sin_cache should be float32") + if k_rope.ndim != 3: + raise ValueError("rope_append_paged_kv_cache only supports GQA/MHA inputs") + + nnz = q_rope.shape[0] + num_qo_heads = q_rope.shape[1] + num_kv_heads = k_rope.shape[1] + + if q_nope is None: + q_nope = torch.empty( + nnz, num_qo_heads, 0, dtype=q_rope.dtype, device=q_rope.device + ) + if k_nope is None: + k_nope = torch.empty( + nnz, num_kv_heads, 0, dtype=k_rope.dtype, device=k_rope.device + ) + + if q_rope_out is None: + q_rope_out = torch.empty_like(q_rope) + if q_nope_out is None: + q_nope_out = torch.empty_like(q_nope) + + if len(paged_kv_cache) != 2: + raise ValueError("paged_kv_cache must be a tuple of (k_cache, v_cache)") + k_cache, v_cache = paged_kv_cache + if k_cache.ndim != 4 or v_cache.ndim != 4: + raise ValueError("rope_append_paged_kv_cache expects 4D GQA/MHA cache tensors") + if k_cache.dtype != v_cache.dtype: + raise ValueError("k_cache and v_cache must have the same dtype") + + from .utils import TensorLayout + + kv_layout_code = TensorLayout[kv_layout].value + batch_indices = batch_indices.int() + positions = positions.int() + kv_indices = kv_indices.int() + kv_indptr = kv_indptr.int() + kv_last_page_len = kv_last_page_len.int() + + _rope_append_paged_kv_cache( + q_rope, + k_rope, + q_nope, + k_nope, + v, + q_rope_out, + q_nope_out, + cos_sin_cache, + pos_ids, + k_cache, + v_cache, + kv_indices, + kv_indptr, + kv_last_page_len, + batch_indices, + positions, + kv_layout_code, + page_size, + kv_scale, + not is_neox, + enable_pdl, + ) + + return q_rope_out, q_nope_out diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 4fdd75e0a3..afee5e67c9 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -54,6 +54,29 @@ struct RopeQuantizeAppendPagedKVCacheParams { float quant_scale_kv; }; +struct RopeAppendPagedKVCacheParams { + uint32_t nnz; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t rope_dim; + uint32_t no_rope_dim; + size_t q_rope_in_stride_n; + size_t q_rope_in_stride_h; + size_t q_nope_in_stride_n; + size_t q_nope_in_stride_h; + size_t q_rope_out_stride_n; + size_t q_rope_out_stride_h; + size_t q_nope_out_stride_n; + size_t q_nope_out_stride_h; + size_t k_rope_in_stride; + size_t k_rope_in_stride_h; + size_t k_nope_in_stride; + size_t k_nope_in_stride_h; + size_t v_in_stride; + size_t v_in_stride_h; + float kv_scale; +}; + /*! * \brief An enumeration class that defines different modes for applying RoPE * (Rotary Positional Embeddings). @@ -1029,6 +1052,194 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( #endif } +/*! + * \brief CUDA kernel to apply RoPE to Q/K and append K/V to paged cache. + * + * This stage-1 kernel keeps Q outputs in the input dtype while allowing the + * paged KV cache to use its own dtype (e.g. FP8). + */ +template +__global__ void RopeAppendPagedKVCacheKernel( + DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in, + DType* q_rope_out, DType* q_nope_out, paged_kv_t paged_kv, + PagedKVIdType* __restrict__ batch_indices, PagedKVIdType* __restrict__ positions, + float* __restrict__ cos_sin_cache, RoPEIdType* __restrict__ pos_ids, + const RopeAppendPagedKVCacheParams params) { +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; + uint32_t by = blockIdx.y; + uint32_t bdy = blockDim.y; + + const uint32_t nnz = params.nnz; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + const uint32_t rope_dim = params.rope_dim; + const uint32_t no_rope_dim = params.no_rope_dim; + const size_t q_rope_in_stride_n = params.q_rope_in_stride_n; + const size_t q_rope_in_stride_h = params.q_rope_in_stride_h; + const size_t q_nope_in_stride_n = params.q_nope_in_stride_n; + const size_t q_nope_in_stride_h = params.q_nope_in_stride_h; + const size_t q_rope_out_stride_n = params.q_rope_out_stride_n; + const size_t q_rope_out_stride_h = params.q_rope_out_stride_h; + const size_t q_nope_out_stride_n = params.q_nope_out_stride_n; + const size_t q_nope_out_stride_h = params.q_nope_out_stride_h; + const size_t k_rope_in_stride = params.k_rope_in_stride; + const size_t k_rope_in_stride_h = params.k_rope_in_stride_h; + const size_t k_nope_in_stride = params.k_nope_in_stride; + const size_t k_nope_in_stride_h = params.k_nope_in_stride_h; + const size_t v_in_stride = params.v_in_stride; + const size_t v_in_stride_h = params.v_in_stride_h; + const float kv_scale = params.kv_scale; + constexpr bool kNeedsScale = + std::is_same_v || std::is_same_v; + + uint32_t rope_chunk_size = rope_dim; + uint32_t rope_chunks = (rope_dim + rope_chunk_size - 1) / rope_chunk_size; + uint32_t no_rope_chunks = (no_rope_dim + rope_chunk_size - 1) / rope_chunk_size; + + uint32_t q_rope_end = num_qo_heads * rope_chunks; + uint32_t k_rope_end = q_rope_end + num_kv_heads * rope_chunks; + uint32_t k_nope_end = k_rope_end + num_kv_heads * no_rope_chunks; + + vec_t cos, sin; + if (bx * bdy + ty < nnz) { + const uint32_t idx = bx * bdy + ty; + const RoPEIdType pos = pos_ids[idx]; + + uint32_t page_iter, entry_idx; + paged_kv.page_size.divmod( + paged_kv.indptr[batch_indices[idx]] * paged_kv.page_size + positions[idx], page_iter, + entry_idx); + + const int half_rope_dim = rope_dim / 2; + if ((tx * vec_size < rope_dim) && (by < k_rope_end)) { + int sin_offset = rope_dim / 2; + int vec_idx; + if constexpr (interleave) { + vec_idx = (tx * vec_size) / 2; + } else { + vec_idx = (tx * vec_size) % half_rope_dim; + } + cos.load(cos_sin_cache + (pos * rope_dim) + vec_idx); + sin.load(cos_sin_cache + (pos * rope_dim) + (sin_offset + vec_idx)); + } + + if (by < q_rope_end) { + uint32_t q_head_idx = by / rope_chunks; + uint32_t rope_chunk_idx = by % rope_chunks; + uint32_t elem_offset = rope_chunk_idx * rope_chunk_size; + + DType* q_rope_in_ptr = + q_rope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_in_stride_n, + q_rope_in_stride_h); + DType* q_rope_out_ptr = + q_rope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_out_stride_n, + q_rope_out_stride_h); + + vec_t q_rope_vec; + if constexpr (interleave) { + q_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half( + q_rope_in_ptr, cos, sin, rope_dim); + } else { + q_rope_vec = vec_apply_llama_rope_cos_sin(q_rope_in_ptr, cos, sin, rope_dim); + } + q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size); + + } else if (by < k_rope_end) { + uint32_t k_head_idx = (by - q_rope_end) / rope_chunks; + uint32_t rope_chunk_idx = (by - q_rope_end) % rope_chunks; + uint32_t elem_offset = rope_chunk_idx * rope_chunk_size; + + DType* k_rope_in_ptr = k_rope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset, + k_rope_in_stride, k_rope_in_stride_h); + + vec_t k_rope_vec; + if constexpr (interleave) { + k_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half( + k_rope_in_ptr, cos, sin, rope_dim); + } else { + k_rope_vec = vec_apply_llama_rope_cos_sin(k_rope_in_ptr, cos, sin, rope_dim); + } + if constexpr (kNeedsScale) { +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + k_rope_vec[i] = k_rope_vec[i] * kv_scale; + } + } + CacheType* k_ptr = paged_kv.get_k_ptr(page_iter, k_head_idx, entry_idx, tx * vec_size); + k_rope_vec.cast_store(k_ptr); + + } else if (by < k_nope_end) { + uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks; + uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks; + uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; + + DType* k_nope_in_ptr = k_nope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset, + k_nope_in_stride, k_nope_in_stride_h); + + vec_t k_nope_vec; + k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size); + if constexpr (kNeedsScale) { +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + k_nope_vec[i] = k_nope_vec[i] * kv_scale; + } + } + + CacheType* k_ptr = paged_kv.get_k_ptr(page_iter, k_head_idx, entry_idx, + rope_dim + elem_offset + tx * vec_size); + k_nope_vec.cast_store(k_ptr); + + } else if (by < k_nope_end + num_kv_heads) { + uint32_t kv_head_idx = by - k_nope_end; + DType* v_in_ptr = + v_in + get_elem_offset_impl(idx, kv_head_idx, 0, v_in_stride, v_in_stride_h); + uint32_t head_dim_total = rope_dim + no_rope_dim; + uint32_t v_chunks = (head_dim_total + rope_chunk_size - 1) / rope_chunk_size; +#pragma unroll 1 + for (uint32_t j = 0; j < v_chunks; ++j) { + uint32_t v_elem_offset = j * rope_chunk_size; + if (v_elem_offset + tx * vec_size < head_dim_total) { + vec_t v_vec; + v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size); + if constexpr (kNeedsScale) { +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + v_vec[i] = v_vec[i] * kv_scale; + } + } + CacheType* v_ptr = + paged_kv.get_v_ptr(page_iter, kv_head_idx, entry_idx, v_elem_offset + tx * vec_size); + v_vec.cast_store(v_ptr); + } + } + + } else { + uint32_t q_nope_start = k_nope_end + num_kv_heads; + uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks; + uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks; + uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; + + DType* q_nope_in_ptr = + q_nope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_in_stride_n, + q_nope_in_stride_h); + DType* q_nope_out_ptr = + q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n, + q_nope_out_stride_h); + + vec_t q_nope_vec; + q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size); + q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size); + } + } +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + template cudaError_t RopeQuantize( DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, QuantType* q_rope_out, @@ -1197,6 +1408,80 @@ cudaError_t RopeQuantizeAppendPagedKVCache( return cudaSuccess; } +/*! + * \brief Host function to apply RoPE and append K/V to paged cache (GQA/MHA). + */ +template +cudaError_t RopeAppendPagedKVCache( + DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in, + DType* q_rope_out, DType* q_nope_out, paged_kv_t paged_kv, + PagedKVIdType* batch_indices, PagedKVIdType* positions, float* cos_sin_cache, + RoPEIdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t rope_dim, uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, + size_t q_nope_in_stride_n, size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, + size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, + size_t k_rope_in_stride, size_t k_rope_in_stride_h, size_t k_nope_in_stride, + size_t k_nope_in_stride_h, size_t v_in_stride, size_t v_in_stride_h, float kv_scale, + bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + constexpr uint32_t vec_size = 32 / sizeof(DType); + uint32_t bdx = (rope_dim + vec_size - 1) / vec_size; + bdx = std::max(1u, bdx); + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = std::max(1u, num_threads / bdx); + uint32_t nblks_x = (nnz + bdy - 1) / bdy; + uint32_t rope_chunks = 1; + uint32_t no_rope_chunks = (no_rope_dim + rope_dim - 1) / rope_dim; + uint32_t total_blocks_y = num_qo_heads * rope_chunks + num_kv_heads * rope_chunks + + num_kv_heads * no_rope_chunks + num_kv_heads + + num_qo_heads * no_rope_chunks; + + dim3 nblks(nblks_x, total_blocks_y); + dim3 nthrs(bdx, bdy); + + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = enable_pdl ? 1 : 0; + cudaLaunchConfig_t config; + config.gridDim = nblks; + config.blockDim = nthrs; + config.stream = stream; + config.dynamicSmemBytes = 0; + config.attrs = attribute; + config.numAttrs = 1; + + auto kernel = RopeAppendPagedKVCacheKernel; + RopeAppendPagedKVCacheParams params; + params.nnz = nnz; + params.num_qo_heads = num_qo_heads; + params.num_kv_heads = num_kv_heads; + params.rope_dim = rope_dim; + params.no_rope_dim = no_rope_dim; + params.q_rope_in_stride_n = q_rope_in_stride_n; + params.q_rope_in_stride_h = q_rope_in_stride_h; + params.q_nope_in_stride_n = q_nope_in_stride_n; + params.q_nope_in_stride_h = q_nope_in_stride_h; + params.q_rope_out_stride_n = q_rope_out_stride_n; + params.q_rope_out_stride_h = q_rope_out_stride_h; + params.q_nope_out_stride_n = q_nope_out_stride_n; + params.q_nope_out_stride_h = q_nope_out_stride_h; + params.k_rope_in_stride = k_rope_in_stride; + params.k_rope_in_stride_h = k_rope_in_stride_h; + params.k_nope_in_stride = k_nope_in_stride; + params.k_nope_in_stride_h = k_nope_in_stride_h; + params.v_in_stride = v_in_stride; + params.v_in_stride_h = v_in_stride_h; + params.kv_scale = kv_scale; + + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( + &config, kernel, q_rope_in, k_rope_in, q_nope_in, k_nope_in, v_in, q_rope_out, q_nope_out, + paged_kv, batch_indices, positions, cos_sin_cache, pos_ids, params)); + }); + + return cudaSuccess; +} + /*! * \brief Host function to apply RoPE, quantize to FP8, and append to MLA paged cache */ diff --git a/tests/attention/test_rope.py b/tests/attention/test_rope.py index 75570223ae..535d71e55b 100644 --- a/tests/attention/test_rope.py +++ b/tests/attention/test_rope.py @@ -19,6 +19,11 @@ from tests.test_helpers.rope_reference import * import flashinfer +from flashinfer.utils import ( + get_compute_capability, + is_sm90a_supported, + is_sm100a_supported, +) @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @@ -1380,6 +1385,391 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( ) +@pytest.mark.parametrize( + "num_qo_heads,num_kv_heads,rope_dim,no_rope_dim", + [ + (32, 8, 128, 0), # Llama/Qwen style GQA + (32, 32, 64, 64), # MHA with q_nope/k_nope present + ], +) +@pytest.mark.parametrize("num_existing_tokens", [10]) +@pytest.mark.parametrize("num_new_tokens", [1, 8]) +@pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("use_fp8_cache", [False, True]) +@pytest.mark.parametrize("enable_pdl", [True, False]) +@pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) +@pytest.mark.parametrize("page_size", [16]) +@pytest.mark.parametrize("rope_idtype", [torch.int32, torch.int64]) +def test_rope_append_paged_kv_cache_decode( + num_qo_heads, + num_kv_heads, + rope_dim, + no_rope_dim, + num_existing_tokens, + num_new_tokens, + input_dtype, + use_fp8_cache, + enable_pdl, + kv_layout, + page_size, + rope_idtype, +): + """Test the non-quant fused paged decode append path.""" + device = "cuda:0" + if use_fp8_cache: + device_obj = torch.device(device) + cc = get_compute_capability(device_obj) + if cc[0] < 9: + pytest.skip( + f"FP8 KV cache requires SM90+ or SM100+, but got SM{cc[0]}{cc[1]}" + ) + if not is_sm90a_supported(device_obj) and not is_sm100a_supported(device_obj): + pytest.skip("FP8 KV cache requires SM90a or SM100a support on this device") + torch.manual_seed(43) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(43) + + head_dim = rope_dim + no_rope_dim + batch_size = 2 + kv_cache_dtype = torch.float8_e4m3fn if use_fp8_cache else input_dtype + + q_rope_existing = torch.randn( + num_existing_tokens, + num_qo_heads, + rope_dim, + dtype=input_dtype, + device=device, + ) + q_nope_existing = ( + None + if no_rope_dim == 0 + else torch.randn( + num_existing_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + k_rope_existing = torch.randn( + num_existing_tokens, + num_kv_heads, + rope_dim, + dtype=input_dtype, + device=device, + ) + k_nope_existing = ( + None + if no_rope_dim == 0 + else torch.randn( + num_existing_tokens, + num_kv_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + v_existing = torch.randn( + num_existing_tokens, + num_kv_heads, + head_dim, + dtype=input_dtype, + device=device, + ) + + max_seq_len = 4096 + rope_ref = FlashInferRotaryEmbedding( + head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device + ) + pos_ids_existing = torch.arange( + num_existing_tokens, device=device, dtype=rope_idtype + ) + + kv_append_length_existing = torch.tensor( + [num_existing_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device + ) + kv_append_indptr_existing = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(kv_append_length_existing, dim=0), + ] + ) + num_pages_existing = (num_existing_tokens + page_size - 1) // page_size + kv_page_indptr_existing = torch.tensor( + [0, num_pages_existing] + [num_pages_existing] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indices_existing = torch.arange( + num_pages_existing, dtype=torch.int32, device=device + ) + kv_last_page_len_existing = torch.tensor( + [ + num_existing_tokens % page_size + if num_existing_tokens % page_size != 0 + else page_size + ] + + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + seq_lens_existing = flashinfer.get_seq_lens( + kv_page_indptr_existing, kv_last_page_len_existing, page_size + ) + batch_indices_existing, positions_existing = flashinfer.get_batch_indices_positions( + kv_append_indptr_existing, seq_lens_existing, num_existing_tokens + ) + + total_tokens = num_existing_tokens + num_new_tokens + max_pages = (total_tokens + page_size - 1) // page_size + if kv_layout == "NHD": + k_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=kv_cache_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=kv_cache_dtype, + device=device, + ) + else: + k_cache = torch.zeros( + max_pages, + num_kv_heads, + page_size, + head_dim, + dtype=kv_cache_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + num_kv_heads, + page_size, + head_dim, + dtype=kv_cache_dtype, + device=device, + ) + + flashinfer.rope.rope_append_paged_kv_cache( + q_rope_existing, + k_rope_existing, + q_nope_existing, + k_nope_existing, + v_existing, + rope_ref.cos_sin_cache, + pos_ids_existing, + (k_cache, v_cache), + kv_page_indices_existing, + kv_page_indptr_existing, + kv_last_page_len_existing, + batch_indices_existing, + positions_existing, + page_size=page_size, + kv_layout=kv_layout, + kv_scale=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + + q_rope_new = torch.randn( + num_new_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope_new = ( + None + if no_rope_dim == 0 + else torch.randn( + num_new_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + k_rope_new = torch.randn( + num_new_tokens, num_kv_heads, rope_dim, dtype=input_dtype, device=device + ) + k_nope_new = ( + None + if no_rope_dim == 0 + else torch.randn( + num_new_tokens, + num_kv_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + v_new = torch.randn( + num_new_tokens, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + pos_ids_new = torch.arange( + num_existing_tokens, + num_existing_tokens + num_new_tokens, + device=device, + dtype=rope_idtype, + ) + + num_pages_new_needed = (total_tokens + page_size - 1) // page_size + kv_page_indptr_new = torch.tensor( + [0, num_pages_new_needed] + [num_pages_new_needed] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indices_new = torch.arange( + num_pages_new_needed, dtype=torch.int32, device=device + ) + kv_last_page_len_new = torch.tensor( + [total_tokens % page_size if total_tokens % page_size != 0 else page_size] + + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + batch_indices_new = torch.zeros(num_new_tokens, device=device, dtype=torch.int32) + positions_new = torch.arange( + num_existing_tokens, + num_existing_tokens + num_new_tokens, + device=device, + dtype=torch.int32, + ) + + k_cache_before = k_cache.clone() + v_cache_before = v_cache.clone() + + q_rope_out_new, q_nope_out_new = flashinfer.rope.rope_append_paged_kv_cache( + q_rope_new, + k_rope_new, + q_nope_new, + k_nope_new, + v_new, + rope_ref.cos_sin_cache, + pos_ids_new, + (k_cache, v_cache), + kv_page_indices_new, + kv_page_indptr_new, + kv_last_page_len_new, + batch_indices_new, + positions_new, + page_size=page_size, + kv_layout=kv_layout, + kv_scale=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + + q_in_new = ( + q_rope_new + if q_nope_new is None + else torch.cat([q_rope_new, q_nope_new], dim=-1) + ) + k_in_new = ( + k_rope_new + if k_nope_new is None + else torch.cat([k_rope_new, k_nope_new], dim=-1) + ) + q_ref_new, k_ref_new = rope_ref.forward_native(pos_ids_new, q_in_new, k_in_new) + + q_rtol, q_atol = (1e-3, 1e-3) if input_dtype == torch.float16 else (2e-2, 2e-2) + torch.testing.assert_close( + q_ref_new[..., :rope_dim].float(), + q_rope_out_new.float(), + rtol=q_rtol, + atol=q_atol, + ) + torch.testing.assert_close( + q_ref_new[..., rope_dim:].float(), + q_nope_out_new.float(), + rtol=q_rtol, + atol=q_atol, + ) + + if use_fp8_cache: + cache_rtol, cache_atol = 0.25, 0.5 + else: + cache_rtol, cache_atol = ( + (1e-3, 1e-3) if input_dtype == torch.float16 else (2e-2, 2e-2) + ) + + k_ref_tokens_new = k_ref_new.to(kv_cache_dtype) + v_ref_tokens_new = v_new.to(kv_cache_dtype) + + for i in range(num_existing_tokens): + b = batch_indices_existing[i].item() + pos = positions_existing[i].item() + page_iter = (kv_page_indptr_existing[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr_existing[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices_existing[page_iter].item() + if kv_layout == "NHD": + torch.testing.assert_close( + k_cache[page_idx, entry_idx, :, :].float(), + k_cache_before[page_idx, entry_idx, :, :].float(), + rtol=0, + atol=0, + msg=f"Existing K cache entry {i} was modified", + ) + torch.testing.assert_close( + v_cache[page_idx, entry_idx, :, :].float(), + v_cache_before[page_idx, entry_idx, :, :].float(), + rtol=0, + atol=0, + msg=f"Existing V cache entry {i} was modified", + ) + else: + torch.testing.assert_close( + k_cache[page_idx, :, entry_idx, :].float(), + k_cache_before[page_idx, :, entry_idx, :].float(), + rtol=0, + atol=0, + msg=f"Existing K cache entry {i} was modified", + ) + torch.testing.assert_close( + v_cache[page_idx, :, entry_idx, :].float(), + v_cache_before[page_idx, :, entry_idx, :].float(), + rtol=0, + atol=0, + msg=f"Existing V cache entry {i} was modified", + ) + + for i in range(num_new_tokens): + b = batch_indices_new[i].item() + pos = positions_new[i].item() + page_iter = (kv_page_indptr_new[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr_new[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices_new[page_iter].item() + if kv_layout == "NHD": + torch.testing.assert_close( + k_cache[page_idx, entry_idx, :, :].float(), + k_ref_tokens_new[i].float(), + rtol=cache_rtol, + atol=cache_atol, + ) + torch.testing.assert_close( + v_cache[page_idx, entry_idx, :, :].float(), + v_ref_tokens_new[i].float(), + rtol=cache_rtol, + atol=cache_atol, + ) + else: + torch.testing.assert_close( + k_cache[page_idx, :, entry_idx, :].float(), + k_ref_tokens_new[i].float(), + rtol=cache_rtol, + atol=cache_atol, + ) + torch.testing.assert_close( + v_cache[page_idx, :, entry_idx, :].float(), + v_ref_tokens_new[i].float(), + rtol=cache_rtol, + atol=cache_atol, + ) + + @pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047]) @pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])