Skip to content

read real strides for kv and block scale#2844

Open
sychen52 wants to merge 1 commit intoflashinfer-ai:mainfrom
sychen52:nvfp4_kv_stride_fix
Open

read real strides for kv and block scale#2844
sychen52 wants to merge 1 commit intoflashinfer-ai:mainfrom
sychen52:nvfp4_kv_stride_fix

Conversation

@sychen52
Copy link
Contributor

@sychen52 sychen52 commented Mar 20, 2026

Acked-by: Shiyang Chen shiychen@nvidia.com

📌 Description

I realize that current kv block scale factor's stride is derived from kv cache instead of reading from the actual tensor. This put a lot of unnecessary assumption/constraint on the scale factor.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Chores
    • Kernel launcher and runtime updated to accept explicit stride parameters for key/value scale tensors, improving stride handling for KV caches.
  • Documentation
    • Expanded backend-specific docs describing layout, contiguity and stride constraints for KV data and per-block scale tensors; clarified automatic transpose/contiguous copy behavior and expected shapes/dtypes.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 20, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4de56ecc-32fa-4ac7-9c90-124237b5e091

📥 Commits

Reviewing files that changed from the base of the PR and between 80444f7 and 30ec0c7.

📒 Files selected for processing (5)
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/decode.py
  • flashinfer/prefill.py
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • include/flashinfer/trtllm/fmha/kernelParams.h
✅ Files skipped from review due to trivial changes (1)
  • flashinfer/decode.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • flashinfer/prefill.py
  • csrc/trtllm_fmha_kernel_launcher.cu
  • include/flashinfer/trtllm/fmha/kernelParams.h

📝 Walkthrough

Walkthrough

Extended the paged-attention kernel launcher and runner parameter structs to accept and propagate four explicit stride fields for key/value block-scale tensors; updated KernelParams TMA stride computation to use these scale-stride fields and adjusted Python docstrings to document trtllm-gen NVFP4 contiguity/layout constraints for KV and block-scales.

Changes

Cohort / File(s) Summary
Kernel Launcher Implementation
csrc/trtllm_fmha_kernel_launcher.cu
Launcher signature changed: added k_sf_stride_heads, k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch; trtllm_paged_attention_decode and ..._context now compute these from optional block-scale tensors and pass them to the launcher.
Runner Parameters Structure
include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
Added four new int fields (kSfStrideHeads, kSfStrideBatch, vSfStrideHeads, vSfStrideBatch) to TllmGenFmhaRunnerParams.
Kernel Parameters & TMA Configuration
include/flashinfer/trtllm/fmha/kernelParams.h
makeTmaShapeStrideKvSf now uses explicit scale-tensor stride fields from runner options instead of deriving them from KV data; updated doc comments describing required HND layouts and contiguity/stride constraints for TMA.
Python API Documentation
flashinfer/decode.py, flashinfer/prefill.py
Docstrings expanded with trtllm-gen NVFP4 requirements for KV layouts and kv_block_scales formatting/contiguity; documents automatic transpose+.contiguous() behavior when kv_layout="NHD".

Sequence Diagram(s)

(Skipped — changes are localized stride-threading and TMA setup without a new multi-component control flow requiring visualization.)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • aleozlx
  • yzh119
  • cyx-6
  • bkryu
  • nv-yunzheq
  • jimmyzho
  • sricketts
  • yyihuang

Poem

🐰 I nudged four strides down the kernel glade,
Keys and values in tidy parade,
TMA boxes line up neat and small,
Scale hops fit, one stride for all,
A hop, a copy — kernels sing in spade.

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The description explains the core issue and solution, but the Tests section is incomplete with no tests marked as added/updated or passing. Clarify the testing status: either add/update tests and confirm they pass, or document why no tests are needed for these changes.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: reading real strides for KV and block scale tensors instead of deriving them.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refines the handling of NVFP4 KV cache block scales within the trtllm-gen backend by introducing explicit stride management. By reading and utilizing the actual strides of key and value block scale tensors, the system gains greater flexibility in memory layout and ensures correct data access patterns, particularly for Tensor Memory Accelerator (TMA) operations. The changes also include updated documentation to guide users on optimal KV cache configurations and contiguity requirements for improved performance.

Highlights

  • Explicit Scale Factor Strides: Introduced explicit stride parameters for K and V block scales in the trtllm_paged_attention_launcher function, enabling more precise memory access for NVFP4 KV cache.
  • Stride Calculation and Propagation: Implemented the calculation and passing of these explicit K and V block scale strides in trtllm_paged_attention_decode and trtllm_paged_attention_context functions, ensuring they are correctly propagated to the kernel.
  • Kernel Parameter Updates: Updated the TllmGenFmhaRunnerParams structure to include dedicated fields for K and V block scale strides, and modified the kernel parameter generation logic to utilize these actual strides for TMA shape/stride creation.
  • Documentation Enhancements: Enhanced Python API documentation for trtllm_batch_decode_with_kv_cache and trtllm_batch_context_with_kv_cache to clarify contiguity requirements for KV cache data and block scales, and to advise on optimal kv_layout for NVFP4 performance.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly adds support for reading the actual strides from key/value block scale tensors instead of deriving them, which is a good improvement for handling non-contiguous tensors. The changes are well-implemented across the CUDA and Python files. My main feedback is to refactor a duplicated block of code responsible for extracting these strides in csrc/trtllm_fmha_kernel_launcher.cu into a helper function to improve code maintainability.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
flashinfer/prefill.py (1)

3712-3715: Consider enforcing these documented stride constraints at runtime.

These are hard kernel/TMA requirements; adding explicit checks would fail fast with actionable errors instead of surfacing low-level kernel failures.

Proposed guard (example)
+def _validate_trtllm_gen_kv_stride_requirements(
+    k_cache: torch.Tensor,
+    v_cache: torch.Tensor,
+    key_block_scales: Optional[torch.Tensor],
+    value_block_scales: Optional[torch.Tensor],
+) -> None:
+    if k_cache.stride(-1) != 1 or v_cache.stride(-1) != 1:
+        raise ValueError("trtllm-gen requires KV cache last-dim stride == 1.")
+    for name, sf in (("key_block_scales", key_block_scales), ("value_block_scales", value_block_scales)):
+        if sf is None:
+            continue
+        expected_stride_m2 = sf.size(-1)  # head_dim // 16
+        if sf.stride(-1) != 1 or sf.stride(-2) != expected_stride_m2:
+            raise ValueError(
+                f"{name} must satisfy stride[-1] == 1 and stride[-2] == size(-1) "
+                f"(got {sf.stride()})."
+            )

Also applies to: 3754-3757, 3766-3772

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 3712 - 3715, Add explicit runtime guards
in the code paths that prepare/launch the trtllm-gen TMA kernels (where the
comment about contiguity appears) to validate tensor strides: assert that
tensor.stride(-1) for the head_dim is 1 and raise a clear ValueError mentioning
"trtllm-gen requires head_dim stride 1", and optionally include the offending
tensor shape/strides in the message; also validate expected contiguity/stride
patterns for batch/head/page dims where those prep points occur. Apply the same
checks at the other preparation sites referenced near the same comment (the
other trtllm-gen buffer/prep functions) so all code paths fail fast with
actionable errors instead of letting low‑level kernel failures surface.
csrc/trtllm_fmha_kernel_launcher.cu (1)

301-312: Consider extracting shared SF-stride parsing into a helper.

Decode/context now duplicate identical stride extraction logic; a small helper would prevent future drift.

♻️ Optional refactor sketch
+static inline std::tuple<int, int, int, int> get_kv_sf_strides(
+    Optional<TensorView> key_block_scales, Optional<TensorView> value_block_scales) {
+  int k_sf_stride_heads = 0, k_sf_stride_batch = 0;
+  int v_sf_stride_heads = 0, v_sf_stride_batch = 0;
+  if (key_block_scales.has_value()) {
+    k_sf_stride_heads = key_block_scales.value().stride(-3);
+    k_sf_stride_batch = key_block_scales.value().stride(0);
+  }
+  if (value_block_scales.has_value()) {
+    v_sf_stride_heads = value_block_scales.value().stride(-3);
+    v_sf_stride_batch = value_block_scales.value().stride(0);
+  }
+  return {k_sf_stride_heads, k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch};
+}

Also applies to: 419-429

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fmha_kernel_launcher.cu` around lines 301 - 312, Extract the
duplicated stride-extraction logic into a small helper (e.g., static inline void
parse_sf_strides(const OptionalTensor& sf, int &stride_heads, int &stride_batch)
or similar) and replace the two blocks that set
k_sf_stride_heads/k_sf_stride_batch and v_sf_stride_heads/v_sf_stride_batch with
calls to that helper using key_block_scales and value_block_scales; the helper
should set strides to 0 when sf.has_value() is false and otherwise use
sf.value().stride(-3) and sf.value().stride(0) so both the first block (k_sf_*)
and the later block (v_sf_*) reuse the same logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 479-482: The V scale-factor TMA descriptor is currently built
reusing the K stride by only calling makeTmaShapeStrideKvSf(..., /*isK=*/true)
for both tmaKSf_ and tmaVSf_, causing incorrect V dequant reads when K/V strides
differ; fix by computing separate strides and shapes for V using the V-specific
stride variables (options.vSfStrideHeads/options.vSfStrideBatch or
sfStrideHeads/sfStrideBatch when isK is false) and call
makeTmaShapeStrideKvSf(..., /*isK=*/false) (or an equivalent overload) when
constructing tmaVSf_ so tmaKSf_ and tmaVSf_ use the correct per-tensor stride
values.

---

Nitpick comments:
In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Around line 301-312: Extract the duplicated stride-extraction logic into a
small helper (e.g., static inline void parse_sf_strides(const OptionalTensor&
sf, int &stride_heads, int &stride_batch) or similar) and replace the two blocks
that set k_sf_stride_heads/k_sf_stride_batch and
v_sf_stride_heads/v_sf_stride_batch with calls to that helper using
key_block_scales and value_block_scales; the helper should set strides to 0 when
sf.has_value() is false and otherwise use sf.value().stride(-3) and
sf.value().stride(0) so both the first block (k_sf_*) and the later block
(v_sf_*) reuse the same logic.

In `@flashinfer/prefill.py`:
- Around line 3712-3715: Add explicit runtime guards in the code paths that
prepare/launch the trtllm-gen TMA kernels (where the comment about contiguity
appears) to validate tensor strides: assert that tensor.stride(-1) for the
head_dim is 1 and raise a clear ValueError mentioning "trtllm-gen requires
head_dim stride 1", and optionally include the offending tensor shape/strides in
the message; also validate expected contiguity/stride patterns for
batch/head/page dims where those prep points occur. Apply the same checks at the
other preparation sites referenced near the same comment (the other trtllm-gen
buffer/prep functions) so all code paths fail fast with actionable errors
instead of letting low‑level kernel failures surface.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7fe4edca-16b2-4424-8f7a-ca52da1dd1d5

📥 Commits

Reviewing files that changed from the base of the PR and between af1e02d and 80444f7.

📒 Files selected for processing (5)
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/decode.py
  • flashinfer/prefill.py
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • include/flashinfer/trtllm/fmha/kernelParams.h

Copy link
Collaborator

@saltyminty saltyminty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved conditional on merge conflict and CI

@saltyminty
Copy link
Collaborator

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !458 has been created, and the CI pipeline #46897259 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46897259: 1/20 passed

Acked-by: Shiyang Chen <shiychen@nvidia.com>
@sychen52 sychen52 force-pushed the nvfp4_kv_stride_fix branch from e7bdbb5 to 30ec0c7 Compare March 24, 2026 20:52
@sychen52 sychen52 requested a review from samuellees as a code owner March 24, 2026 20:52
@sychen52
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@sychen52 is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@saltyminty
Copy link
Collaborator

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !458 has been updated with latest changes, and the CI pipeline #46914864 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46914864: 11/20 passed

@saltyminty saltyminty enabled auto-merge (squash) March 25, 2026 18:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants