read real strides for kv and block scale#2844
read real strides for kv and block scale#2844sychen52 wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (5)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (3)
📝 WalkthroughWalkthroughExtended the paged-attention kernel launcher and runner parameter structs to accept and propagate four explicit stride fields for key/value block-scale tensors; updated KernelParams TMA stride computation to use these scale-stride fields and adjusted Python docstrings to document trtllm-gen NVFP4 contiguity/layout constraints for KV and block-scales. Changes
Sequence Diagram(s)(Skipped — changes are localized stride-threading and TMA setup without a new multi-component control flow requiring visualization.) Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refines the handling of NVFP4 KV cache block scales within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly adds support for reading the actual strides from key/value block scale tensors instead of deriving them, which is a good improvement for handling non-contiguous tensors. The changes are well-implemented across the CUDA and Python files. My main feedback is to refactor a duplicated block of code responsible for extracting these strides in csrc/trtllm_fmha_kernel_launcher.cu into a helper function to improve code maintainability.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
flashinfer/prefill.py (1)
3712-3715: Consider enforcing these documented stride constraints at runtime.These are hard kernel/TMA requirements; adding explicit checks would fail fast with actionable errors instead of surfacing low-level kernel failures.
Proposed guard (example)
+def _validate_trtllm_gen_kv_stride_requirements( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + key_block_scales: Optional[torch.Tensor], + value_block_scales: Optional[torch.Tensor], +) -> None: + if k_cache.stride(-1) != 1 or v_cache.stride(-1) != 1: + raise ValueError("trtllm-gen requires KV cache last-dim stride == 1.") + for name, sf in (("key_block_scales", key_block_scales), ("value_block_scales", value_block_scales)): + if sf is None: + continue + expected_stride_m2 = sf.size(-1) # head_dim // 16 + if sf.stride(-1) != 1 or sf.stride(-2) != expected_stride_m2: + raise ValueError( + f"{name} must satisfy stride[-1] == 1 and stride[-2] == size(-1) " + f"(got {sf.stride()})." + )Also applies to: 3754-3757, 3766-3772
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 3712 - 3715, Add explicit runtime guards in the code paths that prepare/launch the trtllm-gen TMA kernels (where the comment about contiguity appears) to validate tensor strides: assert that tensor.stride(-1) for the head_dim is 1 and raise a clear ValueError mentioning "trtllm-gen requires head_dim stride 1", and optionally include the offending tensor shape/strides in the message; also validate expected contiguity/stride patterns for batch/head/page dims where those prep points occur. Apply the same checks at the other preparation sites referenced near the same comment (the other trtllm-gen buffer/prep functions) so all code paths fail fast with actionable errors instead of letting low‑level kernel failures surface.csrc/trtllm_fmha_kernel_launcher.cu (1)
301-312: Consider extracting shared SF-stride parsing into a helper.Decode/context now duplicate identical stride extraction logic; a small helper would prevent future drift.
♻️ Optional refactor sketch
+static inline std::tuple<int, int, int, int> get_kv_sf_strides( + Optional<TensorView> key_block_scales, Optional<TensorView> value_block_scales) { + int k_sf_stride_heads = 0, k_sf_stride_batch = 0; + int v_sf_stride_heads = 0, v_sf_stride_batch = 0; + if (key_block_scales.has_value()) { + k_sf_stride_heads = key_block_scales.value().stride(-3); + k_sf_stride_batch = key_block_scales.value().stride(0); + } + if (value_block_scales.has_value()) { + v_sf_stride_heads = value_block_scales.value().stride(-3); + v_sf_stride_batch = value_block_scales.value().stride(0); + } + return {k_sf_stride_heads, k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch}; +}Also applies to: 419-429
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fmha_kernel_launcher.cu` around lines 301 - 312, Extract the duplicated stride-extraction logic into a small helper (e.g., static inline void parse_sf_strides(const OptionalTensor& sf, int &stride_heads, int &stride_batch) or similar) and replace the two blocks that set k_sf_stride_heads/k_sf_stride_batch and v_sf_stride_heads/v_sf_stride_batch with calls to that helper using key_block_scales and value_block_scales; the helper should set strides to 0 when sf.has_value() is false and otherwise use sf.value().stride(-3) and sf.value().stride(0) so both the first block (k_sf_*) and the later block (v_sf_*) reuse the same logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 479-482: The V scale-factor TMA descriptor is currently built
reusing the K stride by only calling makeTmaShapeStrideKvSf(..., /*isK=*/true)
for both tmaKSf_ and tmaVSf_, causing incorrect V dequant reads when K/V strides
differ; fix by computing separate strides and shapes for V using the V-specific
stride variables (options.vSfStrideHeads/options.vSfStrideBatch or
sfStrideHeads/sfStrideBatch when isK is false) and call
makeTmaShapeStrideKvSf(..., /*isK=*/false) (or an equivalent overload) when
constructing tmaVSf_ so tmaKSf_ and tmaVSf_ use the correct per-tensor stride
values.
---
Nitpick comments:
In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Around line 301-312: Extract the duplicated stride-extraction logic into a
small helper (e.g., static inline void parse_sf_strides(const OptionalTensor&
sf, int &stride_heads, int &stride_batch) or similar) and replace the two blocks
that set k_sf_stride_heads/k_sf_stride_batch and
v_sf_stride_heads/v_sf_stride_batch with calls to that helper using
key_block_scales and value_block_scales; the helper should set strides to 0 when
sf.has_value() is false and otherwise use sf.value().stride(-3) and
sf.value().stride(0) so both the first block (k_sf_*) and the later block
(v_sf_*) reuse the same logic.
In `@flashinfer/prefill.py`:
- Around line 3712-3715: Add explicit runtime guards in the code paths that
prepare/launch the trtllm-gen TMA kernels (where the comment about contiguity
appears) to validate tensor strides: assert that tensor.stride(-1) for the
head_dim is 1 and raise a clear ValueError mentioning "trtllm-gen requires
head_dim stride 1", and optionally include the offending tensor shape/strides in
the message; also validate expected contiguity/stride patterns for
batch/head/page dims where those prep points occur. Apply the same checks at the
other preparation sites referenced near the same comment (the other trtllm-gen
buffer/prep functions) so all code paths fail fast with actionable errors
instead of letting low‑level kernel failures surface.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7fe4edca-16b2-4424-8f7a-ca52da1dd1d5
📒 Files selected for processing (5)
csrc/trtllm_fmha_kernel_launcher.cuflashinfer/decode.pyflashinfer/prefill.pyinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.hinclude/flashinfer/trtllm/fmha/kernelParams.h
saltyminty
left a comment
There was a problem hiding this comment.
Approved conditional on merge conflict and CI
|
/bot run |
|
[FAILED] Pipeline #46897259: 1/20 passed |
Acked-by: Shiyang Chen <shiychen@nvidia.com>
e7bdbb5 to
30ec0c7
Compare
|
/bot run |
|
@sychen52 is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
|
[FAILED] Pipeline #46914864: 11/20 passed |
Acked-by: Shiyang Chen shiychen@nvidia.com
📌 Description
I realize that current kv block scale factor's stride is derived from kv cache instead of reading from the actual tensor. This put a lot of unnecessary assumption/constraint on the scale factor.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit