Skip to content

Add fmhav2 to flashinfer_benchmark and fix cuda graph violations#2841

Open
jimmyzho wants to merge 5 commits intoflashinfer-ai:mainfrom
jimmyzho:fmhav2-bench
Open

Add fmhav2 to flashinfer_benchmark and fix cuda graph violations#2841
jimmyzho wants to merge 5 commits intoflashinfer-ai:mainfrom
jimmyzho:fmhav2-bench

Conversation

@jimmyzho
Copy link
Copy Markdown
Contributor

@jimmyzho jimmyzho commented Mar 20, 2026

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added support for the trtllm‑fmha‑v2 attention backend.
  • Improvements

    • Allow attention scaling as either a scalar or a tensor.
    • Use token vs. page cumulative indices and ensure contiguous KV cache for the new backend.
    • Updated backend mappings for newer GPU compute capabilities.
  • Bug Fixes

    • Simplified softmax-stats allocation and runtime handling.
  • Tests

    • Re-enabled attention prefill tests and narrowed FP8 skip conditions.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 20, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds trtllm-fmha-v2 prefill backend support, passes bmm2 scale as scalar-or-tensor via a Variant/Union through Python→TVM→CUDA, distinguishes page vs token cumulative KV indirections, ensures contiguous KV cache for trtllm-fmha-v2, and re-enables the FMHA v2 prefill tests with adjusted skips.

Changes

Cohort / File(s) Summary
Benchmark dispatch & KV indirection
benchmarks/routines/attention.py, benchmarks/routines/flashinfer_benchmark_utils.py
Added trtllm-fmha-v2 to backend lists and dispatch branches; introduced kv_token_indptr (cumulative token counts) alongside kv_indptr (cumulative page counts) and route cum_seq_lens_kv appropriately; ensure contiguous KV cache for trtllm-fmha-v2; adjust input layouts for paged vs ragged prefill.
FMHA v2 C++/CUDA API
csrc/fmha_v2_jit_binding.cu, csrc/fmha_v2_run.cu
Replaced separate scalar+tensor scale parameters with tvm::ffi::Variant<double, tvm::ffi::Tensor> for scale_bmm2; extract scalar value or tensor device pointer at runtime and update kernel invocation/signature; removed fallback workspace allocation for softmax stats.
Python prefill wrapper
flashinfer/prefill.py
Changed trtllm_fmha_v2_prefill to accept bmm2_scale: Union[float, torch.Tensor] and forward it directly to module.run; removed creation/passing of device-side scale_bmm2_d.
Tests
tests/attention/test_fmha_v2_prefill.py
Uncommented module-level tests (module-level skip removed/commented); broadened FP8 skip to skip all FP8 input-dtype cases; preserved original hang-skip as commented block.

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python Prefill
    participant TVM as TVM FFI
    participant CUDA as fmha_v2_run

    Py->>TVM: module.run(..., bmm2_scale: float OR torch.Tensor, kv_cache, cum_seq_lens_q, cum_seq_lens_kv, input_layout)
    TVM->>TVM: wrap bmm2_scale into Variant<double, Tensor>
    TVM->>CUDA: fmha_v2_run(Variant scale_bmm2, other params)
    CUDA->>CUDA: if Variant is double -> use scalar value
    CUDA->>CUDA: if Variant is Tensor -> use device ptr + scalar value
    CUDA-->>Py: attention outputs, LSE updates, KV cache updates
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • sricketts
  • aleozlx
  • yongwww
  • yzh119
  • cyx-6
  • samuellees
  • saltyminty
  • bkryu
  • yyihuang
  • kahyunnam
  • nv-yunzheq

Poem

🐰 I hopped through tensors, scales in tow,
Pages counting tokens row by row,
Variant whispers: scalar or device,
FMHA v2 aligned and nice,
Cache contiguous — now watch me go!

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description contains only the repository template with no actual content filled in; all sections are empty placeholders, providing no explanation of changes, related issues, or implementation details. Fill in the description section with what changes were made and why, link any related issues, and provide reviewer notes highlighting the key areas (CUDA graph fixes, fmhav2 integration, API signature changes).
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title directly addresses the main objective: adding fmhav2 support to flashinfer_benchmark and fixing cuda graph violations, which aligns with the comprehensive changes across multiple files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on integrating the trtllm-fmha-v2 backend into the FlashInfer benchmarking framework. It addresses critical CUDA graph compatibility issues by adjusting how certain backends interact with CUDA graph capture, ensuring robust performance testing. Additionally, it refines the handling of scaling parameters and memory layouts to properly support the new attention kernel.

Highlights

  • New Backend Integration: The trtllm-fmha-v2 kernel backend has been successfully integrated into the FlashInfer benchmarking suite, enabling performance evaluation for this new fused multi-head attention kernel.
  • CUDA Graph Compatibility Fixes: Modifications were made to benchmark runners to explicitly exclude trtllm-fmha-v2, trtllm-native, and cudnn-native backends from CUDA graph capture, resolving potential violations and ensuring stable execution.
  • Flexible bmm2_scale Handling: The fmha_v2_run C++ function and its Python binding were updated to accept the bmm2_scale parameter as either a float or a tensor, providing greater flexibility for different scaling requirements.
  • TRT-LLM Specific Indptr: A new kv_token_indptr was introduced for TRT-LLM backends to correctly handle token-based cumulative sequence lengths, which is crucial for accurate attention calculations.
  • Memory Layout Adjustment: Logic was added to ensure the kv_cache for trtllm-fmha-v2 maintains a contiguous memory layout when utilizing the HND paged KV format, optimizing data access.
  • Benchmark Test Activation: The test_fmha_v2_prefill.py test file, which was previously skipped, has been re-enabled, allowing for comprehensive testing of the FMHA v2 prefill functionality.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the trtllm-fmha-v2 backend to the FlashInfer benchmark suite and addresses some CUDA graph violations. The changes span across the Python benchmark scripts, the C++ bindings to support a variant type for scale_bmm2, and the core prefill logic. My review identified a potential bug where CUDA graph compatibility checks are inconsistent across different test wrappers, and some commented-out code that should be removed for clarity. Since this is a draft, I've also noted that several test cases are currently disabled, which should be re-enabled and verified before merging.

Comment thread benchmarks/routines/attention.py Outdated
sleep_after_run=False,
enable_cupti=args.use_cupti,
use_cuda_graph=(is_cuda_graph_compatible and cur_backend != "fa2"),
use_cuda_graph=(is_cuda_graph_compatible and cur_backend not in ("fa2")),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for use_cuda_graph appears to be inconsistent here. In other parts of this file (e.g., lines 682-686 and 1973-1977), you are excluding "trtllm-fmha-v2", "trtllm-native", and "cudnn-native" from CUDA graph usage, which is likely correct given the PR's goal to fix graph violations. However, this line only excludes "fa2". This could lead to CUDA graph errors for the other backends within this specific test wrapper. Please ensure this check is consistent with the others.

            use_cuda_graph=(
                is_cuda_graph_compatible
                and cur_backend
                not in ("fa2", "trtllm-fmha-v2", "trtllm-native", "cudnn-native")
            ),

Comment thread flashinfer/prefill.py Outdated
Comment on lines +4220 to +4228
# scale_bmm2_d = _create_scale_bmm2_d_tensor(scale_bmm2, query.dtype, query.device)
# scale_bmm2_d = torch.tensor([scale_bmm2], dtype=torch.float32, device=query.device)
# if isinstance(bmm2_scale, float):
# bmm2_scale_d = torch.tensor([bmm2_scale], dtype=torch.float32, device=query.device)
# else:
# bmm2_scale_d = bmm2_scale.to(device=query.device, dtype=torch.float32)
# scale_bmm2 = bmm2_scale_d.view(torch.float32)
# bmm2_scale_d = torch.empty(1, dtype=torch.float32, device=query.device)
# bmm2_scale_d.fill_(bmm2_scale)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block contains a significant amount of commented-out code, which seems to be from previous implementation attempts. To improve code readability and maintainability, please remove this dead code before merging.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
benchmarks/routines/attention.py (1)

1354-1360: ⚠️ Potential issue | 🟠 Major

trtllm-fmha-v2 is still being CUDA-graphed in paged prefill.

This loop only disables capture for fa2, while the decode, ragged, and MLA benchmark loops all exempt trtllm-fmha-v2, trtllm-native, and cudnn-native. The new backend will still go through capture here and can hit the same graph-violation path you are trying to avoid.

Proposed fix
-            use_cuda_graph=(is_cuda_graph_compatible and cur_backend not in ("fa2")),
+            use_cuda_graph=(
+                is_cuda_graph_compatible
+                and cur_backend
+                not in ("fa2", "trtllm-fmha-v2", "trtllm-native", "cudnn-native")
+            ),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1354 - 1360, The CUDA-graph
capture is only being disabled for "fa2" here causing "trtllm-fmha-v2" (and
other GPUs) to still be graphed; update the condition passed to bench_gpu_time
in the backend_times assignment so use_cuda_graph is False for the same set of
backends used elsewhere. Specifically, change the use_cuda_graph expression that
currently checks (is_cuda_graph_compatible and cur_backend not in ("fa2")) to
exclude "trtllm-fmha-v2", "trtllm-native", and "cudnn-native" as well (e.g.,
cur_backend not in ("fa2","trtllm-fmha-v2","trtllm-native","cudnn-native")),
leaving bench_gpu_time, run_backend_wrapper, is_cuda_graph_compatible and
cur_backend unchanged.
🧹 Nitpick comments (1)
tests/attention/test_fmha_v2_prefill.py (1)

797-804: Prefer marked params over commenting cases out.

Commenting these tuples out removes the unstable combinations from CI entirely, and every surviving call still passes bmm2_scale as a Python float. Converting the disabled cases to pytest.param(..., marks=skip/xfail) would keep the gaps explicit and make it easier to add a tensor-scale regression case later.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_fmha_v2_prefill.py` around lines 797 - 804, Replace the
commented-out parameter tuples in tests/attention/test_fmha_v2_prefill.py (e.g.,
("PACKED_QKV", None, False), ("CONTIGUOUS_Q_KV", None, False),
("CONTIGUOUS_Q_KV", None, True), ("SEPARATE_Q_K_V", None, False),
("Q_PAGED_KV_NHD", 128, False), ("Q_PAGED_KV_HND", 128, False)) with
pytest.param entries that include pytest.mark.skip or pytest.mark.xfail as
appropriate so the combinations remain explicit in the param list; keep the same
tuple values but wrap them with pytest.param(...,
marks=pytest.mark.skip(reason="unstable")) (or xfail) so CI knows they are
intentionally disabled and can be re-enabled or converted to tensor-scale
regression cases later. Ensure these new pytest.param entries are added into the
same parameterization target (the test's param list that currently contains
("Q_PAGED_KV_NHD", 32, False) etc.) so the test harness still iterates over all
cases.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/attention.py`:
- Around line 1158-1161: The code always materializes a contiguous HND copy via
kv_cache.contiguous() into _fmha_v2_kv_cache which forces a full KV allocation
for all benchmarks; change this to only create the contiguous copy when the
trtllm-fmha-v2 backend is actually used (i.e., check if 'trtllm-fmha-v2' is in
backends) and otherwise keep _fmha_v2_kv_cache as the original kv_cache (or
defer copying), so only trtllm-fmha-v2 triggers the extra allocation.
- Around line 1916-1932: The trtllm-fmha-v2 branch does not pass the validated
out_dtype into trtllm_fmha_v2_prefill, so FP8/output-type is ignored; update the
call in the backend == "trtllm-fmha-v2" branch to forward the validated
out_dtype (the same variable recorded earlier) by adding out_dtype=out_dtype to
the trtllm_fmha_v2_prefill invocation so its output dtype honors the requested
--out_dtype setting.

In `@benchmarks/routines/flashinfer_benchmark_utils.py`:
- Around line 307-310: The advertised backends list includes "trtllm-fmha-v2"
for CUDA SM "9.0" and "12.0" without checking CUDA toolkit version; update the
benchmark backend selection to gate FMHAv2 behind the runtime helpers: call
is_sm90a_supported() before including "trtllm-fmha-v2" for the "9.0" entry and
is_sm120a_supported() before including it for the "12.0" entry (or perform
equivalent CUDA-version checks inside the benchmark filter function that
constructs the backend list) so that FMHAv2 is only advertised when the CUDA
version meets the runtime requirements.

In `@flashinfer/prefill.py`:
- Around line 3955-3956: The code must preserve the old split representation for
BMM2 scale so the C++ FMHAv2 path receives both a host scalar and an encoded
device buffer: when handling the parameter bmm2_scale (formerly via
_create_scale_bmm2_d_tensor()), ensure you provide a float host scalar (used by
set_alpha / params.scale_bmm2) and a device-side packed uint32_t tensor/buffer
(used by params.scale_bmm2_d) before calling module.run(); if bmm2_scale is a
torch.Tensor, extract/convert its scalar value for the host and separately
encode/pack the device representation (matching the previous
_create_scale_bmm2_d_tensor behavior), and if it’s a Python float keep both
representations consistent—update the code paths that construct
params.scale_bmm2 and params.scale_bmm2_d (and any callers of set_alpha and
module.run()) to use these two distinct values.

---

Outside diff comments:
In `@benchmarks/routines/attention.py`:
- Around line 1354-1360: The CUDA-graph capture is only being disabled for "fa2"
here causing "trtllm-fmha-v2" (and other GPUs) to still be graphed; update the
condition passed to bench_gpu_time in the backend_times assignment so
use_cuda_graph is False for the same set of backends used elsewhere.
Specifically, change the use_cuda_graph expression that currently checks
(is_cuda_graph_compatible and cur_backend not in ("fa2")) to exclude
"trtllm-fmha-v2", "trtllm-native", and "cudnn-native" as well (e.g., cur_backend
not in ("fa2","trtllm-fmha-v2","trtllm-native","cudnn-native")), leaving
bench_gpu_time, run_backend_wrapper, is_cuda_graph_compatible and cur_backend
unchanged.

---

Nitpick comments:
In `@tests/attention/test_fmha_v2_prefill.py`:
- Around line 797-804: Replace the commented-out parameter tuples in
tests/attention/test_fmha_v2_prefill.py (e.g., ("PACKED_QKV", None, False),
("CONTIGUOUS_Q_KV", None, False), ("CONTIGUOUS_Q_KV", None, True),
("SEPARATE_Q_K_V", None, False), ("Q_PAGED_KV_NHD", 128, False),
("Q_PAGED_KV_HND", 128, False)) with pytest.param entries that include
pytest.mark.skip or pytest.mark.xfail as appropriate so the combinations remain
explicit in the param list; keep the same tuple values but wrap them with
pytest.param(..., marks=pytest.mark.skip(reason="unstable")) (or xfail) so CI
knows they are intentionally disabled and can be re-enabled or converted to
tensor-scale regression cases later. Ensure these new pytest.param entries are
added into the same parameterization target (the test's param list that
currently contains ("Q_PAGED_KV_NHD", 32, False) etc.) so the test harness still
iterates over all cases.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 258e53c5-6814-4444-b29e-7e3dc775097a

📥 Commits

Reviewing files that changed from the base of the PR and between 7cb016d and b27290f.

📒 Files selected for processing (6)
  • benchmarks/routines/attention.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • csrc/fmha_v2_jit_binding.cu
  • csrc/fmha_v2_run.cu
  • flashinfer/prefill.py
  • tests/attention/test_fmha_v2_prefill.py

Comment thread benchmarks/routines/attention.py Outdated
Comment thread benchmarks/routines/flashinfer_benchmark_utils.py
Comment thread flashinfer/prefill.py
Comment on lines 3955 to +3956
bmm1_scale: float,
bmm2_scale: float,
bmm2_scale: Union[float, torch.Tensor],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Keep scale_bmm2's scalar and device-buffer representations separate.

This change drops _create_scale_bmm2_d_tensor() and forwards bmm2_scale directly, but the FMHAv2 C++ path still needs both halves of the old contract: set_alpha(params.scale_bmm2, ...) consumes the scalar value, and params.scale_bmm2_d expects the prepacked device bits. With a tensor input, the C++ side now just reinterprets the tensor pointer as uint32_t* and the scalar side falls back to 0.f, so FP16/BF16 tensor scales will program the wrong BMM2 scale. Please keep the old split API, or reconstruct both a host scalar and an encoded device tensor before module.run().

Also applies to: 4218-4258

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 3955 - 3956, The code must preserve the
old split representation for BMM2 scale so the C++ FMHAv2 path receives both a
host scalar and an encoded device buffer: when handling the parameter bmm2_scale
(formerly via _create_scale_bmm2_d_tensor()), ensure you provide a float host
scalar (used by set_alpha / params.scale_bmm2) and a device-side packed uint32_t
tensor/buffer (used by params.scale_bmm2_d) before calling module.run(); if
bmm2_scale is a torch.Tensor, extract/convert its scalar value for the host and
separately encode/pack the device representation (matching the previous
_create_scale_bmm2_d_tensor behavior), and if it’s a Python float keep both
representations consistent—update the code paths that construct
params.scale_bmm2 and params.scale_bmm2_d (and any callers of set_alpha and
module.run()) to use these two distinct values.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
benchmarks/routines/attention.py (2)

1912-1928: ⚠️ Potential issue | 🟠 Major

Pass out_dtype through the ragged FMHAv2 benchmark path.

This branch validates --out_dtype up front but drops it here. FP8 runs will therefore benchmark trtllm_fmha_v2_prefill with its default output dtype instead of the requested one, which skews both output comparisons and bandwidth numbers.

♻️ Proposed fix
         elif backend == "trtllm-fmha-v2":
             _q_scale = q_scale if q_scale is not None else 1.0
             _k_scale = k_scale if k_scale is not None else 1.0
             return trtllm_fmha_v2_prefill(
                 qkv=(q, k, v),
                 input_layout="SEPARATE_Q_K_V",
                 workspace_buffer=workspace_buffer,
                 seq_lens=actual_seq_lens_kv_device.flatten(),
                 max_q_len=s_qo,
                 max_kv_len=s_kv,
                 bmm1_scale=_q_scale * _k_scale * scale,
                 bmm2_scale=_fmha_v2_bmm2_scale,
                 batch_size=batch_size,
                 cum_seq_lens_q=qo_indptr,
                 cum_seq_lens_kv=kv_indptr,
+                out_dtype=out_dtype,
                 mask_mode="causal" if causal else "padding",
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1912 - 1928, The
trtllm-fmha-v2 branch drops the validated out_dtype so trtllm_fmha_v2_prefill
runs with its default; update the call in the elif backend == "trtllm-fmha-v2"
branch to pass the out_dtype parameter (the same out_dtype validated earlier)
into trtllm_fmha_v2_prefill so FP8/other requested output dtypes are
honored—look for the trtllm_fmha_v2_prefill invocation and add the out_dtype
argument alongside the existing arguments (e.g., bmm1_scale, bmm2_scale,
batch_size, etc.).

1149-1157: ⚠️ Potential issue | 🟠 Major

Gate the FMHAv2 HND copy on backend selection.

kv_cache.contiguous() still runs unconditionally here, so every paged-prefill benchmark pays for a full extra KV allocation even when trtllm-fmha-v2 is not being timed. On larger cases that can distort memory pressure or OOM unrelated backends.

♻️ Proposed fix
-    _fmha_v2_kv_cache = kv_cache.contiguous()
+    _fmha_v2_kv_cache = (
+        kv_cache.contiguous() if "trtllm-fmha-v2" in backends else None
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1149 - 1157, The unconditional
contiguous copy causes extra KV allocation for all backends; wrap the contiguous
call so _fmha_v2_kv_cache = kv_cache.contiguous() only when the selected backend
equals "trtllm-fmha-v2" (leave _fmha_v2_kv_cache as kv_cache otherwise). Update
the code around the existing _fmha_v2_kv_cache and _fmha_v2_bmm2_scale
assignments to branch on the backend selection used in this routine (e.g.,
compare backend or impl string to "trtllm-fmha-v2") so only that backend pays
the copy.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@benchmarks/routines/attention.py`:
- Around line 1912-1928: The trtllm-fmha-v2 branch drops the validated out_dtype
so trtllm_fmha_v2_prefill runs with its default; update the call in the elif
backend == "trtllm-fmha-v2" branch to pass the out_dtype parameter (the same
out_dtype validated earlier) into trtllm_fmha_v2_prefill so FP8/other requested
output dtypes are honored—look for the trtllm_fmha_v2_prefill invocation and add
the out_dtype argument alongside the existing arguments (e.g., bmm1_scale,
bmm2_scale, batch_size, etc.).
- Around line 1149-1157: The unconditional contiguous copy causes extra KV
allocation for all backends; wrap the contiguous call so _fmha_v2_kv_cache =
kv_cache.contiguous() only when the selected backend equals "trtllm-fmha-v2"
(leave _fmha_v2_kv_cache as kv_cache otherwise). Update the code around the
existing _fmha_v2_kv_cache and _fmha_v2_bmm2_scale assignments to branch on the
backend selection used in this routine (e.g., compare backend or impl string to
"trtllm-fmha-v2") so only that backend pays the copy.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 134a09e5-3fbc-4c60-adc6-c344113ccce5

📥 Commits

Reviewing files that changed from the base of the PR and between b27290f and 2b3241b.

📒 Files selected for processing (2)
  • benchmarks/routines/attention.py
  • flashinfer/prefill.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (4)
flashinfer/prefill.py (1)

4104-4104: ⚠️ Potential issue | 🟠 Major

Tighten the bmm2_scale tensor contract.

csrc/fmha_v2_run.cu still treats the tensor variant as the prepacked scale_bmm2_d buffer and just forwards data_ptr(). With this change, an ordinary torch.Tensor now bypasses _create_scale_bmm2_d_tensor() and reaches the kernel as raw floats, which is misencoded for the FP16/BF16 path. Either keep this API float-only, or validate/document that the tensor form must already be the packed device buffer.

Also applies to: 4385-4394

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` at line 4104, The bmm2_scale parameter is being
accepted as Union[float, torch.Tensor] but the CUDA kernel (csrc/fmha_v2_run.cu
/ fmha_v2_run) expects the tensor form to be the already-prepacked device buffer
(created by _create_scale_bmm2_d_tensor) and currently just forwards data_ptr(),
so ordinary torch.Tensors get misencoded for FP16/BF16; fix by tightening the
contract in prefill.py: either change the signature to accept only float (remove
torch.Tensor variant) or add a runtime check in the caller/handler that if
bmm2_scale is a torch.Tensor it is on CUDA, has the exact shape/dtype/strides of
the prepacked scale buffer (or a flag indicating it was created via
_create_scale_bmm2_d_tensor), otherwise raise a clear ValueError instructing
callers to use _create_scale_bmm2_d_tensor to produce the packed device buffer;
reference bmm2_scale, _create_scale_bmm2_d_tensor, and the fmha_v2_run kernel
when implementing the validation.
benchmarks/routines/flashinfer_benchmark_utils.py (1)

304-313: ⚠️ Potential issue | 🟠 Major

Gate trtllm-fmha-v2 on runtime support, not just compute capability.

filter_backends_by_compute_capability() only consults this CC table, so these unconditional 9.0/12.0 entries make the benchmark advertise FMHAv2 even on toolchains that the runtime backend will later reject. Please apply the same runtime support helper the backend uses instead of listing trtllm-fmha-v2 here unconditionally.

Also applies to: 317-326

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/flashinfer_benchmark_utils.py` around lines 304 - 313,
The compute-capability mapping currently lists "trtllm-fmha-v2" unconditionally
for "9.0" and "12.0", causing filter_backends_by_compute_capability() to
advertise FMHAv2 even when the runtime will later reject it; update this by
removing the unconditional "trtllm-fmha-v2" entries from the CC map and instead
add "trtllm-fmha-v2" only after checking the same runtime-support helper the
backend uses (call that helper where filter_backends_by_compute_capability is
used, e.g., runtime_supports_trtllm_fmha_v2 or the backend's support check) so
the backend is added only when both CC and runtime support are true; make the
same change for the duplicate block (the other mapping around the referenced
lines) so both places gate "trtllm-fmha-v2" on runtime support rather than
listing it unconditionally.
benchmarks/routines/attention.py (2)

1179-1182: ⚠️ Potential issue | 🟠 Major

Avoid materializing the HND KV copy unconditionally.

This contiguous() runs regardless of whether trtllm-fmha-v2 is actually in backends. For large KV caches, this doubles memory usage unnecessarily when only benchmarking other backends.

Proposed fix
-    _fmha_v2_kv_cache = kv_cache.contiguous()
+    _fmha_v2_kv_cache = (
+        kv_cache.contiguous() if "trtllm-fmha-v2" in backends else None
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1179 - 1182, The code
unconditionally materializes a contiguous HND copy with _fmha_v2_kv_cache =
kv_cache.contiguous(), which doubles memory even when the trtllm-fmha-v2 backend
isn't used; change this to only create the contiguous copy when "trtllm-fmha-v2"
is present in backends (e.g. if "trtllm-fmha-v2" in backends: _fmha_v2_kv_cache
= kv_cache.contiguous() else: _fmha_v2_kv_cache = kv_cache) so other backends
avoid the extra allocation and existing code using _fmha_v2_kv_cache continues
to work.

1943-1959: ⚠️ Potential issue | 🟠 Major

Honor --out_dtype in the ragged FMHA-v2 path.

The routine validates and uses out_dtype for other backends (e.g., cudnn-native at line 1917, trtllm-native via pre-allocated output at line 1871), but the trtllm-fmha-v2 branch doesn't pass it through. FP8 benchmarks will produce the wrong output dtype and yield incorrect bandwidth comparisons versus other backends.

Proposed fix
         elif backend == "trtllm-fmha-v2":
             _q_scale = q_scale if q_scale is not None else 1.0
             _k_scale = k_scale if k_scale is not None else 1.0
             return trtllm_fmha_v2_prefill(
                 qkv=(q, k, v),
                 input_layout="SEPARATE_Q_K_V",
                 workspace_buffer=workspace_buffer,
                 seq_lens=actual_seq_lens_kv_device.flatten(),
                 max_q_len=s_qo,
                 max_kv_len=s_kv,
                 bmm1_scale=_q_scale * _k_scale * scale,
                 bmm2_scale=_fmha_v2_bmm2_scale,
                 batch_size=batch_size,
                 cum_seq_lens_q=qo_indptr,
                 cum_seq_lens_kv=kv_indptr,
+                out_dtype=out_dtype,
                 mask_mode="causal" if causal else "padding",
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1943 - 1959, The
trtllm-fmha-v2 branch is not honoring the out_dtype variable when calling
trtllm_fmha_v2_prefill, causing FP8 outputs to be produced in the wrong dtype;
modify the trtllm-fmha-v2 branch to pass the validated out_dtype into the
trtllm_fmha_v2_prefill call (and ensure any pre-allocated output buffers used
for that path are created/typed using out_dtype), matching how out_dtype is
handled in the cudnn-native and trtllm-native branches so benchmark outputs and
bandwidth calculations are consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@benchmarks/routines/attention.py`:
- Around line 1179-1182: The code unconditionally materializes a contiguous HND
copy with _fmha_v2_kv_cache = kv_cache.contiguous(), which doubles memory even
when the trtllm-fmha-v2 backend isn't used; change this to only create the
contiguous copy when "trtllm-fmha-v2" is present in backends (e.g. if
"trtllm-fmha-v2" in backends: _fmha_v2_kv_cache = kv_cache.contiguous() else:
_fmha_v2_kv_cache = kv_cache) so other backends avoid the extra allocation and
existing code using _fmha_v2_kv_cache continues to work.
- Around line 1943-1959: The trtllm-fmha-v2 branch is not honoring the out_dtype
variable when calling trtllm_fmha_v2_prefill, causing FP8 outputs to be produced
in the wrong dtype; modify the trtllm-fmha-v2 branch to pass the validated
out_dtype into the trtllm_fmha_v2_prefill call (and ensure any pre-allocated
output buffers used for that path are created/typed using out_dtype), matching
how out_dtype is handled in the cudnn-native and trtllm-native branches so
benchmark outputs and bandwidth calculations are consistent.

In `@benchmarks/routines/flashinfer_benchmark_utils.py`:
- Around line 304-313: The compute-capability mapping currently lists
"trtllm-fmha-v2" unconditionally for "9.0" and "12.0", causing
filter_backends_by_compute_capability() to advertise FMHAv2 even when the
runtime will later reject it; update this by removing the unconditional
"trtllm-fmha-v2" entries from the CC map and instead add "trtllm-fmha-v2" only
after checking the same runtime-support helper the backend uses (call that
helper where filter_backends_by_compute_capability is used, e.g.,
runtime_supports_trtllm_fmha_v2 or the backend's support check) so the backend
is added only when both CC and runtime support are true; make the same change
for the duplicate block (the other mapping around the referenced lines) so both
places gate "trtllm-fmha-v2" on runtime support rather than listing it
unconditionally.

In `@flashinfer/prefill.py`:
- Line 4104: The bmm2_scale parameter is being accepted as Union[float,
torch.Tensor] but the CUDA kernel (csrc/fmha_v2_run.cu / fmha_v2_run) expects
the tensor form to be the already-prepacked device buffer (created by
_create_scale_bmm2_d_tensor) and currently just forwards data_ptr(), so ordinary
torch.Tensors get misencoded for FP16/BF16; fix by tightening the contract in
prefill.py: either change the signature to accept only float (remove
torch.Tensor variant) or add a runtime check in the caller/handler that if
bmm2_scale is a torch.Tensor it is on CUDA, has the exact shape/dtype/strides of
the prepacked scale buffer (or a flag indicating it was created via
_create_scale_bmm2_d_tensor), otherwise raise a clear ValueError instructing
callers to use _create_scale_bmm2_d_tensor to produce the packed device buffer;
reference bmm2_scale, _create_scale_bmm2_d_tensor, and the fmha_v2_run kernel
when implementing the validation.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: fc4fcf24-25ba-4c50-ab5a-035efbdc3060

📥 Commits

Reviewing files that changed from the base of the PR and between 2b3241b and 1526a1c.

📒 Files selected for processing (3)
  • benchmarks/routines/attention.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • flashinfer/prefill.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/attention.py`:
- Around line 1947-1964: The trtllm-fmha-v2 branch must filter out FP8 query
tensors before calling trtllm_fmha_v2_prefill because that function rejects
torch.float8_e4m3fn and currently causes a runtime ValueError; add a guard at
the start of the backend == "trtllm-fmha-v2" block that checks q.dtype (and/or
k.dtype/v.dtype as appropriate) for torch.float8_e4m3fn and short-circuits the
branch (e.g., return None or otherwise skip this backend) so the benchmark
treats this combination as an unsupported backend instead of letting
trtllm_fmha_v2_prefill raise. Ensure you reference the existing variables q, k,
v and the call to trtllm_fmha_v2_prefill when inserting the guard.
- Around line 1333-1351: The FMHA-v2 paged branch calls trtllm_fmha_v2_prefill
but omits kv_block_scales, so NVFP4 KV runs without dequant metadata; fix by
forwarding the existing kv_block_scales into the trtllm_fmha_v2_prefill call
(alongside qkv=(q, _fmha_v2_kv_cache)) or alternatively short-circuit this
branch when kv_dtype == "nvfp4" to use the trtllm-native path; update the
trtllm_fmha_v2_prefill invocation to include kv_block_scales=kv_block_scales (or
add the guard checking kv_dtype) so NVFP4 metadata is preserved.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: dfefc35f-79cf-4186-be7e-6285a1aac93e

📥 Commits

Reviewing files that changed from the base of the PR and between 1526a1c and 55277ea.

📒 Files selected for processing (2)
  • benchmarks/routines/attention.py
  • tests/attention/test_fmha_v2_prefill.py

Comment on lines +1947 to +1964
elif backend == "trtllm-fmha-v2":
_q_scale = q_scale if q_scale is not None else 1.0
_k_scale = k_scale if k_scale is not None else 1.0
return trtllm_fmha_v2_prefill(
qkv=(q, k, v),
input_layout="SEPARATE_Q_K_V",
workspace_buffer=workspace_buffer,
seq_lens=actual_seq_lens_kv_device.flatten(),
max_q_len=s_qo,
max_kv_len=s_kv,
bmm1_scale=_q_scale * _k_scale * scale,
bmm2_scale=_fmha_v2_bmm2_scale,
batch_size=batch_size,
cum_seq_lens_q=qo_indptr,
cum_seq_lens_kv=kv_indptr,
mask_mode="causal" if causal else "padding",
out_dtype=out_dtype,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Filter out FP8 queries before the ragged FMHA-v2 call.

trtllm_fmha_v2_prefill(input_layout="SEPARATE_Q_K_V") explicitly rejects torch.float8_e4m3fn queries, but this routine still allows that configuration to reach the new branch. That turns a benchmark flag combination into a runtime ValueError instead of a clean backend skip.

🛡️ Suggested guard
     if "trtllm-native" in backends:
         remove_trtllm_native = False
         if not (head_dim_qk == 192 and head_dim_vo == 128) and not (
             head_dim_qk == 128 and head_dim_vo == 128
         ):
             print(
                 "[INFO] trtllm-native backend requires head_dim_qk == 192 and head_dim_vo == 128 or head_dim_qk == 128 and head_dim_vo == 128. Skipping."
             )
             remove_trtllm_native = True
         if remove_trtllm_native:
             backends.remove("trtllm-native")
+    if "trtllm-fmha-v2" in backends and q_dtype == torch.float8_e4m3fn:
+        print(
+            "[INFO] trtllm-fmha-v2 does not support FP8 query with SEPARATE_Q_K_V. Skipping."
+        )
+        backends.remove("trtllm-fmha-v2")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1947 - 1964, The
trtllm-fmha-v2 branch must filter out FP8 query tensors before calling
trtllm_fmha_v2_prefill because that function rejects torch.float8_e4m3fn and
currently causes a runtime ValueError; add a guard at the start of the backend
== "trtllm-fmha-v2" block that checks q.dtype (and/or k.dtype/v.dtype as
appropriate) for torch.float8_e4m3fn and short-circuits the branch (e.g., return
None or otherwise skip this backend) so the benchmark treats this combination as
an unsupported backend instead of letting trtllm_fmha_v2_prefill raise. Ensure
you reference the existing variables q, k, v and the call to
trtllm_fmha_v2_prefill when inserting the guard.

@jimmyzho jimmyzho mentioned this pull request Apr 7, 2026
4 tasks
@jimmyzho jimmyzho changed the title Draft: add fmhav2 to flashinfer_benchmark and fix cuda graph violations Add fmhav2 to flashinfer_benchmark and fix cuda graph violations Apr 13, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
benchmarks/routines/attention.py (2)

1947-1964: ⚠️ Potential issue | 🟠 Major

Guard the ragged FMHA-v2 path against FP8 e4m3 queries.

Line 1950 calls trtllm_fmha_v2_prefill(..., input_layout="SEPARATE_Q_K_V"), but that helper explicitly raises when query.dtype == torch.float8_e4m3fn for this layout. Right now --q_dtype fp8_e4m3 still reaches this branch and turns an unsupported combo into a runtime failure instead of a backend skip.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1947 - 1964, The
trtllm-fmha-v2 ragged path currently calls trtllm_fmha_v2_prefill with
input_layout="SEPARATE_Q_K_V" even when the query tensor (q) uses FP8 e4m3
(torch.float8_e4m3fn), which that helper explicitly forbids; add a guard before
selecting the "trtllm-fmha-v2" branch that checks q.dtype (or q) and if it
equals torch.float8_e4m3fn, avoid this backend (e.g., skip/return fallback)
instead of calling trtllm_fmha_v2_prefill so the unsupported combo never reaches
the helper and causes a runtime failure. Ensure the check references backend ==
"trtllm-fmha-v2", q/q.dtype, and trtllm_fmha_v2_prefill to make the change easy
to locate.

1333-1351: ⚠️ Potential issue | 🟠 Major

Still skip FMHA-v2 for NVFP4 paged KV.

At Line 1337 _fmha_v2_kv_cache can already be the NVFP4-packed cache, but this call only forwards qkv plus scalar BMM scales. There is still no kv_cache_sf / block-scale argument here, so --kv_dtype nvfp4 remains an unsupported or wrong-math benchmark configuration instead of a clean backend skip.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1333 - 1351, The
trtllm-fmha-v2 prefill call can receive NVFP4-packed KV in _fmha_v2_kv_cache but
does not pass a kv cache scale/block-scale argument, so --kv_dtype nvfp4 yields
wrong math instead of a clean skip; update the trtllm_fmha_v2_prefill call site
in the branch that returns trtllm_fmha_v2_prefill to either (A) pass the kv
cache scale parameter (e.g., kv_cache_sf or block_scale) taken from the existing
NVFP4 packing logic into the call so the kernel knows the per-block scale,
referencing _fmha_v2_kv_cache and _fmha_v2_bmm2_scale symbols, or (B) detect
NVFP4 paged KV (the same condition that produces the packed cache) and
explicitly skip this backend by returning a clear error/skip instead of invoking
trtllm_fmha_v2_prefill; implement the chosen fix consistently with how other
backends handle kv_cache_sf/block-scale.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/attention.py`:
- Around line 1333-1351: The prefill path for backend "trtllm-fmha-v2" is
passing the raw 2-D block_tables into trtllm_fmha_v2_prefill which internally
calls torch.stack(...).contiguous(), causing allocations during CUDA graph
capture; to fix, pre-expand/stabilize the paged tables outside the
timed/captured call (e.g., build the Q_PAGED_KV_* expanded contiguous buffer
once and pass that buffer to trtllm_fmha_v2_prefill instead of block_tables) or,
as a temporary fallback, disable CUDA graph usage for the "trtllm-fmha-v2" path
until trtllm_fmha_v2_prefill is changed to accept a reusable expanded buffer;
reference trtllm_fmha_v2_prefill, block_tables, and the internal
torch.stack(...).contiguous() expansion when making the change.

---

Duplicate comments:
In `@benchmarks/routines/attention.py`:
- Around line 1947-1964: The trtllm-fmha-v2 ragged path currently calls
trtllm_fmha_v2_prefill with input_layout="SEPARATE_Q_K_V" even when the query
tensor (q) uses FP8 e4m3 (torch.float8_e4m3fn), which that helper explicitly
forbids; add a guard before selecting the "trtllm-fmha-v2" branch that checks
q.dtype (or q) and if it equals torch.float8_e4m3fn, avoid this backend (e.g.,
skip/return fallback) instead of calling trtllm_fmha_v2_prefill so the
unsupported combo never reaches the helper and causes a runtime failure. Ensure
the check references backend == "trtllm-fmha-v2", q/q.dtype, and
trtllm_fmha_v2_prefill to make the change easy to locate.
- Around line 1333-1351: The trtllm-fmha-v2 prefill call can receive
NVFP4-packed KV in _fmha_v2_kv_cache but does not pass a kv cache
scale/block-scale argument, so --kv_dtype nvfp4 yields wrong math instead of a
clean skip; update the trtllm_fmha_v2_prefill call site in the branch that
returns trtllm_fmha_v2_prefill to either (A) pass the kv cache scale parameter
(e.g., kv_cache_sf or block_scale) taken from the existing NVFP4 packing logic
into the call so the kernel knows the per-block scale, referencing
_fmha_v2_kv_cache and _fmha_v2_bmm2_scale symbols, or (B) detect NVFP4 paged KV
(the same condition that produces the packed cache) and explicitly skip this
backend by returning a clear error/skip instead of invoking
trtllm_fmha_v2_prefill; implement the chosen fix consistently with how other
backends handle kv_cache_sf/block-scale.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 35e17035-1e9a-4592-b54e-698e585e7b5b

📥 Commits

Reviewing files that changed from the base of the PR and between 55277ea and a8b9ef8.

📒 Files selected for processing (3)
  • benchmarks/routines/attention.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • flashinfer/prefill.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • flashinfer/prefill.py

Comment on lines +1333 to +1351
elif backend == "trtllm-fmha-v2":
_q_scale = q_scale if q_scale is not None else 1.0
_k_scale = k_scale if k_scale is not None else 1.0
return trtllm_fmha_v2_prefill(
qkv=(q, _fmha_v2_kv_cache),
input_layout="Q_PAGED_KV_HND",
workspace_buffer=workspace_buffer,
seq_lens=actual_seq_lens_kv_device.flatten(),
max_q_len=s_qo,
max_kv_len=s_kv,
bmm1_scale=_q_scale * _k_scale * scale,
bmm2_scale=_fmha_v2_bmm2_scale,
batch_size=batch_size,
cum_seq_lens_q=qo_indptr,
cum_seq_lens_kv=kv_token_indptr,
block_tables=block_tables,
mask_mode="causal" if causal else "padding",
out_dtype=o_data_type,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The paged FMHA-v2 path still allocates during graph capture.

At Line 1348 this call passes the raw 2-D block_tables. trtllm_fmha_v2_prefill() currently expands paged tables with a fresh torch.stack(...).contiguous() for Q_PAGED_KV_*, so this benchmark still does an allocation inside the captured callable. Because CUDA graphs stay enabled for this backend, that reintroduces the graph-safety problem this PR is trying to remove. Please either pre-expand/capture-stabilize the table outside the timed call or keep CUDA graphs disabled for trtllm-fmha-v2 until the helper can consume a reusable expanded buffer.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1333 - 1351, The prefill path
for backend "trtllm-fmha-v2" is passing the raw 2-D block_tables into
trtllm_fmha_v2_prefill which internally calls torch.stack(...).contiguous(),
causing allocations during CUDA graph capture; to fix, pre-expand/stabilize the
paged tables outside the timed/captured call (e.g., build the Q_PAGED_KV_*
expanded contiguous buffer once and pass that buffer to trtllm_fmha_v2_prefill
instead of block_tables) or, as a temporary fallback, disable CUDA graph usage
for the "trtllm-fmha-v2" path until trtllm_fmha_v2_prefill is changed to accept
a reusable expanded buffer; reference trtllm_fmha_v2_prefill, block_tables, and
the internal torch.stack(...).contiguous() expansion when making the change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants