Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions csrc/flashinfer_rope_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ void rope_quantize_append_paged_kv_cache(
TensorView positions, int64_t kv_layout_code, int64_t page_size, double quant_scale_q,
double quant_scale_kv, bool interleave, bool enable_pdl);

void rope_append_paged_kv_cache(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in,
TensorView k_nope_in, TensorView v_in, TensorView q_rope_out,
TensorView q_nope_out, TensorView cos_sin_cache, TensorView pos_ids,
TensorView k_cache, TensorView v_cache, TensorView kv_indices,
TensorView kv_indptr, TensorView kv_last_page_len,
TensorView batch_indices, TensorView positions,
int64_t kv_layout_code, int64_t page_size, double kv_scale,
bool interleave, bool enable_pdl);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope, apply_rope);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope, apply_llama31_rope);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids, apply_rope_pos_ids);
Expand All @@ -61,3 +70,4 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids_cos_sin_cache, apply_rope_pos_i
TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize, rope_quantize);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize_append_paged_kv_cache,
rope_quantize_append_paged_kv_cache);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_append_paged_kv_cache, rope_append_paged_kv_cache);
138 changes: 138 additions & 0 deletions csrc/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -621,3 +621,141 @@ void rope_quantize_append_paged_kv_cache(
});
});
}

/*!
* TVM FFI binding for fused RoPE + paged KV cache append kernel (GQA/MHA only).
*/
void rope_append_paged_kv_cache(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in,
TensorView k_nope_in, TensorView v_in, TensorView q_rope_out,
TensorView q_nope_out, TensorView cos_sin_cache, TensorView pos_ids,
TensorView k_cache, TensorView v_cache, TensorView kv_indices,
TensorView kv_indptr, TensorView kv_last_page_len,
TensorView batch_indices, TensorView positions,
int64_t kv_layout_code, int64_t page_size, double kv_scale,
bool interleave, bool enable_pdl) {
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_in);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_rope_in);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_in);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_nope_in);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_in);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_out);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_out);
CHECK_INPUT(cos_sin_cache);
CHECK_INPUT(pos_ids);
CHECK_CUDA(k_cache);
CHECK_CUDA(v_cache);
CHECK_INPUT(kv_indices);
CHECK_INPUT(kv_indptr);
CHECK_INPUT(kv_last_page_len);
CHECK_INPUT(batch_indices);
CHECK_INPUT(positions);

CHECK_DIM(3, q_rope_in);
CHECK_DIM(3, k_rope_in);
CHECK_DIM(3, q_nope_in);
CHECK_DIM(3, k_nope_in);
CHECK_DIM(3, v_in);
CHECK_DIM(3, q_rope_out);
CHECK_DIM(3, q_nope_out);
CHECK_DIM(4, k_cache);
CHECK_DIM(4, v_cache);
CHECK_DIM(1, kv_last_page_len);

uint32_t rope_dim = q_rope_in.size(-1);
uint32_t no_rope_dim = q_nope_in.size(-1);
uint32_t head_dim = rope_dim + no_rope_dim;
uint32_t nnz = q_rope_in.size(0);
uint32_t num_qo_heads = q_rope_in.size(1);
uint32_t num_kv_heads = k_rope_in.size(1);
uint32_t batch_size = kv_indptr.size(0) - 1;
QKVLayout kv_layout = QKVLayout(kv_layout_code);

TVM_FFI_ICHECK_EQ(k_rope_in.size(-1), rope_dim);
TVM_FFI_ICHECK_EQ(k_nope_in.size(-1), no_rope_dim);
TVM_FFI_ICHECK_EQ(k_nope_in.size(1), num_kv_heads);
TVM_FFI_ICHECK_EQ(v_in.size(0), nnz);
TVM_FFI_ICHECK_EQ(v_in.size(1), num_kv_heads);
TVM_FFI_ICHECK_EQ(v_in.size(2), head_dim);
TVM_FFI_ICHECK_EQ(q_rope_out.size(0), nnz);
TVM_FFI_ICHECK_EQ(q_rope_out.size(1), num_qo_heads);
TVM_FFI_ICHECK_EQ(q_rope_out.size(2), rope_dim);
TVM_FFI_ICHECK_EQ(q_nope_out.size(0), nnz);
TVM_FFI_ICHECK_EQ(q_nope_out.size(1), num_qo_heads);
TVM_FFI_ICHECK_EQ(q_nope_out.size(2), no_rope_dim);
TVM_FFI_ICHECK_EQ(k_cache.size(0), v_cache.size(0));
TVM_FFI_ICHECK_EQ(k_cache.size(1), v_cache.size(1));
TVM_FFI_ICHECK_EQ(k_cache.size(2), v_cache.size(2));
TVM_FFI_ICHECK_EQ(k_cache.size(3), v_cache.size(3));
TVM_FFI_ICHECK_EQ(kv_last_page_len.size(0), batch_size);

TVM_FFI_ICHECK(q_rope_in.dtype() == dl_float16 || q_rope_in.dtype() == dl_bfloat16)
<< "Input dtype must be float16 or bfloat16";
TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_rope_in.dtype());
TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), q_nope_in.dtype());
TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_nope_in.dtype());
TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), v_in.dtype());
TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), q_rope_out.dtype());
TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), q_nope_out.dtype());
TVM_FFI_ICHECK_EQ(k_cache.dtype(), v_cache.dtype());
TVM_FFI_ICHECK(k_cache.dtype() == dl_float16 || k_cache.dtype() == dl_bfloat16 ||
k_cache.dtype() == dl_float8_e4m3fn || k_cache.dtype() == dl_float8_e5m2)
<< "Cache dtype must be float16, bfloat16, float8_e4m3fn, or float8_e5m2";

const uint32_t q_rope_in_stride_n = q_rope_in.stride(0);
const uint32_t q_rope_in_stride_h = q_rope_in.stride(1);
const uint32_t q_nope_in_stride_n = q_nope_in.stride(0);
const uint32_t q_nope_in_stride_h = q_nope_in.stride(1);
const uint32_t q_rope_out_stride_n = q_rope_out.stride(0);
const uint32_t q_rope_out_stride_h = q_rope_out.stride(1);
const uint32_t q_nope_out_stride_n = q_nope_out.stride(0);
const uint32_t q_nope_out_stride_h = q_nope_out.stride(1);
const uint32_t k_rope_in_stride = k_rope_in.stride(0);
const uint32_t k_rope_in_stride_h = k_rope_in.stride(1);
const uint32_t k_nope_in_stride = k_nope_in.stride(0);
const uint32_t k_nope_in_stride_h = k_nope_in.stride(1);
const uint32_t v_in_stride = v_in.stride(0);
const uint32_t v_in_stride_h = v_in.stride(1);

ffi::CUDADeviceGuard device_guard(q_rope_in.device().device_id);
const cudaStream_t stream = get_stream(q_rope_in.device());

DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in.dtype(), c_type, [&] {
return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids.dtype(), c_idtype, [&] {
auto launch = [&](auto cache_dtype_tag) -> bool {
using c_cache_type = decltype(cache_dtype_tag);
auto k_strides = k_cache.strides();
paged_kv_t<c_cache_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout,
static_cast<c_cache_type*>(k_cache.data_ptr()),
static_cast<c_cache_type*>(v_cache.data_ptr()), k_strides.data(),
static_cast<int32_t*>(kv_indices.data_ptr()),
static_cast<int32_t*>(kv_indptr.data_ptr()),
static_cast<int32_t*>(kv_last_page_len.data_ptr()));
cudaError_t status = RopeAppendPagedKVCache(
static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()),
static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()),
static_cast<c_type*>(v_in.data_ptr()), static_cast<c_type*>(q_rope_out.data_ptr()),
static_cast<c_type*>(q_nope_out.data_ptr()), paged_kv,
static_cast<int32_t*>(batch_indices.data_ptr()),
static_cast<int32_t*>(positions.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<c_idtype*>(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim,
no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n,
q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n,
q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride,
k_nope_in_stride_h, v_in_stride, v_in_stride_h, kv_scale, interleave, enable_pdl,
stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "RopeAppendPagedKVCache failed with error code " << cudaGetErrorString(status);
return true;
};

if (k_cache.dtype() == dl_float16 || k_cache.dtype() == dl_bfloat16) {
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(k_cache.dtype(), c_cache_type,
[&] { return launch(c_cache_type{}); });
}
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(k_cache.dtype(), c_cache_type,
[&] { return launch(c_cache_type{}); });
});
});
}
1 change: 1 addition & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
from .rope import apply_rope_inplace as apply_rope_inplace
from .rope import apply_rope_pos_ids as apply_rope_pos_ids
from .rope import apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace
from .rope import rope_append_paged_kv_cache as rope_append_paged_kv_cache
from .rope import apply_rope_with_cos_sin_cache as apply_rope_with_cos_sin_cache
from .rope import (
apply_rope_with_cos_sin_cache_inplace as apply_rope_with_cos_sin_cache_inplace,
Expand Down
174 changes: 174 additions & 0 deletions flashinfer/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,85 @@ def _fake_rope_quantize_fp8_append_paged_kv_cache(
pass


@register_custom_op(
"flashinfer::rope_append_paged_kv_cache",
mutates_args=("q_rope_out", "q_nope_out", "k_cache", "v_cache"),
)
def _rope_append_paged_kv_cache(
q_rope_in: torch.Tensor,
k_rope_in: torch.Tensor,
q_nope_in: torch.Tensor,
k_nope_in: torch.Tensor,
v_in: torch.Tensor,
q_rope_out: torch.Tensor,
q_nope_out: torch.Tensor,
cos_sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
kv_indices: torch.Tensor,
kv_indptr: torch.Tensor,
kv_last_page_len: torch.Tensor,
batch_indices: torch.Tensor,
positions: torch.Tensor,
kv_layout_code: int,
page_size: int,
kv_scale: float,
interleave: bool,
enable_pdl: bool,
) -> None:
get_rope_module().rope_append_paged_kv_cache(
q_rope_in,
k_rope_in,
q_nope_in,
k_nope_in,
v_in,
q_rope_out,
q_nope_out,
cos_sin_cache,
pos_ids,
k_cache,
v_cache,
kv_indices,
kv_indptr,
kv_last_page_len,
batch_indices,
positions,
kv_layout_code,
page_size,
kv_scale,
interleave,
enable_pdl,
)


@register_fake_op("flashinfer::rope_append_paged_kv_cache")
def _fake_rope_append_paged_kv_cache(
q_rope_in: torch.Tensor,
k_rope_in: torch.Tensor,
q_nope_in: torch.Tensor,
k_nope_in: torch.Tensor,
v_in: torch.Tensor,
q_rope_out: torch.Tensor,
q_nope_out: torch.Tensor,
cos_sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
kv_indices: torch.Tensor,
kv_indptr: torch.Tensor,
kv_last_page_len: torch.Tensor,
batch_indices: torch.Tensor,
positions: torch.Tensor,
kv_layout_code: int,
page_size: int,
kv_scale: float,
interleave: bool,
enable_pdl: bool,
) -> None:
pass


@register_custom_op(
"flashinfer::apply_rope_pos_ids_cos_sin_cache", mutates_args=("q_rope", "k_rope")
)
Expand Down Expand Up @@ -1674,3 +1753,98 @@ def rope_quantize_fp8_append_paged_kv_cache(
)

return q_rope_out, q_nope_out


@flashinfer_api
def rope_append_paged_kv_cache(
q_rope: torch.Tensor,
k_rope: torch.Tensor,
q_nope: Optional[torch.Tensor],
k_nope: Optional[torch.Tensor],
v: torch.Tensor,
cos_sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
paged_kv_cache: Tuple[torch.Tensor, torch.Tensor],
kv_indices: torch.Tensor,
kv_indptr: torch.Tensor,
kv_last_page_len: torch.Tensor,
batch_indices: torch.Tensor,
positions: torch.Tensor,
is_neox: bool = True,
kv_scale: float = 1.0,
page_size: int = 16,
kv_layout: str = "NHD",
q_rope_out: Optional[torch.Tensor] = None,
q_nope_out: Optional[torch.Tensor] = None,
enable_pdl: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Apply RoPE to Q/K and append K/V to paged KV cache.

This primitive keeps query outputs in the input dtype and only uses
``kv_scale`` for cache-side casting when the cache dtype is FP8.
"""
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
if k_rope.ndim != 3:
raise ValueError("rope_append_paged_kv_cache only supports GQA/MHA inputs")

nnz = q_rope.shape[0]
num_qo_heads = q_rope.shape[1]
num_kv_heads = k_rope.shape[1]

if q_nope is None:
q_nope = torch.empty(
nnz, num_qo_heads, 0, dtype=q_rope.dtype, device=q_rope.device
)
if k_nope is None:
k_nope = torch.empty(
nnz, num_kv_heads, 0, dtype=k_rope.dtype, device=k_rope.device
)

if q_rope_out is None:
q_rope_out = torch.empty_like(q_rope)
if q_nope_out is None:
q_nope_out = torch.empty_like(q_nope)

if len(paged_kv_cache) != 2:
raise ValueError("paged_kv_cache must be a tuple of (k_cache, v_cache)")
k_cache, v_cache = paged_kv_cache
if k_cache.ndim != 4 or v_cache.ndim != 4:
raise ValueError("rope_append_paged_kv_cache expects 4D GQA/MHA cache tensors")
if k_cache.dtype != v_cache.dtype:
raise ValueError("k_cache and v_cache must have the same dtype")

from .utils import TensorLayout

kv_layout_code = TensorLayout[kv_layout].value
batch_indices = batch_indices.int()
positions = positions.int()
kv_indices = kv_indices.int()
kv_indptr = kv_indptr.int()
kv_last_page_len = kv_last_page_len.int()

Comment on lines +1809 to +1825
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate cache layout and metadata sizes before dispatch.

This path only checks that the caches are 4D and same dtype. A wrong kv_layout/page_size/num_kv_heads combination, or short pos_ids/batch_indices/positions, still reaches the kernel and will misaddress page writes. Fail fast here with explicit axis and length checks.

πŸ›‘οΈ Suggested validation
     if len(paged_kv_cache) != 2:
         raise ValueError("paged_kv_cache must be a tuple of (k_cache, v_cache)")
     k_cache, v_cache = paged_kv_cache
     if k_cache.ndim != 4 or v_cache.ndim != 4:
         raise ValueError("rope_append_paged_kv_cache expects 4D GQA/MHA cache tensors")
     if k_cache.dtype != v_cache.dtype:
         raise ValueError("k_cache and v_cache must have the same dtype")
+    head_dim = q_rope.shape[-1] + q_nope.shape[-1]
+    if kv_layout == "NHD":
+        expected_tail = (page_size, num_kv_heads, head_dim)
+    elif kv_layout == "HND":
+        expected_tail = (num_kv_heads, page_size, head_dim)
+    else:
+        raise ValueError(f"unsupported kv_layout: {kv_layout}")
+    if k_cache.shape[0] != v_cache.shape[0]:
+        raise ValueError("k_cache and v_cache must have the same number of pages")
+    if tuple(k_cache.shape[1:]) != expected_tail or tuple(v_cache.shape[1:]) != expected_tail:
+        raise ValueError(
+            f"cache shape/layout mismatch: expected (*, {expected_tail[0]}, "
+            f"{expected_tail[1]}, {expected_tail[2]}) for kv_layout={kv_layout}"
+        )
+    if pos_ids.numel() != nnz or batch_indices.numel() != nnz or positions.numel() != nnz:
+        raise ValueError("pos_ids, batch_indices, and positions must all have length nnz")
+    if kv_indptr.numel() != kv_last_page_len.numel() + 1:
+        raise ValueError("kv_indptr must have length batch_size + 1")
 
     from .utils import TensorLayout
πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/rope.py` around lines 1809 - 1825, The function
rope_append_paged_kv_cache currently only checks 4D and dtype but must validate
that the cache layout and all metadata lengths match expected dimensions before
calling the kernel: verify kv_layout is a valid TensorLayout key (use
TensorLayout[kv_layout]), ensure paged_kv_cache shapes match the expected axes
for the chosen kv_layout/page_size/num_kv_heads (e.g. confirm num_kv_heads
matches the corresponding axis in k_cache and v_cache and that page_size is
consistent with the paging axis), and check that batch_indices, positions,
kv_indices, kv_indptr and kv_last_page_len have lengths and value ranges
consistent with batch size, sequence length and number of pages (no index
exceeds axis sizes); if any check fails raise a descriptive ValueError. Use the
existing symbols paged_kv_cache, k_cache, v_cache, kv_layout, page_size,
num_kv_heads, batch_indices, positions, kv_indices, kv_indptr and
kv_last_page_len to locate and implement these validations in
rope_append_paged_kv_cache before dispatching to the kernel.

_rope_append_paged_kv_cache(
q_rope,
k_rope,
q_nope,
k_nope,
v,
q_rope_out,
q_nope_out,
cos_sin_cache,
pos_ids,
k_cache,
v_cache,
kv_indices,
kv_indptr,
kv_last_page_len,
batch_indices,
positions,
kv_layout_code,
page_size,
kv_scale,
not is_neox,
enable_pdl,
)

return q_rope_out, q_nope_out
Loading
Loading