feat: support non-contiguous query for trtllm-gen attention backend#2254
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughPropagates explicit Q and K/V stride parameters through the TensorRT LLM paged-attention path: adds qStrideTokens/qStrideHeads to runner params, updates launcher signature and callers to pass computed strides, refactors stride inference to accept user-specified or layout-derived values, and extends tests to cover non-contiguous queries. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🪛 Ruff (0.14.8)tests/attention/test_trtllm_gen_attention.py1482-1482: Unused function argument: (ARG001) 1482-1482: Unused function argument: (ARG001) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (3)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request implements support for non-contiguous query tensors within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request successfully adds support for non-contiguous query tensors in the trtllm-gen attention backend by passing the query tensor's strides to the kernel. The changes in the CUDA/C++ code are logical, and the new functionality is well-covered by tests. I have provided a few suggestions to enhance code maintainability and clarity, addressing code duplication in csrc/trtllm_fmha_kernel_launcher.cu, a redundant calculation in include/flashinfer/trtllm/fmha/kernelParams.h, and unused parameters in a new test helper function in tests/attention/test_trtllm_gen_attention.py.
| // Query stride: [num_tokens, num_heads, head_dim] | ||
| int q_stride_tokens = query.stride(0); // stride between tokens | ||
| int q_stride_heads = query.stride(1); // stride between heads |
There was a problem hiding this comment.
This logic to compute q_stride_tokens and q_stride_heads is also present in trtllm_paged_attention_decode at lines 263-265. There is a fair amount of duplicated code between trtllm_paged_attention_decode and trtllm_paged_attention_context for extracting tensor metadata (shapes, strides, dtypes). To improve maintainability, consider refactoring this common logic into a helper function or struct.
| int32_t strideHeads{options.qStrideHeads}; | ||
| if (strideHeads == 0) { | ||
| strideHeads = options.mHeadDimQk; | ||
| } | ||
| // The stride between grouped heads (consecutive heads within a GQA group). | ||
| // Use user-provided stride if available, otherwise use headDimQk. | ||
| int32_t strideGroupedHeads{options.qStrideHeads}; | ||
| if (strideGroupedHeads == 0) { | ||
| strideGroupedHeads = options.mHeadDimQk; | ||
| } |
There was a problem hiding this comment.
The logic to determine strideHeads and strideGroupedHeads is identical, which is redundant. You can simplify this by initializing strideGroupedHeads from strideHeads before strideHeads is potentially modified for GQA. This will make the code cleaner and easier to maintain.
| int32_t strideHeads{options.qStrideHeads}; | |
| if (strideHeads == 0) { | |
| strideHeads = options.mHeadDimQk; | |
| } | |
| // The stride between grouped heads (consecutive heads within a GQA group). | |
| // Use user-provided stride if available, otherwise use headDimQk. | |
| int32_t strideGroupedHeads{options.qStrideHeads}; | |
| if (strideGroupedHeads == 0) { | |
| strideGroupedHeads = options.mHeadDimQk; | |
| } | |
| int32_t strideHeads{options.qStrideHeads}; | |
| if (strideHeads == 0) { | |
| strideHeads = options.mHeadDimQk; | |
| } | |
| // The stride between grouped heads (consecutive heads within a GQA group) is the same as the base head stride. | |
| int32_t strideGroupedHeads{strideHeads}; |
| ) | ||
|
|
||
|
|
||
| def make_query_non_contiguous(q, num_qo_heads, head_dim): |
There was a problem hiding this comment.
The parameters num_qo_heads and head_dim are unused in this function. The shape is inferred directly from the input tensor q. Please remove these unused parameters to simplify the function signature.
| def make_query_non_contiguous(q, num_qo_heads, head_dim): | |
| def make_query_non_contiguous(q): |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_attention.py (1)
1478-1491: Remove unused parameters from the helper function.The
num_qo_headsandhead_dimparameters are not used in the function body. The function only usesq.shapeto extract the dimensions (line 1485). These parameters should be removed for clarity.🔎 Suggested refactor
-def make_query_non_contiguous(q, num_qo_heads, head_dim): +def make_query_non_contiguous(q): """ Create a non-contiguous version of the query tensor. Create a (N, H, 2*D) tensor and slice the first D dimensions: x[..., :D] This produces a non-contiguous view with the same data. """ n, h, d = q.shape # Create a larger tensor with 2*D in the last dimension large_tensor = torch.zeros(n, h, 2 * d, dtype=q.dtype, device=q.device) large_tensor[..., :d] = q # Slice to get non-contiguous query (only last dim is contiguous) q_non_contiguous = large_tensor[..., :d] assert not q_non_contiguous.is_contiguous(), "Query should be non-contiguous" return q_non_contiguousAnd update the call sites on lines 536, 932:
- q_input = make_query_non_contiguous(q, num_qo_heads, head_dim) + q_input = make_query_non_contiguous(q)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
csrc/trtllm_fmha_kernel_launcher.cu(6 hunks)include/flashinfer/trtllm/fmha/fmhaRunnerParams.h(1 hunks)include/flashinfer/trtllm/fmha/kernelParams.h(1 hunks)tests/attention/test_trtllm_gen_attention.py(13 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_attention.py (2)
flashinfer/prefill.py (1)
trtllm_batch_context_with_kv_cache(3461-3673)flashinfer/decode.py (1)
trtllm_batch_decode_with_kv_cache(2067-2361)
🪛 Ruff (0.14.8)
tests/attention/test_trtllm_gen_attention.py
1478-1478: Unused function argument: num_qo_heads
(ARG001)
1478-1478: Unused function argument: head_dim
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (8)
include/flashinfer/trtllm/fmha/fmhaRunnerParams.h (1)
235-238: LGTM! Clean addition of query stride fields.The new
qStrideTokensandqStrideHeadsfields are well-positioned and properly documented. The zero-initialization viamemset(line 332) will correctly set these to 0, which serves as a sentinel value for "use default stride" in the stride computation logic.include/flashinfer/trtllm/fmha/kernelParams.h (2)
198-210: LGTM! Stride computation logic is correct.The token stride computation properly handles both user-provided and inferred strides:
- When
qStrideTokensis 0, the code derives the stride from the tensor layout- The packed QKV case correctly adds K/V head dimensions to the hidden dimension
- The check on line 206 prevents an invalid configuration (packed QKV with grouped heads)
212-229: LGTM! GQA stride handling is correctly implemented.The stride computation for grouped query attention (GQA) is correct:
strideGroupedHeadsrepresents the stride between consecutive heads within a groupstrideHeadsrepresents the stride between groups, which is why it's multiplied bynumGroupedHeadson line 227- Both default to
mHeadDimQkwhen user-provided stride is 0 (sentinel value)The resulting stride vector
[1, strideGroupedHeads, strideHeads, strideTokens]aligns properly with the shape vector dimensions.csrc/trtllm_fmha_kernel_launcher.cu (3)
76-88: LGTM! Launcher signature correctly extended with stride parameters.The new stride parameters (
q_stride_tokens,q_stride_heads,kv_stride_*) enable explicit stride specification for non-contiguous tensor support. The parameter placement and types are appropriate.
117-118: LGTM! Stride parameters correctly assigned to runner params.The new query stride fields are properly propagated to
runner_params, aligning with the updatedTllmGenFmhaRunnerParamsstructure.
263-265: LGTM! Stride computation from tensor layout is correct.The stride computation using
query.stride(0)for tokens andquery.stride(1)for heads correctly extracts the tensor's memory layout information. This automatically handles both contiguous and non-contiguous query tensors without requiring users to manually specify strides.Query tensor shape is
[num_tokens, num_heads, head_dim], so:
stride(0)→ stride between consecutive tokens ✓stride(1)→ stride between consecutive heads ✓Also applies to: 344-346
tests/attention/test_trtllm_gen_attention.py (2)
423-423: LGTM! Good test coverage for non-contiguous query paths.The addition of the
non_contiguous_queryparameter and its parametrization with[False, True]ensures both contiguous and non-contiguous query tensors are tested, validating the new stride support.Also applies to: 656-656, 672-672, 690-690
534-541: LGTM! Non-contiguous query handling is correctly implemented.The test logic properly creates a non-contiguous query when requested and uses the same
q_inputfor both the direct API call and the wrapper test, ensuring consistent testing of the stride support.Also applies to: 602-603
|
/bot run |
|
[SUCCESS] Pipeline #40565938: 12/20 passed |
📌 Description
As requested by @nandor , this pr implements non-contiguous query for trtllm-gen attention backend (by passing the stride to tma descriptor constructor).
We can also add similar supports to xqa as well, but in this PR we only make change to trtllm-gen backend.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
cc @PerkzZheng
Summary by CodeRabbit
New Features
Performance Improvements
Tests
✏️ Tip: You can customize this high-level summary in your review settings.