feat: Expose TRT-LLM FMHA style paged KV Cache and page table layout#2770
feat: Expose TRT-LLM FMHA style paged KV Cache and page table layout#2770saltyminty merged 2 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
✅ Files skipped from review due to trivial changes (1)
📝 WalkthroughWalkthroughThis PR adds a Changes
Sequence Diagram(s)sequenceDiagram
participant Caller as Python API
participant Module as TrtllmGenModule
participant Op as Torch Op / Fake Op
participant Launcher as C++ Launcher
participant Kernel as GPU Kernel
Caller->>Module: call trtllm_*_with_kv_cache(uses_shared_paged_pkv_idx)
Module->>Op: _paged_run / paged_run (passes flag)
Op->>Launcher: trtllm_paged_attention_* (uses_shared_paged_kv_idx)
Launcher->>Kernel: launch kernel with runner params (mUsesSharedPagedKvIdx)
Kernel-->>Launcher: compute attention with chosen K/V layout
Launcher-->>Op: return results
Op-->>Module: return tensor(s)
Module-->>Caller: return output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can scan for known vulnerabilities in your dependencies using OSV Scanner.OSV Scanner will automatically detect and report security vulnerabilities in your project's dependencies. No additional configuration is required. |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the flexibility of FlashInfer by exposing the TRT-LLM paged KV Cache and page table layout. This change allows users to explicitly choose between the FlashInfer/vLLM-style shared page indices and the TRT-LLM-style separate page indices for K and V caches. The integration involved updating kernel interfaces, Python APIs, and comprehensive testing to ensure compatibility and correct behavior across different configurations, facilitating easier integration with systems like TRT-LLM. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request successfully exposes TensorRT-LLM's paged KV cache layout by introducing a uses_shared_paged_kv_idx option. The changes are consistently implemented across the C++ backend, Python API, and tests. The new functionality is also well-documented. I have a couple of minor suggestions to remove commented-out code to improve maintainability.
81bf4a1 to
3095b2b
Compare
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (1)
flashinfer/prefill.py (1)
3631-3632: Add the backend requirement decorator on this public TRT-LLM API.This entrypoint is backend/SM-gated, but the expanded public surface is still only decorated with
@flashinfer_api. Please expose the capability guard consistently here as repository policy requires.As per coding guidelines, "Use
@backend_requirementdecorator on APIs that have compute capability requirements and provideis_compute_capability_supported(cc)andis_backend_supported()methods".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 3631 - 3632, The public TRT-LLM entrypoint trtllm_batch_context_with_kv_cache is missing the backend requirement guard; add the `@backend_requirement` decorator above the function (in addition to the existing `@flashinfer_api`) and implement its predicate using the module/class methods is_compute_capability_supported(cc) and is_backend_supported() so the API is gated by both compute capability and backend support as per policy; ensure the decorator references those methods and preserve the existing function signature and behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/decode.py`:
- Line 1393: BatchDecodeWithPagedKVCacheWrapper.run() currently hardcodes the
shared-page flag to True which prevents using the TRT-LLM 3-D page-table layout;
change the call so the boolean is computed from the plan's block_tables shape
instead of hardcoded. Specifically, after calling plan() (or wherever
block_tables is produced), detect whether block_tables are 3-D (e.g., check
ndim/shape or a plan-provided attribute) and set a local
uses_shared_paged_kv_idx boolean accordingly, then pass that variable into
BatchDecodeWithPagedKVCacheWrapper.run() instead of the literal True so 3-D
layouts dispatch with the correct flag.
- Line 2160: The new parameter uses_shared_paged_kv_idx is only valid for the
trtllm-gen path but currently can be passed through to XQA backends; add the
same guard used in flashinfer/mla.py before dispatch to validate backend and GPU
capability and raise a fast-fail (clear exception) if uses_shared_paged_kv_idx
is True while backend is "xqa" or backend=="auto" on XQA-capable GPUs, and
update the function/method docstring near the uses_shared_paged_kv_idx parameter
to state explicitly that False/True is trtllm-gen-only (i.e., that True applies
only to trtllm-gen).
In `@flashinfer/mla.py`:
- Around line 176-177: The code currently reads page_table.shape[0] and
shape[-1] without validating rank/layout; add explicit checks in the function
handling page_table to reject invalid ranks and layouts before dispatch: if
uses_shared_paged_kv_idx is True require page_table.ndim == 2 and otherwise for
separate-layout require page_table.ndim == 3 and page_table.shape[1] == 2 (or
the expected second-dimension value); raise a clear exception when the checks
fail. Keep the existing uses of B_block_table = page_table.shape[0] and
block_num = page_table.shape[-1] but only after these validations so XQA/TRT-LLM
cannot misinterpret the tensor.
In `@flashinfer/prefill.py`:
- Around line 2331-2333: BatchPrefillWithPagedKVCacheWrapper.run() is
incorrectly hardcoding the uses_shared_paged_kv_idx argument as True when
calling the wrapped prefill (causing TRT-LLM 3-D block_tables callers to take
the shared-index path); change the call to pass the actual flag instead of True
by reading/passing the wrapper's uses_shared_paged_kv_idx property or inspecting
the planned/cache metadata (e.g., block_tables shape or an existing attribute)
and forward that boolean into the call (replace the literal True with the
appropriate variable) so the correct paged-index path is selected; ensure the
symbols sinks and skip_softmax_threshold_scale_factor remain unchanged.
- Around line 3654-3655: Validate block_tables shape and page id ranges against
uses_shared_paged_kv_idx before invoking the CUDA op
(trtllm_paged_attention_context): ensure that when uses_shared_paged_kv_idx is
False block_tables is 2-D (slices x seq_len) and when True it is 3-D (slices x
num_pages x seq_len); also validate that every page id in block_tables is within
the allowed range (e.g., 0 <= id < num_pages or other kernel-expected bounds)
and raise a clear ValueError if violated. Apply the same guard logic at the
other callsites mentioned (around the blocks that call
trtllm_paged_attention_context at the ranges you noted) so the public contract
for uses_shared_paged_kv_idx is enforced before any CUDA call.
---
Nitpick comments:
In `@flashinfer/prefill.py`:
- Around line 3631-3632: The public TRT-LLM entrypoint
trtllm_batch_context_with_kv_cache is missing the backend requirement guard; add
the `@backend_requirement` decorator above the function (in addition to the
existing `@flashinfer_api`) and implement its predicate using the module/class
methods is_compute_capability_supported(cc) and is_backend_supported() so the
API is gated by both compute capability and backend support as per policy;
ensure the decorator references those methods and preserve the existing function
signature and behavior.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: aadf6006-0f39-4290-b3e8-f280539a079a
📒 Files selected for processing (9)
csrc/trtllm_fmha_kernel_launcher.cuflashinfer/artifacts.pyflashinfer/decode.pyflashinfer/mla.pyflashinfer/prefill.pyinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.hinclude/flashinfer/trtllm/fmha/kernelParams.htests/attention/test_trtllm_gen_attention.pytests/attention/test_trtllm_gen_mla.py
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/prefill.py (1)
2324-2335:⚠️ Potential issue | 🟠 MajorForward the actual page-index layout instead of hardcoding shared mode.
Line 2334 hardcodes
uses_shared_paged_kv_idx=True, so a planned 3Dblock_tablespath can still launch as shared layout.Suggested fix
run_args += [ @@ self._qo_indptr_buf, self._paged_kv_indptr_buf, sinks, skip_softmax_threshold_scale_factor, - True, # uses_shared_paged_kv_idx + ( + self._block_tables is None + or self._block_tables.ndim == 2 + ), # uses_shared_paged_kv_idx ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 2324 - 2335, The call currently passes a hardcoded True for uses_shared_paged_kv_idx, which forces shared page-index layout even when self._block_tables may represent a 3D/non-shared layout; compute or obtain the correct flag (e.g., a local variable uses_shared_paged_kv_idx derived from self._block_tables or from the existing layout/config helper) and pass that variable instead of True so the actual page-index layout is forwarded to the callee.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/prefill.py`:
- Around line 2324-2335: The call currently passes a hardcoded True for
uses_shared_paged_kv_idx, which forces shared page-index layout even when
self._block_tables may represent a 3D/non-shared layout; compute or obtain the
correct flag (e.g., a local variable uses_shared_paged_kv_idx derived from
self._block_tables or from the existing layout/config helper) and pass that
variable instead of True so the actual page-index layout is forwarded to the
callee.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 56d1955d-6e3a-47bc-89da-8b3a46128dc1
📒 Files selected for processing (4)
flashinfer/decode.pyflashinfer/mla.pyflashinfer/prefill.pyflashinfer/utils.py
3a6099f to
7be2733
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/trtllm_fmha_kernel_launcher.cu (1)
241-277:⚠️ Potential issue | 🟠 MajorValidate
block_tablesagainstuses_shared_paged_kv_idxat the FFI boundary.These exported entry points now accept both 2-D and 3-D page tables, but they never check that
block_tablesactually matches the selected mode before handing a rawint*to the runner. A direct TVM/FFI caller can therefore launch the kernel with the wrong interpretation and get wrong outputs or invalid reads.🛡️ Proposed fix
bool const uses_shared_paged_kv_idx_value = uses_shared_paged_kv_idx.value_or(true); + int const expected_block_tables_ndim = + uses_shared_paged_kv_idx_value ? 2 : 3; + TVM_FFI_ICHECK_EQ(block_tables.ndim(), expected_block_tables_ndim) + << "block_tables must be " << expected_block_tables_ndim + << "D when uses_shared_paged_kv_idx=" + << (uses_shared_paged_kv_idx_value ? "true" : "false"); + if (!uses_shared_paged_kv_idx_value) { + TVM_FFI_ICHECK_EQ(block_tables.size(1), 2) + << "block_tables.shape[1] must be 2 when uses_shared_paged_kv_idx=false"; + }Apply the same check in both
trtllm_paged_attention_decode()andtrtllm_paged_attention_context().Also applies to: 360-385
♻️ Duplicate comments (1)
flashinfer/decode.py (1)
1425-1433:⚠️ Potential issue | 🟠 MajorStop forcing shared page indices in the wrapper path.
plan()can preserve a caller-suppliedself._block_tables, butrun()still always appendsuses_shared_paged_kv_idx=True. A planned[batch, 2, max_pages]table will therefore be launched in shared mode instead of being handled or rejected.🔧 Proposed fix
+ uses_shared_paged_kv_idx = ( + self._block_tables is None or self._block_tables.ndim == 2 + ) run_args += [ None, # packed_custom_mask None, # mask_indptr_buf _get_cache_alibi_slopes_buf(q.shape[1], q.device), None, # maybe_prefix_len_ptr @@ sinks, key_block_scales, value_block_scales, skip_softmax_threshold_scale_factor, - True, # uses_shared_paged_kv_idx + uses_shared_paged_kv_idx, ]If wrapper support for 3-D tables is not intended yet, fail fast in
plan()/run()instead of silently dispatching the wrong mode.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/decode.py` around lines 1425 - 1433, The wrapper currently forces uses_shared_paged_kv_idx=True when calling the kernel (the call site passing True in run()), which mismatches a caller-supplied self._block_tables (e.g., 3-D [batch,2,max_pages]); instead either stop forcing that flag or fail fast: modify the run()/plan() flow so you do not hardcode True—derive uses_shared_paged_kv_idx from the existing metadata (or a new boolean on the object) and pass that through, and add a validation in plan() (and/or run()) that inspects self._block_tables shape and raises an error if a 3-D table is provided when shared-page support is not implemented; update any call sites that assumed the hardcoded True to use the new derived flag.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/decode.py`:
- Around line 2231-2235: Update the docstring for block_tables to explicitly
document the TRT-LLM KV-cache 3-D layout and note that when
uses_shared_paged_kv_idx is False callers/tests must reshape/interleave the
legacy kv_cache and kv_block_scales into the [batch_size, 2,
max_num_pages_per_seq] layout (dim 1 = K/V) before calling this API; ensure the
same explanatory note is added to the other docstring locations referencing
kv_cache/kv_block_scales (the second occurrence that currently describes the
legacy layout).
---
Duplicate comments:
In `@flashinfer/decode.py`:
- Around line 1425-1433: The wrapper currently forces
uses_shared_paged_kv_idx=True when calling the kernel (the call site passing
True in run()), which mismatches a caller-supplied self._block_tables (e.g., 3-D
[batch,2,max_pages]); instead either stop forcing that flag or fail fast: modify
the run()/plan() flow so you do not hardcode True—derive
uses_shared_paged_kv_idx from the existing metadata (or a new boolean on the
object) and pass that through, and add a validation in plan() (and/or run())
that inspects self._block_tables shape and raises an error if a 3-D table is
provided when shared-page support is not implemented; update any call sites that
assumed the hardcoded True to use the new derived flag.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: fcbd4e57-f7be-4bf8-8fae-5d290ec71de8
📒 Files selected for processing (10)
csrc/trtllm_fmha_kernel_launcher.cuflashinfer/artifacts.pyflashinfer/decode.pyflashinfer/mla.pyflashinfer/prefill.pyflashinfer/utils.pyinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.hinclude/flashinfer/trtllm/fmha/kernelParams.htests/attention/test_trtllm_gen_attention.pytests/attention/test_trtllm_gen_mla.py
✅ Files skipped from review due to trivial changes (1)
- include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
🚧 Files skipped from review as they are similar to previous changes (5)
- include/flashinfer/trtllm/fmha/kernelParams.h
- tests/attention/test_trtllm_gen_mla.py
- flashinfer/artifacts.py
- flashinfer/mla.py
- flashinfer/prefill.py
|
/bot run |
|
[SUCCESS] Pipeline #46623817: 13/20 passed |
bb144a0 to
2aeba4c
Compare
|
/bot run |
|
[SUCCESS] Pipeline #46644158: 13/20 passed |
|
CI looks good. |
📌 Description
We received a request to expose TRT-LLM's paged KV Cache layout and page table style, in order to ease the process of integrating FlashInfer into TRT-LLM.
This pull request exposes this feature via adding
uses_shared_paged_kv_idxas an option to the TRT-LLM Gen FMHA kernels, updates docstrings, and adds relevant tests.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes / Validation
Tests