Skip to content

Conversation

@qsang-nv
Copy link
Collaborator

@qsang-nv qsang-nv commented Oct 29, 2025

📌 Description

Expose xqa backend to trtllm attention interface, and improve layout coverage of trtllm-gen and xqa backends.

Now both trtllm-gen/xqa supports NHD/HND kv-cache layout.

  • support NHD layout for trtllm-gen
  • refactor xqa (869c0c1)
    • allow user passed stride_page/head/token
    • support both HND and NHD
    • remove macros such as PAGED_KV_CACHE_LAYOUT and USE_PAGED_KV_CACHE
  • adding unittests for both trtllm-gen/xqa on NHD/HND
  • adding unified API for trtllm-gen/xqa, and unified unittest

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added xqa-based batch decode API and public kv_layout option (NHD/HND); added enable_pdl toggle to inference wrappers.
  • Improvements

    • Automatic backend selection for decoding, consistent KV-layout normalization across paths, and unified stride-aware paged-KV handling with layout-aware shapes, scales, and workspace handling.
  • Tests

    • Expanded tests to cover both KV layouts, enable_pdl, new batch-decode workflows, backend/layout permutations, and fp8/mixed-dtype scenarios.

Signed-off-by: Qidi Sang <[email protected]>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new XQA (eXtended Query Attention) backend to optimize batch decoding operations within the flashinfer library. It conditionally leverages specialized XQA kernels for improved performance on specific NVIDIA GPU architectures and data types, including FP16/BF16 and FP8 tensor core operations, thereby enhancing efficiency for large language model inference.

Highlights

  • XQA Backend Integration: Introduced and integrated new xqa and xqa_mla functions, likely representing an eXtended Query Attention backend, to enhance decoding performance.
  • Conditional XQA Usage for Batch Decode: The trtllm_batch_decode_with_kv_cache function now conditionally utilizes the xqa function for decoding. This optimization is applied on NVIDIA GPUs with compute capability 9 or 12, specifically for FP16 or BF16 query data types, excluding NVFP4 output.
  • Conditional XQA_MLA Usage for Multi-Layer Attention: The trtllm_batch_decode_with_kv_cache_mla function has been updated to conditionally use the xqa_mla function. This specialized path is activated for compute capability 12 GPUs when processing FP8 (float8_e4m3fn) query and KV cache types, and when sinks is not provided, targeting efficient multi-layer attention operations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 29, 2025

Warning

Rate limit exceeded

@yzh119 has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 11 minutes and 14 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between e040826 and 43bf624.

📒 Files selected for processing (1)
  • flashinfer/decode.py (9 hunks)

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Threads runtime PDL control and stride-based paged‑KV handling through Python APIs, C++ bindings, and CUDA kernels; adds kv_layout (HND/NHD) with normalization paths; extends tensor maps, kernel/launcher signatures to accept enable_pdl and kv_stride parameters; updates tests to cover layouts and PDL permutations.

Changes

