Skip to content

feat: Expose TRT-LLM FMHA style paged KV Cache and page table layout#2770

Merged
saltyminty merged 2 commits intoflashinfer-ai:mainfrom
DomBrown:dev/trtllm_page_layout
Mar 23, 2026
Merged

feat: Expose TRT-LLM FMHA style paged KV Cache and page table layout#2770
saltyminty merged 2 commits intoflashinfer-ai:mainfrom
DomBrown:dev/trtllm_page_layout

Conversation

@DomBrown
Copy link
Copy Markdown
Contributor

@DomBrown DomBrown commented Mar 12, 2026

📌 Description

We received a request to expose TRT-LLM's paged KV Cache layout and page table style, in order to ease the process of integrating FlashInfer into TRT-LLM.

This pull request exposes this feature via adding uses_shared_paged_kv_idx as an option to the TRT-LLM Gen FMHA kernels, updates docstrings, and adds relevant tests.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added a public boolean parameter uses_shared_paged_kv_idx to prefill/decode/MLA APIs to select shared (2D) or separate K/V (3D) paged-KV layouts; docs updated.
  • Bug Fixes / Validation

    • Added input validation with clear errors for mismatched block_tables shapes; backend-specific rejection when a layout isn't supported.
  • Tests

    • Expanded tests to cover both shared and separate paged-KV layouts for prefill and decode paths.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 12, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f92e5ec9-fc49-44df-b416-7978bd0ca7ba

📥 Commits

Reviewing files that changed from the base of the PR and between bb144a0 and 2aeba4c.

📒 Files selected for processing (1)
  • tests/attention/test_trtllm_gen_attention.py
✅ Files skipped from review due to trivial changes (1)
  • tests/attention/test_trtllm_gen_attention.py

📝 Walkthrough

Walkthrough

This PR adds a uses_shared_paged_kv_idx boolean flag across the TRT-LLM paged-attention stack, threading it from public Python APIs through validation helpers into C++ kernel/runner parameters and kernel launcher signatures, and updates tests to cover both shared (2D) and separate (3D) K/V page-index layouts.

Changes

Cohort / File(s) Summary
C++ Kernel Launcher Core
csrc/trtllm_fmha_kernel_launcher.cu
Inserted bool uses_shared_paged_kv_idx parameter into trtllm_paged_attention_launcher (before sm_count) and threaded Optional<bool> uses_shared_paged_kv_idx through trtllm_paged_attention_decode and trtllm_paged_attention_context, updating call sites.
Kernel Parameter Structures
include/flashinfer/trtllm/fmha/fmhaRunnerParams.h, include/flashinfer/trtllm/fmha/kernelParams.h
Added mUsesSharedPagedKvIdx to TllmGenFmhaRunnerParams; default-initialized KernelParams::mUsesSharedPagedKvIdx{true} and changed setKernelParams() to propagate the option.
Python Decode API
flashinfer/decode.py
Exposed uses_shared_paged_kv_idx: bool = True on trtllm_batch_decode_with_kv_cache(), threaded through TrtllmGenDecodeModule paths, added validation (reject False for xqa), and passed flag into the trtllm-gen kernel call.
Python Prefill API
flashinfer/prefill.py
Added uses_shared_paged_kv_idx: bool = True to trtllm_batch_context_with_kv_cache() and internal paged-run wrappers, wired through the module/op call and validation before launch.
Python MLA API
flashinfer/mla.py
Added uses_shared_paged_kv_idx: bool = True to MLA decode API, extended _check_trtllm_gen_mla_shape() to accept the flag and validate/normalize block_tables, and forwarded flag through backend dispatch.
Utility Validation
flashinfer/utils.py
Added _check_block_tables_shape() to enforce 2D shape for shared layout and 3D-with-dim1==2 for separate K/V layout, raising descriptive ValueErrors on mismatch.
Tests
tests/attention/test_trtllm_gen_attention.py, tests/attention/test_trtllm_gen_mla.py
Added prepare_paged_kv_for_kernel() to convert inputs for separate-KV layout, parametrized prefill/decode/MLA tests over uses_shared_paged_kv_idx in [True, False], and added conditional backend skips/shape adjustments.

Sequence Diagram(s)

sequenceDiagram
  participant Caller as Python API
  participant Module as TrtllmGenModule
  participant Op as Torch Op / Fake Op
  participant Launcher as C++ Launcher
  participant Kernel as GPU Kernel

  Caller->>Module: call trtllm_*_with_kv_cache(uses_shared_paged_pkv_idx)
  Module->>Op: _paged_run / paged_run (passes flag)
  Op->>Launcher: trtllm_paged_attention_* (uses_shared_paged_kv_idx)
  Launcher->>Kernel: launch kernel with runner params (mUsesSharedPagedKvIdx)
  Kernel-->>Launcher: compute attention with chosen K/V layout
  Launcher-->>Op: return results
  Op-->>Module: return tensor(s)
  Module-->>Caller: return output
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • aleozlx
  • cyx-6
  • yzh119
  • nvmbreughe
  • jiahanc
  • jimmyzho
  • bkryu
  • Anerudhan

Poem

🐰 Hops through paged KV lands anew,
Shared or split, I bound and view,
Pages shuffled, indices told,
Two shapes now fit both brave and bold—
A little hop, a kernel true 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 34.15% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: exposing TRT-LLM's paged KV Cache and page table layout configuration via the uses_shared_paged_kv_idx option.
Description check ✅ Passed The description provides a clear explanation of the purpose (exposing TRT-LLM's paged KV Cache layout), mentions the implementation approach (adding uses_shared_paged_kv_idx option), and confirms completion of pre-commit checks and tests.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can scan for known vulnerabilities in your dependencies using OSV Scanner.

OSV Scanner will automatically detect and report security vulnerabilities in your project's dependencies. No additional configuration is required.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the flexibility of FlashInfer by exposing the TRT-LLM paged KV Cache and page table layout. This change allows users to explicitly choose between the FlashInfer/vLLM-style shared page indices and the TRT-LLM-style separate page indices for K and V caches. The integration involved updating kernel interfaces, Python APIs, and comprehensive testing to ensure compatibility and correct behavior across different configurations, facilitating easier integration with systems like TRT-LLM.

Highlights

  • TRT-LLM Paged KV Cache Layout Exposure: Introduced a new boolean parameter, uses_shared_paged_kv_idx, across TRT-LLM Gen FMHA kernels and Python APIs to control the paged KV Cache and page table layout. This allows switching between FlashInfer/vLLM style (shared K/V page indices) and TRT-LLM style (separate K/V page indices).
  • API and Kernel Updates: Modified core C++ kernel launchers (trtllm_paged_attention_launcher, trtllm_paged_attention_decode, trtllm_paged_attention_context) and their corresponding Python wrappers (flashinfer.decode.trtllm_batch_decode_with_kv_cache, flashinfer.prefill.trtllm_batch_context_with_kv_cache, flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla) to accept and utilize the new uses_shared_paged_kv_idx parameter.
  • Documentation and Type Hinting: Updated docstrings and type hints in Python APIs to clearly explain the uses_shared_paged_kv_idx parameter, its default behavior (FlashInfer/vLLM layout), and the implications for the block_tables tensor shape when using the TRT-LLM layout.
  • Comprehensive Testing: Added extensive test cases in test_trtllm_gen_attention.py and test_trtllm_gen_mla.py to validate both shared and non-shared paged KV index layouts across various configurations, including new helper functions for KV cache preparation and specific skips for unsupported backend/layout combinations (e.g., XQA backend with non-shared indices).
  • Kernel Artifact Recompilation: Updated artifact paths and checksums for TRTLLM_GEN_FMHA in flashinfer/artifacts.py, indicating a recompilation of the underlying CUDA kernels to incorporate the new functionality.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_fmha_kernel_launcher.cu
    • Added uses_shared_paged_kv_idx parameter to trtllm_paged_attention_launcher, trtllm_paged_attention_decode, and trtllm_paged_attention_context functions.
    • Assigned uses_shared_paged_kv_idx value to runner_params.mUsesSharedPagedKvIdx.
  • flashinfer/artifacts.py
    • Updated the TRTLLM_GEN_FMHA artifact path to a new hash.
    • Updated the TRTLLM_GEN_FMHA checksum hash.
  • flashinfer/decode.py
    • Added uses_shared_paged_kv_idx parameter to run and trtllm_batch_decode_with_kv_cache functions.
    • Updated docstrings for block_tables to describe its shape based on uses_shared_paged_kv_idx.
    • Passed uses_shared_paged_kv_idx to internal kernel calls.
  • flashinfer/mla.py
    • Adjusted _check_trtllm_gen_mla_shape to handle page_table with potentially more dimensions.
    • Added uses_shared_paged_kv_idx parameter to trtllm_batch_decode_with_kv_cache_mla.
    • Updated docstrings for block_tables to describe its shape based on uses_shared_paged_kv_idx.
    • Added a ValueError check for XQA backend when uses_shared_paged_kv_idx is false.
    • Passed uses_shared_paged_kv_idx to internal kernel calls.
  • flashinfer/prefill.py
    • Added uses_shared_paged_kv_idx parameter to _paged_run and trtllm_batch_context_with_kv_cache functions.
    • Updated docstrings for block_tables to describe its shape based on uses_shared_paged_kv_idx.
    • Passed uses_shared_paged_kv_idx to internal kernel calls.
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
    • Added mUsesSharedPagedKvIdx boolean member to TllmGenFmhaRunnerParams struct with a descriptive comment.
  • include/flashinfer/trtllm/fmha/kernelParams.h
    • Added mUsesSharedPagedKvIdx boolean member to KernelParams struct, initialized to false.
    • Assigned options.mUsesSharedPagedKvIdx to params.mUsesSharedPagedKvIdx in create_kernel_params.
  • tests/attention/test_trtllm_gen_attention.py
    • Imported Union from typing.
    • Added prepare_paged_kv_for_kernel helper function to transform KV cache and page table based on uses_shared_paged_kv_idx.
    • Added uses_shared_paged_kv_idx parameter to _test_trtllm_batch_prefill and _test_trtllm_batch_decode.
    • Modified _test_trtllm_batch_prefill and _test_trtllm_batch_decode to use prepare_paged_kv_for_kernel.
    • Added pytest.skip for XQA backend when uses_shared_paged_kv_idx is false in _test_trtllm_batch_decode.
    • Updated wrapper API test conditions to include uses_shared_paged_kv_idx.
    • Parameterized test_trtllm_batch_prefill, test_trtllm_batch_prefill_bs1, test_trtllm_batch_decode, test_trtllm_batch_decode_bs1, test_trtllm_batch_decode_head_dim_256, and test_trtllm_batch_decode_long_sequence_length with uses_shared_paged_kv_idx.
  • tests/attention/test_trtllm_gen_mla.py
    • Added uses_shared_paged_kv_idx parameter to trtllm_batch_decode_mla.
    • Added pytest.skip for XQA backend when uses_shared_paged_kv_idx is false in trtllm_batch_decode_mla.
    • Implemented logic to duplicate page table rows for separate KV page indices when uses_shared_paged_kv_idx is false.
    • Passed uses_shared_paged_kv_idx to the flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla call.
    • Parameterized test_trtllm_batch_decode_mla with uses_shared_paged_kv_idx.
Activity
  • The pull request introduces a new feature to expose TRT-LLM's paged KV Cache layout.
  • Core C++ kernel launchers and Python APIs have been modified to support the new uses_shared_paged_kv_idx parameter.
  • Documentation for the new parameter, including its impact on block_tables shape, has been added.
  • New test cases have been implemented to cover the different KV cache page table layouts.
  • Kernel artifacts have been updated, indicating a recompile to include the new functionality.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully exposes TensorRT-LLM's paged KV cache layout by introducing a uses_shared_paged_kv_idx option. The changes are consistently implemented across the C++ backend, Python API, and tests. The new functionality is also well-documented. I have a couple of minor suggestions to remove commented-out code to improve maintainability.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (1)
flashinfer/prefill.py (1)

3631-3632: Add the backend requirement decorator on this public TRT-LLM API.

This entrypoint is backend/SM-gated, but the expanded public surface is still only decorated with @flashinfer_api. Please expose the capability guard consistently here as repository policy requires.

As per coding guidelines, "Use @backend_requirement decorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 3631 - 3632, The public TRT-LLM
entrypoint trtllm_batch_context_with_kv_cache is missing the backend requirement
guard; add the `@backend_requirement` decorator above the function (in addition to
the existing `@flashinfer_api`) and implement its predicate using the module/class
methods is_compute_capability_supported(cc) and is_backend_supported() so the
API is gated by both compute capability and backend support as per policy;
ensure the decorator references those methods and preserve the existing function
signature and behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/decode.py`:
- Line 1393: BatchDecodeWithPagedKVCacheWrapper.run() currently hardcodes the
shared-page flag to True which prevents using the TRT-LLM 3-D page-table layout;
change the call so the boolean is computed from the plan's block_tables shape
instead of hardcoded. Specifically, after calling plan() (or wherever
block_tables is produced), detect whether block_tables are 3-D (e.g., check
ndim/shape or a plan-provided attribute) and set a local
uses_shared_paged_kv_idx boolean accordingly, then pass that variable into
BatchDecodeWithPagedKVCacheWrapper.run() instead of the literal True so 3-D
layouts dispatch with the correct flag.
- Line 2160: The new parameter uses_shared_paged_kv_idx is only valid for the
trtllm-gen path but currently can be passed through to XQA backends; add the
same guard used in flashinfer/mla.py before dispatch to validate backend and GPU
capability and raise a fast-fail (clear exception) if uses_shared_paged_kv_idx
is True while backend is "xqa" or backend=="auto" on XQA-capable GPUs, and
update the function/method docstring near the uses_shared_paged_kv_idx parameter
to state explicitly that False/True is trtllm-gen-only (i.e., that True applies
only to trtllm-gen).

In `@flashinfer/mla.py`:
- Around line 176-177: The code currently reads page_table.shape[0] and
shape[-1] without validating rank/layout; add explicit checks in the function
handling page_table to reject invalid ranks and layouts before dispatch: if
uses_shared_paged_kv_idx is True require page_table.ndim == 2 and otherwise for
separate-layout require page_table.ndim == 3 and page_table.shape[1] == 2 (or
the expected second-dimension value); raise a clear exception when the checks
fail. Keep the existing uses of B_block_table = page_table.shape[0] and
block_num = page_table.shape[-1] but only after these validations so XQA/TRT-LLM
cannot misinterpret the tensor.

In `@flashinfer/prefill.py`:
- Around line 2331-2333: BatchPrefillWithPagedKVCacheWrapper.run() is
incorrectly hardcoding the uses_shared_paged_kv_idx argument as True when
calling the wrapped prefill (causing TRT-LLM 3-D block_tables callers to take
the shared-index path); change the call to pass the actual flag instead of True
by reading/passing the wrapper's uses_shared_paged_kv_idx property or inspecting
the planned/cache metadata (e.g., block_tables shape or an existing attribute)
and forward that boolean into the call (replace the literal True with the
appropriate variable) so the correct paged-index path is selected; ensure the
symbols sinks and skip_softmax_threshold_scale_factor remain unchanged.
- Around line 3654-3655: Validate block_tables shape and page id ranges against
uses_shared_paged_kv_idx before invoking the CUDA op
(trtllm_paged_attention_context): ensure that when uses_shared_paged_kv_idx is
False block_tables is 2-D (slices x seq_len) and when True it is 3-D (slices x
num_pages x seq_len); also validate that every page id in block_tables is within
the allowed range (e.g., 0 <= id < num_pages or other kernel-expected bounds)
and raise a clear ValueError if violated. Apply the same guard logic at the
other callsites mentioned (around the blocks that call
trtllm_paged_attention_context at the ranges you noted) so the public contract
for uses_shared_paged_kv_idx is enforced before any CUDA call.

---

Nitpick comments:
In `@flashinfer/prefill.py`:
- Around line 3631-3632: The public TRT-LLM entrypoint
trtllm_batch_context_with_kv_cache is missing the backend requirement guard; add
the `@backend_requirement` decorator above the function (in addition to the
existing `@flashinfer_api`) and implement its predicate using the module/class
methods is_compute_capability_supported(cc) and is_backend_supported() so the
API is gated by both compute capability and backend support as per policy;
ensure the decorator references those methods and preserve the existing function
signature and behavior.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: aadf6006-0f39-4290-b3e8-f280539a079a

📥 Commits

Reviewing files that changed from the base of the PR and between 043bc43 and 35b68e0.

📒 Files selected for processing (9)
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/artifacts.py
  • flashinfer/decode.py
  • flashinfer/mla.py
  • flashinfer/prefill.py
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • tests/attention/test_trtllm_gen_attention.py
  • tests/attention/test_trtllm_gen_mla.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
flashinfer/prefill.py (1)

2324-2335: ⚠️ Potential issue | 🟠 Major

Forward the actual page-index layout instead of hardcoding shared mode.

Line 2334 hardcodes uses_shared_paged_kv_idx=True, so a planned 3D block_tables path can still launch as shared layout.

Suggested fix
                 run_args += [
@@
                     self._qo_indptr_buf,
                     self._paged_kv_indptr_buf,
                     sinks,
                     skip_softmax_threshold_scale_factor,
-                    True,  # uses_shared_paged_kv_idx
+                    (
+                        self._block_tables is None
+                        or self._block_tables.ndim == 2
+                    ),  # uses_shared_paged_kv_idx
                 ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 2324 - 2335, The call currently passes a
hardcoded True for uses_shared_paged_kv_idx, which forces shared page-index
layout even when self._block_tables may represent a 3D/non-shared layout;
compute or obtain the correct flag (e.g., a local variable
uses_shared_paged_kv_idx derived from self._block_tables or from the existing
layout/config helper) and pass that variable instead of True so the actual
page-index layout is forwarded to the callee.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/prefill.py`:
- Around line 2324-2335: The call currently passes a hardcoded True for
uses_shared_paged_kv_idx, which forces shared page-index layout even when
self._block_tables may represent a 3D/non-shared layout; compute or obtain the
correct flag (e.g., a local variable uses_shared_paged_kv_idx derived from
self._block_tables or from the existing layout/config helper) and pass that
variable instead of True so the actual page-index layout is forwarded to the
callee.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 56d1955d-6e3a-47bc-89da-8b3a46128dc1

📥 Commits

Reviewing files that changed from the base of the PR and between 35b68e0 and 3a6099f.

📒 Files selected for processing (4)
  • flashinfer/decode.py
  • flashinfer/mla.py
  • flashinfer/prefill.py
  • flashinfer/utils.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/trtllm_fmha_kernel_launcher.cu (1)

241-277: ⚠️ Potential issue | 🟠 Major

Validate block_tables against uses_shared_paged_kv_idx at the FFI boundary.

These exported entry points now accept both 2-D and 3-D page tables, but they never check that block_tables actually matches the selected mode before handing a raw int* to the runner. A direct TVM/FFI caller can therefore launch the kernel with the wrong interpretation and get wrong outputs or invalid reads.

🛡️ Proposed fix
   bool const uses_shared_paged_kv_idx_value = uses_shared_paged_kv_idx.value_or(true);
+  int const expected_block_tables_ndim =
+      uses_shared_paged_kv_idx_value ? 2 : 3;
+  TVM_FFI_ICHECK_EQ(block_tables.ndim(), expected_block_tables_ndim)
+      << "block_tables must be " << expected_block_tables_ndim
+      << "D when uses_shared_paged_kv_idx="
+      << (uses_shared_paged_kv_idx_value ? "true" : "false");
+  if (!uses_shared_paged_kv_idx_value) {
+    TVM_FFI_ICHECK_EQ(block_tables.size(1), 2)
+        << "block_tables.shape[1] must be 2 when uses_shared_paged_kv_idx=false";
+  }

Apply the same check in both trtllm_paged_attention_decode() and trtllm_paged_attention_context().

Also applies to: 360-385

♻️ Duplicate comments (1)
flashinfer/decode.py (1)

1425-1433: ⚠️ Potential issue | 🟠 Major

Stop forcing shared page indices in the wrapper path.

plan() can preserve a caller-supplied self._block_tables, but run() still always appends uses_shared_paged_kv_idx=True. A planned [batch, 2, max_pages] table will therefore be launched in shared mode instead of being handled or rejected.

🔧 Proposed fix
+                uses_shared_paged_kv_idx = (
+                    self._block_tables is None or self._block_tables.ndim == 2
+                )
                 run_args += [
                     None,  # packed_custom_mask
                     None,  # mask_indptr_buf
                     _get_cache_alibi_slopes_buf(q.shape[1], q.device),
                     None,  # maybe_prefix_len_ptr
@@
                     sinks,
                     key_block_scales,
                     value_block_scales,
                     skip_softmax_threshold_scale_factor,
-                    True,  # uses_shared_paged_kv_idx
+                    uses_shared_paged_kv_idx,
                 ]

If wrapper support for 3-D tables is not intended yet, fail fast in plan()/run() instead of silently dispatching the wrong mode.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 1425 - 1433, The wrapper currently forces
uses_shared_paged_kv_idx=True when calling the kernel (the call site passing
True in run()), which mismatches a caller-supplied self._block_tables (e.g., 3-D
[batch,2,max_pages]); instead either stop forcing that flag or fail fast: modify
the run()/plan() flow so you do not hardcode True—derive
uses_shared_paged_kv_idx from the existing metadata (or a new boolean on the
object) and pass that through, and add a validation in plan() (and/or run())
that inspects self._block_tables shape and raises an error if a 3-D table is
provided when shared-page support is not implemented; update any call sites that
assumed the hardcoded True to use the new derived flag.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/decode.py`:
- Around line 2231-2235: Update the docstring for block_tables to explicitly
document the TRT-LLM KV-cache 3-D layout and note that when
uses_shared_paged_kv_idx is False callers/tests must reshape/interleave the
legacy kv_cache and kv_block_scales into the [batch_size, 2,
max_num_pages_per_seq] layout (dim 1 = K/V) before calling this API; ensure the
same explanatory note is added to the other docstring locations referencing
kv_cache/kv_block_scales (the second occurrence that currently describes the
legacy layout).

---

Duplicate comments:
In `@flashinfer/decode.py`:
- Around line 1425-1433: The wrapper currently forces
uses_shared_paged_kv_idx=True when calling the kernel (the call site passing
True in run()), which mismatches a caller-supplied self._block_tables (e.g., 3-D
[batch,2,max_pages]); instead either stop forcing that flag or fail fast: modify
the run()/plan() flow so you do not hardcode True—derive
uses_shared_paged_kv_idx from the existing metadata (or a new boolean on the
object) and pass that through, and add a validation in plan() (and/or run())
that inspects self._block_tables shape and raises an error if a 3-D table is
provided when shared-page support is not implemented; update any call sites that
assumed the hardcoded True to use the new derived flag.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: fcbd4e57-f7be-4bf8-8fae-5d290ec71de8

📥 Commits

Reviewing files that changed from the base of the PR and between 3a6099f and 7be2733.

📒 Files selected for processing (10)
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/artifacts.py
  • flashinfer/decode.py
  • flashinfer/mla.py
  • flashinfer/prefill.py
  • flashinfer/utils.py
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • tests/attention/test_trtllm_gen_attention.py
  • tests/attention/test_trtllm_gen_mla.py
✅ Files skipped from review due to trivial changes (1)
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
🚧 Files skipped from review as they are similar to previous changes (5)
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • tests/attention/test_trtllm_gen_mla.py
  • flashinfer/artifacts.py
  • flashinfer/mla.py
  • flashinfer/prefill.py

Copy link
Copy Markdown
Collaborator

@saltyminty saltyminty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved pending CI. Just left one minor comment

@saltyminty
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !441 has been created, and the CI pipeline #46623817 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #46623817: 13/20 passed

@DomBrown DomBrown force-pushed the dev/trtllm_page_layout branch 2 times, most recently from bb144a0 to 2aeba4c Compare March 20, 2026 23:59
@saltyminty
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !441 has been updated with latest changes, and the CI pipeline #46644158 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #46644158: 13/20 passed

@yzh119 yzh119 added the run-ci label Mar 23, 2026
@saltyminty
Copy link
Copy Markdown
Collaborator

CI looks good.

@saltyminty saltyminty merged commit 625c1c6 into flashinfer-ai:main Mar 23, 2026
45 of 47 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants