feat: Support padding tokens with seqlen=0 for rope+quant+kv cache update fusion kernel#2792
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the KV cache update mechanism to gracefully handle padding tokens, which is crucial for enabling full CUDA graph functionality in systems like vLLM. By explicitly marking and skipping padding tokens during KV cache writes, the changes prevent data corruption and ensure the integrity of the cache while maintaining performance benefits of fixed-size batches. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (3)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds deterministic initialization for batch_indices/positions and kernel-level guards that skip RoPE, quantization, and paged KV cache append work for tokens marked with Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host (CPU)
participant Kernel as RopeQuantizeAppendPagedKVCacheKernel (GPU)
participant KV as Paged KV Cache
Host->>Host: prepare inputs\n(batch_indices, positions)
Host->>Kernel: launch kernel(inputs)
Kernel->>Kernel: idx := thread idx
alt batch_indices[idx] >= 0
Kernel->>Kernel: compute page location\napply RoPE, quantize
Kernel->>KV: append/store K/V/Q to paged cache
else batch_indices[idx] < 0
Kernel->>Kernel: skip all RoPE/quantize/cache ops
end
Kernel-->>Host: kernel completes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds support for padding tokens in the rope+quant+kv cache update fused kernel, which is useful for cudagraphs. The approach involves modifying get_batch_indices_positions_kernel to mark padding tokens and updating RopeQuantizeAppendPagedKVCacheKernel to skip them. A new test case is added to validate this padding logic. While the implementation changes seem correct, I've identified issues in the new test case where token positions are calculated incorrectly. This could cause the test to pass while not properly verifying the intended behavior, potentially masking bugs. I've provided suggestions to correct the test logic.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/attention/test_rope.py (1)
1390-1590: Addenable_pdlcoverage to this new padding regression test.Lines 1392-1589 only exercise the default path. Please parameterize
enable_pdland pass it into the fused call so padding behavior is validated under the programmatic dependent launch mode too.Proposed test update
`@pytest.mark.parametrize`("kv_layout", ["NHD", "HND"]) `@pytest.mark.parametrize`("page_size", [16]) +@pytest.mark.parametrize("enable_pdl", [True, False]) def test_rope_quantize_fp8_append_paged_kv_cache_padding( @@ kv_layout, page_size, + enable_pdl, ): @@ flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( @@ quant_scale_kv=1.0, is_neox=False, + enable_pdl=enable_pdl, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_rope.py` around lines 1390 - 1590, The test test_rope_quantize_fp8_append_paged_kv_cache currently only runs the default path; add a pytest parameterization for enable_pdl (e.g., `@pytest.mark.parametrize`("enable_pdl",[False,True]) and add enable_pdl to the test signature) and pass enable_pdl into the fused call flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache(enable_pdl=enable_pdl, ... ) so the padding behavior is validated under programmatic dependent launch mode as well.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/pos_enc.cuh`:
- Around line 862-865: Replace the early-return on the batch_indices check so
all threads reach the PDL epilogue: remove the "if (batch_indices[idx] < 0)
return;" and instead wrap the work body that follows (the block currently
between lines 867–1030) in "if (batch_indices[idx] >= 0) { ... }". Keep the
final epilogue (including the griddepcontrol.launch_dependents instruction)
outside that guard so it executes unconditionally for every thread in the block;
ensure you reference the same batch_indices and idx variables used in the
kernel.
---
Nitpick comments:
In `@tests/attention/test_rope.py`:
- Around line 1390-1590: The test test_rope_quantize_fp8_append_paged_kv_cache
currently only runs the default path; add a pytest parameterization for
enable_pdl (e.g., `@pytest.mark.parametrize`("enable_pdl",[False,True]) and add
enable_pdl to the test signature) and pass enable_pdl into the fused call
flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache(enable_pdl=enable_pdl,
... ) so the padding behavior is validated under programmatic dependent launch
mode as well.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6761f8e1-64cf-4451-8fe3-a2914f272b95
📒 Files selected for processing (4)
flashinfer/page.pyflashinfer/triton/page.pyinclude/flashinfer/pos_enc.cuhtests/attention/test_rope.py
| // skip padding tokens with batch_indices < 0 | ||
| if (batch_indices[idx] >= 0) { |
There was a problem hiding this comment.
The main change is just this line. The following is just indent formatting.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
include/flashinfer/pos_enc.cuh (1)
862-868: Make padding-sentinel check type-safe acrossPagedKVIdTypeinstantiations.Line 863 uses
batch_indices[idx] >= 0, which is only safe whenPagedKVIdTypeis signed. If it is ever unsigned, the padding sentinel-1becomes the maximum value and this branch incorrectly passes, leading to invalidindptrindexing.🔧 Proposed fix
- // skip padding tokens with batch_indices < 0 - if (batch_indices[idx] >= 0) { + constexpr PagedKVIdType kPaddingSentinel = static_cast<PagedKVIdType>(-1); + const PagedKVIdType batch_idx = batch_indices[idx]; + if (batch_idx != kPaddingSentinel) { // Compute page location for this token uint32_t page_iter, entry_idx; paged_kv_like.page_size.divmod( - paged_kv_like.indptr[batch_indices[idx]] * paged_kv_like.page_size + positions[idx], + paged_kv_like.indptr[batch_idx] * paged_kv_like.page_size + positions[idx], page_iter, entry_idx);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/pos_enc.cuh` around lines 862 - 868, The current padding check uses "batch_indices[idx] >= 0" which breaks for unsigned PagedKVIdType; change it to a type-safe sentinel comparison (e.g., compare against a named padding sentinel value) so you explicitly test for the padding marker instead of signedness. Replace the condition in the block using batch_indices, positions and paged_kv_like (the if around batch_indices[idx] >= 0 that precedes paged_kv_like.indptr[...] and page_size.divmod calls) with a check like "batch_indices[idx] != static_cast<PagedKVIdType>(-1)" or, better, introduce a constexpr PagedKVIdType kPaddingSentinel = static_cast<PagedKVIdType>(-1) and use "batch_indices[idx] != kPaddingSentinel" to ensure correct behavior for both signed and unsigned instantiations.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@include/flashinfer/pos_enc.cuh`:
- Around line 862-868: The current padding check uses "batch_indices[idx] >= 0"
which breaks for unsigned PagedKVIdType; change it to a type-safe sentinel
comparison (e.g., compare against a named padding sentinel value) so you
explicitly test for the padding marker instead of signedness. Replace the
condition in the block using batch_indices, positions and paged_kv_like (the if
around batch_indices[idx] >= 0 that precedes paged_kv_like.indptr[...] and
page_size.divmod calls) with a check like "batch_indices[idx] !=
static_cast<PagedKVIdType>(-1)" or, better, introduce a constexpr PagedKVIdType
kPaddingSentinel = static_cast<PagedKVIdType>(-1) and use "batch_indices[idx] !=
kPaddingSentinel" to ensure correct behavior for both signed and unsigned
instantiations.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a25b44d8-52fc-4514-8eeb-44210f1c0877
📒 Files selected for processing (1)
include/flashinfer/pos_enc.cuh
f0f6c80 to
63197ac
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
include/flashinfer/pos_enc.cuh (1)
862-863: Consider adding a compile-time assertion to document the signed type requirement for the padding check.The code assumes
PagedKVIdTypeis signed to make the checkbatch_indices[idx] >= 0work correctly for detecting padding tokens (marked with -1). While the type is hardcoded toint32_tin the JIT binding atcsrc/rope.cu(lines 574, 605) and in Python viatorch.int(), a static assertion would document this requirement and guard against accidental misuse:static_assert(std::is_signed<PagedKVIdType>::value, "PagedKVIdType must be signed for padding token detection (batch_indices < 0)");This defensive check aligns with the kernel's assumptions and improves long-term maintainability without adding runtime overhead.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/pos_enc.cuh` around lines 862 - 863, Add a compile-time assertion that PagedKVIdType is a signed type to document and enforce the kernel's assumption used by the padding check `batch_indices[idx] >= 0`; insert a static_assert using `std::is_signed<PagedKVIdType>::value` (e.g., near the typedef/using of PagedKVIdType or at the top of the kernel in pos_enc.cuh before the `batch_indices` usage) with a clear message like "PagedKVIdType must be signed for padding token detection (batch_indices < 0)"; this is purely compile-time and has no runtime overhead but prevents accidental unsigned types from being used.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@include/flashinfer/pos_enc.cuh`:
- Around line 862-863: Add a compile-time assertion that PagedKVIdType is a
signed type to document and enforce the kernel's assumption used by the padding
check `batch_indices[idx] >= 0`; insert a static_assert using
`std::is_signed<PagedKVIdType>::value` (e.g., near the typedef/using of
PagedKVIdType or at the top of the kernel in pos_enc.cuh before the
`batch_indices` usage) with a clear message like "PagedKVIdType must be signed
for padding token detection (batch_indices < 0)"; this is purely compile-time
and has no runtime overhead but prevents accidental unsigned types from being
used.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: de20e129-16a7-4945-80ef-553ab0f8df70
📒 Files selected for processing (3)
flashinfer/page.pyinclude/flashinfer/pos_enc.cuhtests/attention/test_rope.py
🚧 Files skipped from review as they are similar to previous changes (2)
- flashinfer/page.py
- tests/attention/test_rope.py
|
Hi @yzh119, could you help review this? We need this fix for integrating this kernel to vLLM. Thanks! |
|
cc @kahyunnam for viz. |
|
/bot run |
|
[SUCCESS] Pipeline #46584451: 14/20 passed |
|
/bot run |
|
[FAILED] Pipeline #46776615: 12/20 passed |
63197ac to
832ac30
Compare
📌 Description
vLLM is using seqlen=0 padding tokens for running a full cudagraph: https://github.com/vllm-project/vllm/blob/95c0f928cdeeaa21c4906e73cee6a156e1b3b995/vllm/v1/worker/gpu/model_runner.py#L652-L654
Update the following functions:
get_batch_indices_positions_kernel: initializebatch_indices/positionsto-1/0for recognizing the padding tokensrope_quantize_fp8_append_paged_kv_cache: skip those padding tokensTesting:
pytest -v -s tests/attention/test_rope.py::test_rope_quantize_fp8_append_paged_kv_cache_padding🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Tests