feat: Add TRTLLM fmha_v2 library for SM90 attention with Skip-Softmax #2446
feat: Add TRTLLM fmha_v2 library for SM90 attention with Skip-Softmax #2446yzh119 merged 52 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a full FMHA v2 implementation: Jinja CUDA kernel templates and launchers (multiple variants), Hopper/TMA support, skip‑softmax voting/barrier, Python JIT code generation and dispatcher, host runtime and JIT/FFI bindings, prefill integration, tests, and updated runtime/launcher APIs. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant PythonAPI as FlashInfer\n(prefill API)
participant JITGen as JIT\nCodeGen
participant Runtime as Host Runtime\n(fmha_v2_run)
participant Dispatcher as Launcher\nDispatcher
participant CUDA as CUDA\nLauncher
participant GPU as Device
User->>PythonAPI: trtllm_fmha_v2_prefill(...)
PythonAPI->>JITGen: generate_jit_sources() / ensure module
JITGen-->>PythonAPI: compiled module / launcher symbols
PythonAPI->>Runtime: fmha_v2_run(params, layout, dtype, ...)
Runtime->>Dispatcher: get_fmha_v2_kernel(dtype, head_dims)
Dispatcher-->>Runtime: FMHAv2KernelLauncher
Runtime->>Runtime: set_params() / determine_launch_params()
Runtime->>CUDA: launcher_name(params, launch_params, stream)
CUDA->>GPU: cudaLaunchKernel / cudaLaunchCooperativeKernel
GPU-->>User: outputs
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces the TensorRT-LLM Fused Multi-Head Attention v2 (FMHAv2) library into FlashInfer, enabling dynamic generation and compilation of highly optimized CUDA kernels. This integration significantly expands FlashInfer's capabilities by supporting a wide range of attention configurations, input data layouts, and advanced features, with a particular focus on performance on Hopper GPUs. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant new feature for JIT compiling TRT-LLM FMHAv2 kernels. The changes are extensive, including new Jinja templates for CUDA code generation, Python code for managing the JIT process, and new public APIs. While the overall structure is sound, I've found several critical issues in the CUDA kernel launcher templates where incorrect kernel traits are used for different attention mask types, potentially leading to incorrect execution or launch failures. These need to be addressed. The addition of comprehensive tests is a great step towards ensuring the correctness of this new functionality.
There was a problem hiding this comment.
Actionable comments posted: 14
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
flashinfer/jit/__init__.py (1)
90-94:⚠️ Potential issue | 🟠 MajorWrap CUDA runtime preload in exception handling to prevent import-time failures.
If
libcudart.so.12exists but is incompatible or has missing symbols,ctypes.CDLL()will raiseOSErrorand make theflashinfer.jitmodule unimportable. Add error handling to allow graceful degradation:Fix
if os.path.exists(f"{cuda_lib_path}/libcudart.so.12"): - ctypes.CDLL(f"{cuda_lib_path}/libcudart.so.12", mode=ctypes.RTLD_GLOBAL) + try: + ctypes.CDLL(f"{cuda_lib_path}/libcudart.so.12", mode=ctypes.RTLD_GLOBAL) + except OSError: + passflashinfer/jit/attention/fmha_v2/generator_utils.py (1)
3816-3826:⚠️ Potential issue | 🟠 MajorAvoid runtime subscripting of
itertools.product.The project supports Python >=3.10, and
itertools.productis not subscriptable at runtime. Line 3816 usesproduct[tuple[bool, bool, InputLayout, bool]](...)which will raiseTypeError: type 'itertools.product' is not subscriptablewhenenumerate_hgmma_flash_warpspec_kernelsexecutes.Use a type annotation instead to preserve typing without runtime subscription:
Safer approach
+from typing import Iterable ... - combinations = product[tuple[bool, bool, InputLayout, bool]]( + combinations: Iterable[tuple[bool, bool, InputLayout, bool]] = product( [False, True], [False, True], [ InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV, InputLayout.Q_PAGED_KV, InputLayout.SEPARATE_Q_K_V, ], [False, True], )flashinfer/jit/attention/modules.py (1)
1903-1936:⚠️ Potential issue | 🟡 MinorUse or remove the unused
deviceparameter.
deviceis unused (Line 1903), which triggers lint and misses a chance to validate SM120 at the entrypoint. Either remove it or use it for an arch guard.🛠️ Suggested fix (use the parameter to validate SM120)
def gen_trtllm_fmha_v2_sm120_module(device: torch.device) -> JitSpec: + if torch.cuda.get_device_properties(device).major < 12: + raise ValueError("trtllm_fmha_v2_sm120_module requires SM120+ device") uri = "trtllm_fmha_v2"flashinfer/prefill.py (1)
3772-3778:⚠️ Potential issue | 🔴 CriticalFix mismatched get_trtllm_fmha_v2_module call signature.
Line 3777 passes
120, butget_trtllm_fmha_v2_module()is defined without parameters. This will raise a TypeError at runtime.🐛 Suggested fix
- module = get_trtllm_fmha_v2_module(120) + module = get_trtllm_fmha_v2_module()
🤖 Fix all issues with AI agents
In `@csrc/fmha_v2_run.cu`:
- Around line 298-308: fmha_v2_paged_run is an unimplemented stub that
fmha_v2_ragged_run forwards to; replace the TODO with a hard failure so callers
don't silently do nothing. Inside fmha_v2_paged_run add an assertion or explicit
runtime error (e.g. assert(false && "fmha_v2_paged_run not implemented") or
throw std::runtime_error("fmha_v2_paged_run not implemented")), and include the
required header (<cassert> or <stdexcept>) if missing so the build fails clearly
when this function is invoked.
- Around line 531-535: The code unconditionally allocates softmax_stats_d and
passes it to the kernel, ignoring the caller's softmax_stats; change the logic
so that when softmax_stats.has_value() is true you set params.softmax_stats_ptr
to softmax_stats.value().data_ptr() (and only allocate softmax_stats_d when
softmax_stats is not provided), use the existing softmax_stats_ptr variable for
the kernel invocation instead of the debug-only pointer, and if internal scratch
is still needed allocate a separate buffer and copy the kernel results into
softmax_stats.value().data_ptr() afterward; update the allocation path around
allocator.aligned_alloc, the softmax_stats_d/softmax_stats_ptr variables, and
the assignment to params.softmax_stats_ptr accordingly.
- Around line 521-525: When attention_mask_type equals
Attention_mask_type::CUSTOM_MASK the code allocates packed_mask_d via
allocator.aligned_alloc but never populates it or accepts a mask tensor; update
the API and call site to accept a packed/custom mask input (e.g., add a device
pointer or Tensor argument named packed_mask_input) and memcpy/copy that data
into packed_mask_d before calling set_params() (which currently consumes
packed_mask_d and mask_mode_code), or if you prefer not to support custom masks
yet, explicitly error out when attention_mask_type == CUSTOM_MASK (using a clear
runtime error) so uninitialized memory is never passed into set_params()/the
kernel; change references in this file for packed_mask_d, attention_mask_type,
Attention_mask_type::CUSTOM_MASK, allocator.aligned_alloc, set_params(), and
mask_mode_code accordingly.
In `@csrc/fmha_v2/templates/kernel_hopper.jinja`:
- Around line 324-358: The current `_nl` launcher uses a single constexpr
smem_size = Kernel_traits::BYTES_PER_SMEM but then launches kernels compiled
with specialized traits (Kernel_traits_causal_nl,
Kernel_traits_sliding_or_chunked_causal_nl, Kernel_traits_nl), so
cudaFuncSetAttribute may be given the wrong size; change the code to compute and
use a branch-specific smem_size before each cudaFuncSetAttribute call (use
Kernel_traits_causal_nl::BYTES_PER_SMEM for {{ causal_kernel_name }}_nl,
Kernel_traits_sliding_or_chunked_causal_nl::BYTES_PER_SMEM for {{
sliding_or_chunked_causal_kernel_name }}_nl, and
Kernel_traits_nl::BYTES_PER_SMEM for {{ kernel_name }}_nl) and remove or stop
using the original Kernel_traits::BYTES_PER_SMEM constant for these branches so
the attribute matches the actual kernel trait.
- Around line 181-210: The TMA descriptors are hardcoded to F16_RN which
misinterprets BF16 data; update tma_desc_QKV.set_tma_desctriptor calls (for Q,
K, V - references: params.tma_desc_q, params.tma_desc_k, params.tma_desc_v) to
use a templated/variable tma_desc_format instead of
fmha::cudaTmaDescFormat::F16_RN, and ensure that the template context receives a
mapped value (e.g., tmp["tma_desc_format"] =
dtype_to_tma_format_mapping[kspec.dtype]) so the kernel dtype selects F16_RN for
FP16 and BF16_RN for BF16 (apply same change for all three descriptors and for
warp-specialized BF16 paths when use_tma is enabled).
- Around line 147-151: The code sets tensor_size_qkv[1] using a non-existent
launch_params.seqlens member; replace that usage with
launch_params.total_kv_seqlen so tensor_size_qkv[1] = params.is_s_padded ?
params.s * params.b : launch_params.total_kv_seqlen; locate the assignment to
tensor_size_qkv in kernel_hopper.jinja and update it accordingly (references:
tensor_size_qkv, params.is_s_padded, params.s, params.b,
launch_params.total_kv_seqlen, Fused_multihead_attention_launch_params).
In `@flashinfer/jit/__init__.py`:
- Line 58: Add a deprecated alias so existing callers of
get_trtllm_fmha_v2_module continue to work: keep the exported gen_fmha_v2_module
and assign get_trtllm_fmha_v2_module = gen_fmha_v2_module in the same module,
and emit a DeprecationWarning (via the warnings module) when
get_trtllm_fmha_v2_module is accessed/called to guide users to the new name;
update the top-level export list if present to include both names.
In `@flashinfer/jit/attention/fmha_v2/fmha_library.py`:
- Around line 1132-1228: generate_jit_sources currently returns only
source_paths; change it to produce and return a JitSpec by wrapping the
generated files/specs with gen_jit_spec so the JIT wiring and GPU gating are
applied. After writing kernel sources and the API (get_api_code /
write_if_different) collect the same uri, gen_directory, source_paths and
specs_names and call gen_jit_spec(...) returning its result instead of
source_paths, and pass a supported_major_versions list (e.g.,
supported_major_versions=[8, 9] or the appropriate CUDA major versions for this
module) to ensure the module only compiles on supported GPUs. Ensure function
name generate_jit_sources remains the entry point but now returns the JitSpec
from gen_jit_spec.
- Around line 92-114: The unused-parameter linter warnings are caused by
parameters/variables not being referenced; rename them by prefixing with
underscores: in function select_ldgsts change the signature to use _sm,
_warp_specialization, _head_size, _dtype (and update any internal refs
accordingly) so Ruff ARG001 is silenced, and in get_signature rename the
parameter use_tma to _use_tma and the local variable sm_name to _sm_name
(updating any local references) to mark them intentionally unused.
- Around line 146-250: Import Any from typing and annotate the local variable
spec as dict[str, Any] so mypy knows it can contain heterogeneous types before
it's unpacked into FMHAv2KernelSpec; update the declaration of spec (the dict
passed to FMHAv2KernelSpec(**spec)) to use the dict[str, Any] annotation and add
the corresponding "from typing import Any" import at the top of the module.
- Around line 1173-1184: Annotate the other_configs variable (the result of
itertools.product in fmha_library.py) with a precise Iterator type instead of
leaving it untyped; add an import for typing.Iterator and declare other_configs
as Iterator[tuple] (i.e., use Iterator[tuple] as the annotation for the product
result) so mypy recognizes it as an iterator returned by itertools.product.
In `@flashinfer/prefill.py`:
- Around line 3827-3832: The parameter typing for sinks is incorrect: change the
annotation and any docs from Optional[List[torch.Tensor]] to
Optional[torch.Tensor] (i.e., a single tensor) wherever it appears (e.g., the
function signature containing
block_tables/out/out_dtype/kv_layout/sinks/pos_encoding_mode around the sinks
parameter and the other occurrence at lines 3888-3890) so it matches the FFI
signature; update any associated docstring/comments and any runtime checks that
expect a list to treat sinks as a tensor instead.
- Around line 3985-4031: The code calls mask_mode.lower() without guarding for
None (mask_mode is Optional), causing an AttributeError when mask_mode is None;
update the checks and mappings that use mask_mode.lower() — specifically the
is_non_causal assignment, the mask_mode_map lookup that sets mask_mode_code, and
any other use of mask_mode.lower() — to handle None safely (e.g., treat None as
the default "causal" or check mask_mode is not None before calling .lower()) so
mask_mode being None no longer raises.
In `@tests/attention/test_fmha_v2_prefill_deepseek.py`:
- Around line 254-275: Remove the debug print that prints the test output:
delete the call to print(output) after the trtllm_fmha_v2_prefill invocation
(referencing the variable output and the function trtllm_fmha_v2_prefill) so the
test no longer emits stdout noise during CI; ensure no other stray prints remain
in the same test function.
🧹 Nitpick comments (2)
tests/attention/test_fmha_v2_prefill_deepseek.py (1)
183-190: Normalizemask_modecasing to match API expectations.
The prefill API defaults to lowercase mode strings (e.g.,"causal","sliding_window"). Normalizing here avoids case-sensitive mismatches. This pattern repeats in later parametrizations too.🔧 Suggested normalization
- (True, -1, "CAUSAL"), - (True, 127, "SLIDING_WINDOW"), - (True, 512, "SLIDING_WINDOW"), + (True, -1, "causal"), + (True, 127, "sliding_window"), + (True, 512, "sliding_window"),flashinfer/jit/attention/fmha_v2/utils.py (1)
346-349: Address Ruff unused-argument warnings.
get_GMMA_shapeignoresm/k, andget_signatureunpackssm_namewithout using it. Prefix unused parameters/variables with_to satisfy Ruff.✂️ Example cleanup
-def get_GMMA_shape(instruction_traits, m, n, k, warps_n): +def get_GMMA_shape(instruction_traits, _m, n, _k, warps_n): ... - def get_signature(lname, version, cross_mha, use_tma): + def get_signature(lname, version, cross_mha, _use_tma): # The architecture that determines the instruction. - effective_sm, sm_name = get_effective_sm_and_name(kspec) + effective_sm, _sm_name = get_effective_sm_and_name(kspec)Also applies to: 490-492
| // Packed mask (allocated conditionally for CUSTOM_MASK) | ||
| void* packed_mask_d = | ||
| (attention_mask_type == Attention_mask_type::CUSTOM_MASK) | ||
| ? allocator.aligned_alloc<void>(packed_mask_size_in_bytes, 128, "packed_mask_d") | ||
| : nullptr; |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -name "fmha_v2_run.cu" -type fRepository: flashinfer-ai/flashinfer
Length of output: 89
🏁 Script executed:
wc -l csrc/fmha_v2_run.cuRepository: flashinfer-ai/flashinfer
Length of output: 91
🏁 Script executed:
sed -n '510,535p' csrc/fmha_v2_run.cuRepository: flashinfer-ai/flashinfer
Length of output: 1441
🏁 Script executed:
sed -n '685,698p' csrc/fmha_v2_run.cuRepository: flashinfer-ai/flashinfer
Length of output: 619
🏁 Script executed:
sed -n '1,50p' csrc/fmha_v2_run.cuRepository: flashinfer-ai/flashinfer
Length of output: 1879
🏁 Script executed:
sed -n '480,540p' csrc/fmha_v2_run.cuRepository: flashinfer-ai/flashinfer
Length of output: 3104
🏁 Script executed:
rg -n "void\* packed_mask_d" csrc/fmha_v2_run.cu -B 50 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 5344
🏁 Script executed:
sed -n '400,475p' csrc/fmha_v2_run.cuRepository: flashinfer-ai/flashinfer
Length of output: 3263
🏁 Script executed:
sed -n '300,410p' csrc/fmha_v2_run.cuRepository: flashinfer-ai/flashinfer
Length of output: 6247
🏁 Script executed:
sed -n '640,675p' csrc/fmha_v2_run.cuRepository: flashinfer-ai/flashinfer
Length of output: 2082
🏁 Script executed:
rg -n "set_params" csrc/fmha_v2_run.cuRepository: flashinfer-ai/flashinfer
Length of output: 499
🏁 Script executed:
sed -n '620,670p' csrc/fmha_v2_run.cuRepository: flashinfer-ai/flashinfer
Length of output: 3060
🏁 Script executed:
sed -n '328,360p' csrc/fmha_v2_run.cuRepository: flashinfer-ai/flashinfer
Length of output: 2055
🏁 Script executed:
rg -n "CUSTOM_MASK" csrc/fmha_v2_run.cuRepository: flashinfer-ai/flashinfer
Length of output: 427
🏁 Script executed:
rg -n "CUSTOM_MASK|custom_mask" csrc/fmha_v2_run.cu -A 3 -B 3Repository: flashinfer-ai/flashinfer
Length of output: 1854
🏁 Script executed:
rg -n "mask_mode_code|attention_mask_type" csrc/fmha_v2_run.cu | grep -i "error\|assert\|reject\|valid\|check\|fail"Repository: flashinfer-ai/flashinfer
Length of output: 50
CUSTOM_MASK path uses uninitialized mask data.
packed_mask_d is allocated (lines 523-525) but never populated, and there is no API input to supply a custom mask. The function signature accepts mask_mode_code as an integer flag but no corresponding mask tensor. When attention_mask_type == CUSTOM_MASK, uninitialized memory is passed to set_params() and subsequently to the kernel, resulting in undefined behavior. Add a mask tensor input parameter and populate packed_mask_d before calling set_params(), or reject CUSTOM_MASK with a clear error.
Also applies to: 692-695 (NOTE comment acknowledging the gap without implementation)
🤖 Prompt for AI Agents
In `@csrc/fmha_v2_run.cu` around lines 521 - 525, When attention_mask_type equals
Attention_mask_type::CUSTOM_MASK the code allocates packed_mask_d via
allocator.aligned_alloc but never populates it or accepts a mask tensor; update
the API and call site to accept a packed/custom mask input (e.g., add a device
pointer or Tensor argument named packed_mask_input) and memcpy/copy that data
into packed_mask_d before calling set_params() (which currently consumes
packed_mask_d and mask_mode_code), or if you prefer not to support custom masks
yet, explicitly error out when attention_mask_type == CUSTOM_MASK (using a clear
runtime error) so uninitialized memory is never passed into set_params()/the
kernel; change references in this file for packed_mask_d, attention_mask_type,
Attention_mask_type::CUSTOM_MASK, allocator.aligned_alloc, set_params(), and
mask_mode_code accordingly.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/prefill.py`:
- Line 3846: The call module = get_trtllm_fmha_v2_module(120) passes an argument
to get_trtllm_fmha_v2_module which is defined without parameters, causing a
TypeError; fix by either removing the argument at the call site (change the call
to get_trtllm_fmha_v2_module()) if SM filtering is not needed, or update the
function signature def get_trtllm_fmha_v2_module(sm_version): and propagate/use
that sm_version inside the function to perform any SM-version-specific
filtering/logic; choose one approach and make the call and function signature
consistent.
🧹 Nitpick comments (2)
flashinfer/prefill.py (2)
4123-4126: Remove useless conditional.Both branches of the ternary return
1.0, making the condition pointless. This appears to be leftover from development where different scales were intended for FP8 vs non-FP8.♻️ Suggested fix
- # Softmax scale is typically 1.0 for FP8 and auto for FP16/BF16 - is_e4m3 = ( - query.dtype == torch.float8_e4m3fn if hasattr(torch, "float8_e4m3fn") else False - ) - scale_softmax = 1.0 if is_e4m3 else 1.0 + scale_softmax = 1.0If dtype-specific scaling is needed in the future, restore the conditional with different values.
3906-3908: Unused parameter and incomplete return type annotation.
- The
non_blockingparameter (line 3906) is declared but never used in the function body.- The return type annotation
-> torch.Tensordoesn't reflect that the function returnsTuple[torch.Tensor, torch.Tensor]whensave_softmax_stats=True.♻️ Suggested fix
- non_blocking: Optional[bool] = True, save_softmax_stats: Optional[bool] = False, -) -> torch.Tensor: +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:Either use
non_blockingfor tensor copies (e.g., when converting layouts) or remove the parameter to avoid confusion.
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Removed filtering of target SM versions for CUDA architectures.
Refactor input layout.
|
@hypdeb I have addressed your feedback - I've added a unit-test for chunked prefill (q < kv) with chunked attention mask, please let me know if this looks good to you now, thanks! |
|
/bot run |
|
[SUCCESS] Pipeline #45369480: 8/20 passed |
|
/bot run |
|
/bot run |
|
[SUCCESS] Pipeline #45475570: 8/20 passed |
bkryu
left a comment
There was a problem hiding this comment.
Approving as owner of unit test files.
Can you add microbenchmark support?
|
🎉🎉🎉 |
…flashinfer-ai#2446) <!-- .github/pull_request_template.md --> ## 📌 Description Integrates TRT-LLM's Fused Multi-Head Attention v2 (FMHAv2) kernels into FlashInfer as a prefill attention backend via a new trtllm_fmha_v2_prefill Python API, backed by a CUDA runtime dispatch layer (fmha_v2_run.cu) that bridges FlashInfer's tensor conventions with TRT-LLM's fused kernel infrastructure. - Data types: FP16, BF16, FP8 (E4M3) inputs with configurable output dtype (e.g., FP8 → BF16/FP16). FP32 accumulation for BF16/E4M3; fused BMM1/BMM2 scaling for quantized inference. - Input layouts: Packed QKV [T, 3, H, D], Contiguous Q+KV [T, H, D]+[T, 2, H_kv, D], Separate Q/K/V, and Q + Paged KV Cache (page sizes 32/128, HND/NHD layouts). Block tables auto-expanded from [B, M] → [B, 2, M] for separate K/V offset addressing. - Masking: Causal, sliding window (window_left), chunked causal (power-of-2 sizes). - Ragged batching: Variable-length sequences via cumulative sequence length tensors. - GQA/MQA: Grouped-query and multi-query attention (h % h_kv == 0). - Attention sinks: Per-head sink values in softmax denominator. - Logits soft capping: Fused tanh-based softcapping with configurable scale. - ALiBi: Optional ALiBi positional encoding. - Softmax stats (LSE) export: Optional save_softmax_stats returns [max, sum_exp] per token/head. - Non-blocking execution Also adds skip-softmax optimization introduced in [TRT-LLM's Documentation](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog16_Accelerating_Long_Context_Inference_with_Skip_Softmax_Attention.md#accelerating-long-context-inference-with-skip-softmax-attention) ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Full FMHA v2 runtime with JIT-generated, runtime-dispatched kernels supporting multiple QKV layouts, causal/sliding/custom masks, chunked/windowed attention, FP8 and mixed-precision, and optional skip-softmax optimization. * **API** * New Python module generators and public prefill/run accessors to load and invoke v2 kernels. * **Tests** * Extensive new tests covering layouts, paged KV, sinks, dtypes, and numerical correctness. * **Bug Fixes** * Safer 64-bit offset arithmetic to prevent row-pointer overflow. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Yuxin <yuxinz@nvidia.com> Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> Co-authored-by: Yuxin <yuxinz@nvidia.com> Co-authored-by: Bo Li <22713281+bobboli@users.noreply.github.com> Co-authored-by: Blake Ledden <blake@secondnaturecomputing.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
…flashinfer-ai#2446) <!-- .github/pull_request_template.md --> ## 📌 Description Integrates TRT-LLM's Fused Multi-Head Attention v2 (FMHAv2) kernels into FlashInfer as a prefill attention backend via a new trtllm_fmha_v2_prefill Python API, backed by a CUDA runtime dispatch layer (fmha_v2_run.cu) that bridges FlashInfer's tensor conventions with TRT-LLM's fused kernel infrastructure. - Data types: FP16, BF16, FP8 (E4M3) inputs with configurable output dtype (e.g., FP8 → BF16/FP16). FP32 accumulation for BF16/E4M3; fused BMM1/BMM2 scaling for quantized inference. - Input layouts: Packed QKV [T, 3, H, D], Contiguous Q+KV [T, H, D]+[T, 2, H_kv, D], Separate Q/K/V, and Q + Paged KV Cache (page sizes 32/128, HND/NHD layouts). Block tables auto-expanded from [B, M] → [B, 2, M] for separate K/V offset addressing. - Masking: Causal, sliding window (window_left), chunked causal (power-of-2 sizes). - Ragged batching: Variable-length sequences via cumulative sequence length tensors. - GQA/MQA: Grouped-query and multi-query attention (h % h_kv == 0). - Attention sinks: Per-head sink values in softmax denominator. - Logits soft capping: Fused tanh-based softcapping with configurable scale. - ALiBi: Optional ALiBi positional encoding. - Softmax stats (LSE) export: Optional save_softmax_stats returns [max, sum_exp] per token/head. - Non-blocking execution Also adds skip-softmax optimization introduced in [TRT-LLM's Documentation](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog16_Accelerating_Long_Context_Inference_with_Skip_Softmax_Attention.md#accelerating-long-context-inference-with-skip-softmax-attention) ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Full FMHA v2 runtime with JIT-generated, runtime-dispatched kernels supporting multiple QKV layouts, causal/sliding/custom masks, chunked/windowed attention, FP8 and mixed-precision, and optional skip-softmax optimization. * **API** * New Python module generators and public prefill/run accessors to load and invoke v2 kernels. * **Tests** * Extensive new tests covering layouts, paged KV, sinks, dtypes, and numerical correctness. * **Bug Fixes** * Safer 64-bit offset arithmetic to prevent row-pointer overflow. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Yuxin <yuxinz@nvidia.com> Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> Co-authored-by: Yuxin <yuxinz@nvidia.com> Co-authored-by: Bo Li <22713281+bobboli@users.noreply.github.com> Co-authored-by: Blake Ledden <blake@secondnaturecomputing.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
Integrates TRT-LLM's Fused Multi-Head Attention v2 (FMHAv2) kernels into FlashInfer as a prefill attention backend via a new trtllm_fmha_v2_prefill Python API, backed by a CUDA runtime dispatch layer (fmha_v2_run.cu) that bridges FlashInfer's tensor conventions with TRT-LLM's fused kernel infrastructure.
Also adds skip-softmax optimization introduced in TRT-LLM's Documentation
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
API
Tests
Bug Fixes