Skip to content

feat: support non-contiguous query for trtllm-gen attention backend#2254

Merged
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
yzh119:trtllm-gen-non-contiguous-query
Dec 22, 2025
Merged

feat: support non-contiguous query for trtllm-gen attention backend#2254
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
yzh119:trtllm-gen-non-contiguous-query

Conversation

@yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Dec 21, 2025

📌 Description

As requested by @nandor , this pr implements non-contiguous query for trtllm-gen attention backend (by passing the stride to tma descriptor constructor).

We can also add similar supports to xqa as well, but in this PR we only make change to trtllm-gen backend.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

cc @PerkzZheng

Summary by CodeRabbit

  • New Features

    • Support for non-contiguous query tensors in attention operations, enabling more flexible memory layouts.
  • Performance Improvements

    • More explicit stride handling for attention to better accommodate varied tensor layouts and improve memory use.
  • Tests

    • Extended tests to cover non-contiguous query inputs and validate attention behavior across layouts.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 21, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Propagates explicit Q and K/V stride parameters through the TensorRT LLM paged-attention path: adds qStrideTokens/qStrideHeads to runner params, updates launcher signature and callers to pass computed strides, refactors stride inference to accept user-specified or layout-derived values, and extends tests to cover non-contiguous queries.

Changes

Cohort / File(s) Change Summary
Launcher signature & callers
csrc/trtllm_fmha_kernel_launcher.cu
Updated trtllm_paged_attention_launcher signature to accept q_stride_tokens, q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch; assigned qStrideTokens and qStrideHeads into runner_params; updated trtllm_paged_attention_decode and trtllm_paged_attention_context to compute and pass these strides.
Runner params struct
include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
Added int qStrideTokens and int qStrideHeads members to TllmGenFmhaRunnerParams.
Kernel stride inference
include/flashinfer/trtllm/fmha/kernelParams.h
Allow user-provided Q strides (tokens/heads) when non-zero; otherwise derive token/head strides from layout and QKV packing; handle grouped-head (GQA) adjustments for head stride.
Tests: non-contiguous queries
tests/attention/test_trtllm_gen_attention.py
Added make_query_non_contiguous() helper; added non_contiguous_query test parameter; branch tests to exercise non-contiguous query tensors and skip incompatible backends.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Pay extra attention to:
    • csrc/trtllm_fmha_kernel_launcher.cu: parameter reordering and correct mapping into runner_params.
    • include/flashinfer/trtllm/fmha/kernelParams.h: conditional logic for default stride derivation and GQA multiplier correctness.
    • Tests: correctness of non-contiguous query creation and proper test coverage/skip conditions.

Possibly related PRs

Suggested reviewers

  • joker-eph
  • aleozlx
  • cyx-6
  • wenscarl
  • PerkzZheng
  • nvmbreughe
  • jiahanc

Poem

🐰 I bounded through kernels, counting each head and token,
I stitched explicit strides where once they were broken.
Non-contiguous hops now tested with glee,
Heads and tokens march true, as clear as can be.
A tiny rabbit's nod — precise, subtle, and woken.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: support non-contiguous query for trtllm-gen attention backend' directly and clearly describes the main change: adding support for non-contiguous queries to the trtllm-gen backend, which aligns with the primary objective and changes throughout the codebase.
Description check ✅ Passed The description explains the purpose (supporting non-contiguous queries for trtllm-gen backend as requested), mentions the implementation approach (passing stride to TMA descriptor), and notes scope limitations (only trtllm-gen, not xqa yet). However, the pre-commit and test checklist items are listed but not marked as complete.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 76a66e1 and 032eace.

📒 Files selected for processing (1)
  • tests/attention/test_trtllm_gen_attention.py (13 hunks)
🧰 Additional context used
🪛 Ruff (0.14.8)
tests/attention/test_trtllm_gen_attention.py

1482-1482: Unused function argument: num_qo_heads

(ARG001)


1482-1482: Unused function argument: head_dim

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
tests/attention/test_trtllm_gen_attention.py (3)

423-423: LGTM! Prefill test changes are well-structured.

The addition of the non_contiguous_query parameter and conditional logic to create either contiguous or non-contiguous query tensors is clean and provides comprehensive test coverage for stride handling in the prefill path.

Also applies to: 533-541, 603-603, 656-656, 672-672, 690-690


766-766: LGTM! Decode test changes mirror prefill implementation effectively.

The decode path follows the same pattern as prefill for non-contiguous query testing. The conditional skip for the xqa backend (Lines 1122-1124) is correct, as that backend doesn't yet support non-contiguous queries.

Also applies to: 929-937, 1016-1016, 1103-1103, 1120-1124, 1144-1144


1482-1495: Well-designed test helper for stride validation.

The make_query_non_contiguous function creates a non-contiguous tensor by allocating a larger buffer (2×D in the last dimension) and slicing it, which correctly produces non-default strides. The assertion on Line 1494 validates the non-contiguity. This approach effectively tests the stride-aware path through the FMHA launcher.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request implements support for non-contiguous query tensors within the trtllm-gen attention backend, a feature specifically requested by a collaborator. This enhancement allows the attention mechanism to process query inputs that are not stored contiguously in memory, which is crucial for optimizing memory usage and improving compatibility with diverse data representations. The core modifications involve updating the kernel's parameter structures and logic to correctly interpret and utilize query stride information, ensuring robust performance across various memory access patterns.

Highlights

  • Non-Contiguous Query Support: The trtllm-gen attention backend now supports non-contiguous query tensors, enhancing flexibility in memory layouts for query inputs.
  • Stride Parameter Integration: New q_stride_tokens and q_stride_heads parameters have been introduced in the trtllm_paged_attention_launcher function and TllmGenFmhaRunnerParams struct to explicitly handle query strides.
  • Dynamic Stride Calculation: The KernelParams constructor intelligently uses provided query strides or calculates them based on the tensor layout if not explicitly given, ensuring correct memory access.
  • Expanded Test Coverage: New test cases have been added for both prefill and decode operations to validate the correct functioning with non-contiguous query inputs, including a utility function to create such tensors.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully adds support for non-contiguous query tensors in the trtllm-gen attention backend by passing the query tensor's strides to the kernel. The changes in the CUDA/C++ code are logical, and the new functionality is well-covered by tests. I have provided a few suggestions to enhance code maintainability and clarity, addressing code duplication in csrc/trtllm_fmha_kernel_launcher.cu, a redundant calculation in include/flashinfer/trtllm/fmha/kernelParams.h, and unused parameters in a new test helper function in tests/attention/test_trtllm_gen_attention.py.

Comment on lines +344 to +346
// Query stride: [num_tokens, num_heads, head_dim]
int q_stride_tokens = query.stride(0); // stride between tokens
int q_stride_heads = query.stride(1); // stride between heads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to compute q_stride_tokens and q_stride_heads is also present in trtllm_paged_attention_decode at lines 263-265. There is a fair amount of duplicated code between trtllm_paged_attention_decode and trtllm_paged_attention_context for extracting tensor metadata (shapes, strides, dtypes). To improve maintainability, consider refactoring this common logic into a helper function or struct.

Comment on lines +214 to +223
int32_t strideHeads{options.qStrideHeads};
if (strideHeads == 0) {
strideHeads = options.mHeadDimQk;
}
// The stride between grouped heads (consecutive heads within a GQA group).
// Use user-provided stride if available, otherwise use headDimQk.
int32_t strideGroupedHeads{options.qStrideHeads};
if (strideGroupedHeads == 0) {
strideGroupedHeads = options.mHeadDimQk;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to determine strideHeads and strideGroupedHeads is identical, which is redundant. You can simplify this by initializing strideGroupedHeads from strideHeads before strideHeads is potentially modified for GQA. This will make the code cleaner and easier to maintain.

Suggested change
int32_t strideHeads{options.qStrideHeads};
if (strideHeads == 0) {
strideHeads = options.mHeadDimQk;
}
// The stride between grouped heads (consecutive heads within a GQA group).
// Use user-provided stride if available, otherwise use headDimQk.
int32_t strideGroupedHeads{options.qStrideHeads};
if (strideGroupedHeads == 0) {
strideGroupedHeads = options.mHeadDimQk;
}
int32_t strideHeads{options.qStrideHeads};
if (strideHeads == 0) {
strideHeads = options.mHeadDimQk;
}
// The stride between grouped heads (consecutive heads within a GQA group) is the same as the base head stride.
int32_t strideGroupedHeads{strideHeads};

)


def make_query_non_contiguous(q, num_qo_heads, head_dim):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameters num_qo_heads and head_dim are unused in this function. The shape is inferred directly from the input tensor q. Please remove these unused parameters to simplify the function signature.

Suggested change
def make_query_non_contiguous(q, num_qo_heads, head_dim):
def make_query_non_contiguous(q):

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_attention.py (1)

1478-1491: Remove unused parameters from the helper function.

The num_qo_heads and head_dim parameters are not used in the function body. The function only uses q.shape to extract the dimensions (line 1485). These parameters should be removed for clarity.

🔎 Suggested refactor
-def make_query_non_contiguous(q, num_qo_heads, head_dim):
+def make_query_non_contiguous(q):
     """
     Create a non-contiguous version of the query tensor.
     Create a (N, H, 2*D) tensor and slice the first D dimensions: x[..., :D]
     This produces a non-contiguous view with the same data.
     """
     n, h, d = q.shape
     # Create a larger tensor with 2*D in the last dimension
     large_tensor = torch.zeros(n, h, 2 * d, dtype=q.dtype, device=q.device)
     large_tensor[..., :d] = q
     # Slice to get non-contiguous query (only last dim is contiguous)
     q_non_contiguous = large_tensor[..., :d]
     assert not q_non_contiguous.is_contiguous(), "Query should be non-contiguous"
     return q_non_contiguous

And update the call sites on lines 536, 932:

-        q_input = make_query_non_contiguous(q, num_qo_heads, head_dim)
+        q_input = make_query_non_contiguous(q)
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 519671d and 76a66e1.

📒 Files selected for processing (4)
  • csrc/trtllm_fmha_kernel_launcher.cu (6 hunks)
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h (1 hunks)
  • include/flashinfer/trtllm/fmha/kernelParams.h (1 hunks)
  • tests/attention/test_trtllm_gen_attention.py (13 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_attention.py (2)
flashinfer/prefill.py (1)
  • trtllm_batch_context_with_kv_cache (3461-3673)
flashinfer/decode.py (1)
  • trtllm_batch_decode_with_kv_cache (2067-2361)
🪛 Ruff (0.14.8)
tests/attention/test_trtllm_gen_attention.py

1478-1478: Unused function argument: num_qo_heads

(ARG001)


1478-1478: Unused function argument: head_dim

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (8)
include/flashinfer/trtllm/fmha/fmhaRunnerParams.h (1)

235-238: LGTM! Clean addition of query stride fields.

The new qStrideTokens and qStrideHeads fields are well-positioned and properly documented. The zero-initialization via memset (line 332) will correctly set these to 0, which serves as a sentinel value for "use default stride" in the stride computation logic.

include/flashinfer/trtllm/fmha/kernelParams.h (2)

198-210: LGTM! Stride computation logic is correct.

The token stride computation properly handles both user-provided and inferred strides:

  • When qStrideTokens is 0, the code derives the stride from the tensor layout
  • The packed QKV case correctly adds K/V head dimensions to the hidden dimension
  • The check on line 206 prevents an invalid configuration (packed QKV with grouped heads)

212-229: LGTM! GQA stride handling is correctly implemented.

The stride computation for grouped query attention (GQA) is correct:

  • strideGroupedHeads represents the stride between consecutive heads within a group
  • strideHeads represents the stride between groups, which is why it's multiplied by numGroupedHeads on line 227
  • Both default to mHeadDimQk when user-provided stride is 0 (sentinel value)

The resulting stride vector [1, strideGroupedHeads, strideHeads, strideTokens] aligns properly with the shape vector dimensions.

csrc/trtllm_fmha_kernel_launcher.cu (3)

76-88: LGTM! Launcher signature correctly extended with stride parameters.

The new stride parameters (q_stride_tokens, q_stride_heads, kv_stride_*) enable explicit stride specification for non-contiguous tensor support. The parameter placement and types are appropriate.


117-118: LGTM! Stride parameters correctly assigned to runner params.

The new query stride fields are properly propagated to runner_params, aligning with the updated TllmGenFmhaRunnerParams structure.


263-265: LGTM! Stride computation from tensor layout is correct.

The stride computation using query.stride(0) for tokens and query.stride(1) for heads correctly extracts the tensor's memory layout information. This automatically handles both contiguous and non-contiguous query tensors without requiring users to manually specify strides.

Query tensor shape is [num_tokens, num_heads, head_dim], so:

  • stride(0) → stride between consecutive tokens ✓
  • stride(1) → stride between consecutive heads ✓

Also applies to: 344-346

tests/attention/test_trtllm_gen_attention.py (2)

423-423: LGTM! Good test coverage for non-contiguous query paths.

The addition of the non_contiguous_query parameter and its parametrization with [False, True] ensures both contiguous and non-contiguous query tensors are tested, validating the new stride support.

Also applies to: 656-656, 672-672, 690-690


534-541: LGTM! Non-contiguous query handling is correctly implemented.

The test logic properly creates a non-contiguous query when requested and uses the same q_input for both the direct API call and the wrapper test, ensuring consistent testing of the stride support.

Also applies to: 602-603

@yzh119
Copy link
Collaborator Author

yzh119 commented Dec 21, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !207 has been created, and the CI pipeline #40565938 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #40565938: 12/20 passed

Copy link
Contributor

@PerkzZheng PerkzZheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LDTM. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants