Add fused RoPE + paged KV cache append op for the non-quant path #24678#2785
Add fused RoPE + paged KV cache append op for the non-quant path #24678#2785baonudesifeizhai wants to merge 6 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds a fused operation rope_append_paged_kv_cache across CUDA headers/kernels, host dispatch, FFI bindings, Python APIs, and tests to apply RoPE to Q/K and append K/V into a paged KV cache (supports FP16/BF16/FP8, layouts, MLA/GQA/MHA variants). Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python API
participant FFI as FFI Binding
participant Host as CUDA Host/Dispatch
participant Kernel as CUDA Kernel
participant KV as Paged KV Cache (device)
Py->>FFI: call rope_append_paged_kv_cache(args)
FFI->>Host: forward TensorViews & params
Host->>Host: validate shapes/dtypes, resolve layout & templates
Host->>Kernel: launch RopeAppendPagedKVCacheKernel(stream, params...)
Kernel->>KV: read existing pages, write appended K/V pages
Kernel-->>Host: kernel complete
Host-->>FFI: return q_rope_out, q_nope_out
FFI-->>Py: return outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a crucial optimization for large language model serving by adding a fused CUDA operation for applying Rotary Positional Embeddings (RoPE) and appending to a paged Key-Value (KV) cache. This new functionality streamlines the decode integration path for non-quantized inputs, enhancing performance and completing the set of necessary fused operations for efficient inference, while retaining compatibility with FP8 KV caches. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new fused RoPE and paged KV cache append operation for the non-quantized path, which is a valuable addition for serving decode integration. The implementation follows the existing patterns in the codebase. However, I've identified critical out-of-bounds memory access issues in the new CUDA kernel RopeAppendPagedKVCacheKernel. The processing for k_nope, v, and q_nope does not correctly handle partial vector reads/writes at chunk boundaries, which can lead to memory corruption or crashes. Please address these issues to ensure kernel correctness and stability.
| vec_t<float, vec_size> k_nope_vec; | ||
| k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size); | ||
| if constexpr (kNeedsScale) { | ||
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| k_nope_vec[i] = k_nope_vec[i] * kv_scale; | ||
| } | ||
| } | ||
|
|
||
| CacheType* k_ptr = paged_kv.get_k_ptr(page_iter, k_head_idx, entry_idx, | ||
| rope_dim + elem_offset + tx * vec_size); | ||
| k_nope_vec.cast_store(k_ptr); |
| uint32_t v_elem_offset = j * rope_chunk_size; | ||
| if (v_elem_offset + tx * vec_size < head_dim_total) { | ||
| vec_t<float, vec_size> v_vec; | ||
| v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size); |
There was a problem hiding this comment.
The vectorized load for v_in can read out of bounds. The check at line 1205 (if (v_elem_offset + tx * vec_size < head_dim_total)) is not sufficient to prevent a vectorized read from crossing the boundary of head_dim_total. This needs to be a masked load for threads processing the boundary to avoid memory corruption.
| vec_t<float, vec_size> q_nope_vec; | ||
| q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size); | ||
| q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size); |
There was a problem hiding this comment.
The vectorized load and store for q_nope can lead to out-of-bounds memory access if no_rope_dim is not a multiple of rope_chunk_size (rope_dim). This happens when processing the last, partial chunk, as the vectorized operations do not account for boundary conditions. This can cause memory corruption.
You should add boundary checks for both the load and store operations. A similar pattern can be found in RopeQuantizeKernel which uses scale_store_partial_chunk for this purpose. The k_nope and v processing sections in this kernel have a similar issue.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/rope.py`:
- Around line 1809-1825: The function rope_append_paged_kv_cache currently only
checks 4D and dtype but must validate that the cache layout and all metadata
lengths match expected dimensions before calling the kernel: verify kv_layout is
a valid TensorLayout key (use TensorLayout[kv_layout]), ensure paged_kv_cache
shapes match the expected axes for the chosen kv_layout/page_size/num_kv_heads
(e.g. confirm num_kv_heads matches the corresponding axis in k_cache and v_cache
and that page_size is consistent with the paging axis), and check that
batch_indices, positions, kv_indices, kv_indptr and kv_last_page_len have
lengths and value ranges consistent with batch size, sequence length and number
of pages (no index exceeds axis sizes); if any check fails raise a descriptive
ValueError. Use the existing symbols paged_kv_cache, k_cache, v_cache,
kv_layout, page_size, num_kv_heads, batch_indices, positions, kv_indices,
kv_indptr and kv_last_page_len to locate and implement these validations in
rope_append_paged_kv_cache before dispatching to the kernel.
In `@include/flashinfer/pos_enc.cuh`:
- Around line 1175-1194: The no-RoPE branch loads/stores full vec_size
unconditionally (k_nope_vec.cast_load / cast_store) which overruns when rope_dim
< bdx; guard the load/store by computing the remaining elements for this chunk
(using rope_dim, elem_offset, tx and vec_size) and if the chunk is partial call
the existing scale_store_partial_chunk(...) helper (as RopeQuantizeKernel does)
to safely load/scale/store only the valid lanes, otherwise perform the full
cast_load/cast_store path; update the k_nope_in pointer use (k_nope_in +
get_elem_offset_impl(...)) and the destination from paged_kv.get_k_ptr(...)
accordingly so partial writes never exceed the slice.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2a5d426d-063d-4a6a-b144-2e5ff38bc61c
📒 Files selected for processing (5)
csrc/flashinfer_rope_binding.cucsrc/rope.cuflashinfer/__init__.pyflashinfer/rope.pyinclude/flashinfer/pos_enc.cuh
| if len(paged_kv_cache) != 2: | ||
| raise ValueError("paged_kv_cache must be a tuple of (k_cache, v_cache)") | ||
| k_cache, v_cache = paged_kv_cache | ||
| if k_cache.ndim != 4 or v_cache.ndim != 4: | ||
| raise ValueError("rope_append_paged_kv_cache expects 4D GQA/MHA cache tensors") | ||
| if k_cache.dtype != v_cache.dtype: | ||
| raise ValueError("k_cache and v_cache must have the same dtype") | ||
|
|
||
| from .utils import TensorLayout | ||
|
|
||
| kv_layout_code = TensorLayout[kv_layout].value | ||
| batch_indices = batch_indices.int() | ||
| positions = positions.int() | ||
| kv_indices = kv_indices.int() | ||
| kv_indptr = kv_indptr.int() | ||
| kv_last_page_len = kv_last_page_len.int() | ||
|
|
There was a problem hiding this comment.
Validate cache layout and metadata sizes before dispatch.
This path only checks that the caches are 4D and same dtype. A wrong kv_layout/page_size/num_kv_heads combination, or short pos_ids/batch_indices/positions, still reaches the kernel and will misaddress page writes. Fail fast here with explicit axis and length checks.
🛡️ Suggested validation
if len(paged_kv_cache) != 2:
raise ValueError("paged_kv_cache must be a tuple of (k_cache, v_cache)")
k_cache, v_cache = paged_kv_cache
if k_cache.ndim != 4 or v_cache.ndim != 4:
raise ValueError("rope_append_paged_kv_cache expects 4D GQA/MHA cache tensors")
if k_cache.dtype != v_cache.dtype:
raise ValueError("k_cache and v_cache must have the same dtype")
+ head_dim = q_rope.shape[-1] + q_nope.shape[-1]
+ if kv_layout == "NHD":
+ expected_tail = (page_size, num_kv_heads, head_dim)
+ elif kv_layout == "HND":
+ expected_tail = (num_kv_heads, page_size, head_dim)
+ else:
+ raise ValueError(f"unsupported kv_layout: {kv_layout}")
+ if k_cache.shape[0] != v_cache.shape[0]:
+ raise ValueError("k_cache and v_cache must have the same number of pages")
+ if tuple(k_cache.shape[1:]) != expected_tail or tuple(v_cache.shape[1:]) != expected_tail:
+ raise ValueError(
+ f"cache shape/layout mismatch: expected (*, {expected_tail[0]}, "
+ f"{expected_tail[1]}, {expected_tail[2]}) for kv_layout={kv_layout}"
+ )
+ if pos_ids.numel() != nnz or batch_indices.numel() != nnz or positions.numel() != nnz:
+ raise ValueError("pos_ids, batch_indices, and positions must all have length nnz")
+ if kv_indptr.numel() != kv_last_page_len.numel() + 1:
+ raise ValueError("kv_indptr must have length batch_size + 1")
from .utils import TensorLayout🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/rope.py` around lines 1809 - 1825, The function
rope_append_paged_kv_cache currently only checks 4D and dtype but must validate
that the cache layout and all metadata lengths match expected dimensions before
calling the kernel: verify kv_layout is a valid TensorLayout key (use
TensorLayout[kv_layout]), ensure paged_kv_cache shapes match the expected axes
for the chosen kv_layout/page_size/num_kv_heads (e.g. confirm num_kv_heads
matches the corresponding axis in k_cache and v_cache and that page_size is
consistent with the paging axis), and check that batch_indices, positions,
kv_indices, kv_indptr and kv_last_page_len have lengths and value ranges
consistent with batch size, sequence length and number of pages (no index
exceeds axis sizes); if any check fails raise a descriptive ValueError. Use the
existing symbols paged_kv_cache, k_cache, v_cache, kv_layout, page_size,
num_kv_heads, batch_indices, positions, kv_indices, kv_indptr and
kv_last_page_len to locate and implement these validations in
rope_append_paged_kv_cache before dispatching to the kernel.
| } else if (by < k_nope_end) { | ||
| uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks; | ||
| uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks; | ||
| uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; | ||
|
|
||
| DType* k_nope_in_ptr = k_nope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset, | ||
| k_nope_in_stride, k_nope_in_stride_h); | ||
|
|
||
| vec_t<float, vec_size> k_nope_vec; | ||
| k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size); | ||
| if constexpr (kNeedsScale) { | ||
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| k_nope_vec[i] = k_nope_vec[i] * kv_scale; | ||
| } | ||
| } | ||
|
|
||
| CacheType* k_ptr = paged_kv.get_k_ptr(page_iter, k_head_idx, entry_idx, | ||
| rope_dim + elem_offset + tx * vec_size); | ||
| k_nope_vec.cast_store(k_ptr); |
There was a problem hiding this comment.
Guard partial no-RoPE chunks before vector load/store.
bdx is sized from rope_dim, but these K/Q no-RoPE branches always load/store a full vec_size. With common splits like rope_dim=128, no_rope_dim=64, the extra threads in the no-RoPE section read and write past the end of the slice. Reuse scale_store_partial_chunk(...) here the same way RopeQuantizeKernel already handles no-RoPE tails.
🛠️ Tail-safe fix
- vec_t<float, vec_size> k_nope_vec;
- k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size);
- if constexpr (kNeedsScale) {
-#pragma unroll
- for (uint32_t i = 0; i < vec_size; ++i) {
- k_nope_vec[i] = k_nope_vec[i] * kv_scale;
- }
- }
-
- CacheType* k_ptr = paged_kv.get_k_ptr(page_iter, k_head_idx, entry_idx,
- rope_dim + elem_offset + tx * vec_size);
- k_nope_vec.cast_store(k_ptr);
+ const uint32_t chunk_valid =
+ (elem_offset < no_rope_dim) ? min(rope_chunk_size, no_rope_dim - elem_offset) : 0u;
+ const uint32_t lane_elem_offset = tx * vec_size;
+ CacheType* k_ptr =
+ paged_kv.get_k_ptr(page_iter, k_head_idx, entry_idx, rope_dim + elem_offset);
+ scale_store_partial_chunk<DType, CacheType, vec_size>(
+ k_nope_in_ptr, k_ptr, lane_elem_offset, chunk_valid, kNeedsScale ? kv_scale : 1.f);
@@
- vec_t<float, vec_size> q_nope_vec;
- q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
- q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
+ const uint32_t chunk_valid =
+ (elem_offset < no_rope_dim) ? min(rope_chunk_size, no_rope_dim - elem_offset) : 0u;
+ const uint32_t lane_elem_offset = tx * vec_size;
+ scale_store_partial_chunk<DType, DType, vec_size>(
+ q_nope_in_ptr, q_nope_out_ptr, lane_elem_offset, chunk_valid, 1.f);Also applies to: 1220-1235
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@include/flashinfer/pos_enc.cuh` around lines 1175 - 1194, The no-RoPE branch
loads/stores full vec_size unconditionally (k_nope_vec.cast_load / cast_store)
which overruns when rope_dim < bdx; guard the load/store by computing the
remaining elements for this chunk (using rope_dim, elem_offset, tx and vec_size)
and if the chunk is partial call the existing scale_store_partial_chunk(...)
helper (as RopeQuantizeKernel does) to safely load/scale/store only the valid
lanes, otherwise perform the full cast_load/cast_store path; update the
k_nope_in pointer use (k_nope_in + get_elem_offset_impl(...)) and the
destination from paged_kv.get_k_ptr(...) accordingly so partial writes never
exceed the slice.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/attention/test_rope.py (2)
1546-1565: Exercise a non-unitkv_scalein the FP8 branch.Both append calls hardcode
kv_scale=1.0, souse_fp8_cache=Trueonly validates plain cast-to-FP8. That misses the cache-side scaling behavior this API is adding.💡 Minimal coverage change
kv_cache_dtype = torch.float8_e4m3fn if use_fp8_cache else input_dtype + kv_scale = 0.5 if use_fp8_cache else 1.0 ... flashinfer.rope.rope_append_paged_kv_cache( q_rope_existing, k_rope_existing, @@ - kv_scale=1.0, + kv_scale=kv_scale, is_neox=False, enable_pdl=enable_pdl, ) @@ q_rope_out_new, q_nope_out_new = flashinfer.rope.rope_append_paged_kv_cache( q_rope_new, k_rope_new, @@ - kv_scale=1.0, + kv_scale=kv_scale, is_neox=False, enable_pdl=enable_pdl, ) @@ - k_ref_tokens_new = k_ref_new.to(kv_cache_dtype) - v_ref_tokens_new = v_new.to(kv_cache_dtype) + k_ref_tokens_new = ( + (k_ref_new * kv_scale).to(kv_cache_dtype) + if use_fp8_cache + else k_ref_new.to(kv_cache_dtype) + ) + v_ref_tokens_new = ( + (v_new * kv_scale).to(kv_cache_dtype) + if use_fp8_cache + else v_new.to(kv_cache_dtype) + )Also applies to: 1631-1650, 1678-1686
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_rope.py` around lines 1546 - 1565, The test calls to flashinfer.rope.rope_append_paged_kv_cache always pass kv_scale=1.0, so the FP8 branch never exercises non-trivial cache-side scaling; update the test calls (the instances invoking rope_append_paged_kv_cache with parameters like q_rope_existing, k_rope_existing, v_existing, rope_ref.cos_sin_cache, pos_ids_existing, (k_cache, v_cache), kv_page_indices_existing, kv_page_indptr_existing, kv_last_page_len_existing, batch_indices_existing, positions_existing, page_size, kv_layout, is_neox, enable_pdl) to include a non-unit kv_scale (e.g. 0.5 or 2.0) whenever use_fp8_cache / FP8 path is being tested so the code path that applies cache-side scaling is executed and validated; apply the same change for the other occurrences of the same call pattern mentioned in the comment.
1546-1565: Assert the bootstrap append as well.The first
rope_append_paged_kv_cachecall seeds the cache, but the test never checks its returned Q tensors or the written K/V values. The later “unchanged” assertions only prove the second append did not overwrite whatever step 1 produced, so an empty-cache append bug would still pass here.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_rope.py` around lines 1546 - 1565, The first rope_append_paged_kv_cache call is used to seed the cache but its outputs and the K/V writes are never asserted; capture its returned Q (and any other returned tensors) from the initial call to flashinfer.rope.rope_append_paged_kv_cache (the call that passes q_rope_existing, k_rope_existing, q_nope_existing, k_nope_existing, v_existing, rope_ref.cos_sin_cache, pos_ids_existing, (k_cache, v_cache), kv_page_indices_existing, kv_page_indptr_existing, kv_last_page_len_existing, batch_indices_existing, positions_existing, page_size, kv_layout, kv_scale=1.0, is_neox=False, enable_pdl) and add explicit assertions that those returned Q tensors equal the expected Q (e.g., copies of q_rope_existing transformed as in the test) and that the underlying k_cache and v_cache (or the persisted pages referenced by kv_page_indices_existing/kv_page_indptr_existing) contain the expected K and V values written by the first call, so the subsequent “unchanged” checks prove the second append didn’t just preserve an empty initial state.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/attention/test_rope.py`:
- Around line 1393-1420: The test test_rope_append_paged_kv_cache_decode
currently allocates CUDA FP8 caches when use_fp8_cache is True without skipping
unsupported GPUs; update the top of that test (inside the function) to check GPU
capability using flashinfer.utils.get_compute_capability() together with
flashinfer.utils.is_sm90a_supported() and
flashinfer.utils.is_sm100a_supported(), and call pytest.skip(...) when
use_fp8_cache is True and the compute capability is not SM90A/SM100A (or
otherwise unsupported), so the FP8 allocation is only attempted on supported
architectures; ensure the check runs before any FP8-specific dtype or tensor
allocations.
---
Nitpick comments:
In `@tests/attention/test_rope.py`:
- Around line 1546-1565: The test calls to
flashinfer.rope.rope_append_paged_kv_cache always pass kv_scale=1.0, so the FP8
branch never exercises non-trivial cache-side scaling; update the test calls
(the instances invoking rope_append_paged_kv_cache with parameters like
q_rope_existing, k_rope_existing, v_existing, rope_ref.cos_sin_cache,
pos_ids_existing, (k_cache, v_cache), kv_page_indices_existing,
kv_page_indptr_existing, kv_last_page_len_existing, batch_indices_existing,
positions_existing, page_size, kv_layout, is_neox, enable_pdl) to include a
non-unit kv_scale (e.g. 0.5 or 2.0) whenever use_fp8_cache / FP8 path is being
tested so the code path that applies cache-side scaling is executed and
validated; apply the same change for the other occurrences of the same call
pattern mentioned in the comment.
- Around line 1546-1565: The first rope_append_paged_kv_cache call is used to
seed the cache but its outputs and the K/V writes are never asserted; capture
its returned Q (and any other returned tensors) from the initial call to
flashinfer.rope.rope_append_paged_kv_cache (the call that passes
q_rope_existing, k_rope_existing, q_nope_existing, k_nope_existing, v_existing,
rope_ref.cos_sin_cache, pos_ids_existing, (k_cache, v_cache),
kv_page_indices_existing, kv_page_indptr_existing, kv_last_page_len_existing,
batch_indices_existing, positions_existing, page_size, kv_layout, kv_scale=1.0,
is_neox=False, enable_pdl) and add explicit assertions that those returned Q
tensors equal the expected Q (e.g., copies of q_rope_existing transformed as in
the test) and that the underlying k_cache and v_cache (or the persisted pages
referenced by kv_page_indices_existing/kv_page_indptr_existing) contain the
expected K and V values written by the first call, so the subsequent “unchanged”
checks prove the second append didn’t just preserve an empty initial state.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3d903e76-4ce3-4054-8575-bb92e6aa73f3
📒 Files selected for processing (1)
tests/attention/test_rope.py
|
/bot run |
|
[CANCELED] Pipeline #46402900: canceled |
|
/bot run |
|
[SUCCESS] Pipeline #46541900: 14/20 passed |
📌 Description
🔍 Related Issues
#24678
vllm-project/vllm#37041
Add
rope_append_paged_kv_cache, a fused non-quant RoPE + paged KV cache append op.This fills the missing fused
rope + cachepath for serving decode integration, and can still write to FP8 KV cache throughkv_scale.Motivation
#24678
We were missing the fused non-quant
rope + cacheop needed by the vLLM integration path.Validation
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
pytest -q tests/attention/test_rope.py -k "test_generalized_rope_quantize_append_kv_cache or test_rope_quantize_fp8_append_paged_kv_cache_decode" passed 6528 passed, 23722 deselected in 359.06s (0:05:59)
unittest, etc.).Reviewer Notes
Summary by CodeRabbit