Add fmhav2 to flashinfer_benchmark and fix cuda graph violations#2841
Add fmhav2 to flashinfer_benchmark and fix cuda graph violations#2841jimmyzho wants to merge 5 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds trtllm-fmha-v2 prefill backend support, passes bmm2 scale as scalar-or-tensor via a Variant/Union through Python→TVM→CUDA, distinguishes page vs token cumulative KV indirections, ensures contiguous KV cache for trtllm-fmha-v2, and re-enables the FMHA v2 prefill tests with adjusted skips. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python Prefill
participant TVM as TVM FFI
participant CUDA as fmha_v2_run
Py->>TVM: module.run(..., bmm2_scale: float OR torch.Tensor, kv_cache, cum_seq_lens_q, cum_seq_lens_kv, input_layout)
TVM->>TVM: wrap bmm2_scale into Variant<double, Tensor>
TVM->>CUDA: fmha_v2_run(Variant scale_bmm2, other params)
CUDA->>CUDA: if Variant is double -> use scalar value
CUDA->>CUDA: if Variant is Tensor -> use device ptr + scalar value
CUDA-->>Py: attention outputs, LSE updates, KV cache updates
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on integrating the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for the trtllm-fmha-v2 backend to the FlashInfer benchmark suite and addresses some CUDA graph violations. The changes span across the Python benchmark scripts, the C++ bindings to support a variant type for scale_bmm2, and the core prefill logic. My review identified a potential bug where CUDA graph compatibility checks are inconsistent across different test wrappers, and some commented-out code that should be removed for clarity. Since this is a draft, I've also noted that several test cases are currently disabled, which should be re-enabled and verified before merging.
| sleep_after_run=False, | ||
| enable_cupti=args.use_cupti, | ||
| use_cuda_graph=(is_cuda_graph_compatible and cur_backend != "fa2"), | ||
| use_cuda_graph=(is_cuda_graph_compatible and cur_backend not in ("fa2")), |
There was a problem hiding this comment.
The logic for use_cuda_graph appears to be inconsistent here. In other parts of this file (e.g., lines 682-686 and 1973-1977), you are excluding "trtllm-fmha-v2", "trtllm-native", and "cudnn-native" from CUDA graph usage, which is likely correct given the PR's goal to fix graph violations. However, this line only excludes "fa2". This could lead to CUDA graph errors for the other backends within this specific test wrapper. Please ensure this check is consistent with the others.
use_cuda_graph=(
is_cuda_graph_compatible
and cur_backend
not in ("fa2", "trtllm-fmha-v2", "trtllm-native", "cudnn-native")
),| # scale_bmm2_d = _create_scale_bmm2_d_tensor(scale_bmm2, query.dtype, query.device) | ||
| # scale_bmm2_d = torch.tensor([scale_bmm2], dtype=torch.float32, device=query.device) | ||
| # if isinstance(bmm2_scale, float): | ||
| # bmm2_scale_d = torch.tensor([bmm2_scale], dtype=torch.float32, device=query.device) | ||
| # else: | ||
| # bmm2_scale_d = bmm2_scale.to(device=query.device, dtype=torch.float32) | ||
| # scale_bmm2 = bmm2_scale_d.view(torch.float32) | ||
| # bmm2_scale_d = torch.empty(1, dtype=torch.float32, device=query.device) | ||
| # bmm2_scale_d.fill_(bmm2_scale) |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
benchmarks/routines/attention.py (1)
1354-1360:⚠️ Potential issue | 🟠 Major
trtllm-fmha-v2is still being CUDA-graphed in paged prefill.This loop only disables capture for
fa2, while the decode, ragged, and MLA benchmark loops all exempttrtllm-fmha-v2,trtllm-native, andcudnn-native. The new backend will still go through capture here and can hit the same graph-violation path you are trying to avoid.Proposed fix
- use_cuda_graph=(is_cuda_graph_compatible and cur_backend not in ("fa2")), + use_cuda_graph=( + is_cuda_graph_compatible + and cur_backend + not in ("fa2", "trtllm-fmha-v2", "trtllm-native", "cudnn-native") + ),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 1354 - 1360, The CUDA-graph capture is only being disabled for "fa2" here causing "trtllm-fmha-v2" (and other GPUs) to still be graphed; update the condition passed to bench_gpu_time in the backend_times assignment so use_cuda_graph is False for the same set of backends used elsewhere. Specifically, change the use_cuda_graph expression that currently checks (is_cuda_graph_compatible and cur_backend not in ("fa2")) to exclude "trtllm-fmha-v2", "trtllm-native", and "cudnn-native" as well (e.g., cur_backend not in ("fa2","trtllm-fmha-v2","trtllm-native","cudnn-native")), leaving bench_gpu_time, run_backend_wrapper, is_cuda_graph_compatible and cur_backend unchanged.
🧹 Nitpick comments (1)
tests/attention/test_fmha_v2_prefill.py (1)
797-804: Prefer marked params over commenting cases out.Commenting these tuples out removes the unstable combinations from CI entirely, and every surviving call still passes
bmm2_scaleas a Python float. Converting the disabled cases topytest.param(..., marks=skip/xfail)would keep the gaps explicit and make it easier to add a tensor-scale regression case later.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_fmha_v2_prefill.py` around lines 797 - 804, Replace the commented-out parameter tuples in tests/attention/test_fmha_v2_prefill.py (e.g., ("PACKED_QKV", None, False), ("CONTIGUOUS_Q_KV", None, False), ("CONTIGUOUS_Q_KV", None, True), ("SEPARATE_Q_K_V", None, False), ("Q_PAGED_KV_NHD", 128, False), ("Q_PAGED_KV_HND", 128, False)) with pytest.param entries that include pytest.mark.skip or pytest.mark.xfail as appropriate so the combinations remain explicit in the param list; keep the same tuple values but wrap them with pytest.param(..., marks=pytest.mark.skip(reason="unstable")) (or xfail) so CI knows they are intentionally disabled and can be re-enabled or converted to tensor-scale regression cases later. Ensure these new pytest.param entries are added into the same parameterization target (the test's param list that currently contains ("Q_PAGED_KV_NHD", 32, False) etc.) so the test harness still iterates over all cases.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/attention.py`:
- Around line 1158-1161: The code always materializes a contiguous HND copy via
kv_cache.contiguous() into _fmha_v2_kv_cache which forces a full KV allocation
for all benchmarks; change this to only create the contiguous copy when the
trtllm-fmha-v2 backend is actually used (i.e., check if 'trtllm-fmha-v2' is in
backends) and otherwise keep _fmha_v2_kv_cache as the original kv_cache (or
defer copying), so only trtllm-fmha-v2 triggers the extra allocation.
- Around line 1916-1932: The trtllm-fmha-v2 branch does not pass the validated
out_dtype into trtllm_fmha_v2_prefill, so FP8/output-type is ignored; update the
call in the backend == "trtllm-fmha-v2" branch to forward the validated
out_dtype (the same variable recorded earlier) by adding out_dtype=out_dtype to
the trtllm_fmha_v2_prefill invocation so its output dtype honors the requested
--out_dtype setting.
In `@benchmarks/routines/flashinfer_benchmark_utils.py`:
- Around line 307-310: The advertised backends list includes "trtllm-fmha-v2"
for CUDA SM "9.0" and "12.0" without checking CUDA toolkit version; update the
benchmark backend selection to gate FMHAv2 behind the runtime helpers: call
is_sm90a_supported() before including "trtllm-fmha-v2" for the "9.0" entry and
is_sm120a_supported() before including it for the "12.0" entry (or perform
equivalent CUDA-version checks inside the benchmark filter function that
constructs the backend list) so that FMHAv2 is only advertised when the CUDA
version meets the runtime requirements.
In `@flashinfer/prefill.py`:
- Around line 3955-3956: The code must preserve the old split representation for
BMM2 scale so the C++ FMHAv2 path receives both a host scalar and an encoded
device buffer: when handling the parameter bmm2_scale (formerly via
_create_scale_bmm2_d_tensor()), ensure you provide a float host scalar (used by
set_alpha / params.scale_bmm2) and a device-side packed uint32_t tensor/buffer
(used by params.scale_bmm2_d) before calling module.run(); if bmm2_scale is a
torch.Tensor, extract/convert its scalar value for the host and separately
encode/pack the device representation (matching the previous
_create_scale_bmm2_d_tensor behavior), and if it’s a Python float keep both
representations consistent—update the code paths that construct
params.scale_bmm2 and params.scale_bmm2_d (and any callers of set_alpha and
module.run()) to use these two distinct values.
---
Outside diff comments:
In `@benchmarks/routines/attention.py`:
- Around line 1354-1360: The CUDA-graph capture is only being disabled for "fa2"
here causing "trtllm-fmha-v2" (and other GPUs) to still be graphed; update the
condition passed to bench_gpu_time in the backend_times assignment so
use_cuda_graph is False for the same set of backends used elsewhere.
Specifically, change the use_cuda_graph expression that currently checks
(is_cuda_graph_compatible and cur_backend not in ("fa2")) to exclude
"trtllm-fmha-v2", "trtllm-native", and "cudnn-native" as well (e.g., cur_backend
not in ("fa2","trtllm-fmha-v2","trtllm-native","cudnn-native")), leaving
bench_gpu_time, run_backend_wrapper, is_cuda_graph_compatible and cur_backend
unchanged.
---
Nitpick comments:
In `@tests/attention/test_fmha_v2_prefill.py`:
- Around line 797-804: Replace the commented-out parameter tuples in
tests/attention/test_fmha_v2_prefill.py (e.g., ("PACKED_QKV", None, False),
("CONTIGUOUS_Q_KV", None, False), ("CONTIGUOUS_Q_KV", None, True),
("SEPARATE_Q_K_V", None, False), ("Q_PAGED_KV_NHD", 128, False),
("Q_PAGED_KV_HND", 128, False)) with pytest.param entries that include
pytest.mark.skip or pytest.mark.xfail as appropriate so the combinations remain
explicit in the param list; keep the same tuple values but wrap them with
pytest.param(..., marks=pytest.mark.skip(reason="unstable")) (or xfail) so CI
knows they are intentionally disabled and can be re-enabled or converted to
tensor-scale regression cases later. Ensure these new pytest.param entries are
added into the same parameterization target (the test's param list that
currently contains ("Q_PAGED_KV_NHD", 32, False) etc.) so the test harness still
iterates over all cases.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 258e53c5-6814-4444-b29e-7e3dc775097a
📒 Files selected for processing (6)
benchmarks/routines/attention.pybenchmarks/routines/flashinfer_benchmark_utils.pycsrc/fmha_v2_jit_binding.cucsrc/fmha_v2_run.cuflashinfer/prefill.pytests/attention/test_fmha_v2_prefill.py
| bmm1_scale: float, | ||
| bmm2_scale: float, | ||
| bmm2_scale: Union[float, torch.Tensor], |
There was a problem hiding this comment.
Keep scale_bmm2's scalar and device-buffer representations separate.
This change drops _create_scale_bmm2_d_tensor() and forwards bmm2_scale directly, but the FMHAv2 C++ path still needs both halves of the old contract: set_alpha(params.scale_bmm2, ...) consumes the scalar value, and params.scale_bmm2_d expects the prepacked device bits. With a tensor input, the C++ side now just reinterprets the tensor pointer as uint32_t* and the scalar side falls back to 0.f, so FP16/BF16 tensor scales will program the wrong BMM2 scale. Please keep the old split API, or reconstruct both a host scalar and an encoded device tensor before module.run().
Also applies to: 4218-4258
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/prefill.py` around lines 3955 - 3956, The code must preserve the
old split representation for BMM2 scale so the C++ FMHAv2 path receives both a
host scalar and an encoded device buffer: when handling the parameter bmm2_scale
(formerly via _create_scale_bmm2_d_tensor()), ensure you provide a float host
scalar (used by set_alpha / params.scale_bmm2) and a device-side packed uint32_t
tensor/buffer (used by params.scale_bmm2_d) before calling module.run(); if
bmm2_scale is a torch.Tensor, extract/convert its scalar value for the host and
separately encode/pack the device representation (matching the previous
_create_scale_bmm2_d_tensor behavior), and if it’s a Python float keep both
representations consistent—update the code paths that construct
params.scale_bmm2 and params.scale_bmm2_d (and any callers of set_alpha and
module.run()) to use these two distinct values.
There was a problem hiding this comment.
♻️ Duplicate comments (2)
benchmarks/routines/attention.py (2)
1912-1928:⚠️ Potential issue | 🟠 MajorPass
out_dtypethrough the ragged FMHAv2 benchmark path.This branch validates
--out_dtypeup front but drops it here. FP8 runs will therefore benchmarktrtllm_fmha_v2_prefillwith its default output dtype instead of the requested one, which skews both output comparisons and bandwidth numbers.♻️ Proposed fix
elif backend == "trtllm-fmha-v2": _q_scale = q_scale if q_scale is not None else 1.0 _k_scale = k_scale if k_scale is not None else 1.0 return trtllm_fmha_v2_prefill( qkv=(q, k, v), input_layout="SEPARATE_Q_K_V", workspace_buffer=workspace_buffer, seq_lens=actual_seq_lens_kv_device.flatten(), max_q_len=s_qo, max_kv_len=s_kv, bmm1_scale=_q_scale * _k_scale * scale, bmm2_scale=_fmha_v2_bmm2_scale, batch_size=batch_size, cum_seq_lens_q=qo_indptr, cum_seq_lens_kv=kv_indptr, + out_dtype=out_dtype, mask_mode="causal" if causal else "padding", )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 1912 - 1928, The trtllm-fmha-v2 branch drops the validated out_dtype so trtllm_fmha_v2_prefill runs with its default; update the call in the elif backend == "trtllm-fmha-v2" branch to pass the out_dtype parameter (the same out_dtype validated earlier) into trtllm_fmha_v2_prefill so FP8/other requested output dtypes are honored—look for the trtllm_fmha_v2_prefill invocation and add the out_dtype argument alongside the existing arguments (e.g., bmm1_scale, bmm2_scale, batch_size, etc.).
1149-1157:⚠️ Potential issue | 🟠 MajorGate the FMHAv2 HND copy on backend selection.
kv_cache.contiguous()still runs unconditionally here, so every paged-prefill benchmark pays for a full extra KV allocation even whentrtllm-fmha-v2is not being timed. On larger cases that can distort memory pressure or OOM unrelated backends.♻️ Proposed fix
- _fmha_v2_kv_cache = kv_cache.contiguous() + _fmha_v2_kv_cache = ( + kv_cache.contiguous() if "trtllm-fmha-v2" in backends else None + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 1149 - 1157, The unconditional contiguous copy causes extra KV allocation for all backends; wrap the contiguous call so _fmha_v2_kv_cache = kv_cache.contiguous() only when the selected backend equals "trtllm-fmha-v2" (leave _fmha_v2_kv_cache as kv_cache otherwise). Update the code around the existing _fmha_v2_kv_cache and _fmha_v2_bmm2_scale assignments to branch on the backend selection used in this routine (e.g., compare backend or impl string to "trtllm-fmha-v2") so only that backend pays the copy.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@benchmarks/routines/attention.py`:
- Around line 1912-1928: The trtllm-fmha-v2 branch drops the validated out_dtype
so trtllm_fmha_v2_prefill runs with its default; update the call in the elif
backend == "trtllm-fmha-v2" branch to pass the out_dtype parameter (the same
out_dtype validated earlier) into trtllm_fmha_v2_prefill so FP8/other requested
output dtypes are honored—look for the trtllm_fmha_v2_prefill invocation and add
the out_dtype argument alongside the existing arguments (e.g., bmm1_scale,
bmm2_scale, batch_size, etc.).
- Around line 1149-1157: The unconditional contiguous copy causes extra KV
allocation for all backends; wrap the contiguous call so _fmha_v2_kv_cache =
kv_cache.contiguous() only when the selected backend equals "trtllm-fmha-v2"
(leave _fmha_v2_kv_cache as kv_cache otherwise). Update the code around the
existing _fmha_v2_kv_cache and _fmha_v2_bmm2_scale assignments to branch on the
backend selection used in this routine (e.g., compare backend or impl string to
"trtllm-fmha-v2") so only that backend pays the copy.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 134a09e5-3fbc-4c60-adc6-c344113ccce5
📒 Files selected for processing (2)
benchmarks/routines/attention.pyflashinfer/prefill.py
There was a problem hiding this comment.
♻️ Duplicate comments (4)
flashinfer/prefill.py (1)
4104-4104:⚠️ Potential issue | 🟠 MajorTighten the
bmm2_scaletensor contract.
csrc/fmha_v2_run.custill treats the tensor variant as the prepackedscale_bmm2_dbuffer and just forwardsdata_ptr(). With this change, an ordinarytorch.Tensornow bypasses_create_scale_bmm2_d_tensor()and reaches the kernel as raw floats, which is misencoded for the FP16/BF16 path. Either keep this API float-only, or validate/document that the tensor form must already be the packed device buffer.Also applies to: 4385-4394
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` at line 4104, The bmm2_scale parameter is being accepted as Union[float, torch.Tensor] but the CUDA kernel (csrc/fmha_v2_run.cu / fmha_v2_run) expects the tensor form to be the already-prepacked device buffer (created by _create_scale_bmm2_d_tensor) and currently just forwards data_ptr(), so ordinary torch.Tensors get misencoded for FP16/BF16; fix by tightening the contract in prefill.py: either change the signature to accept only float (remove torch.Tensor variant) or add a runtime check in the caller/handler that if bmm2_scale is a torch.Tensor it is on CUDA, has the exact shape/dtype/strides of the prepacked scale buffer (or a flag indicating it was created via _create_scale_bmm2_d_tensor), otherwise raise a clear ValueError instructing callers to use _create_scale_bmm2_d_tensor to produce the packed device buffer; reference bmm2_scale, _create_scale_bmm2_d_tensor, and the fmha_v2_run kernel when implementing the validation.benchmarks/routines/flashinfer_benchmark_utils.py (1)
304-313:⚠️ Potential issue | 🟠 MajorGate
trtllm-fmha-v2on runtime support, not just compute capability.
filter_backends_by_compute_capability()only consults this CC table, so these unconditional9.0/12.0entries make the benchmark advertise FMHAv2 even on toolchains that the runtime backend will later reject. Please apply the same runtime support helper the backend uses instead of listingtrtllm-fmha-v2here unconditionally.Also applies to: 317-326
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/flashinfer_benchmark_utils.py` around lines 304 - 313, The compute-capability mapping currently lists "trtllm-fmha-v2" unconditionally for "9.0" and "12.0", causing filter_backends_by_compute_capability() to advertise FMHAv2 even when the runtime will later reject it; update this by removing the unconditional "trtllm-fmha-v2" entries from the CC map and instead add "trtllm-fmha-v2" only after checking the same runtime-support helper the backend uses (call that helper where filter_backends_by_compute_capability is used, e.g., runtime_supports_trtllm_fmha_v2 or the backend's support check) so the backend is added only when both CC and runtime support are true; make the same change for the duplicate block (the other mapping around the referenced lines) so both places gate "trtllm-fmha-v2" on runtime support rather than listing it unconditionally.benchmarks/routines/attention.py (2)
1179-1182:⚠️ Potential issue | 🟠 MajorAvoid materializing the HND KV copy unconditionally.
This
contiguous()runs regardless of whethertrtllm-fmha-v2is actually inbackends. For large KV caches, this doubles memory usage unnecessarily when only benchmarking other backends.Proposed fix
- _fmha_v2_kv_cache = kv_cache.contiguous() + _fmha_v2_kv_cache = ( + kv_cache.contiguous() if "trtllm-fmha-v2" in backends else None + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 1179 - 1182, The code unconditionally materializes a contiguous HND copy with _fmha_v2_kv_cache = kv_cache.contiguous(), which doubles memory even when the trtllm-fmha-v2 backend isn't used; change this to only create the contiguous copy when "trtllm-fmha-v2" is present in backends (e.g. if "trtllm-fmha-v2" in backends: _fmha_v2_kv_cache = kv_cache.contiguous() else: _fmha_v2_kv_cache = kv_cache) so other backends avoid the extra allocation and existing code using _fmha_v2_kv_cache continues to work.
1943-1959:⚠️ Potential issue | 🟠 MajorHonor
--out_dtypein the ragged FMHA-v2 path.The routine validates and uses
out_dtypefor other backends (e.g.,cudnn-nativeat line 1917,trtllm-nativevia pre-allocated output at line 1871), but thetrtllm-fmha-v2branch doesn't pass it through. FP8 benchmarks will produce the wrong output dtype and yield incorrect bandwidth comparisons versus other backends.Proposed fix
elif backend == "trtllm-fmha-v2": _q_scale = q_scale if q_scale is not None else 1.0 _k_scale = k_scale if k_scale is not None else 1.0 return trtllm_fmha_v2_prefill( qkv=(q, k, v), input_layout="SEPARATE_Q_K_V", workspace_buffer=workspace_buffer, seq_lens=actual_seq_lens_kv_device.flatten(), max_q_len=s_qo, max_kv_len=s_kv, bmm1_scale=_q_scale * _k_scale * scale, bmm2_scale=_fmha_v2_bmm2_scale, batch_size=batch_size, cum_seq_lens_q=qo_indptr, cum_seq_lens_kv=kv_indptr, + out_dtype=out_dtype, mask_mode="causal" if causal else "padding", )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 1943 - 1959, The trtllm-fmha-v2 branch is not honoring the out_dtype variable when calling trtllm_fmha_v2_prefill, causing FP8 outputs to be produced in the wrong dtype; modify the trtllm-fmha-v2 branch to pass the validated out_dtype into the trtllm_fmha_v2_prefill call (and ensure any pre-allocated output buffers used for that path are created/typed using out_dtype), matching how out_dtype is handled in the cudnn-native and trtllm-native branches so benchmark outputs and bandwidth calculations are consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@benchmarks/routines/attention.py`:
- Around line 1179-1182: The code unconditionally materializes a contiguous HND
copy with _fmha_v2_kv_cache = kv_cache.contiguous(), which doubles memory even
when the trtllm-fmha-v2 backend isn't used; change this to only create the
contiguous copy when "trtllm-fmha-v2" is present in backends (e.g. if
"trtllm-fmha-v2" in backends: _fmha_v2_kv_cache = kv_cache.contiguous() else:
_fmha_v2_kv_cache = kv_cache) so other backends avoid the extra allocation and
existing code using _fmha_v2_kv_cache continues to work.
- Around line 1943-1959: The trtllm-fmha-v2 branch is not honoring the out_dtype
variable when calling trtllm_fmha_v2_prefill, causing FP8 outputs to be produced
in the wrong dtype; modify the trtllm-fmha-v2 branch to pass the validated
out_dtype into the trtllm_fmha_v2_prefill call (and ensure any pre-allocated
output buffers used for that path are created/typed using out_dtype), matching
how out_dtype is handled in the cudnn-native and trtllm-native branches so
benchmark outputs and bandwidth calculations are consistent.
In `@benchmarks/routines/flashinfer_benchmark_utils.py`:
- Around line 304-313: The compute-capability mapping currently lists
"trtllm-fmha-v2" unconditionally for "9.0" and "12.0", causing
filter_backends_by_compute_capability() to advertise FMHAv2 even when the
runtime will later reject it; update this by removing the unconditional
"trtllm-fmha-v2" entries from the CC map and instead add "trtllm-fmha-v2" only
after checking the same runtime-support helper the backend uses (call that
helper where filter_backends_by_compute_capability is used, e.g.,
runtime_supports_trtllm_fmha_v2 or the backend's support check) so the backend
is added only when both CC and runtime support are true; make the same change
for the duplicate block (the other mapping around the referenced lines) so both
places gate "trtllm-fmha-v2" on runtime support rather than listing it
unconditionally.
In `@flashinfer/prefill.py`:
- Line 4104: The bmm2_scale parameter is being accepted as Union[float,
torch.Tensor] but the CUDA kernel (csrc/fmha_v2_run.cu / fmha_v2_run) expects
the tensor form to be the already-prepacked device buffer (created by
_create_scale_bmm2_d_tensor) and currently just forwards data_ptr(), so ordinary
torch.Tensors get misencoded for FP16/BF16; fix by tightening the contract in
prefill.py: either change the signature to accept only float (remove
torch.Tensor variant) or add a runtime check in the caller/handler that if
bmm2_scale is a torch.Tensor it is on CUDA, has the exact shape/dtype/strides of
the prepacked scale buffer (or a flag indicating it was created via
_create_scale_bmm2_d_tensor), otherwise raise a clear ValueError instructing
callers to use _create_scale_bmm2_d_tensor to produce the packed device buffer;
reference bmm2_scale, _create_scale_bmm2_d_tensor, and the fmha_v2_run kernel
when implementing the validation.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: fc4fcf24-25ba-4c50-ab5a-035efbdc3060
📒 Files selected for processing (3)
benchmarks/routines/attention.pybenchmarks/routines/flashinfer_benchmark_utils.pyflashinfer/prefill.py
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/attention.py`:
- Around line 1947-1964: The trtllm-fmha-v2 branch must filter out FP8 query
tensors before calling trtllm_fmha_v2_prefill because that function rejects
torch.float8_e4m3fn and currently causes a runtime ValueError; add a guard at
the start of the backend == "trtllm-fmha-v2" block that checks q.dtype (and/or
k.dtype/v.dtype as appropriate) for torch.float8_e4m3fn and short-circuits the
branch (e.g., return None or otherwise skip this backend) so the benchmark
treats this combination as an unsupported backend instead of letting
trtllm_fmha_v2_prefill raise. Ensure you reference the existing variables q, k,
v and the call to trtllm_fmha_v2_prefill when inserting the guard.
- Around line 1333-1351: The FMHA-v2 paged branch calls trtllm_fmha_v2_prefill
but omits kv_block_scales, so NVFP4 KV runs without dequant metadata; fix by
forwarding the existing kv_block_scales into the trtllm_fmha_v2_prefill call
(alongside qkv=(q, _fmha_v2_kv_cache)) or alternatively short-circuit this
branch when kv_dtype == "nvfp4" to use the trtllm-native path; update the
trtllm_fmha_v2_prefill invocation to include kv_block_scales=kv_block_scales (or
add the guard checking kv_dtype) so NVFP4 metadata is preserved.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: dfefc35f-79cf-4186-be7e-6285a1aac93e
📒 Files selected for processing (2)
benchmarks/routines/attention.pytests/attention/test_fmha_v2_prefill.py
| elif backend == "trtllm-fmha-v2": | ||
| _q_scale = q_scale if q_scale is not None else 1.0 | ||
| _k_scale = k_scale if k_scale is not None else 1.0 | ||
| return trtllm_fmha_v2_prefill( | ||
| qkv=(q, k, v), | ||
| input_layout="SEPARATE_Q_K_V", | ||
| workspace_buffer=workspace_buffer, | ||
| seq_lens=actual_seq_lens_kv_device.flatten(), | ||
| max_q_len=s_qo, | ||
| max_kv_len=s_kv, | ||
| bmm1_scale=_q_scale * _k_scale * scale, | ||
| bmm2_scale=_fmha_v2_bmm2_scale, | ||
| batch_size=batch_size, | ||
| cum_seq_lens_q=qo_indptr, | ||
| cum_seq_lens_kv=kv_indptr, | ||
| mask_mode="causal" if causal else "padding", | ||
| out_dtype=out_dtype, | ||
| ) |
There was a problem hiding this comment.
Filter out FP8 queries before the ragged FMHA-v2 call.
trtllm_fmha_v2_prefill(input_layout="SEPARATE_Q_K_V") explicitly rejects torch.float8_e4m3fn queries, but this routine still allows that configuration to reach the new branch. That turns a benchmark flag combination into a runtime ValueError instead of a clean backend skip.
🛡️ Suggested guard
if "trtllm-native" in backends:
remove_trtllm_native = False
if not (head_dim_qk == 192 and head_dim_vo == 128) and not (
head_dim_qk == 128 and head_dim_vo == 128
):
print(
"[INFO] trtllm-native backend requires head_dim_qk == 192 and head_dim_vo == 128 or head_dim_qk == 128 and head_dim_vo == 128. Skipping."
)
remove_trtllm_native = True
if remove_trtllm_native:
backends.remove("trtllm-native")
+ if "trtllm-fmha-v2" in backends and q_dtype == torch.float8_e4m3fn:
+ print(
+ "[INFO] trtllm-fmha-v2 does not support FP8 query with SEPARATE_Q_K_V. Skipping."
+ )
+ backends.remove("trtllm-fmha-v2")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/routines/attention.py` around lines 1947 - 1964, The
trtllm-fmha-v2 branch must filter out FP8 query tensors before calling
trtllm_fmha_v2_prefill because that function rejects torch.float8_e4m3fn and
currently causes a runtime ValueError; add a guard at the start of the backend
== "trtllm-fmha-v2" block that checks q.dtype (and/or k.dtype/v.dtype as
appropriate) for torch.float8_e4m3fn and short-circuits the branch (e.g., return
None or otherwise skip this backend) so the benchmark treats this combination as
an unsupported backend instead of letting trtllm_fmha_v2_prefill raise. Ensure
you reference the existing variables q, k, v and the call to
trtllm_fmha_v2_prefill when inserting the guard.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
benchmarks/routines/attention.py (2)
1947-1964:⚠️ Potential issue | 🟠 MajorGuard the ragged FMHA-v2 path against FP8 e4m3 queries.
Line 1950 calls
trtllm_fmha_v2_prefill(..., input_layout="SEPARATE_Q_K_V"), but that helper explicitly raises whenquery.dtype == torch.float8_e4m3fnfor this layout. Right now--q_dtype fp8_e4m3still reaches this branch and turns an unsupported combo into a runtime failure instead of a backend skip.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 1947 - 1964, The trtllm-fmha-v2 ragged path currently calls trtllm_fmha_v2_prefill with input_layout="SEPARATE_Q_K_V" even when the query tensor (q) uses FP8 e4m3 (torch.float8_e4m3fn), which that helper explicitly forbids; add a guard before selecting the "trtllm-fmha-v2" branch that checks q.dtype (or q) and if it equals torch.float8_e4m3fn, avoid this backend (e.g., skip/return fallback) instead of calling trtllm_fmha_v2_prefill so the unsupported combo never reaches the helper and causes a runtime failure. Ensure the check references backend == "trtllm-fmha-v2", q/q.dtype, and trtllm_fmha_v2_prefill to make the change easy to locate.
1333-1351:⚠️ Potential issue | 🟠 MajorStill skip FMHA-v2 for NVFP4 paged KV.
At Line 1337
_fmha_v2_kv_cachecan already be the NVFP4-packed cache, but this call only forwardsqkvplus scalar BMM scales. There is still nokv_cache_sf/ block-scale argument here, so--kv_dtype nvfp4remains an unsupported or wrong-math benchmark configuration instead of a clean backend skip.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 1333 - 1351, The trtllm-fmha-v2 prefill call can receive NVFP4-packed KV in _fmha_v2_kv_cache but does not pass a kv cache scale/block-scale argument, so --kv_dtype nvfp4 yields wrong math instead of a clean skip; update the trtllm_fmha_v2_prefill call site in the branch that returns trtllm_fmha_v2_prefill to either (A) pass the kv cache scale parameter (e.g., kv_cache_sf or block_scale) taken from the existing NVFP4 packing logic into the call so the kernel knows the per-block scale, referencing _fmha_v2_kv_cache and _fmha_v2_bmm2_scale symbols, or (B) detect NVFP4 paged KV (the same condition that produces the packed cache) and explicitly skip this backend by returning a clear error/skip instead of invoking trtllm_fmha_v2_prefill; implement the chosen fix consistently with how other backends handle kv_cache_sf/block-scale.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/attention.py`:
- Around line 1333-1351: The prefill path for backend "trtllm-fmha-v2" is
passing the raw 2-D block_tables into trtllm_fmha_v2_prefill which internally
calls torch.stack(...).contiguous(), causing allocations during CUDA graph
capture; to fix, pre-expand/stabilize the paged tables outside the
timed/captured call (e.g., build the Q_PAGED_KV_* expanded contiguous buffer
once and pass that buffer to trtllm_fmha_v2_prefill instead of block_tables) or,
as a temporary fallback, disable CUDA graph usage for the "trtllm-fmha-v2" path
until trtllm_fmha_v2_prefill is changed to accept a reusable expanded buffer;
reference trtllm_fmha_v2_prefill, block_tables, and the internal
torch.stack(...).contiguous() expansion when making the change.
---
Duplicate comments:
In `@benchmarks/routines/attention.py`:
- Around line 1947-1964: The trtllm-fmha-v2 ragged path currently calls
trtllm_fmha_v2_prefill with input_layout="SEPARATE_Q_K_V" even when the query
tensor (q) uses FP8 e4m3 (torch.float8_e4m3fn), which that helper explicitly
forbids; add a guard before selecting the "trtllm-fmha-v2" branch that checks
q.dtype (or q) and if it equals torch.float8_e4m3fn, avoid this backend (e.g.,
skip/return fallback) instead of calling trtllm_fmha_v2_prefill so the
unsupported combo never reaches the helper and causes a runtime failure. Ensure
the check references backend == "trtllm-fmha-v2", q/q.dtype, and
trtllm_fmha_v2_prefill to make the change easy to locate.
- Around line 1333-1351: The trtllm-fmha-v2 prefill call can receive
NVFP4-packed KV in _fmha_v2_kv_cache but does not pass a kv cache
scale/block-scale argument, so --kv_dtype nvfp4 yields wrong math instead of a
clean skip; update the trtllm_fmha_v2_prefill call site in the branch that
returns trtllm_fmha_v2_prefill to either (A) pass the kv cache scale parameter
(e.g., kv_cache_sf or block_scale) taken from the existing NVFP4 packing logic
into the call so the kernel knows the per-block scale, referencing
_fmha_v2_kv_cache and _fmha_v2_bmm2_scale symbols, or (B) detect NVFP4 paged KV
(the same condition that produces the packed cache) and explicitly skip this
backend by returning a clear error/skip instead of invoking
trtllm_fmha_v2_prefill; implement the chosen fix consistently with how other
backends handle kv_cache_sf/block-scale.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 35e17035-1e9a-4592-b54e-698e585e7b5b
📒 Files selected for processing (3)
benchmarks/routines/attention.pybenchmarks/routines/flashinfer_benchmark_utils.pyflashinfer/prefill.py
🚧 Files skipped from review as they are similar to previous changes (2)
- benchmarks/routines/flashinfer_benchmark_utils.py
- flashinfer/prefill.py
| elif backend == "trtllm-fmha-v2": | ||
| _q_scale = q_scale if q_scale is not None else 1.0 | ||
| _k_scale = k_scale if k_scale is not None else 1.0 | ||
| return trtllm_fmha_v2_prefill( | ||
| qkv=(q, _fmha_v2_kv_cache), | ||
| input_layout="Q_PAGED_KV_HND", | ||
| workspace_buffer=workspace_buffer, | ||
| seq_lens=actual_seq_lens_kv_device.flatten(), | ||
| max_q_len=s_qo, | ||
| max_kv_len=s_kv, | ||
| bmm1_scale=_q_scale * _k_scale * scale, | ||
| bmm2_scale=_fmha_v2_bmm2_scale, | ||
| batch_size=batch_size, | ||
| cum_seq_lens_q=qo_indptr, | ||
| cum_seq_lens_kv=kv_token_indptr, | ||
| block_tables=block_tables, | ||
| mask_mode="causal" if causal else "padding", | ||
| out_dtype=o_data_type, | ||
| ) |
There was a problem hiding this comment.
The paged FMHA-v2 path still allocates during graph capture.
At Line 1348 this call passes the raw 2-D block_tables. trtllm_fmha_v2_prefill() currently expands paged tables with a fresh torch.stack(...).contiguous() for Q_PAGED_KV_*, so this benchmark still does an allocation inside the captured callable. Because CUDA graphs stay enabled for this backend, that reintroduces the graph-safety problem this PR is trying to remove. Please either pre-expand/capture-stabilize the table outside the timed call or keep CUDA graphs disabled for trtllm-fmha-v2 until the helper can consume a reusable expanded buffer.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/routines/attention.py` around lines 1333 - 1351, The prefill path
for backend "trtllm-fmha-v2" is passing the raw 2-D block_tables into
trtllm_fmha_v2_prefill which internally calls torch.stack(...).contiguous(),
causing allocations during CUDA graph capture; to fix, pre-expand/stabilize the
paged tables outside the timed/captured call (e.g., build the Q_PAGED_KV_*
expanded contiguous buffer once and pass that buffer to trtllm_fmha_v2_prefill
instead of block_tables) or, as a temporary fallback, disable CUDA graph usage
for the "trtllm-fmha-v2" path until trtllm_fmha_v2_prefill is changed to accept
a reusable expanded buffer; reference trtllm_fmha_v2_prefill, block_tables, and
the internal torch.stack(...).contiguous() expansion when making the change.
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Tests