Skip to content

feat: Add DiT-oriented kernels where Qk (Bmm1) type can be reinterpreted into Int8 or BFloat16#2711

Open
xrq-phys wants to merge 6 commits intoflashinfer-ai:mainfrom
xrq-phys:ruqingx/feat/vx
Open

feat: Add DiT-oriented kernels where Qk (Bmm1) type can be reinterpreted into Int8 or BFloat16#2711
xrq-phys wants to merge 6 commits intoflashinfer-ai:mainfrom
xrq-phys:ruqingx/feat/vx

Conversation

@xrq-phys
Copy link

@xrq-phys xrq-phys commented Mar 6, 2026

📌 Description

This PR adds support for DiT-oriented TRTLLM kernels with 3 variants:

  • Qk in BFloat16 → Bmm1 in BFloat16
    • V in E4m3 → Bmm2 in E4m3
  • Qk in Int8 with SageAttention scaling factors → Bmm1 in Int8
    • V in E4m3 → Bmm2 in E4m3
  • Qk in E4m3 with SageAttention scaling factors → Bmm1 in E4m3
    • V in E4m3 → Bmm2 in E4m3

To integrate, the following changes are made to FlashInfer:

  • Our artifactory will produce a separate KernelMetaInfoVx definition in flashInferMetaInfo.h to tag dtypeQk since these kernels does not fall under the regular dtypeQ/dtypeKv type trait.
    • (In the future, we expect TRTLLM kernels to yield a more comprehensive type traits that can unify flashInferMetaInfo.h and flashInferMetaInfoVx.h, specifically:)
      • Separate control over dtypeQ, dtypeK, and dtypeV
      • Explicit control over dtypeBmm1 and dtymeBmm2
  • This PR patches FmhaKernels to support the new DiT kernels. trtllm_ragged_attention_launcher is reused as the entry point to these kernels.
    • Since KernelMetaInfoVx yields different type traits, a compatibility layer was added as fmhaKernelMetaAdapter.h.
    • We expect fmhaKernelMetaAdapter.h to be removed after TRTLLM kernels are updated to support kernel specs above, but before this API unification takes place, we'll still need such a halfway to make FlashInfer work with DiT SageAttn.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Optional SageAttention support (scaling-factor tensors + per-block counts) for ragged & paged attention; query/key reinterpretation for mixed dtypes and INT8 input handling; expanded kernel parameterization for improved FMHA paths.
  • Tests

    • Added end-to-end ragged attention tests covering SageAttention and mixed-dtype scenarios.
  • Documentation

    • Published ragged attention entry in the public API reference.
  • Chores

    • Updated runtime artifact path and checksum.

xrq-phys added 2 commits March 7, 2026 02:45
- Update artifact pointers.
- Create compat layer fmhaKernelMetaAdapter that aggregates LLM & VG
  layers.
- Forked kernelParams.h to kernelParamsVx.h, setting additionally
  SageAttention parameters. For now we are keeping LLM & VG params
  separate to allow async kernel refreshing schedules.
- Update fmhaKernels to: load/filter/hash kernels with
  dtypeQkReinterpret + SageAttention block settings.

Signed-off-by: Ruqing Xu <ruqingx@nvidia.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 6, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

📝 Walkthrough

This change extends FMHA kernel selection and launch paths to support Q/K reinterpretation and SageAttention state: adds new key fields to runner/kernel caches, threads Sage scaling-factor pointers and per-block counts through launchers/runner/params, updates kernel metadata/params (Vx path), and adds tests/docs.

Changes

Cohort / File(s) Summary
Launcher & C++ entrypoints
csrc/trtllm_fmha_kernel_launcher.cu
Extended launcher signatures to accept qk_reinterpret_type, Sage-attn SF pointers and per-block counts; added logic to compute/forward qk reinterpretation and Sage pointers into runnerParams.
Runner, Cache & Params
include/flashinfer/trtllm/fmha/fmhaRunner.cuh, include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
Runner constructor and TllmGenFmhaRunnerCache key/get updated to include qk_reinterpret_type and four Sage element-counts; RunnerParams gains ptrSageAttnSfsQ/K/P/V.
Kernel metadata & factory
include/flashinfer/trtllm/fmha/fmhaKernels.cuh, include/flashinfer/trtllm/fmha/fmhaKernelMetaAdapter.h
Introduced KernelMetaAdapter and integrated it; kernel type/hash and kernel loading now include dtypeQkReinterpret and per-block Sage counts; logging updated.
Vx kernel params & TMA descriptors
include/flashinfer/trtllm/fmha/kernelParamsVx.h
New KernelParamsVx implementation: TMA descriptor builders, device-pointer resolution, and setKernelParams wiring (includes Sage per-block counts and SF descriptors).
FMHA kernel headers/aliases
include/flashinfer/trtllm/fmha/fmhaKernels.cuh, include/.../fmhaKernelMetaAdapter.h
Swapped KernelMeta typedef to use the adapter; propagate new public members through Kernel and factory interfaces.
Python API & Prefill
flashinfer/prefill.py
Public API trtllm_ragged_attention_deepseek updated to accept Sage-attn SF tensors and per-block counts; arguments unpacked and forwarded to run functions; FP8/INT8 handling updated.
Artifacts & docs
flashinfer/artifacts.py, docs/api/attention.rst
Updated TRTLLM FMHA artifact path/checksum and added documentation listing for new symbol.
Tests
tests/attention/test_trtllm_ragged_dit.py
Added tests exercising ragged attention including SageAttention paths and dtype combinations.
Misc small mappings
csrc/trtllm_fmha_kernel_launcher.cu (helper)
Extended dl_dtype_to_tllm_data_type to map dl_int8 to INT8 and updated KeyHash to include added tuple fields.

Sequence Diagram(s)

sequenceDiagram
    participant API as Python API
    participant Launcher as FMHA Launcher
    participant Cache as RunnerCache
    participant Runner as TllmGenFmhaRunner
    participant Factory as KernelFactory
    participant GPU as GPU Kernel

    API->>Launcher: call (Q,K,V, options, sage_sfs?, num_elts_sage?)
    Launcher->>Cache: get(q_dtype, kv_dtype, o_dtype, qk_reinterpret_type, num_elts_sage...)
    Cache-->>Launcher: runner (cached or newly created)
    Launcher->>Runner: populate RunnerParams (ptrSageAttnSfs*, counts, shapes, strides)
    Runner->>Factory: select kernel (includes dtypeQkReinterpret & Sage counts)
    Factory-->>Runner: kernel metadata / function pointer
    Runner->>GPU: cuLaunchKernel / KernelParamsVx or KernelParams
    GPU-->>Runner: (optional) reduction / post-process
    Runner-->>Launcher: completion (output buffer)
    Launcher-->>API: return output
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • aleozlx
  • joker-eph
  • cyx-6
  • djmmoss
  • yongwww
  • nvmbreughe
  • bkryu
  • IwakuraRein

Poem

"I’m a rabbit in a CUDA hut,
threading Sage scales through every nut.
Keys and queries dress anew,
kernels pick the perfect shoe.
Hop—softmax!—now outputs strut. 🐇"

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 38.24% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: adding DiT-oriented kernels with Qk type reinterpretation capability into Int8 or BFloat16, which aligns with the substantial modifications across multiple files in the changeset.
Description check ✅ Passed The description provides comprehensive context explaining the three kernel variants, integration changes, the compatibility layer, and forward-looking plans. Pre-commit checks are confirmed complete. However, test status is incomplete (not all tests marked passing despite the checklist item being checked).

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can enforce grammar and style rules using `languagetool`.

Configure the reviews.tools.languagetool setting to enable/disable rules and categories. Refer to the LanguageTool Community to learn more.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands FlashInfer's capabilities by integrating specialized attention kernels tailored for Diffusion Transformer (DiT) models within the TensorRT-LLM framework. The changes enable more flexible mixed-precision computations, including configurations with BFloat16, Int8, and E4m3 data types for query-key products and value tensors, alongside support for SageAttention scaling. This enhancement allows FlashInfer to leverage highly optimized kernels for a broader range of advanced generative AI models.

Highlights

  • DiT-oriented TRTLLM Kernel Support: Added support for three variants of Diffusion Transformer (DiT)-oriented TRTLLM kernels: Qk in BFloat16 with V in E4m3, Qk in Int8 with SageAttention scaling factors and V in E4m3, and Qk in E4m3 with SageAttention scaling factors and V in E4m3.
  • New Kernel Metadata and Adapter Layer: Introduced KernelMetaInfoVx for tagging dtypeQk to handle the distinct type traits of these new kernels, and created fmhaKernelMetaAdapter.h as a compatibility layer to unify the interface for both standard and new TRTLLM kernels.
  • Extended trtllm_ragged_attention_launcher: The existing trtllm_ragged_attention_launcher entry point was patched to accommodate the new DiT kernels, including new parameters for qk_reinterpret_type and SageAttention scaling factors.
  • SageAttention Scaling Factor Integration: Implemented mechanisms to pass SageAttention scaling factors (Q, K, P, V) and their block sizes through the API and into the kernel parameters for specialized quantization schemes.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_fmha_kernel_launcher.cu
    • Extended TllmGenFmhaRunnerCache key and get method to include qk_reinterpret_type and SageAttention block size parameters.
    • Added support for DATA_TYPE_INT8 conversion in dl_dtype_to_tllm_data_type.
    • Modified trtllm_ragged_attention_launcher and trtllm_ragged_attention function signatures to accept new SageAttention scaling factor pointers and block size parameters.
  • docs/api/attention.rst
    • Added trtllm_ragged_attention_deepseek to the API documentation.
  • flashinfer/artifacts.py
    • Updated the TRTLLM_GEN_FMHA artifact path and its corresponding checksum.
  • flashinfer/prefill.py
    • Modified the trtllm_ragged_attention_deepseek Python function to accept sage_attn_sfs and num_elts_per_sage_attn_blk parameters, passing them to the C++ backend.
  • include/flashinfer/trtllm/fmha/fmhaKernelMetaAdapter.h
    • Added new file defining TllmGenFmhaKernelMetaInfoAdapter to provide a unified interface for standard and new Vx kernel metadata, including fields for SageAttention block sizes and dtypeQkReinterpret.
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
    • Updated includes to reference fmhaKernelMetaAdapter.h and kernelParamsVx.h.
    • Modified TllmGenFmhaKernel constructor and kernel selection logic to incorporate dtypeQkReinterpret and SageAttention block size parameters for kernel matching.
    • Implemented conditional kernel launching to use KernelParamsVx for Vx kernels and KernelParams for standard kernels, ensuring proper parameter handling.
  • include/flashinfer/trtllm/fmha/fmhaRunner.cuh
    • Modified TllmGenFmhaRunner constructor to accept new parameters for dtypeQkReinterpret and SageAttention block sizes.
    • Updated the call to getTllmFmhaKernels to pass the newly introduced parameters for kernel retrieval.
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
    • Added ptrSageAttnSfsQ, ptrSageAttnSfsK, ptrSageAttnSfsP, and ptrSageAttnSfsV fields to TllmGenFmhaRunnerParams to support SageAttention scaling factors.
  • include/flashinfer/trtllm/fmha/kernelParamsVx.h
    • Added new file defining KernelParamsVx, a specialized structure for handling kernel parameters for Vx kernels, including TMA descriptor building and SageAttention block size calculations.
  • tests/attention/test_trtllm_ragged_dit.py
    • Added new file containing unit tests for trtllm_ragged_attention_deepseek, covering QKV in FP8, QK in BF16 with V in FP8, and SageAttention with QK in INT8 and V in FP8.
Activity
  • Pre-commit checks were installed and run.
  • Tests were added or updated as needed.
  • All tests are passing.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@xrq-phys xrq-phys changed the title Ruqingx/feat/vx feat: Add DiT-oriented kernels where Qk (Bmm1) type can be reinterpreted into Int8 or BFloat16 Mar 6, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for DiT-oriented TRTLLM kernels, including variants with mixed-precision and SageAttention. The changes are extensive, involving updates to kernel selection logic, data type handling, and parameter passing to accommodate the new kernel specializations. A compatibility layer (fmhaKernelMetaAdapter.h) has been added to unify the metadata of existing and new kernels, which is a good approach for this transitional period. The addition of new tests for the DiT kernels is also a great inclusion.

My review includes a few suggestions to improve maintainability and robustness:

  • Improving the hash function in the kernel cache to reduce potential collisions.
  • Adding a clarifying comment for a potentially misleading variable name.
  • Refactoring a duplicated helper lambda into a common utility file.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/trtllm/fmha/fmhaRunner.cuh (1)

28-53: ⚠️ Potential issue | 🔴 Critical

Allow DATA_TYPE_INT8 for the new SageAttention Q/K path.

The new ragged DiT coverage exercises int8 query/key inputs (tests/attention/test_trtllm_ragged_dit.py:181-193), but this constructor still rejects any mDtypeQ outside E4M3/FP16/BF16. That blocks the new int8 variant before kernel lookup.

🐛 Suggested fix
     FLASHINFER_CHECK(
-        mDtypeQ == DATA_TYPE_E4M3 || mDtypeQ == DATA_TYPE_FP16 || mDtypeQ == DATA_TYPE_BF16,
+        mDtypeQ == DATA_TYPE_E4M3 || mDtypeQ == DATA_TYPE_FP16 ||
+            mDtypeQ == DATA_TYPE_BF16 || mDtypeQ == DATA_TYPE_INT8,
         "Unsupported Q data type: " + std::string(toStr(mDtypeQ)));
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/fmhaRunner.cuh` around lines 28 - 53, The
constructor TllmGenFmhaRunner currently rejects DATA_TYPE_INT8 for query/key
types, blocking the new SageAttention int8 path; update the validation checks
(the FLASHINFER_CHECK calls that inspect mDtypeQ and mDtypeKv) to include
DATA_TYPE_INT8 as an allowed type so int8 Q/K inputs pass validation before
calling getTllmFmhaKernels (leave the output-type check unchanged unless tests
require int8 output).
🧹 Nitpick comments (2)
tests/attention/test_trtllm_ragged_dit.py (1)

57-58: Drop the extra CUDA-availability skip.

This suite already assumes CUDA-capable runners, so this branch only hides misconfigured test environments instead of surfacing them.

Based on learnings, tests in the repository assume CUDA is available and do not require torch.cuda.is_available() guards in pytest fixtures. Ensure test files under tests/ follow this convention and avoid adding CPU-only guards in fixtures unless explicitly handling a non-CUDA environment.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_ragged_dit.py` around lines 57 - 58, Remove the
redundant CUDA availability guard by deleting the conditional block that checks
"if not torch.cuda.is_available()" and calls "pytest.skip(...)" in the test
file; tests under tests/ should assume CUDA is present, so remove the "if not
torch.cuda.is_available(): pytest.skip('CUDA not available.')" branch (search
for that exact conditional) to allow failures to surface in misconfigured
environments.
csrc/trtllm_fmha_kernel_launcher.cu (1)

73-81: Consider using a better hash combining strategy.

The small bit shifts (1-7 bits) combined with XOR may lead to hash collisions when multiple fields have similar values. While the practical impact is minimal given the limited number of unique kernel configurations cached, a more robust approach would use multiplicative hash combining.

♻️ Suggested improvement using hash_combine pattern
 struct KeyHash {
   std::size_t operator()(const Key& k) const {
-    return std::hash<int>()(static_cast<int>(std::get<0>(k))) ^
-           (std::hash<int>()(static_cast<int>(std::get<1>(k))) << 1) ^
-           (std::hash<int>()(static_cast<int>(std::get<2>(k))) << 2) ^
-           (std::hash<int>()(static_cast<int>(std::get<3>(k))) << 3) ^
-           (std::hash<int>()(std::get<4>(k)) << 4) ^ (std::hash<int>()(std::get<5>(k)) << 5) ^
-           (std::hash<int>()(std::get<6>(k)) << 6) ^ (std::hash<int>()(std::get<7>(k)) << 7);
+    std::size_t seed = 0;
+    auto hash_combine = [&seed](auto val) {
+      seed ^= std::hash<int>()(static_cast<int>(val)) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+    };
+    hash_combine(std::get<0>(k));
+    hash_combine(std::get<1>(k));
+    hash_combine(std::get<2>(k));
+    hash_combine(std::get<3>(k));
+    hash_combine(std::get<4>(k));
+    hash_combine(std::get<5>(k));
+    hash_combine(std::get<6>(k));
+    hash_combine(std::get<7>(k));
+    return seed;
   }
 };
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fmha_kernel_launcher.cu` around lines 73 - 81, The
KeyHash::operator() uses small fixed bit shifts and XOR which can produce
collisions; replace it with a robust hash_combine pattern: start with a seed
(std::size_t) and for each element of Key (use std::get<0>(k) ...
std::get<7>(k)) mix in std::hash<int>()(value) using a multiplicative constant
(e.g. 0x9e3779b97f4a7c15ULL) and seed ^= h + constant + (seed<<6) + (seed>>2) or
equivalent combine logic; update KeyHash to iterate the 8 fields and combine
each hash into the seed, then return seed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Around line 552-563: The optional SageAttention scale-factor tensors
(sage_attn_sfs_q, sage_attn_sfs_k, sage_attn_sfs_p, sage_attn_sfs_v) are being
cast to float* without dtype checks; add the same TVM_FFI_ICHECK_EQ(...dtype(),
dl_float32) validations used for attention_sinks and lse before performing the
static_cast, and only set sage_attn_sfs_*_ptr to nullptr if the optional has no
value—this ensures each tensor's dtype is dl_float32 prior to casting in
trtllm_fmha_kernel_launcher.cu.

In `@flashinfer/prefill.py`:
- Around line 3446-3452: The SageAttention tensor tuple (sage_attn_sfs) and
corresponding block sizes (num_elts_per_sage_attn_blk) must be validated before
handing raw pointers to the C++ runner: ensure each tensor in sage_attn_sfs that
has a non-zero entry in num_elts_per_sage_attn_blk is non-None, is on the
expected device (use tensor.device or tensor.get_device()), has dtype
torch.float32, and is contiguous (or call .contiguous() before taking the
pointer); for None entries require the matching block size be zero. Update the
code paths that forward these values (the code working with sage_attn_sfs and
num_elts_per_sage_attn_blk around the SageAttention prep and the later block at
lines ~3573-3608) to perform these checks and raise/handle a clear error if
validation fails, then pass the tensor.data_ptr() only after validation.

In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 949-978: The hash currently maps blockSize==0 and blockSize==1 to
the same encoded value (computeLog2BlockSize returns 0), so update hashID to
mark “disabled” separately: keep computeLog2BlockSize as-is for nonzero sizes,
but when numEltsPerSageAttnBlkQ/K/P/V == 0 set reserved bits 28..31 as per-field
disabled flags (e.g., set bit 28 if numEltsPerSageAttnBlkQ==0, bit 29 for
numEltsPerSageAttnBlkK, bit 30 for P, bit 31 for V) before composing the final
return; reference the hashID function and the numEltsPerSageAttnBlkQ/K/P/V
parameters and ensure you OR the corresponding (1ULL << 28..31) flags into the
returned uint64_t so disabled is distinguishable from block size 1.

In `@include/flashinfer/trtllm/fmha/kernelParamsVx.h`:
- Around line 718-726: The O TMA descriptor and related metadata are being built
using Q-side values (kernelMeta.mDataTypeQ, numEltsInClampedHeadDimQ,
kernelMeta.mTileSizeQ) causing mismatch when O differs; update the O-path to use
O-side metadata: use the O data type (kernelMeta.mDataTypeO) when calling
buildNdTmaDescriptor, compute tileShapeO using O-specific sizes (e.g.,
numEltsInClampedHeadDimO and kernelMeta.mTileSizeO or equivalent O head-dim/
tile-size fields), and ensure mNumHiddenEltsO is computed from the O head
dimension (mHeadDimV) not Q; keep references to makeTmaShapeStrideO, tileShapeO,
params.tmaO_, buildNdTmaDescriptor, kernelMeta.mDataTypeQ/mDataTypeO,
mHeadDimV/mHeadDimQk, and mNumHiddenEltsO to locate the changes.
- Around line 784-814: The code computes params.mChunkedAttentionSizeLog2 when
isSlidingOrChunkedCausalMask(...) and options.mChunkedAttentionSize is set, but
then unconditionally resets params.mChunkedAttentionSizeLog2 to 0 at the end;
remove that reset so the computed value persists. Locate the assignment
params.mChunkedAttentionSizeLog2 = 0 (near the end of the block) and delete it
(or make it conditional only when chunked attention is disabled), ensuring the
earlier computation that uses options.mChunkedAttentionSize and
isSlidingOrChunkedCausalMask retains its result.

---

Outside diff comments:
In `@include/flashinfer/trtllm/fmha/fmhaRunner.cuh`:
- Around line 28-53: The constructor TllmGenFmhaRunner currently rejects
DATA_TYPE_INT8 for query/key types, blocking the new SageAttention int8 path;
update the validation checks (the FLASHINFER_CHECK calls that inspect mDtypeQ
and mDtypeKv) to include DATA_TYPE_INT8 as an allowed type so int8 Q/K inputs
pass validation before calling getTllmFmhaKernels (leave the output-type check
unchanged unless tests require int8 output).

---

Nitpick comments:
In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Around line 73-81: The KeyHash::operator() uses small fixed bit shifts and XOR
which can produce collisions; replace it with a robust hash_combine pattern:
start with a seed (std::size_t) and for each element of Key (use std::get<0>(k)
... std::get<7>(k)) mix in std::hash<int>()(value) using a multiplicative
constant (e.g. 0x9e3779b97f4a7c15ULL) and seed ^= h + constant + (seed<<6) +
(seed>>2) or equivalent combine logic; update KeyHash to iterate the 8 fields
and combine each hash into the seed, then return seed.

In `@tests/attention/test_trtllm_ragged_dit.py`:
- Around line 57-58: Remove the redundant CUDA availability guard by deleting
the conditional block that checks "if not torch.cuda.is_available()" and calls
"pytest.skip(...)" in the test file; tests under tests/ should assume CUDA is
present, so remove the "if not torch.cuda.is_available(): pytest.skip('CUDA not
available.')" branch (search for that exact conditional) to allow failures to
surface in misconfigured environments.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c9e29865-5ce7-4261-9225-2c6cbb57a8c5

📥 Commits

Reviewing files that changed from the base of the PR and between 124a2d3 and 247edbd.

📒 Files selected for processing (10)
  • csrc/trtllm_fmha_kernel_launcher.cu
  • docs/api/attention.rst
  • flashinfer/artifacts.py
  • flashinfer/prefill.py
  • include/flashinfer/trtllm/fmha/fmhaKernelMetaAdapter.h
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/fmhaRunner.cuh
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • include/flashinfer/trtllm/fmha/kernelParamsVx.h
  • tests/attention/test_trtllm_ragged_dit.py

k_stride_batch *= dtype_size;
} else {
// no reinterpret (ignored when qk_reinterpret_type == E4m3 or q_data_type != E4m3)
qk_reinterpret_type = DATA_TYPE_E4M3;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it intended that qk_reinterpet_type is hardcoded to E4M# here as opposed to keeping it the same as the original dtypes? Not sure if I'm missing some functionality here – is DATA_TYPE_E4M3 the "key" for. "ignore"?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. In these kernels the TMA loads input "as if E4m3" but TC treats them as Int8 / Bfloat16.

Currently only dtypeQ == Etm3 kernels consider reinterpreting.

Signed-off-by: Ruqing Xu <ruqingx@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
csrc/trtllm_fmha_kernel_launcher.cu (1)

570-570: Consider renaming kv_data_type for clarity.

This variable is now derived from value.dtype() only, making the name kv_data_type potentially misleading in the DiT case where K and V have different types. Consider renaming to v_data_type or adding a clarifying comment.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fmha_kernel_launcher.cu` at line 570, The variable kv_data_type
is misleading because it is derived only from value.dtype() and may not
represent K and V separately (e.g., DiT where K and V differ); rename
kv_data_type to v_data_type (or alternatively add a clarifying comment where
dl_dtype_to_tllm_data_type(value.dtype()) is assigned) and update any subsequent
uses to reference v_data_type so the code accurately reflects that this type is
for V only.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Line 570: The variable kv_data_type is misleading because it is derived only
from value.dtype() and may not represent K and V separately (e.g., DiT where K
and V differ); rename kv_data_type to v_data_type (or alternatively add a
clarifying comment where dl_dtype_to_tllm_data_type(value.dtype()) is assigned)
and update any subsequent uses to reference v_data_type so the code accurately
reflects that this type is for V only.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 0c9b60bd-9322-466e-9c5b-65729dd78859

📥 Commits

Reviewing files that changed from the base of the PR and between 247edbd and 11ac55e.

📒 Files selected for processing (1)
  • csrc/trtllm_fmha_kernel_launcher.cu

To encode no-sage

Signed-off-by: Ruqing Xu <ruqingx@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)

952-963: ⚠️ Potential issue | 🔴 Critical

Fix the SageAttention block-size encoding in the factory key.

__builtin_ctz(blockSize + 1) is not log2(blockSize) + 1: for 2/4/8/... it returns 0, so those sizes alias the disabled case in the cache key. That can reuse a TllmGenFmhaKernel built for the wrong SageAttention shape.

🔧 Suggested fix
     auto const computeLog2BlockSizePlus1 = [](int blockSize) -> int {
       if (blockSize <= 0) {
         return 0;
       }
       FLASHINFER_CHECK((blockSize & (blockSize - 1)) == 0,
                        "SageAttention block size must be a power of 2.");
-      return __builtin_ctz(static_cast<unsigned int>(blockSize) + 1);
+      return __builtin_ctz(static_cast<unsigned int>(blockSize)) + 1;
     };
#!/bin/bash
python - <<'PY'
def ctz(n: int) -> int:
    return (n & -n).bit_length() - 1

def current(block: int) -> int:
    if block <= 0:
        return 0
    return ctz(block + 1)

def expected(block: int) -> int:
    if block <= 0:
        return 0
    return ctz(block) + 1

for b in [0, 1, 2, 4, 8, 16, 32, 64]:
    print(f"blockSize={b:>2} current={current(b)} expected={expected(b)}")
PY

Also applies to: 975-981

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh` around lines 952 - 963, The
factory key encodes SageAttention block-size incorrectly inside hashID: the
lambda computeLog2BlockSizePlus1 uses __builtin_ctz(blockSize + 1) which maps
powers of two to zero and collides with the disabled case; update
computeLog2BlockSizePlus1 (and the other identical occurrence in the same
function) to return 0 for blockSize <= 0, otherwise return
__builtin_ctz(static_cast<unsigned int>(blockSize)) + 1 so that powers-of-two
produce log2(blockSize)+1; keep the existing FLASHINFER_CHECK for power-of-two
validation and apply the same change to the duplicate block-size encoding sites
referenced in this function.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 312-320: The code launches Vx kernels via cuLaunchKernelEx before
verifying reduction mode; move the rejection check so you call
FLASHINFER_CHECK(!isGmemReductionWithSeparateKernel(static_cast<MultiCtasKvMode>(kernelMeta.mMultiCtasKvMode)),
...) inside the kernelMeta.isKernelVx() branch before creating kernelParams or
calling KernelParamsVx::setKernelParams and cuLaunchKernelEx. Ensure the check
uses the same kernelMeta and MultiCtasKvMode cast so unsupported Vx reduction
modes are rejected early and the kernel is never launched.

---

Duplicate comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 952-963: The factory key encodes SageAttention block-size
incorrectly inside hashID: the lambda computeLog2BlockSizePlus1 uses
__builtin_ctz(blockSize + 1) which maps powers of two to zero and collides with
the disabled case; update computeLog2BlockSizePlus1 (and the other identical
occurrence in the same function) to return 0 for blockSize <= 0, otherwise
return __builtin_ctz(static_cast<unsigned int>(blockSize)) + 1 so that
powers-of-two produce log2(blockSize)+1; keep the existing FLASHINFER_CHECK for
power-of-two validation and apply the same change to the duplicate block-size
encoding sites referenced in this function.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e50114de-f61e-4cc8-b6a5-a6420e1a97e7

📥 Commits

Reviewing files that changed from the base of the PR and between 11ac55e and 99f1a3e.

📒 Files selected for processing (1)
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh

# do not support FP8 output for ragged attention)
if query.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
out_dtype = torch.bfloat16
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need any special handling for torch.int8 here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I'll try to add a more explicit handling here today. Sorry for the late reply.

Thanks!

Signed-off-by: Ruqing Xu <ruqingx@nvidia.com>
@saltyminty saltyminty enabled auto-merge (squash) March 11, 2026 21:33
@saltyminty saltyminty self-requested a review March 11, 2026 21:33
@bkryu bkryu added the run-ci label Mar 12, 2026
@bkryu
Copy link
Collaborator

bkryu commented Mar 12, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !408 has been created, and the CI pipeline #45991403 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #45991403: 6/20 passed

Copy link
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Internal CI looks good.

Signed-off-by: Ruqing Xu <ruqingx@nvidia.com>
auto-merge was automatically disabled March 13, 2026 03:41

Head branch was pushed to by a user without write access

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)

3534-3558: ⚠️ Potential issue | 🟠 Major

Extend provided-out dtype guard to the new int8 path.

Line 3534 defaults int8-query outputs to BF16, but Line 3551 still validates only FP8-query cases. A caller-provided byte out can bypass the new int8 policy.

💡 Proposed fix
-        if query.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) and out.dtype in (
+        if query.dtype in (torch.float8_e4m3fn, torch.float8_e5m2, torch.int8) and out.dtype in (
             torch.float8_e4m3fn,
             torch.float8_e5m2,
+            torch.int8,
         ):
             raise ValueError(
-                "FP8 output is not supported for trtllm_ragged_attention_deepseek; "
-                "use bfloat16 or float16 for out."
+                "Byte-sized output is not supported for trtllm_ragged_attention_deepseek; "
+                "use bfloat16 or float16 for out."
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 3534 - 3558, The provided-`out` dtype
guard must also cover the new int8 query path: in the block that currently
checks "if query.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) and out.dtype
in (torch.float8_e4m3fn, torch.float8_e5m2):" update the condition to include
torch.int8 on the query side and ensure out.dtype is not a byte/int8 type (e.g.,
include torch.int8 in the forbidden out types), and update the error message for
trtllm_ragged_attention_deepseek to mention int8 alongside FP8 so callers cannot
bypass the int8->bfloat16 policy by passing a byte `out`.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/prefill.py`:
- Around line 3534-3558: The provided-`out` dtype guard must also cover the new
int8 query path: in the block that currently checks "if query.dtype in
(torch.float8_e4m3fn, torch.float8_e5m2) and out.dtype in (torch.float8_e4m3fn,
torch.float8_e5m2):" update the condition to include torch.int8 on the query
side and ensure out.dtype is not a byte/int8 type (e.g., include torch.int8 in
the forbidden out types), and update the error message for
trtllm_ragged_attention_deepseek to mention int8 alongside FP8 so callers cannot
bypass the int8->bfloat16 policy by passing a byte `out`.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6c2b41b2-5b8b-48b6-8cff-323efb135e9c

📥 Commits

Reviewing files that changed from the base of the PR and between 90ff697 and 02ac8e3.

📒 Files selected for processing (1)
  • flashinfer/prefill.py

@yongwww
Copy link
Member

yongwww commented Mar 13, 2026

I canceled the PR test because the CI will not pass until #2781 lands. Please re-trigger the test after that PR is merged.

@saltyminty
Copy link
Collaborator

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !408 has been updated with latest changes, and the CI pipeline #46476271 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46476271: 11/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants