Skip to content

feat: Add TRTLLM fmha_v2 library for SM90 attention with Skip-Softmax #2446

Merged
yzh119 merged 52 commits intoflashinfer-ai:mainfrom
jimmyzho:fmhav2
Mar 9, 2026
Merged

feat: Add TRTLLM fmha_v2 library for SM90 attention with Skip-Softmax #2446
yzh119 merged 52 commits intoflashinfer-ai:mainfrom
jimmyzho:fmhav2

Conversation

@jimmyzho
Copy link
Contributor

@jimmyzho jimmyzho commented Jan 30, 2026

📌 Description

Integrates TRT-LLM's Fused Multi-Head Attention v2 (FMHAv2) kernels into FlashInfer as a prefill attention backend via a new trtllm_fmha_v2_prefill Python API, backed by a CUDA runtime dispatch layer (fmha_v2_run.cu) that bridges FlashInfer's tensor conventions with TRT-LLM's fused kernel infrastructure.

  • Data types: FP16, BF16, FP8 (E4M3) inputs with configurable output dtype (e.g., FP8 → BF16/FP16). FP32 accumulation for BF16/E4M3; fused BMM1/BMM2 scaling for quantized inference.
  • Input layouts: Packed QKV [T, 3, H, D], Contiguous Q+KV [T, H, D]+[T, 2, H_kv, D], Separate Q/K/V, and Q + Paged KV Cache (page sizes 32/128, HND/NHD layouts). Block tables auto-expanded from [B, M] → [B, 2, M] for separate K/V offset addressing.
  • Masking: Causal, sliding window (window_left), chunked causal (power-of-2 sizes).
  • Ragged batching: Variable-length sequences via cumulative sequence length tensors.
  • GQA/MQA: Grouped-query and multi-query attention (h % h_kv == 0).
  • Attention sinks: Per-head sink values in softmax denominator.
  • Logits soft capping: Fused tanh-based softcapping with configurable scale.
  • ALiBi: Optional ALiBi positional encoding.
  • Softmax stats (LSE) export: Optional save_softmax_stats returns [max, sum_exp] per token/head.
  • Non-blocking execution

Also adds skip-softmax optimization introduced in TRT-LLM's Documentation

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Full FMHA v2 runtime with JIT-generated, runtime-dispatched kernels supporting multiple QKV layouts, causal/sliding/custom masks, chunked/windowed attention, FP8 and mixed-precision, and optional skip-softmax optimization.
  • API

    • New Python module generators and public prefill/run accessors to load and invoke v2 kernels.
  • Tests

    • Extensive new tests covering layouts, paged KV, sinks, dtypes, and numerical correctness.
  • Bug Fixes

    • Safer 64-bit offset arithmetic to prevent row-pointer overflow.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 30, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a full FMHA v2 implementation: Jinja CUDA kernel templates and launchers (multiple variants), Hopper/TMA support, skip‑softmax voting/barrier, Python JIT code generation and dispatcher, host runtime and JIT/FFI bindings, prefill integration, tests, and updated runtime/launcher APIs.

Changes

Cohort / File(s) Summary
CUDA Kernel Templates
csrc/fmha_v2/templates/fa_kernel.jinja, csrc/fmha_v2/templates/kernel.jinja, csrc/fmha_v2/templates/kernel_hopper.jinja, csrc/fmha_v2/templates/kernel_hopper_ws.jinja
Add extensive templated fused-MHA kernels and launchers across many variants (padding/causal/sliding/custom masks, no-loop, multi-CTA, Hopper/TMA, WS). Expose templated extern "C" kernels/launchers, runtime helpers (get_max_heads_per_wave), dynamic shared-memory sizing, and CUDA-version guards.
Host Runtime & Dispatcher
csrc/fmha_v2_customize_config.jinja, csrc/fmha_v2_dispatcher.jinja, csrc/fmha_v2_run.cu, csrc/fmha_v2_jit_binding.cu
Add runtime dtype/head-dim dispatch macros and dispatcher returning launcher function pointers; implement fmha_v2_run host runtime (param construction, launch selection) and declare/export JIT FFI binding for fmha_v2_run.
JIT Code Generation
flashinfer/jit/attention/fmha_v2/fmha_library.py, .../utils.py, .../generator_utils.py
Add FMHAv2KernelSpec, spec generation/validation, kernel naming/trait selection, code rendering and JIT source emission. Thread new skip‑softmax/skip‑softmax metadata and skip-softmax feature through generator and cubin/header selection.
Module API & Integration
flashinfer/jit/attention/modules.py, flashinfer/jit/attention/__init__.py, flashinfer/jit/__init__.py, flashinfer/prefill.py, flashinfer/__init__.py
Add new module builders (gen_fmha_v2_module, gen_trtllm_fmha_v2_sm120_module), integrate generated sources into JIT builds (copy static runtime/JIT binding), update public exports, and expose trtllm_fmha_v2_prefill and v2 module accessors.
Tests
tests/attention/test_fmha_v2_prefill_deepseek.py
Add comprehensive prefill/deepseek tests and a Torch reference implementation covering multiple QKV layouts (including paged KV), mask modes, FP8/non-FP8 scaling, windowed/chunked attention, and attention sinks.
Skip‑Softmax Feature
csrc/fmha_v2/fmha/warpspec/kernel_traits.h, .../epilogue.h, .../compute.h, csrc/fmha_v2/fused_multihead_attention.h, ..._demo_bert_params.h, fused_multihead_attention.cpp
Introduce ENABLE_SKIP_SOFTMAX flag, SKIP_SOFTMAX barrier ID, shared double-buffered per-warpgroup vote buffers, update softmax API (compute_and_update_scale now returns bool and accepts vote pointer), propagate skip-softmax params and enable flag through host params/launch params, and add optional statistics instrumentation.
Build / JIT Integration / Static Sources
flashinfer/jit/attention/modules.py, csrc/fmha_v2_jit_binding.cu, csrc/fmha_v2_run.cu
Copy static runtime/JIT binding sources into generated JIT dir, add NVCC flags/include paths (SM-specific), and provide SM120-targeted module builders.
Minor gmem fixes
csrc/fmha_v2/fmha/gmem_tile_o_packed.h, .../gmem_tile_ps.h, hopper/gmem_tile_o_packed.h
Fix row-pointer arithmetic by casting indices to int64_t to avoid overflow in large-offset computations.
Misc/Host API Changes
various headers and cpp files under csrc/fmha_v2/...
Add skip_softmax_threshold_scale_factor to params and CLI, propagate through determine_launch_params and set_params, and add launch_params.enable_skip_softmax flag and optional device counters for skip-softmax stats.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant PythonAPI as FlashInfer\n(prefill API)
    participant JITGen as JIT\nCodeGen
    participant Runtime as Host Runtime\n(fmha_v2_run)
    participant Dispatcher as Launcher\nDispatcher
    participant CUDA as CUDA\nLauncher
    participant GPU as Device

    User->>PythonAPI: trtllm_fmha_v2_prefill(...)
    PythonAPI->>JITGen: generate_jit_sources() / ensure module
    JITGen-->>PythonAPI: compiled module / launcher symbols
    PythonAPI->>Runtime: fmha_v2_run(params, layout, dtype, ...)
    Runtime->>Dispatcher: get_fmha_v2_kernel(dtype, head_dims)
    Dispatcher-->>Runtime: FMHAv2KernelLauncher
    Runtime->>Runtime: set_params() / determine_launch_params()
    Runtime->>CUDA: launcher_name(params, launch_params, stream)
    CUDA->>GPU: cudaLaunchKernel / cudaLaunchCooperativeKernel
    GPU-->>User: outputs
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related issues

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • aleozlx
  • djmmoss
  • yzh119
  • cyx-6
  • bkryu
  • nvmbreughe
  • jiahanc

Poem

🐇 I hopped through Jinja, kernels, and streams,

I taught warps to whisper and vote on their dreams,
Launchers and dispatchers now dance in a row,
Prefill hums softly — tensors ready to go,
Hooray for v2 — may the GPUs glow.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the TensorRT-LLM Fused Multi-Head Attention v2 (FMHAv2) library into FlashInfer, enabling dynamic generation and compilation of highly optimized CUDA kernels. This integration significantly expands FlashInfer's capabilities by supporting a wide range of attention configurations, input data layouts, and advanced features, with a particular focus on performance on Hopper GPUs.

Highlights

  • TRT-LLM FMHAv2 Integration: Integrates the TensorRT-LLM Fused Multi-Head Attention v2 library, providing highly optimized CUDA kernels for attention operations.
  • JIT Kernel Generation: Implements a Just-In-Time (JIT) compilation system using Jinja2 templates to dynamically generate CUDA kernels tailored to specific attention configurations (e.g., head size, data type, mask type).
  • Flexible Input Layout Support: The new trtllm_fmha_v2_prefill API supports diverse input formats for query, key, and value tensors, including packed QKV, Q with paged KV, separate QKV, and Q with contiguous KV.
  • Hopper Architecture Optimizations: Includes specialized kernel templates and configurations optimized for NVIDIA Hopper (SM90+) GPUs, leveraging features like warp specialization and Tensor Memory Accelerator (TMA) for enhanced performance.
  • Comprehensive Attention Features: Supports advanced attention features such as causal masking, sliding window attention, ALiBi positional encoding, logits soft capping, and attention sinks.
  • Python API and Extensive Testing: Exposes a new Python API trtllm_fmha_v2_prefill and includes a robust suite of unit tests to validate functionality across various input layouts and attention configurations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant new feature for JIT compiling TRT-LLM FMHAv2 kernels. The changes are extensive, including new Jinja templates for CUDA code generation, Python code for managing the JIT process, and new public APIs. While the overall structure is sound, I've found several critical issues in the CUDA kernel launcher templates where incorrect kernel traits are used for different attention mask types, potentially leading to incorrect execution or launch failures. These need to be addressed. The addition of comprehensive tests is a great step towards ensuring the correctness of this new functionality.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
flashinfer/jit/__init__.py (1)

90-94: ⚠️ Potential issue | 🟠 Major

Wrap CUDA runtime preload in exception handling to prevent import-time failures.

If libcudart.so.12 exists but is incompatible or has missing symbols, ctypes.CDLL() will raise OSError and make the flashinfer.jit module unimportable. Add error handling to allow graceful degradation:

Fix
 if os.path.exists(f"{cuda_lib_path}/libcudart.so.12"):
-    ctypes.CDLL(f"{cuda_lib_path}/libcudart.so.12", mode=ctypes.RTLD_GLOBAL)
+    try:
+        ctypes.CDLL(f"{cuda_lib_path}/libcudart.so.12", mode=ctypes.RTLD_GLOBAL)
+    except OSError:
+        pass
flashinfer/jit/attention/fmha_v2/generator_utils.py (1)

3816-3826: ⚠️ Potential issue | 🟠 Major

Avoid runtime subscripting of itertools.product.

The project supports Python >=3.10, and itertools.product is not subscriptable at runtime. Line 3816 uses product[tuple[bool, bool, InputLayout, bool]](...) which will raise TypeError: type 'itertools.product' is not subscriptable when enumerate_hgmma_flash_warpspec_kernels executes.

Use a type annotation instead to preserve typing without runtime subscription:

Safer approach
+from typing import Iterable
...
-    combinations = product[tuple[bool, bool, InputLayout, bool]](
+    combinations: Iterable[tuple[bool, bool, InputLayout, bool]] = product(
         [False, True],
         [False, True],
         [
             InputLayout.PACKED_QKV,
             InputLayout.CONTIGUOUS_Q_KV,
             InputLayout.Q_PAGED_KV,
             InputLayout.SEPARATE_Q_K_V,
         ],
         [False, True],
     )
flashinfer/jit/attention/modules.py (1)

1903-1936: ⚠️ Potential issue | 🟡 Minor

Use or remove the unused device parameter.

device is unused (Line 1903), which triggers lint and misses a chance to validate SM120 at the entrypoint. Either remove it or use it for an arch guard.

🛠️ Suggested fix (use the parameter to validate SM120)
 def gen_trtllm_fmha_v2_sm120_module(device: torch.device) -> JitSpec:
+    if torch.cuda.get_device_properties(device).major < 12:
+        raise ValueError("trtllm_fmha_v2_sm120_module requires SM120+ device")
     uri = "trtllm_fmha_v2"
flashinfer/prefill.py (1)

3772-3778: ⚠️ Potential issue | 🔴 Critical

Fix mismatched get_trtllm_fmha_v2_module call signature.

Line 3777 passes 120, but get_trtllm_fmha_v2_module() is defined without parameters. This will raise a TypeError at runtime.

🐛 Suggested fix
-    module = get_trtllm_fmha_v2_module(120)
+    module = get_trtllm_fmha_v2_module()
🤖 Fix all issues with AI agents
In `@csrc/fmha_v2_run.cu`:
- Around line 298-308: fmha_v2_paged_run is an unimplemented stub that
fmha_v2_ragged_run forwards to; replace the TODO with a hard failure so callers
don't silently do nothing. Inside fmha_v2_paged_run add an assertion or explicit
runtime error (e.g. assert(false && "fmha_v2_paged_run not implemented") or
throw std::runtime_error("fmha_v2_paged_run not implemented")), and include the
required header (<cassert> or <stdexcept>) if missing so the build fails clearly
when this function is invoked.
- Around line 531-535: The code unconditionally allocates softmax_stats_d and
passes it to the kernel, ignoring the caller's softmax_stats; change the logic
so that when softmax_stats.has_value() is true you set params.softmax_stats_ptr
to softmax_stats.value().data_ptr() (and only allocate softmax_stats_d when
softmax_stats is not provided), use the existing softmax_stats_ptr variable for
the kernel invocation instead of the debug-only pointer, and if internal scratch
is still needed allocate a separate buffer and copy the kernel results into
softmax_stats.value().data_ptr() afterward; update the allocation path around
allocator.aligned_alloc, the softmax_stats_d/softmax_stats_ptr variables, and
the assignment to params.softmax_stats_ptr accordingly.
- Around line 521-525: When attention_mask_type equals
Attention_mask_type::CUSTOM_MASK the code allocates packed_mask_d via
allocator.aligned_alloc but never populates it or accepts a mask tensor; update
the API and call site to accept a packed/custom mask input (e.g., add a device
pointer or Tensor argument named packed_mask_input) and memcpy/copy that data
into packed_mask_d before calling set_params() (which currently consumes
packed_mask_d and mask_mode_code), or if you prefer not to support custom masks
yet, explicitly error out when attention_mask_type == CUSTOM_MASK (using a clear
runtime error) so uninitialized memory is never passed into set_params()/the
kernel; change references in this file for packed_mask_d, attention_mask_type,
Attention_mask_type::CUSTOM_MASK, allocator.aligned_alloc, set_params(), and
mask_mode_code accordingly.

In `@csrc/fmha_v2/templates/kernel_hopper.jinja`:
- Around line 324-358: The current `_nl` launcher uses a single constexpr
smem_size = Kernel_traits::BYTES_PER_SMEM but then launches kernels compiled
with specialized traits (Kernel_traits_causal_nl,
Kernel_traits_sliding_or_chunked_causal_nl, Kernel_traits_nl), so
cudaFuncSetAttribute may be given the wrong size; change the code to compute and
use a branch-specific smem_size before each cudaFuncSetAttribute call (use
Kernel_traits_causal_nl::BYTES_PER_SMEM for {{ causal_kernel_name }}_nl,
Kernel_traits_sliding_or_chunked_causal_nl::BYTES_PER_SMEM for {{
sliding_or_chunked_causal_kernel_name }}_nl, and
Kernel_traits_nl::BYTES_PER_SMEM for {{ kernel_name }}_nl) and remove or stop
using the original Kernel_traits::BYTES_PER_SMEM constant for these branches so
the attribute matches the actual kernel trait.
- Around line 181-210: The TMA descriptors are hardcoded to F16_RN which
misinterprets BF16 data; update tma_desc_QKV.set_tma_desctriptor calls (for Q,
K, V - references: params.tma_desc_q, params.tma_desc_k, params.tma_desc_v) to
use a templated/variable tma_desc_format instead of
fmha::cudaTmaDescFormat::F16_RN, and ensure that the template context receives a
mapped value (e.g., tmp["tma_desc_format"] =
dtype_to_tma_format_mapping[kspec.dtype]) so the kernel dtype selects F16_RN for
FP16 and BF16_RN for BF16 (apply same change for all three descriptors and for
warp-specialized BF16 paths when use_tma is enabled).
- Around line 147-151: The code sets tensor_size_qkv[1] using a non-existent
launch_params.seqlens member; replace that usage with
launch_params.total_kv_seqlen so tensor_size_qkv[1] = params.is_s_padded ?
params.s * params.b : launch_params.total_kv_seqlen; locate the assignment to
tensor_size_qkv in kernel_hopper.jinja and update it accordingly (references:
tensor_size_qkv, params.is_s_padded, params.s, params.b,
launch_params.total_kv_seqlen, Fused_multihead_attention_launch_params).

In `@flashinfer/jit/__init__.py`:
- Line 58: Add a deprecated alias so existing callers of
get_trtllm_fmha_v2_module continue to work: keep the exported gen_fmha_v2_module
and assign get_trtllm_fmha_v2_module = gen_fmha_v2_module in the same module,
and emit a DeprecationWarning (via the warnings module) when
get_trtllm_fmha_v2_module is accessed/called to guide users to the new name;
update the top-level export list if present to include both names.

In `@flashinfer/jit/attention/fmha_v2/fmha_library.py`:
- Around line 1132-1228: generate_jit_sources currently returns only
source_paths; change it to produce and return a JitSpec by wrapping the
generated files/specs with gen_jit_spec so the JIT wiring and GPU gating are
applied. After writing kernel sources and the API (get_api_code /
write_if_different) collect the same uri, gen_directory, source_paths and
specs_names and call gen_jit_spec(...) returning its result instead of
source_paths, and pass a supported_major_versions list (e.g.,
supported_major_versions=[8, 9] or the appropriate CUDA major versions for this
module) to ensure the module only compiles on supported GPUs. Ensure function
name generate_jit_sources remains the entry point but now returns the JitSpec
from gen_jit_spec.
- Around line 92-114: The unused-parameter linter warnings are caused by
parameters/variables not being referenced; rename them by prefixing with
underscores: in function select_ldgsts change the signature to use _sm,
_warp_specialization, _head_size, _dtype (and update any internal refs
accordingly) so Ruff ARG001 is silenced, and in get_signature rename the
parameter use_tma to _use_tma and the local variable sm_name to _sm_name
(updating any local references) to mark them intentionally unused.
- Around line 146-250: Import Any from typing and annotate the local variable
spec as dict[str, Any] so mypy knows it can contain heterogeneous types before
it's unpacked into FMHAv2KernelSpec; update the declaration of spec (the dict
passed to FMHAv2KernelSpec(**spec)) to use the dict[str, Any] annotation and add
the corresponding "from typing import Any" import at the top of the module.
- Around line 1173-1184: Annotate the other_configs variable (the result of
itertools.product in fmha_library.py) with a precise Iterator type instead of
leaving it untyped; add an import for typing.Iterator and declare other_configs
as Iterator[tuple] (i.e., use Iterator[tuple] as the annotation for the product
result) so mypy recognizes it as an iterator returned by itertools.product.

In `@flashinfer/prefill.py`:
- Around line 3827-3832: The parameter typing for sinks is incorrect: change the
annotation and any docs from Optional[List[torch.Tensor]] to
Optional[torch.Tensor] (i.e., a single tensor) wherever it appears (e.g., the
function signature containing
block_tables/out/out_dtype/kv_layout/sinks/pos_encoding_mode around the sinks
parameter and the other occurrence at lines 3888-3890) so it matches the FFI
signature; update any associated docstring/comments and any runtime checks that
expect a list to treat sinks as a tensor instead.
- Around line 3985-4031: The code calls mask_mode.lower() without guarding for
None (mask_mode is Optional), causing an AttributeError when mask_mode is None;
update the checks and mappings that use mask_mode.lower() — specifically the
is_non_causal assignment, the mask_mode_map lookup that sets mask_mode_code, and
any other use of mask_mode.lower() — to handle None safely (e.g., treat None as
the default "causal" or check mask_mode is not None before calling .lower()) so
mask_mode being None no longer raises.

In `@tests/attention/test_fmha_v2_prefill_deepseek.py`:
- Around line 254-275: Remove the debug print that prints the test output:
delete the call to print(output) after the trtllm_fmha_v2_prefill invocation
(referencing the variable output and the function trtllm_fmha_v2_prefill) so the
test no longer emits stdout noise during CI; ensure no other stray prints remain
in the same test function.
🧹 Nitpick comments (2)
tests/attention/test_fmha_v2_prefill_deepseek.py (1)

183-190: Normalize mask_mode casing to match API expectations.
The prefill API defaults to lowercase mode strings (e.g., "causal", "sliding_window"). Normalizing here avoids case-sensitive mismatches. This pattern repeats in later parametrizations too.

🔧 Suggested normalization
-        (True, -1, "CAUSAL"),
-        (True, 127, "SLIDING_WINDOW"),
-        (True, 512, "SLIDING_WINDOW"),
+        (True, -1, "causal"),
+        (True, 127, "sliding_window"),
+        (True, 512, "sliding_window"),
flashinfer/jit/attention/fmha_v2/utils.py (1)

346-349: Address Ruff unused-argument warnings.

get_GMMA_shape ignores m/k, and get_signature unpacks sm_name without using it. Prefix unused parameters/variables with _ to satisfy Ruff.

✂️ Example cleanup
-def get_GMMA_shape(instruction_traits, m, n, k, warps_n):
+def get_GMMA_shape(instruction_traits, _m, n, _k, warps_n):
 ...
-    def get_signature(lname, version, cross_mha, use_tma):
+    def get_signature(lname, version, cross_mha, _use_tma):
         # The architecture that determines the instruction.
-        effective_sm, sm_name = get_effective_sm_and_name(kspec)
+        effective_sm, _sm_name = get_effective_sm_and_name(kspec)

Also applies to: 490-492

Comment on lines +521 to +525
// Packed mask (allocated conditionally for CUSTOM_MASK)
void* packed_mask_d =
(attention_mask_type == Attention_mask_type::CUSTOM_MASK)
? allocator.aligned_alloc<void>(packed_mask_size_in_bytes, 128, "packed_mask_d")
: nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -name "fmha_v2_run.cu" -type f

Repository: flashinfer-ai/flashinfer

Length of output: 89


🏁 Script executed:

wc -l csrc/fmha_v2_run.cu

Repository: flashinfer-ai/flashinfer

Length of output: 91


🏁 Script executed:

sed -n '510,535p' csrc/fmha_v2_run.cu

Repository: flashinfer-ai/flashinfer

Length of output: 1441


🏁 Script executed:

sed -n '685,698p' csrc/fmha_v2_run.cu

Repository: flashinfer-ai/flashinfer

Length of output: 619


🏁 Script executed:

sed -n '1,50p' csrc/fmha_v2_run.cu

Repository: flashinfer-ai/flashinfer

Length of output: 1879


🏁 Script executed:

sed -n '480,540p' csrc/fmha_v2_run.cu

Repository: flashinfer-ai/flashinfer

Length of output: 3104


🏁 Script executed:

rg -n "void\* packed_mask_d" csrc/fmha_v2_run.cu -B 50 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 5344


🏁 Script executed:

sed -n '400,475p' csrc/fmha_v2_run.cu

Repository: flashinfer-ai/flashinfer

Length of output: 3263


🏁 Script executed:

sed -n '300,410p' csrc/fmha_v2_run.cu

Repository: flashinfer-ai/flashinfer

Length of output: 6247


🏁 Script executed:

sed -n '640,675p' csrc/fmha_v2_run.cu

Repository: flashinfer-ai/flashinfer

Length of output: 2082


🏁 Script executed:

rg -n "set_params" csrc/fmha_v2_run.cu

Repository: flashinfer-ai/flashinfer

Length of output: 499


🏁 Script executed:

sed -n '620,670p' csrc/fmha_v2_run.cu

Repository: flashinfer-ai/flashinfer

Length of output: 3060


🏁 Script executed:

sed -n '328,360p' csrc/fmha_v2_run.cu

Repository: flashinfer-ai/flashinfer

Length of output: 2055


🏁 Script executed:

rg -n "CUSTOM_MASK" csrc/fmha_v2_run.cu

Repository: flashinfer-ai/flashinfer

Length of output: 427


🏁 Script executed:

rg -n "CUSTOM_MASK|custom_mask" csrc/fmha_v2_run.cu -A 3 -B 3

Repository: flashinfer-ai/flashinfer

Length of output: 1854


🏁 Script executed:

rg -n "mask_mode_code|attention_mask_type" csrc/fmha_v2_run.cu | grep -i "error\|assert\|reject\|valid\|check\|fail"

Repository: flashinfer-ai/flashinfer

Length of output: 50


CUSTOM_MASK path uses uninitialized mask data.

packed_mask_d is allocated (lines 523-525) but never populated, and there is no API input to supply a custom mask. The function signature accepts mask_mode_code as an integer flag but no corresponding mask tensor. When attention_mask_type == CUSTOM_MASK, uninitialized memory is passed to set_params() and subsequently to the kernel, resulting in undefined behavior. Add a mask tensor input parameter and populate packed_mask_d before calling set_params(), or reject CUSTOM_MASK with a clear error.

Also applies to: 692-695 (NOTE comment acknowledging the gap without implementation)

🤖 Prompt for AI Agents
In `@csrc/fmha_v2_run.cu` around lines 521 - 525, When attention_mask_type equals
Attention_mask_type::CUSTOM_MASK the code allocates packed_mask_d via
allocator.aligned_alloc but never populates it or accepts a mask tensor; update
the API and call site to accept a packed/custom mask input (e.g., add a device
pointer or Tensor argument named packed_mask_input) and memcpy/copy that data
into packed_mask_d before calling set_params() (which currently consumes
packed_mask_d and mask_mode_code), or if you prefer not to support custom masks
yet, explicitly error out when attention_mask_type == CUSTOM_MASK (using a clear
runtime error) so uninitialized memory is never passed into set_params()/the
kernel; change references in this file for packed_mask_d, attention_mask_type,
Attention_mask_type::CUSTOM_MASK, allocator.aligned_alloc, set_params(), and
mask_mode_code accordingly.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/prefill.py`:
- Line 3846: The call module = get_trtllm_fmha_v2_module(120) passes an argument
to get_trtllm_fmha_v2_module which is defined without parameters, causing a
TypeError; fix by either removing the argument at the call site (change the call
to get_trtllm_fmha_v2_module()) if SM filtering is not needed, or update the
function signature def get_trtllm_fmha_v2_module(sm_version): and propagate/use
that sm_version inside the function to perform any SM-version-specific
filtering/logic; choose one approach and make the call and function signature
consistent.
🧹 Nitpick comments (2)
flashinfer/prefill.py (2)

4123-4126: Remove useless conditional.

Both branches of the ternary return 1.0, making the condition pointless. This appears to be leftover from development where different scales were intended for FP8 vs non-FP8.

♻️ Suggested fix
-    # Softmax scale is typically 1.0 for FP8 and auto for FP16/BF16
-    is_e4m3 = (
-        query.dtype == torch.float8_e4m3fn if hasattr(torch, "float8_e4m3fn") else False
-    )
-    scale_softmax = 1.0 if is_e4m3 else 1.0
+    scale_softmax = 1.0

If dtype-specific scaling is needed in the future, restore the conditional with different values.


3906-3908: Unused parameter and incomplete return type annotation.

  1. The non_blocking parameter (line 3906) is declared but never used in the function body.
  2. The return type annotation -> torch.Tensor doesn't reflect that the function returns Tuple[torch.Tensor, torch.Tensor] when save_softmax_stats=True.
♻️ Suggested fix
-    non_blocking: Optional[bool] = True,
     save_softmax_stats: Optional[bool] = False,
-) -> torch.Tensor:
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:

Either use non_blocking for tensor copies (e.g., when converting layouts) or remove the parameter to avoid confusion.

bobboli and others added 9 commits March 4, 2026 17:13
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Removed filtering of target SM versions for CUDA architectures.
@jimmyzho jimmyzho requested a review from jdebache March 4, 2026 23:04
@jimmyzho
Copy link
Contributor Author

jimmyzho commented Mar 4, 2026

@hypdeb I have addressed your feedback - I've added a unit-test for chunked prefill (q < kv) with chunked attention mask, please let me know if this looks good to you now, thanks!

@jimmyzho
Copy link
Contributor Author

jimmyzho commented Mar 4, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !377 has been created, and the CI pipeline #45369480 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #45369480: 8/20 passed

@jimmyzho
Copy link
Contributor Author

jimmyzho commented Mar 5, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !377 has been updated with latest changes, and the CI pipeline #45457879 is currently running. I'll report back once the pipeline job completes.

@jimmyzho
Copy link
Contributor Author

jimmyzho commented Mar 6, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !377 has been updated with latest changes, and the CI pipeline #45475570 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #45475570: 8/20 passed

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving as owner of unit test files.

Can you add microbenchmark support?

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks you for the great work @jimmyzho !

@yzh119 yzh119 merged commit c43bc92 into flashinfer-ai:main Mar 9, 2026
22 checks passed
@blake-snc
Copy link
Contributor

🎉🎉🎉

frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
…flashinfer-ai#2446)

<!-- .github/pull_request_template.md -->

## 📌 Description

Integrates TRT-LLM's Fused Multi-Head Attention v2 (FMHAv2) kernels into
FlashInfer as a prefill attention backend via a new
trtllm_fmha_v2_prefill Python API, backed by a CUDA runtime dispatch
layer (fmha_v2_run.cu) that bridges FlashInfer's tensor conventions with
TRT-LLM's fused kernel infrastructure.

- Data types: FP16, BF16, FP8 (E4M3) inputs with configurable output
dtype (e.g., FP8 → BF16/FP16). FP32 accumulation for BF16/E4M3; fused
BMM1/BMM2 scaling for quantized inference.
- Input layouts: Packed QKV [T, 3, H, D], Contiguous Q+KV [T, H, D]+[T,
2, H_kv, D], Separate Q/K/V, and Q + Paged KV Cache (page sizes 32/128,
HND/NHD layouts). Block tables auto-expanded from [B, M] → [B, 2, M] for
separate K/V offset addressing.
- Masking: Causal, sliding window (window_left), chunked causal
(power-of-2 sizes).
- Ragged batching: Variable-length sequences via cumulative sequence
length tensors.
- GQA/MQA: Grouped-query and multi-query attention (h % h_kv == 0).
- Attention sinks: Per-head sink values in softmax denominator.
- Logits soft capping: Fused tanh-based softcapping with configurable
scale.
- ALiBi: Optional ALiBi positional encoding.
- Softmax stats (LSE) export: Optional save_softmax_stats returns [max,
sum_exp] per token/head.
- Non-blocking execution

Also adds skip-softmax optimization introduced in [TRT-LLM's
Documentation](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog16_Accelerating_Long_Context_Inference_with_Skip_Softmax_Attention.md#accelerating-long-context-inference-with-skip-softmax-attention)

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Full FMHA v2 runtime with JIT-generated, runtime-dispatched kernels
supporting multiple QKV layouts, causal/sliding/custom masks,
chunked/windowed attention, FP8 and mixed-precision, and optional
skip-softmax optimization.

* **API**
* New Python module generators and public prefill/run accessors to load
and invoke v2 kernels.

* **Tests**
* Extensive new tests covering layouts, paged KV, sinks, dtypes, and
numerical correctness.

* **Bug Fixes**
  * Safer 64-bit offset arithmetic to prevent row-pointer overflow.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Yuxin <yuxinz@nvidia.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Co-authored-by: Yuxin <yuxinz@nvidia.com>
Co-authored-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Co-authored-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…flashinfer-ai#2446)

<!-- .github/pull_request_template.md -->

## 📌 Description

Integrates TRT-LLM's Fused Multi-Head Attention v2 (FMHAv2) kernels into
FlashInfer as a prefill attention backend via a new
trtllm_fmha_v2_prefill Python API, backed by a CUDA runtime dispatch
layer (fmha_v2_run.cu) that bridges FlashInfer's tensor conventions with
TRT-LLM's fused kernel infrastructure.

- Data types: FP16, BF16, FP8 (E4M3) inputs with configurable output
dtype (e.g., FP8 → BF16/FP16). FP32 accumulation for BF16/E4M3; fused
BMM1/BMM2 scaling for quantized inference.
- Input layouts: Packed QKV [T, 3, H, D], Contiguous Q+KV [T, H, D]+[T,
2, H_kv, D], Separate Q/K/V, and Q + Paged KV Cache (page sizes 32/128,
HND/NHD layouts). Block tables auto-expanded from [B, M] → [B, 2, M] for
separate K/V offset addressing.
- Masking: Causal, sliding window (window_left), chunked causal
(power-of-2 sizes).
- Ragged batching: Variable-length sequences via cumulative sequence
length tensors.
- GQA/MQA: Grouped-query and multi-query attention (h % h_kv == 0).
- Attention sinks: Per-head sink values in softmax denominator.
- Logits soft capping: Fused tanh-based softcapping with configurable
scale.
- ALiBi: Optional ALiBi positional encoding.
- Softmax stats (LSE) export: Optional save_softmax_stats returns [max,
sum_exp] per token/head.
- Non-blocking execution

Also adds skip-softmax optimization introduced in [TRT-LLM's
Documentation](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog16_Accelerating_Long_Context_Inference_with_Skip_Softmax_Attention.md#accelerating-long-context-inference-with-skip-softmax-attention)

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Full FMHA v2 runtime with JIT-generated, runtime-dispatched kernels
supporting multiple QKV layouts, causal/sliding/custom masks,
chunked/windowed attention, FP8 and mixed-precision, and optional
skip-softmax optimization.

* **API**
* New Python module generators and public prefill/run accessors to load
and invoke v2 kernels.

* **Tests**
* Extensive new tests covering layouts, paged KV, sinks, dtypes, and
numerical correctness.

* **Bug Fixes**
  * Safer 64-bit offset arithmetic to prevent row-pointer overflow.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Yuxin <yuxinz@nvidia.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Co-authored-by: Yuxin <yuxinz@nvidia.com>
Co-authored-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Co-authored-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.