Cohort / File(s) Summary
Python decode & public APIs
\flashinfer/decode.py`, `flashinfer/prefill.py`, `flashinfer/xqa.py``
Added xqa_batch_decode_with_kv_cache; threaded kv_layout and enable_pdl through APIs; device-default enable_pdl; normalize HND↔NHD where required (trtllm-gen path converts NHD→HND before decode); updated signatures/docstrings.
JIT flags & Python tests
\flashinfer/jit/xqa.py`, `tests/attention/*`, `tests/attention/test_xqa_batch_decode.py`, `tests/attention/test_trtllm_gen_attention.py``
Removed forced PAGED_KV_CACHE NVCC flags; tests parameterized for kv_layout and enable_pdl; added xqa batch-decode tests; updated test helpers/shapes and tolerances for NHD/HND and mixed dtypes.
C++ bindings & wrappers
\csrc/flashinfer_xqa_binding.cu`, `csrc/xqa/xqa_wrapper.cu``
Appended enable_pdl to xqa_wrapper/_mla signatures; unified argument lists to always pass k/v cache and page list; extract and pass kv stride values (elements) into launch calls.
Launch APIs & headers
\csrc/xqa/mha.h`, `csrc/xqa/mha.cu`, `csrc/xqa/mha_sm90.cu`, `csrc/xqa/mla_sm120.cu``
Extended launch functions with bool enable_pdl and uint64_t kv_stride_page, kv_stride_token, kv_stride_head; replaced compile-time ENABLE_PDL checks with runtime flag and convert strides into head units for kernels.
Paged‑KV addressing & tensor maps
\csrc/xqa/mhaUtils.cuh`, `csrc/xqa/tensorMap.cpp`, `csrc/xqa/tensorMap.h``
Replaced conditional layout code with stride-based addressing; added token/head/stride fields to HeadPtr/IndexedHeadPtrImpl; made makeTensorMapForPagedKVCache accept stride_page/token/head parameters.
Kernel impls & loaders
\csrc/xqa/mla_sm120.cu`, `csrc/xqa/mha_sm90.cu`, `csrc/xqa/mha.cu`, `csrc/xqa/mhaUtils.cuh``
Removed conditional paged‑KV branches; KVTilePartLoader/loader constructors updated to take nbPages; tensor-map/tile logic threaded with stride params; some loads made bounds-checked instead of reinterpret_cast.
Bindings & module declarations
\csrc/flashinfer_xqa_binding.cu`, `csrc/xqa/defines.h`, `csrc/xqa/utils.cuh``
Added enable_pdl to exported wrappers; removed compile-time paged-KV defaults; introduced CUDA-arch-aware ENABLE_PDL defaults and added CUDA_ARCH == 1210 gating.
TRT‑LLM integration & decode/prefill
\csrc/trtllm_fmha_kernel_launcher.cu`, `flashinfer/prefill.py`, `tests/attention/test_trtllm_gen_attention.py``
Threaded kv_layout into trtllm APIs; removed hard assertion for HND and added runtime NHD→HND transposes for trtllm-gen; tests exercise both layouts/backends.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Py as Python API
    participant Bind as C++ Binding
    participant Host as Launch/Host
    participant CU as CUDA Kernel

    Py->>Bind: xqa_batch_decode_with_kv_cache(q, k_cache, v_cache, kv_layout, enable_pdl)
    Bind->>Bind: resolve enable_pdl if None (device_support_pdl)
    alt kv_layout == "HND"
        Bind->>Bind: transpose K/V (HND → NHD) for xqa path
    end
    Bind->>Bind: extract kv_stride_page/token/head (in elements)
    Bind->>Host: call launch*(..., enable_pdl, kv_stride_page, kv_stride_token, kv_stride_head, stream)
    Host->>Host: makeLaunchConfig(..., enable_pdl)
    Host->>Host: makeTensorMapForPagedKVCache(..., stride_page, stride_token, stride_head)
    Host->>CU: kernel(..., stride_page_in_heads, stride_token_in_heads, stride_head_in_heads)
    CU->>CU: stride-based addressing selects page/head/token
    CU-->>Host: outputs
    Host-->>Bind: return tensor
    Bind-->>Py: result
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

  • Areas needing focused review:
    • Unit/scale consistency for kv_stride parameters (elements ↔ bytes ↔ heads) across Python → binding → tensorMap → kernel.
    • ABI/signature consistency between Python/JIT entry calls, C++ bindings, and CUDA kernel entry points.
    • Correctness and placement of HND↔NHD transposes (both xqa and trtllm paths).
    • HeadPtr / IndexedHeadPtrImpl arithmetic, page-index bit-flag handling, and null-page cases.
    • FP8/bf16/fp16 mixed-dtype interactions and corresponding test tolerances.

Possibly related PRs

Suggested reviewers

  • cyx-6
  • djmmoss
  • yongwww
  • aleozlx
  • wenscarl

Poem

"🐰 I hopped through pages with stride and cheer,

Swapped HND and NHD when the layout was near,
PDL listens when devices give the sign,
Kernels find tokens, heads, and pages in line,
A rabbit claps — decode now shines!"

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.79% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The PR title "feat: add xqa backend and completes NHD/HND coverage for trtllm-gen/xqa backend" is directly related to the main changes in the changeset. According to the raw_summary, this PR adds xqa backend support and extends both trtllm-gen and xqa backends to support NHD and HND KV cache layouts. The title accurately captures both the primary change (adding xqa backend) and the layout improvements, and it is specific enough that a reviewer scanning the history would understand the primary objectives. While the title is somewhat verbose, it avoids vague terminology and clearly conveys what was implemented.
Description Check ✅ Passed The PR description follows the required template structure and includes a comprehensive "📌 Description" section that clearly explains the changes: exposing the xqa backend to the trtllm attention interface, adding NHD/HND layout support for both backends, refactoring xqa with stride parameters while removing hardcoded macros, and adding test coverage. The description is detailed and on-topic, with specific technical details about the implementation. The "🔍 Related Issues" section is present but unfilled, and the "Reviewer Notes" section is empty; however, these are non-critical sections. The Pre-commit Checklist items are all marked as complete, though the "All tests are passing" checkbox is unchecked, which the author notes indicates tests still need verification.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a new XQA (eXtended Query Attention) backend into the FlashInfer library. The primary goal is to enhance decoding performance by leveraging specialized xqa and xqa_mla kernels for specific GPU architectures and data types, such as FP16/BF16 and FP8. This change introduces conditional logic within existing batch decoding functions to dynamically select the most optimized backend based on the execution environment and tensor properties, thereby improving efficiency for supported configurations.

Highlights

  • New XQA Backend Integration: Introduces xqa and xqa_mla functions, providing optimized decoding paths for specific hardware and data types.
  • Conditional XQA Decoding: The trtllm_batch_decode_with_kv_cache function now conditionally uses the xqa backend for decoding when running on CUDA compute capabilities 9 or 12 with FP16/BF16 query types.
  • Conditional XQA_MLA Decoding for FP8: The trtllm_batch_decode_with_kv_cache_mla function now conditionally uses the xqa_mla backend for FP8 decoding on CUDA compute capability 12.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new xqa backend for decoding, which is enabled for specific GPU architectures. The changes are a good addition, but I've identified a few issues. There's a bug in the compute capability check that restricts the new backend from running on all supported GPUs. I also found a critical data type mismatch for the semaphore buffer being passed to the new xqa functions, which could lead to runtime errors. Additionally, there are some hardcoded values in scale calculations that should be replaced with function arguments to improve code maintainability.

Comment on lines 2238 to 2256
xqa(
query,
k_cache.reshape(-1, head_dim),
v_cache.reshape(-1, head_dim),
block_tables,
seq_lens,
out,
workspace_0,
workspace_1,
num_kv_heads,
page_size,
sinks=sinks,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sliding_win_size=window_left + 1 if window_left >= 0 else 0,
sm_count=sm_count,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The xqa function expects the semaphores argument to be a torch.uint32 tensor, but workspace_1 has a torch.uint8 dtype because it's a chunk of workspace_buffer. This type mismatch will cause issues in the xqa kernel. You should reinterpret the tensor view to the correct dtype before passing it.

Suggested change
xqa(
query,
k_cache.reshape(-1, head_dim),
v_cache.reshape(-1, head_dim),
block_tables,
seq_lens,
out,
workspace_0,
workspace_1,
num_kv_heads,
page_size,
sinks=sinks,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sliding_win_size=window_left + 1 if window_left >= 0 else 0,
sm_count=sm_count,
)
xqa(
query,
k_cache.reshape(-1, head_dim),
v_cache.reshape(-1, head_dim),
block_tables,
seq_lens,
out,
workspace_0,
workspace_1.view(torch.uint32),
num_kv_heads,
page_size,
sinks=sinks,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sliding_win_size=window_left + 1 if window_left >= 0 else 0,
sm_count=sm_count,
)

Comment on lines 2452 to 2467
xqa_mla(
query,
kv_cache,
kv_cache,
block_tables,
seq_lens,
out,
workspace_0,
workspace_1,
block_size,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sm_count=sm_count,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The xqa_mla function expects the semaphores argument to be a torch.uint32 tensor, but workspace_1 has a torch.uint8 dtype because it's a chunk of workspace_buffer. This type mismatch will cause issues in the xqa_mla kernel. You should reinterpret the tensor view to the correct dtype before passing it.

Suggested change
xqa_mla(
query,
kv_cache,
kv_cache,
block_tables,
seq_lens,
out,
workspace_0,
workspace_1,
block_size,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sm_count=sm_count,
)
xqa_mla(
query,
kv_cache,
kv_cache,
block_tables,
seq_lens,
out,
workspace_0,
workspace_1.view(torch.uint32),
block_size,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sm_count=sm_count,
)

)
# To decide if using xqa to decode
if (
get_compute_capability(torch.device(device="cuda"))[0] in [9, 12]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The compute capability check for using the xqa backend is missing support for SM100 (compute capability 10) and SM110 (compute capability 11). The xqa kernel supports SM90, SM100, SM110, and SM120. The condition should be updated to include 10 and 11 to enable the backend on all supported architectures.

Suggested change
get_compute_capability(torch.device(device="cuda"))[0] in [9, 12]
get_compute_capability(torch.device(device="cuda"))[0] in [9, 10, 11, 12]

Comment on lines 2436 to 2450
if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
kv_scale_value = bmm2_scale_tensor.item()
q_scale_value = (
bmm1_scale_log2_tensor.item()
/ kv_scale_value
* ((128 + 64) ** 0.5)
* math.log2(math.e)
)
else:
kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0
q_scale_value = (
bmm1_scale / kv_scale_value * ((128 + 64) ** 0.5)
if bmm1_scale is not None
else 1.0
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The scale calculation for q_scale_value uses a hardcoded value (128 + 64). This corresponds to qk_nope_head_dim + qk_rope_head_dim. Using the function arguments qk_nope_head_dim and qk_rope_head_dim will improve code readability and maintainability.

Suggested change
if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
kv_scale_value = bmm2_scale_tensor.item()
q_scale_value = (
bmm1_scale_log2_tensor.item()
/ kv_scale_value
* ((128 + 64) ** 0.5)
* math.log2(math.e)
)
else:
kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0
q_scale_value = (
bmm1_scale / kv_scale_value * ((128 + 64) ** 0.5)
if bmm1_scale is not None
else 1.0
)
if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
kv_scale_value = bmm2_scale_tensor.item()
q_scale_value = (
bmm1_scale_log2_tensor.item()
/ kv_scale_value
* ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5)
* math.log2(math.e)
)
else:
kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0
q_scale_value = (
bmm1_scale / kv_scale_value * ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5)
if bmm1_scale is not None
else 1.0
)

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new XQA backend for decoding, which is a valuable addition. The implementation is mostly correct, but I've identified a few issues that need to be addressed. These include a missing compute capability in a conditional check, a critical data type mismatch for semaphore buffers that could lead to runtime errors, and the use of hardcoded values where function parameters should be used for better maintainability. Please see the detailed comments for specific suggestions.

Comment on lines 2238 to 2256
xqa(
query,
k_cache.reshape(-1, head_dim),
v_cache.reshape(-1, head_dim),
block_tables,
seq_lens,
out,
workspace_0,
workspace_1,
num_kv_heads,
page_size,
sinks=sinks,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sliding_win_size=window_left + 1 if window_left >= 0 else 0,
sm_count=sm_count,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The semaphores parameter of the xqa function expects a torch.uint32 tensor, but it's being passed workspace_1, which is a chunk of a torch.uint8 tensor. This type mismatch will likely cause a runtime error or incorrect behavior. You should view the tensor as torch.uint32 before passing it to the function.

Suggested change
xqa(
query,
k_cache.reshape(-1, head_dim),
v_cache.reshape(-1, head_dim),
block_tables,
seq_lens,
out,
workspace_0,
workspace_1,
num_kv_heads,
page_size,
sinks=sinks,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sliding_win_size=window_left + 1 if window_left >= 0 else 0,
sm_count=sm_count,
)
xqa(
query,
k_cache.reshape(-1, head_dim),
v_cache.reshape(-1, head_dim),
block_tables,
seq_lens,
out,
workspace_0,
workspace_1.view(torch.uint32),
num_kv_heads,
page_size,
sinks=sinks,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sliding_win_size=window_left + 1 if window_left >= 0 else 0,
sm_count=sm_count,
)

Comment on lines 2452 to 2467
xqa_mla(
query,
kv_cache,
kv_cache,
block_tables,
seq_lens,
out,
workspace_0,
workspace_1,
block_size,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sm_count=sm_count,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The semaphores parameter of the xqa_mla function expects a torch.uint32 tensor, but it's being passed workspace_1, which is a chunk of a torch.uint8 tensor. This type mismatch will likely cause a runtime error or incorrect behavior. You should view the tensor as torch.uint32 before passing it to the function.

Suggested change
xqa_mla(
query,
kv_cache,
kv_cache,
block_tables,
seq_lens,
out,
workspace_0,
workspace_1,
block_size,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sm_count=sm_count,
)
xqa_mla(
query,
kv_cache,
kv_cache,
block_tables,
seq_lens,
out,
workspace_0,
workspace_1.view(torch.uint32),
block_size,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sm_count=sm_count,
)

Comment on lines 2227 to 2231
if (
get_compute_capability(torch.device(device="cuda"))[0] in [9, 12]
and out_dtype != "nvfp4"
and query.dtype in [torch.float16, torch.bfloat16]
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check for compute capability is missing support for SM100 (compute capability 10). The xqa implementation in flashinfer/xqa.py indicates support for SM90, SM100, and SM120, which correspond to compute capabilities 9, 10, and 12. This condition should be updated to include 10 to enable the XQA path on SM100 GPUs.

Suggested change
if (
get_compute_capability(torch.device(device="cuda"))[0] in [9, 12]
and out_dtype != "nvfp4"
and query.dtype in [torch.float16, torch.bfloat16]
):
if (
get_compute_capability(torch.device(device="cuda"))[0] in [9, 10, 12]
and out_dtype != "nvfp4"
and query.dtype in [torch.float16, torch.bfloat16]
):

Comment on lines 2436 to 2450
if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
kv_scale_value = bmm2_scale_tensor.item()
q_scale_value = (
bmm1_scale_log2_tensor.item()
/ kv_scale_value
* ((128 + 64) ** 0.5)
* math.log2(math.e)
)
else:
kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0
q_scale_value = (
bmm1_scale / kv_scale_value * ((128 + 64) ** 0.5)
if bmm1_scale is not None
else 1.0
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The head dimension is hardcoded as (128 + 64). This makes the code less maintainable. You should use the function parameters qk_nope_head_dim and qk_rope_head_dim instead, which are available in this function's scope.

Suggested change
if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
kv_scale_value = bmm2_scale_tensor.item()
q_scale_value = (
bmm1_scale_log2_tensor.item()
/ kv_scale_value
* ((128 + 64) ** 0.5)
* math.log2(math.e)
)
else:
kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0
q_scale_value = (
bmm1_scale / kv_scale_value * ((128 + 64) ** 0.5)
if bmm1_scale is not None
else 1.0
)
if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
kv_scale_value = bmm2_scale_tensor.item()
q_scale_value = (
bmm1_scale_log2_tensor.item()
/ kv_scale_value
* ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5)
* math.log2(math.e)
)
else:
kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0
q_scale_value = (
bmm1_scale / kv_scale_value * ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5)
if bmm1_scale is not None
else 1.0
)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bb6b620 and 289d526.

📒 Files selected for processing (1)
  • flashinfer/decode.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (2)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
include/flashinfer/trtllm/common.h (1)
  • device (83-90)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/decode.py (1)

24-24: LGTM - Import statement is correct.

The xqa and xqa_mla imports are properly used in the conditional decode paths added below.

bkryu
bkryu previously requested changes Oct 29, 2025
Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @qsang-nv, Currently test_trtllm_gen_attention.py and test_trtllm_gen_mla.py check for SM 100f and skips tests otherwise.

The current changes to decode.py will only impact SM90 and 120f cases, hence there are no unit tests being added for the current change. Can you add them?

Signed-off-by: Qidi Sang <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
flashinfer/decode.py (2)

2518-2521: Use the query tensor’s device when checking compute capability.

This condition still probes cuda:0 unconditionally. On multi-GPU systems, a request running on another device will pick the wrong capability and skip the XQA-MLA fast path. The earlier review already flagged this; please switch to query.device.

-        get_compute_capability(torch.device(device="cuda"))[0] == 12
+        get_compute_capability(query.device)[0] == 12

2345-2368: Fix semaphore buffer dtype passed to XQA.

xqa expects its semaphore buffer to be torch.uint32, but workspace_1 remains torch.int8 after torch.chunk. Reinterpreting without converting will trip kernel-side type checks or yield garbage. Cast the second chunk to the correct view before passing it along.

-    workspace_0, workspace_1 = torch.chunk(workspace_buffer, 2, dim=0)
+    workspace_0, workspace_1 = torch.chunk(workspace_buffer, 2, dim=0)
+    semaphores = workspace_1.view(torch.uint32)
@@
-        workspace_1,
+        semaphores,
🧹 Nitpick comments (1)
flashinfer/xqa.py (1)

276-276: Inconsistent default value for cached function parameter.

get_xqa_module_mla has enable_pdl: bool = True with a default value, while get_xqa_module at line 40 has enable_pdl: bool without a default. For functions decorated with @functools.cache, it's better to be consistent—either both should have defaults or neither should. Cached functions typically avoid defaults to ensure all parameters are explicit in the cache key.

Consider removing the default value for consistency:

-    enable_pdl: bool = True,
+    enable_pdl: bool,
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a9f8bc8 and c2a0cad.

📒 Files selected for processing (6)
  • flashinfer/aot.py (2 hunks)
  • flashinfer/decode.py (4 hunks)
  • flashinfer/jit/xqa.py (6 hunks)
  • flashinfer/xqa.py (10 hunks)
  • tests/attention/test_xqa.py (5 hunks)
  • tests/attention/test_xqa_batch_decode.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
flashinfer/jit/xqa.py (1)
flashinfer/jit/core.py (1)
  • gen_jit_spec (287-353)
flashinfer/xqa.py (2)
flashinfer/jit/core.py (1)
  • build_and_load (272-284)
flashinfer/utils.py (2)
  • register_custom_op (272-281)
  • register_custom_op (291-310)
tests/attention/test_xqa_batch_decode.py (3)
tests/test_helpers/sink_attention_reference.py (1)
  • sink_attention_unified (39-402)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
flashinfer/decode.py (4)
  • BatchDecodeWithPagedKVCacheWrapper (582-1411)
  • use_tensor_cores (780-781)
  • use_tensor_cores (1577-1578)
  • xqa_batch_decode_with_kv_cache (2254-2377)
flashinfer/decode.py (2)
flashinfer/xqa.py (4)
  • xqa (56-93)
  • xqa (124-265)
  • xqa_mla (292-321)
  • xqa_mla (348-457)
flashinfer/utils.py (3)
  • device_support_pdl (568-572)
  • get_device_sm_count (595-596)
  • get_compute_capability (251-254)
🪛 Ruff (0.14.2)
tests/attention/test_xqa_batch_decode.py

277-277: Unpacked variable in_kv_lens is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

flashinfer/decode.py

2260-2260: Unused function argument: max_seq_len

(ARG001)


2266-2266: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/xqa.py (1)

53-53: LGTM: Registration names correctly include enable_pdl.

Including enable_pdl in both the custom op and fake op registration names (lines 53, 96, 289, 324) is correct and essential. This ensures that PDL-enabled and PDL-disabled kernels are registered as distinct operations, preventing cache collisions and ensuring the correct module is selected at runtime.

Also applies to: 96-96, 289-289, 324-324

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
flashinfer/decode.py (2)

2361-2379: Critical: Thread enable_pdl flag to the XQA kernel.

The enable_pdl flag is computed at line 2320 but never passed to the xqa call. This causes the kernel to always use its default value (True), ignoring the caller's intent and device capabilities. Note that xqa_mla at line 2559 correctly passes this parameter.

Apply this fix:

         sliding_win_size=window_left + 1 if window_left >= 0 else 0,
         sm_count=sm_count,
+        enable_pdl=enable_pdl,
     )

2521-2526: Fix device handling in compute capability check.

Line 2522 uses torch.device(device="cuda") which defaults to cuda:0, causing incorrect capability detection on multi-GPU systems when the query tensor resides on a different GPU.

Apply this fix:

     if (
-        get_compute_capability(torch.device(device="cuda"))[0] == 12
+        get_compute_capability(query.device)[0] == 12
         and query.dtype == torch.float8_e4m3fn
🧹 Nitpick comments (1)
flashinfer/decode.py (1)

2264-2264: Consider whether max_seq_len should be used.

Static analysis flags that max_seq_len is unused in this function. If the parameter is genuinely not needed by the xqa kernel (which infers max_seq_len from block_tables and page_size), consider removing it from the signature to avoid confusion. If it's intended for future use or validation, add a comment explaining why it's currently unused.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c2a0cad and 21de9af.

📒 Files selected for processing (1)
  • flashinfer/decode.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (2)
flashinfer/xqa.py (4)
  • xqa (56-93)
  • xqa (124-265)
  • xqa_mla (292-321)
  • xqa_mla (348-457)
flashinfer/utils.py (3)
  • device_support_pdl (568-572)
  • get_device_sm_count (595-596)
  • get_compute_capability (251-254)
🪛 Ruff (0.14.2)
flashinfer/decode.py

2264-2264: Unused function argument: max_seq_len

(ARG001)


2270-2270: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Signed-off-by: Qidi Sang <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
flashinfer/decode.py (1)

2349-2379: Pass the right buffers into xqa and honor the caller’s output/PDL choices

We’re still hitting the same blockers that earlier reviews called out:

  • out defaults to None, but xqa requires a real tensor – callers that rely on the optional output will crash.
  • The semaphore chunk stays uint8; the kernel expects a uint32 view, so dtype checks fail.
  • enable_pdl never reaches xqa, forcing PDL on even when the caller disables it or the device doesn’t support it.

Please fix all three together.

     workspace_0, workspace_1 = torch.chunk(workspace_buffer, 2, dim=0)
     kv_scale_value = bmm2_scale
     q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5)
 
+    semaphores = workspace_1.view(torch.uint32)
+    if out is None:
+        out = torch.empty_like(query)
+
     k_cache_new = k_cache.reshape(-1, head_dim).contiguous()
     v_cache_new = v_cache.reshape(-1, head_dim).contiguous()
     query_new = query.unsqueeze(1).contiguous()
     seq_lens_new = seq_lens.unsqueeze(1).contiguous()
     sinks_new = (
@@
         workspace_0,
-        workspace_1,
+        semaphores,
         num_kv_heads,
         page_size,
         sinks=sinks_new,
         q_scale=q_scale_value,
         kv_scale=torch.tensor(
             [kv_scale_value], dtype=torch.float32, device=query.device
         ),
         sliding_win_size=window_left + 1 if window_left >= 0 else 0,
         sm_count=sm_count,
+        enable_pdl=enable_pdl,
     )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 21de9af and 595ee1b.

📒 Files selected for processing (3)
  • csrc/xqa/mha_sm90.cu (1 hunks)
  • flashinfer/decode.py (3 hunks)
  • tests/attention/test_xqa_batch_decode.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_xqa_batch_decode.py (3)
tests/test_helpers/sink_attention_reference.py (1)
  • sink_attention_unified (39-402)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
flashinfer/decode.py (1)
  • xqa_batch_decode_with_kv_cache (2258-2381)
flashinfer/decode.py (2)
flashinfer/xqa.py (2)
  • xqa (56-93)
  • xqa (124-265)
flashinfer/utils.py (2)
  • device_support_pdl (568-572)
  • get_device_sm_count (595-596)
🪛 Ruff (0.14.2)
tests/attention/test_xqa_batch_decode.py

277-277: Unpacked variable in_kv_lens is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

flashinfer/decode.py

2264-2264: Unused function argument: max_seq_len

(ARG001)


2270-2270: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
csrc/xqa/mha_sm90.cu (1)

1969-1974: Clamping behavior verified—refactor is safe.

The "WithDup" naming is a deliberate design pattern used across similar load functions in the codebase (loadShmColWiseVecWithDup, loadShmRowWiseVecWithDup). The clamping behavior is intentional: when thread indices exceed bound, they clamp to the last valid index and duplicate its value.

The semantics are correct:

  • bound = headGrpSize - 1 (passed at line 2743) represents the last valid index (0-indexed, inclusive)
  • Accessing gmemVec[(headGrpSize - 1) * GmmaAccCoreMat::cols + j] is valid per the attentionSinksVec initialization (lines 1228–1231)
  • The refactor from reinterpret_cast to explicit element-wise copying improves type safety without changing behavior

Comment on lines 69 to 115
# Create separate K and V caches with NHD layout
max_seq_len = torch.max(seq_lens).item()
num_tokens = max_seq_len * batch_size
num_pages = (num_tokens + page_size - 1) // page_size
ref_kv_dtype_torch = DTYPE_MAP[ref_kv_dtype]
if kv_dtype != "fp8":
assert kv_dtype == ref_kv_dtype, (
"kv_dtype and ref_kv_dtype must be the same for non-fp8 kv_cache"
)

# NHD layout: [num_pages, page_size, num_kv_heads, head_dim]
k_cache = torch.randn(
num_pages,
page_size,
num_kv_heads,
head_dim,
dtype=ref_kv_dtype_torch,
device=GPU_DEVICE,
)
v_cache = torch.randn(
num_pages,
page_size,
num_kv_heads,
head_dim,
dtype=ref_kv_dtype_torch,
device=GPU_DEVICE,
)

# Convert K and V separately to fp8 if needed
if kv_dtype == "fp8":
k_cache, k_scale = to_float8(k_cache / 4.0)
v_cache, v_scale = to_float8(v_cache / 4.0)
# use high precision and fake-quantization for reference to avoid precision/functional issue
ref_kv_cache = torch.stack(
[
k_cache.to(ref_kv_dtype_torch) * k_scale,
v_cache.to(ref_kv_dtype_torch) * v_scale,
],
dim=1,
)
else:
k_scale = v_scale = 1.0
ref_kv_cache = torch.stack([k_cache, v_cache], dim=1)
# Combine K and V into interleaved format for the API
kv_cache = torch.stack([k_cache, v_cache], dim=1)

return kv_cache, k_scale, v_scale, ref_kv_cache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Ensure the synthetic KV cache actually covers every page ID we emit

num_pages is computed from max_seq_len * batch_size, which can be smaller than the sum of ceil(seq_len_i / page_size) across the batch. Example: page_size=16, two requests with seq_lens=[17, 17]max_seq_len * batch_size = 34 so we allocate 3 pages, but the page tables legitimately reference 4 distinct pages. That produces out-of-range page IDs (crashing both the reference path and the kernel under test). Please size the cache by the total pages needed.

-    max_seq_len = torch.max(seq_lens).item()
-    num_tokens = max_seq_len * batch_size
-    num_pages = (num_tokens + page_size - 1) // page_size
+    page_per_seq = (seq_lens + page_size - 1) // page_size
+    num_pages = int(page_per_seq.sum().item())
+    max_seq_len = int(seq_lens.max().item())
🤖 Prompt for AI Agents
In tests/attention/test_xqa_batch_decode.py around lines 69 to 115, the
synthetic KV cache is undersized because num_pages is computed from max_seq_len
* batch_size rather than the actual total pages referenced by all sequences;
compute num_pages as the sum over batch of ceil(seq_lens[i] / page_size) (i.e.,
total_pages = sum((seq_lens + page_size - 1)//page_size)) and use that value to
allocate k_cache, v_cache, and any dependent quantities so every page ID emitted
by the page tables is covered.

@bkryu bkryu dismissed their stale review October 31, 2025 18:41

Dismissing to remove the "Request Change" tag after concerns are addressed

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
csrc/xqa/mha.cu (1)

1406-1410: Add runtime enable_pdl guard to PDL intrinsics and plumb parameter through kernel_mha

The kernel_mha function lacks an enable_pdl parameter. Both launchMHA and launchMHAFlashInfer possess this flag but fail to pass it to kernel_mha, causing unconditional PDL intrinsic execution at lines 1406-1410 when ENABLE_PDL is defined, regardless of runtime enable_pdl value.

Required changes:

  1. Add bool enable_pdl parameter to kernel_mha signature (mha.cu:2455)
  2. Guard PDL intrinsics with runtime flag at lines 1406-1410:
 #if ENABLE_PDL
+  if (enable_pdl) {
     preExit();
     acqBulk();
+  }
 #endif
  1. Pass enable_pdl from launchMHA to kernel_mha (lines 2577, 2627)
  2. Pass enable_pdl from launchMHAFlashInfer to kernel_mha (lines 2706, 2730)
csrc/xqa/mha_sm90.cu (1)

1254-1260: Add runtime guard for PDL intrinsics or remove compile-time conditional compilation

The review comment is accurate. PDL intrinsics (preExit() and acqBulk()) in both mha_sm90.cu (lines 1255, 1258, 1286, 1307, 1343, 1412) and mha.cu (lines 1407–1408) are currently gated only by compile-time ENABLE_PDL macros. The enable_pdl runtime parameter exists at the host level (passed to makeLaunchConfig()) but is not passed as a kernel parameter, so kernels cannot access it to conditionally execute these intrinsics.

This creates a runtime/compile-time divergence: launching with enable_pdl=false has no effect if the code was compiled with ENABLE_PDL defined.

To resolve:

  • Add bool enable_pdl as a kernel parameter to all affected kernels (kernel_mha_impl in mha.cu, kernel_mha in both files)
  • Thread it through the launch paths
  • Wrap intrinsic calls: if (enable_pdl) { preExit(); acqBulk(); }
flashinfer/xqa.py (1)

99-118: Missing enable_pdl parameter in fake op signature.

The _fake_xqa function is missing the enable_pdl: bool parameter that was added to the corresponding custom op at line 73. Fake ops must match the signature of their real counterparts for proper torch.compile integration.

Apply this diff:

     def _fake_xqa(
         run_sm90_fp8_mha: bool,
         sm_count: int,
         num_kv_heads: int,
         sliding_win_size: int,
         q_scale: float,
         output: torch.Tensor,
         q: torch.Tensor,
         sinks: Optional[torch.Tensor],
         k_cache: torch.Tensor,
         v_cache: torch.Tensor,
         page_table: torch.Tensor,
         max_seq_len: int,
         seq_lens: torch.Tensor,
         batch_size: int,
         kv_scale: torch.Tensor,
         semaphores: torch.Tensor,
         workspace_buffer: torch.Tensor,
+        enable_pdl: bool,
     ) -> None:
         pass
🧹 Nitpick comments (3)
csrc/xqa/mha_sm90.cu (2)

3039-3077: Host launcher: enable_pdl passed to launch config (good)

Passing enable_pdl into makeLaunchConfig is correct. Once the kernel accepts enable_pdl, also pass it as a kernel argument.


3179-3215: FlashInfer launcher: enable_pdl passed (good)

Same note: forward enable_pdl into the kernel parameter when added.

csrc/xqa/mha.cu (1)

2593-2594: Host launch: enable_pdl wired to launch config (good)

Once the kernel accepts enable_pdl, also append it to the cudaLaunchKernelEx argument list and the kernel signature.

Also applies to: 2700-2701

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 595ee1b and 9c08d33.

📒 Files selected for processing (8)
  • csrc/flashinfer_xqa_binding.cu (2 hunks)
  • csrc/xqa/defines.h (1 hunks)
  • csrc/xqa/mha.cu (5 hunks)
  • csrc/xqa/mha.h (6 hunks)
  • csrc/xqa/mha_sm90.cu (5 hunks)
  • csrc/xqa/mla_sm120.cu (4 hunks)
  • csrc/xqa/xqa_wrapper.cu (4 hunks)
  • flashinfer/xqa.py (14 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
csrc/xqa/mla_sm120.cu (1)
csrc/xqa/hostUtils.h (1)
  • makeLaunchConfig (4-12)
csrc/flashinfer_xqa_binding.cu (1)
csrc/xqa/mha_sm90.cu (2)
  • scratch (506-513)
  • scratch (506-506)
flashinfer/xqa.py (1)
flashinfer/utils.py (1)
  • device_support_pdl (568-572)
csrc/xqa/mha.cu (1)
csrc/xqa/hostUtils.h (1)
  • makeLaunchConfig (4-12)
csrc/xqa/mha_sm90.cu (1)
csrc/xqa/hostUtils.h (1)
  • makeLaunchConfig (4-12)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (12)
csrc/xqa/xqa_wrapper.cu (2)

31-48: MLA wrapper: enable_pdl correctly threaded

Parameter is appended and forwarded to launchMLAFlashInfer before stream. Looks good.

Please confirm all FFI callers updated to pass the new trailing bool to xqa_wrapper_mla.


67-95: MHA wrapper: enable_pdl correctly threaded

Parameter appended and propagated to the MHA launcher selection. No issues spotted.

Please confirm Python/TVM bindings and tests now pass enable_pdl to xqa_wrapper.

csrc/flashinfer_xqa_binding.cu (2)

20-31: Binding prototypes match wrapper signature change

Signatures include trailing enable_pdl; exported names unchanged. Looks good.

Ensure downstream TVM packed-call sites are regenerated to include the new bool (last arg).


34-53: Binding for xqa_wrapper: OK

Prototype and export aligned with wrapper changes.

csrc/xqa/mha.cu (1)

92-102: Add SM 1210 path: OK

Including CUDA_ARCH==1210 alongside 860/890/1200 is consistent with existing pattern.

Please confirm performance tuning (preferedKHeadPartBytes=64, cacheVTileSeqLen=32) is valid for 12.1 as well.

csrc/xqa/mla_sm120.cu (2)

1727-1727: LGTM! Clean transition from compile-time to runtime configuration.

The function signature now accepts enable_pdl as a runtime parameter, and it's correctly propagated to makeLaunchConfig. This replaces the previous compile-time ENABLE_PDL != 0 check with dynamic configuration.

Also applies to: 1765-1765


1864-1864: LGTM! Consistent with the launchMLA changes.

The same pattern is applied to launchMLAFlashInfer, threading enable_pdl through to the launch configuration.

Also applies to: 1888-1888

flashinfer/xqa.py (4)

29-29: LGTM! Appropriate import for hardware capability detection.

The device_support_pdl import is used to automatically detect whether the device supports PDL based on compute capability (SM90+).


73-73: LGTM! Parameter correctly threaded through custom op.

The enable_pdl parameter is properly added to the custom op signature and propagated to the underlying wrapper.

Also applies to: 93-93


411-413: Same terminology verification needed for xqa_mla.

The same PDL acronym expansion issue applies here. Ensure consistency with the verified terminology.


141-141: LGTM! Parameter correctly integrated with smart defaulting.

The enable_pdl parameter is properly added as optional, defaults to hardware capability detection via device_support_pdl, and is correctly propagated through the call chain.

Also applies to: 215-215, 271-271

csrc/xqa/mha.h (1)

131-131: LGTM! Consistent parameter addition across all launch functions.

All six launch function declarations correctly add the bool enable_pdl parameter in a consistent position (before the cudaStream_t stream parameter). This aligns with the implementation changes in the corresponding .cu files.

Also applies to: 150-150, 192-192, 211-212, 234-234, 252-252

Comment on lines 131 to 142
#ifndef ENABLE_PDL
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
#if __CUDA_ARCH__ == 900
#define ENABLE_PDL 2
#else
#define ENABLE_PDL 1
#endif
#else
/* default for host or older architectures */
#define ENABLE_PDL 0
#endif
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Runtime-vs-compile-time PDL mismatch

ENABLE_PDL defaults to 1/2 for SM≥90 at compile time, but kernels still use this macro while the host now passes a runtime enable_pdl. Host compilation sees ENABLE_PDL==0 (no CUDA_ARCH), so kernels may execute preExit/acqBulk even when enable_pdl=false at launch. This is inconsistent and can lead to invalid usage of programmatic stream serialization.

Thread enable_pdl into the kernels and guard PDL intrinsics with the runtime flag (while keeping the arch guards). See follow-up diffs in kernel files below.

🤖 Prompt for AI Agents
In csrc/xqa/defines.h around lines 131-142, the current ENABLE_PDL macro
selection based solely on __CUDA_ARCH__ causes kernels to unconditionally
compile PDL intrinsics even though the host uses a runtime enable_pdl flag;
update the headers and usages so that kernels still respect the arch guards but
also check a passed-in runtime boolean (e.g., enable_pdl) before invoking PDL
intrinsics: keep the existing __CUDA_ARCH__ checks to determine PDL availability
at compile time, expose a runtime enable_pdl parameter into kernels (thread it
through kernel arguments or capture in device lambdas), and wrap all calls to
preExit/acqBulk (and other PDL intrinsics) with a combined condition that
requires both compile-time availability and the runtime flag so that when
enable_pdl is false at launch no PDL intrinsics execute on device.

Comment on lines +1969 to 1975
uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols;
#pragma unroll
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
ret[i][j] = gmemVec[baseOffset + j];
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Out‑of‑bounds read in loadGmemColWiseVecWithDup for attention sinks

gmemVec points to a buffer of size headGrpSize (see finalizeAndWriteOut_sync passing attentionSinksVec[0]), but this code multiplies the index by GmmaAccCoreMat::cols and reads baseOffset+j, which can exceed headGrpSize. We should load a single sink value per head and duplicate it across columns, without advancing memory by cols.

Apply this fix:

-    uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
-    uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols;
-#pragma unroll
-    for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
-      ret[i][j] = gmemVec[baseOffset + j];
-    }
+    uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
+#pragma unroll
+    for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
+      // Duplicate the same head sink across the 2 columns
+      ret[i][j] = gmemVec[clampedIdx];
+    }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols;
#pragma unroll
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
ret[i][j] = gmemVec[baseOffset + j];
}
}
uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
#pragma unroll
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
// Duplicate the same head sink across the 2 columns
ret[i][j] = gmemVec[clampedIdx];
}

kv_scale: Optional[torch.Tensor] = None,
sliding_win_size: int = 0,
sm_count: Optional[int] = None,
enable_pdl: Optional[bool] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to use the decorator for the support check?
@backend_requirement ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes using backend_requirement would be preferred. There are lots of functions with this flag, we can clean them up in a standalone PR.

@yzh119 yzh119 changed the title add xqa backend feat: add xqa backend and completes NHD/HND coverage for trtllm-gen/xqa backend Nov 1, 2025
@yzh119
Copy link
Collaborator

yzh119 commented Nov 1, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !105 has been created, and the CI pipeline #37716350 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (6)
flashinfer/decode.py (6)

2152-2163: Critical: Squeeze singleton dimension when unpacking kv_cache.

When kv_cache.shape[1] == 1 at line 2156, both k_cache and v_cache are assigned to the same tensor without squeezing the singleton dimension. This preserves shape [num_pages, 1, ...] instead of the expected [num_pages, ...]. Downstream operations at lines 2200-2205 that transpose dimensions will operate on the wrong axes.

Apply this fix:

     if isinstance(kv_cache, tuple):
         k_cache, v_cache = kv_cache
     else:
         if kv_cache.shape[1] == 1:
-            k_cache, v_cache = kv_cache, kv_cache
+            k_cache = kv_cache.squeeze(1)
+            v_cache = k_cache
         else:
             assert kv_cache.shape[1] == 2, (
                 "When kv_cache is a single tensor, the second dimension must be 1 or 2"
             )
             # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...])
             # it doesn't change underlying storage
             k_cache, v_cache = kv_cache.unbind(dim=1)

2165-2168: Add architecture validation for xqa backend.

The auto backend selection defaults to xqa for any non-SM100 architecture, but xqa requires SM90+ (as enforced in xqa.py). This could cause runtime failures on older GPUs like SM80.

Consider adding an architecture check:

 if backend == "auto":
-    backend = (
-        "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
-    )
+    major = get_compute_capability(query.device)[0]
+    if major == 10:
+        backend = "trtllm-gen"
+    elif major >= 9:
+        backend = "xqa"
+    else:
+        raise GPUArchitectureError(
+            f"No compatible backend for SM{major}0. Requires SM90+ or SM100+."
+        )

2334-2343: Critical: Fix type hints, allocate output, and remove unused parameter.

Multiple issues in the function signature and parameter handling:

  1. Line 2341: enable_pdl: bool = None violates PEP 484. Must be Optional[bool] = None.
  2. Line 2334: max_seq_len parameter is declared but never used in the function body.
  3. Output allocation missing: When out=None, the function passes None directly to xqa at line 2450, but xqa expects a non-Optional torch.Tensor. This will cause a runtime error.

Apply this fix:

 def xqa_batch_decode_with_kv_cache(
     query: torch.Tensor,
     kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
     workspace_buffer: torch.Tensor,
     block_tables: torch.Tensor,
     seq_lens: torch.Tensor,
-    max_seq_len: int,
     bmm1_scale: float,
     bmm2_scale: float,
     window_left: int = -1,
     out: Optional[torch.Tensor] = None,
     sinks: Optional[torch.Tensor] = None,
     kv_layout: str = "NHD",
-    enable_pdl: bool = None,
+    enable_pdl: Optional[bool] = None,
     q_len_per_req: Optional[int] = 1,
 ) -> torch.Tensor:
     ...
     enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl
 
     assert q_len_per_req == 1, "xqa not support speculative decoding yet"
 
     ...
     
+    # Allocate output if not provided
+    if out is None:
+        out = torch.empty_like(query)
+    
     xqa(
         query_new,

2434-2434: Critical: workspace_1 dtype mismatch with xqa expectations.

Line 2434 uses torch.chunk to split workspace_buffer (uint8) into two halves, but at line 2452, workspace_1 is passed as the semaphores parameter to xqa, which expects a torch.uint32 tensor. This type mismatch will cause incorrect behavior or runtime errors.

You cannot safely reinterpret uint8 as uint32 with .view(torch.uint32) because the buffer may not be properly aligned or sized. Allocate a proper uint32 semaphores tensor:

-    workspace_0, workspace_1 = torch.chunk(workspace_buffer, 2, dim=0)
+    # Reserve first half for workspace_0 (uint8), allocate separate uint32 semaphores
+    workspace_0_size = workspace_buffer.numel() // 2
+    workspace_0 = workspace_buffer[:workspace_0_size]
+    
+    # Allocate semaphores as uint32 (estimate size based on xqa requirements)
+    semaphores = torch.zeros(
+        workspace_buffer.numel() // 8, dtype=torch.uint32, device=query.device
+    )
+    
     kv_scale_value = bmm2_scale
     q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5)
 
     query_new = query.unsqueeze(1).contiguous()
     seq_lens_new = seq_lens.unsqueeze(1).contiguous()
     sinks_new = (
         sinks.reshape(num_kv_heads, -1).contiguous() if sinks is not None else None
     )
 
     xqa(
         query_new,
         k_cache,
         v_cache,
         block_tables,
         seq_lens_new,
         out,
         workspace_0,
-        workspace_1,
+        semaphores,
         num_kv_heads,
         page_size,
         sinks=sinks_new,
         q_scale=q_scale_value,
         kv_scale=torch.tensor(
             [kv_scale_value], dtype=torch.float32, device=query.device
         ),
         sliding_win_size=window_left + 1 if window_left >= 0 else 0,
         kv_layout=kv_layout,
         sm_count=sm_count,
     )

2444-2463: Critical: Pass enable_pdl parameter to xqa.

The enable_pdl flag is computed at line 2396 but never passed to the xqa call. The kernel will always use its default value (True), ignoring the caller's choice and device support check.

Add the parameter to the xqa call:

     xqa(
         query_new,
         k_cache,
         v_cache,
         block_tables,
         seq_lens_new,
         out,
         workspace_0,
         workspace_1,
         num_kv_heads,
         page_size,
         sinks=sinks_new,
         q_scale=q_scale_value,
         kv_scale=torch.tensor(
             [kv_scale_value], dtype=torch.float32, device=query.device
         ),
         sliding_win_size=window_left + 1 if window_left >= 0 else 0,
         kv_layout=kv_layout,
         sm_count=sm_count,
+        enable_pdl=enable_pdl,
     )

2400-2411: Critical: Squeeze singleton dimension before assigning kv_cache.

When kv_cache.shape[1] == 1, line 2404 assigns both k_cache and v_cache to the same tensor without squeezing the singleton dimension. This means k_cache has shape [num_pages, 1, ...] instead of the expected [num_pages, ...]. Downstream shape extraction at lines 2423-2432 will read the wrong dimensions.

Squeeze the singleton dimension first:

     if isinstance(kv_cache, tuple):
         k_cache, v_cache = kv_cache
     else:
         if kv_cache.shape[1] == 1:
-            k_cache, v_cache = kv_cache, kv_cache
+            k_cache = kv_cache.squeeze(1)
+            v_cache = k_cache
         else:
             assert kv_cache.shape[1] == 2, (
                 "When kv_cache is a single tensor, the second dimension must be 1 or 2"
             )
             # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...])
             # it doesn't change underlying storage
             k_cache, v_cache = kv_cache.unbind(dim=1)
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_attention.py (1)

231-237: Remove unnecessary no-op rearrange operations.

In the NHD layout branch, einops.rearrange(k_page, "p h d -> p h d") and the corresponding v_page rearrange are no-ops since the source and target patterns are identical. These can be removed for clarity and slight performance improvement.

Apply this diff:

             else:  # NHD layout
                 # NHD layout: [page_size, num_kv_heads, head_dim]
                 if j == pages_i - 1:
                     k_page = k_page[:last_len_i, :, :]
                     v_page = v_page[:last_len_i, :, :]
-                k_list.append(einops.rearrange(k_page, "p h d -> p h d"))
-                v_list.append(einops.rearrange(v_page, "p h d -> p h d"))
+                k_list.append(k_page)
+                v_list.append(v_page)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4950b67 and e535e80.

📒 Files selected for processing (2)
  • flashinfer/decode.py (9 hunks)
  • tests/attention/test_trtllm_gen_attention.py (23 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (3)
flashinfer/xqa.py (2)
  • xqa (55-94)
  • xqa (125-283)
flashinfer/utils.py (7)
  • get_compute_capability (251-254)
  • FP4Tensor (599-669)
  • get_device_sm_count (595-596)
  • ceil_div (575-586)
  • round_up (589-591)
  • check_shape_dtype_device (518-536)
  • device_support_pdl (568-572)
csrc/trtllm_fmha_kernel_launcher.cu (2)
  • trtllm_paged_attention_decode (197-265)
  • trtllm_paged_attention_decode (197-204)
🪛 Ruff (0.14.2)
flashinfer/decode.py

2173-2173: Avoid specifying long messages outside the exception class

(TRY003)


2175-2175: Avoid specifying long messages outside the exception class

(TRY003)


2218-2218: Consider (*query.shape[:-1], ceil_div(query.shape[-1], 2)) instead of concatenation

Replace with (*query.shape[:-1], ceil_div(query.shape[-1], 2))

(RUF005)


2241-2241: Avoid specifying long messages outside the exception class

(TRY003)


2264-2268: Avoid specifying long messages outside the exception class

(TRY003)


2279-2279: Avoid specifying long messages outside the exception class

(TRY003)


2282-2282: Avoid specifying long messages outside the exception class

(TRY003)


2324-2324: Avoid specifying long messages outside the exception class

(TRY003)


2334-2334: Unused function argument: max_seq_len

(ARG001)


2341-2341: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (8)
flashinfer/decode.py (1)

2200-2205: Verify layout conversion is applied after squeeze.

The HND→NHD transpose at lines 2204-2205 assumes the input has shape [..., H, N, D]. This depends on the kv_cache unpacking at lines 2152-2163 correctly removing any singleton dimensions. If the squeeze issue at line 2156 is not fixed, the transpose will operate on the wrong axes.

Once the squeeze issue is resolved, this code will be correct.

tests/attention/test_trtllm_gen_attention.py (7)

93-149: LGTM! Layout-aware KV cache generation is correct.

The function properly creates KV caches with the correct dimension ordering for both HND and NHD layouts.


342-342: LGTM! Test parameterization properly extended.

The addition of kv_layout and backend parameters ensures comprehensive coverage across different layout configurations and backend implementations.

Also applies to: 571-571, 897-898


645-662: LGTM! Backend-specific GPU checks are appropriate.

The GPU architecture requirements and feature limitations for each backend are correctly validated with clear skip messages.


801-804: LGTM! Backend-specific assertions are correctly applied.

The workspace buffer zero-check and wrapper testing are appropriately limited to the trtllm-gen backend where these requirements apply.

Also applies to: 842-845


819-821: LGTM! Tolerance adjustment for xqa+fp8 is reasonable.

The relaxed tolerances account for the precision characteristics of the xqa backend when using fp8 KV cache.


415-415: LGTM! Layout parameter properly threaded through test chain.

The kv_layout parameter is consistently passed through all helper functions and API calls, ensuring layout coherence across cache creation, flattening, and attention operations.

Also applies to: 470-470, 506-506, 690-690, 763-763, 796-796


976-976: Verify if these HND-only restrictions are still accurate.

Multiple comments state "trtllm-gen only support HND", but the PR objectives indicate that trtllm-gen now supports both HND and NHD layouts. These specialized tests (batch_size=1, head_dim=256, long sequences) restrict kv_layout to ["HND"] only.

If trtllm-gen indeed supports both layouts now, consider:

  1. Updating these comments to clarify that the HND restriction is specific to these test scenarios (not a general limitation)
  2. Or extending these tests to include NHD if the backend fully supports it

Run the following to check if there are any backend-specific layout restrictions documented elsewhere:

Also applies to: 1031-1031, 1097-1097

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #37716350: 1/17 passed

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)

3398-3401: Validate kv_layout and enforce seq_lens dtype.

Add early kv_layout validation (consistent with other entry points) and ensure seq_lens is int32 as expected by the kernel.

     if enable_pdl is None:
         enable_pdl = device_support_pdl(query.device)
+    _check_kv_layout(kv_layout)
+    if seq_lens.dtype != torch.int32:
+        seq_lens = seq_lens.to(torch.int32)
♻️ Duplicate comments (4)
csrc/trtllm_fmha_kernel_launcher.cu (1)

297-297: Same layout naming issue as the decode path.

This comment has the same naming inconsistency as line 235. Apply the same clarification here for consistency.

flashinfer/prefill.py (1)

3338-3339: Default kv_layout differs from rest of prefill APIs; add validation.

Several APIs here default to "NHD", but this function defaults to "HND" and doesn't validate kv_layout. Please align defaults or document the deliberate difference and add _check_kv_layout(kv_layout).

flashinfer/decode.py (2)

2441-2442: Major: seq_lens must be uint32 with beam axis.

Cast to uint32 to match the xqa wrapper’s contract.

-    seq_lens_new = seq_lens.unsqueeze(1).contiguous()
+    seq_lens_new = seq_lens.to(torch.uint32).unsqueeze(1).contiguous()

2403-2411: Critical: squeeze single‑channel KV tensor before inferring shape.

When kv_cache has shape [P, 1, ...], shape extraction misreads page_size/num_heads. Squeeze first.

-        if kv_cache.shape[1] == 1:
-            k_cache, v_cache = kv_cache, kv_cache
+        if kv_cache.shape[1] == 1:
+            # [num_pages, 1, ...] -> [num_pages, ...]
+            k_cache = kv_cache.squeeze(1)
+            v_cache = k_cache
🧹 Nitpick comments (5)
csrc/trtllm_fmha_kernel_launcher.cu (1)

235-237: Clarify the layout naming convention.

The comment states "NHD layout: [..., H, N, D]", but in standard tensor notation, NHD typically refers to (N, H, D) where N is the sequence/token dimension, H is heads, and D is the dimension. The layout shown as [..., H, N, D] would conventionally be called HND layout (heads first).

This naming inconsistency could confuse developers familiar with standard conventions and lead to integration errors when interfacing with other systems or documentation.

Consider either:

  1. Using standard naming (change "NHD" to "HND" in the comment), or
  2. Adding a clarification that explains this codebase's specific naming convention
-  // Assume NHD layout: [..., H, N, D]
+  // Assume HND layout: [..., H, N, D]

Or:

-  // Assume NHD layout: [..., H, N, D]
+  // Assume layout: [..., H, N, D] (referred to as NHD in this codebase)
flashinfer/prefill.py (1)

3338-3341: Type the sinks parameter correctly.

Kernels expect a single optional tensor for sinks, not a list. Update the annotation to avoid misuse downstream.

-    sinks: Optional[List[torch.Tensor]] = None,
+    sinks: Optional[torch.Tensor] = None,
flashinfer/decode.py (3)

2062-2080: API polish: sinks typing, kv_layout default consistency, Optional usage.

  • sinks should be Optional[torch.Tensor], not list.
  • kv_layout defaults to "HND" here while most decode/prefill APIs default to "NHD". Align or document.
  • Keep Optional types explicit.
-    sinks: Optional[List[torch.Tensor]] = None,
-    kv_layout: str = "HND",
-    enable_pdl: Optional[bool] = None,
+    sinks: Optional[torch.Tensor] = None,
+    kv_layout: str = "NHD",
+    enable_pdl: Optional[bool] = None,

Also add _check_kv_layout(kv_layout) early in the function.


2327-2343: Fix type hints and add kv_layout validation in xqa path.

  • enable_pdl should be Optional[bool]
  • q_len_per_req should be int (you assert == 1)
  • Validate kv_layout upfront
-def xqa_batch_decode_with_kv_cache(
+def xqa_batch_decode_with_kv_cache(
     query: torch.Tensor,
     kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
     workspace_buffer: torch.Tensor,
     block_tables: torch.Tensor,
     seq_lens: torch.Tensor,
     max_seq_len: int,
     bmm1_scale: float,
     bmm2_scale: float,
     window_left: int = -1,
     out: Optional[torch.Tensor] = None,
     sinks: Optional[torch.Tensor] = None,
     kv_layout: str = "NHD",
-    enable_pdl: bool = None,
-    q_len_per_req: Optional[int] = 1,
+    enable_pdl: Optional[bool] = None,
+    q_len_per_req: int = 1,
 ) -> torch.Tensor:
+    _check_kv_layout(kv_layout)

1873-1890: Type hint nit: enable_pdl should be Optional.

Signature shows enable_pdl: bool = None. Make it Optional[bool] for correctness.

-        enable_pdl: bool = None,
+        enable_pdl: Optional[bool] = None,
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e535e80 and 39e36dc.

📒 Files selected for processing (3)
  • csrc/trtllm_fmha_kernel_launcher.cu (2 hunks)
  • flashinfer/decode.py (10 hunks)
  • flashinfer/prefill.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (3)
flashinfer/xqa.py (2)
  • xqa (55-94)
  • xqa (125-283)
flashinfer/utils.py (7)
  • get_compute_capability (251-254)
  • FP4Tensor (599-669)
  • get_device_sm_count (595-596)
  • ceil_div (575-586)
  • round_up (589-591)
  • check_shape_dtype_device (518-536)
  • device_support_pdl (568-572)
csrc/trtllm_fmha_kernel_launcher.cu (2)
  • trtllm_paged_attention_decode (197-265)
  • trtllm_paged_attention_decode (197-204)
🪛 Ruff (0.14.2)
flashinfer/decode.py

2173-2173: Avoid specifying long messages outside the exception class

(TRY003)


2175-2175: Avoid specifying long messages outside the exception class

(TRY003)


2218-2218: Consider (*query.shape[:-1], ceil_div(query.shape[-1], 2)) instead of concatenation

Replace with (*query.shape[:-1], ceil_div(query.shape[-1], 2))

(RUF005)


2241-2241: Avoid specifying long messages outside the exception class

(TRY003)


2264-2268: Avoid specifying long messages outside the exception class

(TRY003)


2279-2279: Avoid specifying long messages outside the exception class

(TRY003)


2282-2282: Avoid specifying long messages outside the exception class

(TRY003)


2324-2324: Avoid specifying long messages outside the exception class

(TRY003)


2334-2334: Unused function argument: max_seq_len

(ARG001)


2341-2341: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/prefill.py (1)

2091-2096: Good: normalize layout for trtllm-gen.

The NHD→HND transpose here avoids kernel-side conditionals and matches trtllm-gen expectations. Nice and minimal.

flashinfer/decode.py (2)

1239-1243: Good: normalize layout for trtllm-gen.

The NHD→HND transpose here mirrors prefill and keeps decode kernels layout-stable.


2201-2206: Good: normalize layout for trtllm-gen path.

NHD→HND transpose keeps the paged cache consistent with trtllm-gen decode assumption.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (5)
flashinfer/decode.py (5)

2139-2144: Doc nit: correct architecture name.

The docstring incorrectly states "sm_90 and sm_91 (ampere architecture)". SM90/SM91 are Hopper architecture; Ampere corresponds to SM80/SM86.

Apply this fix:

     backend : str = "auto"
         The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``.
         When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability.
         For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend.
-        For sm_90 and sm_91 (ampere architecture), ``auto`` will choose ``xqa`` backend.
+        For sm_90 and sm_91 (Hopper architecture), ``auto`` will choose ``xqa`` backend.

2341-2342: Fix type hints: use explicit Optional and remove Optional from non-optional parameter.

  • enable_pdl: bool = None should be enable_pdl: Optional[bool] = None per PEP 484
  • q_len_per_req: Optional[int] = 1 should be q_len_per_req: int = 1 (the value is always an int, never None)

Apply this diff:

     kv_layout: str = "NHD",
-    enable_pdl: bool = None,
-    q_len_per_req: Optional[int] = 1,
+    enable_pdl: Optional[bool] = None,
+    q_len_per_req: int = 1,
 ) -> torch.Tensor:

As per coding guidelines.


2434-2436: Allocate uint32 semaphores buffer correctly.

The code slices workspace_u8 (uint8 dtype) for semaphores, but the xqa kernel expects a uint32 tensor. This dtype mismatch will cause incorrect behavior or runtime errors.

Apply this fix to properly allocate semaphores:

     workspace_u8 = workspace_buffer.view(torch.uint8)
-    semaphore = workspace_u8[: 8 * 1024 * 1024]  # reserve 8MB for semaphore
-    scratch = workspace_u8[8 * 1024 * 1024 :]
+    # Reserve 8MB for semaphores as uint32 (8MB / 4 bytes = 2M entries)
+    sem_bytes = 8 * 1024 * 1024
+    semaphores = workspace_u8[:sem_bytes].view(torch.uint32)
+    scratch = workspace_u8[sem_bytes:]

Based on learnings.


2165-2168: Guard unsupported GPU architectures in auto backend selection.

The current logic sends any non-SM10 device to xqa, but xqa requires SM90+ (see xqa.py:227). Devices with compute capability < 9.0 will fail at runtime. Additionally, SM11 and SM12 should be explicitly handled.

Apply this fix:

     if backend == "auto":
-        backend = (
-            "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
-        )
+        major = get_compute_capability(query.device)[0]
+        if major == 10:
+            backend = "trtllm-gen"
+        elif major in [9, 11, 12]:
+            backend = "xqa"
+        else:
+            raise GPUArchitectureError(
+                f"No compatible decode backend for SM{major}0. "
+                f"Requires SM90+, SM100+, or SM110+."
+            )

2446-2465: Critical: Allocate output tensor and pass enable_pdl parameter.

Two critical issues:

  1. Output allocation: out is Optional at line 2338 but xqa requires a non-Optional tensor. When out=None, passing None to xqa will cause a runtime error.

  2. Missing enable_pdl: The enable_pdl flag is computed at line 2396 but never passed to the xqa call, so the kernel always uses its default behavior instead of respecting the caller's choice.

Apply this fix:

+    # Allocate output if not provided
+    if out is None:
+        out = torch.empty_like(query)
+
     xqa(
         query_new,
         k_cache,
         v_cache,
         block_tables,
         seq_lens_new,
         out,
         scratch,
         semaphores,
         num_kv_heads,
         page_size,
         sinks=sinks_new,
         q_scale=q_scale_value,
         kv_scale=torch.tensor(
             [kv_scale_value], dtype=torch.float32, device=query.device
         ),
         sliding_win_size=window_left + 1 if window_left >= 0 else 0,
         kv_layout=kv_layout,
         sm_count=sm_count,
+        enable_pdl=enable_pdl,
     )

Based on learnings.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 39e36dc and ed46ea9.

📒 Files selected for processing (1)
  • flashinfer/decode.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (3)
flashinfer/xqa.py (2)
  • xqa (55-94)
  • xqa (125-283)
flashinfer/utils.py (7)
  • get_compute_capability (251-254)
  • FP4Tensor (599-669)
  • get_device_sm_count (595-596)
  • ceil_div (575-586)
  • round_up (589-591)
  • check_shape_dtype_device (518-536)
  • device_support_pdl (568-572)
csrc/trtllm_fmha_kernel_launcher.cu (2)
  • trtllm_paged_attention_decode (197-265)
  • trtllm_paged_attention_decode (197-204)
🪛 Ruff (0.14.2)
flashinfer/decode.py

2173-2173: Avoid specifying long messages outside the exception class

(TRY003)


2175-2175: Avoid specifying long messages outside the exception class

(TRY003)


2218-2218: Consider (*query.shape[:-1], ceil_div(query.shape[-1], 2)) instead of concatenation

Replace with (*query.shape[:-1], ceil_div(query.shape[-1], 2))

(RUF005)


2241-2241: Avoid specifying long messages outside the exception class

(TRY003)


2264-2268: Avoid specifying long messages outside the exception class

(TRY003)


2279-2279: Avoid specifying long messages outside the exception class

(TRY003)


2282-2282: Avoid specifying long messages outside the exception class

(TRY003)


2324-2324: Avoid specifying long messages outside the exception class

(TRY003)


2334-2334: Unused function argument: max_seq_len

(ARG001)


2341-2341: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
flashinfer/decode.py (5)

24-24: LGTM: Import of xqa module.

The import is correctly placed and necessary for the new xqa backend integration.


1238-1242: LGTM: Layout conversion for trtllm-gen backend.

The NHD→HND conversion via transpose(-3, -2) correctly swaps the sequence and head dimensions as needed by the trtllm-gen backend. The conditional ensures it only applies when necessary.


2170-2199: LGTM: XQA backend routing with proper guards.

The xqa backend path correctly validates unsupported features (nvfp4 output, scale factors) and allocates output tensors as needed before routing to xqa_batch_decode_with_kv_cache. The parameter forwarding is complete.


2200-2324: LGTM: TRTLLM-gen backend with layout conversion and nvfp4 support.

The trtllm-gen path correctly:

  • Converts NHD→HND layout when necessary
  • Handles nvfp4 output with proper shape calculations and scale factor management
  • Falls back to standard output types with appropriate validation
  • Passes all parameters including enable_pdl to the kernel

2438-2444: LGTM: Scale computation and input preparation.

The scale calculations correctly compute q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5) and the input tensors are properly reshaped with unsqueeze(1) and made contiguous for the xqa kernel.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 1, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !105 has been updated with latest changes, and the CI pipeline #37735466 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/attention/test_xqa.py (1)

236-275: Consider extracting a helper for cache creation to reduce duplication.

The cache tensor creation logic is duplicated between NHD and HND branches. While functional, extracting this into a helper function could improve maintainability.

Example refactor:

def create_cache_tensors(layout, total_num_pages, tokens_per_page, nb_k_heads, valid_elems_per_head, dtype, device):
    if layout == "NHD":
        shape = (total_num_pages, tokens_per_page, nb_k_heads, valid_elems_per_head)
    else:  # HND
        shape = (total_num_pages, nb_k_heads, tokens_per_page, valid_elems_per_head)
    
    cache_k = torch.zeros(*shape, dtype=dtype, device=device)
    cache_v = torch.zeros(*shape, dtype=dtype, device=device)
    return cache_k, cache_v

cache_k_heads, cache_v_heads = create_cache_tensors(
    kv_layout, total_num_pages, tokens_per_page, nb_k_heads, 
    valid_elems_per_head, input_type, "cuda"
)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ed46ea9 and e040826.

📒 Files selected for processing (2)
  • csrc/xqa/mla_sm120.cu (10 hunks)
  • tests/attention/test_xqa.py (12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/xqa/mla_sm120.cu (2)
csrc/xqa/hostUtils.h (1)
  • makeLaunchConfig (4-12)
csrc/xqa/tensorMap.cpp (2)
  • makeTensorMapForPagedKVCache (75-108)
  • makeTensorMapForPagedKVCache (75-79)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (6)
tests/attention/test_xqa.py (3)

43-60: LGTM! Clean implementation of dual KV cache layout support.

The kv_layout parameter addition to CacheSeq correctly handles both NHD and HND layouts with appropriate tensor indexing. The default value maintains backward compatibility.


301-324: LGTM! Correct handling of both cache layouts.

The cache_head_at helper correctly implements layout-aware cache indexing for both NHD and HND formats. While it duplicates some logic from CacheSeq.__getitem__, the direct cache access is more efficient for the padding zeroing use case.


446-616: LGTM! MLA test correctly uses NHD layout.

The enable_pdl parameter addition and cache initialization properly reflect that MLA only supports NHD layout. The CacheSeq construction correctly relies on the default kv_layout="NHD", which is appropriate for MLA.

csrc/xqa/mla_sm120.cu (3)

84-102: LGTM! Good addition of bounds checking for paged KV cache.

The nbPages parameter addition to KVTilePartLoader enables proper validation of page indices (used at line 144 for bound checks). The simplified baseOffset calculation using idxReq * cacheList.maxNbPagesPerSeq is cleaner and more maintainable.


1651-1723: LGTM! Clean integration of layout-flexible KV cache support.

The additions to launchMLA properly thread through:

  • Separate K/V cache pool pointers for explicit memory management
  • enable_pdl for programmatic dependent launch control
  • Stride parameters (kv_stride_page, kv_stride_token, kv_stride_head) enabling both NHD and HND layouts

The stride parameters are correctly propagated to makeTensorMapForPagedKVCache for both K and V tensor maps, enabling the flexible cache layouts tested in the Python layer.


1773-1831: LGTM! Consistent layout support in FlashInfer variant.

launchMLAFlashInfer mirrors the launchMLA changes with the same KV cache pool and stride parameter additions. The consistent API design between both launch functions maintains a clean interface.

@yzh119 yzh119 enabled auto-merge (squash) November 1, 2025 20:43
@yzh119 yzh119 disabled auto-merge November 1, 2025 22:47
@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #37735466: 13/17 passed

@yzh119 yzh119 merged commit 5854494 into flashinfer-ai:main Nov 2, 2025
4 checks passed
yzh119 pushed a commit that referenced this pull request Nov 7, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

In #2001 , XQA decode kernels became available through
`trtllm_batch_decode_with_kv_cache` on SM90 and SM120.

Current PR adds the ability to benchmark through the microbenchmark.

Example microbenchmark command and outputs before and after:
```
### Before current PR:
## SM90 (H200)
$ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck  --use_cupti
[WARNING] trtllm-gen-native for routine BatchDecodeWithPagedKVCacheWrapper is not supported on compute capability 9.0. Skipping.
[PERF] fa2            :: median time 0.035 ms; std 0.002 ms; achieved tflops 7.721 TFLOPs/sec; achieved tb_per_sec 0.966 TB/sec
[PERF] cudnn          :: median time 0.020 ms; std 0.000 ms; achieved tflops 13.519 TFLOPs/sec; achieved tb_per_sec 1.692 TB/sec

## SM120 (RTX 5090)
$ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck  --use_cupti
[WARNING] trtllm-gen-native for routine BatchDecodeWithPagedKVCacheWrapper is not supported on compute capability 12.0. Skipping.
[PERF] fa2            :: median time 0.033 ms; std 0.001 ms; achieved tflops 8.204 TFLOPs/sec; achieved tb_per_sec 1.027 TB/sec
[PERF] cudnn          :: median time 0.030 ms; std 0.000 ms; achieved tflops 8.943 TFLOPs/sec; achieved tb_per_sec 1.119 TB/sec

### After current PR:
## SM90 (H200)
$ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck  --use_cupti
[PERF] fa2            :: median time 0.035 ms; std 0.002 ms; achieved tflops 7.721 TFLOPs/sec; achieved tb_per_sec 0.966 TB/sec
[PERF] trtllm-gen-nati:: median time 0.019 ms; std 0.002 ms; achieved tflops 13.820 TFLOPs/sec; achieved tb_per_sec 1.729 TB/sec
[PERF] cudnn          :: median time 0.020 ms; std 0.000 ms; achieved tflops 13.574 TFLOPs/sec; achieved tb_per_sec 1.698 TB/sec

## SM120 (RTX 5090)
$ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck  --use_cupti
[PERF] fa2            :: median time 0.033 ms; std 0.001 ms; achieved tflops 8.121 TFLOPs/sec; achieved tb_per_sec 1.016 TB/sec
[PERF] trtllm-gen-nati:: median time 0.034 ms; std 0.001 ms; achieved tflops 7.903 TFLOPs/sec; achieved tb_per_sec 0.989 TB/sec
[PERF] cudnn          :: median time 0.030 ms; std 0.001 ms; achieved tflops 9.020 TFLOPs/sec; achieved tb_per_sec 1.129 TB/sec
```

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Chores**
* Standardized backend identifier to "trtllm-native" and expanded its
support across benchmark routines and utilities.
* Argument parsing now canonicalizes deprecated backend aliases and
emits a deprecation warning when encountered.
* **Documentation**
* README and tool-facing messages updated to use the canonical backend
name and include contextual notes about the change.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants