Skip to content

[Fmha] support nvfp4 output keepsMmaAb generation kernels#2795

Open
PerkzZheng wants to merge 1 commit intoflashinfer-ai:mainfrom
PerkzZheng:user/perkzz/add-e2m1-output
Open

[Fmha] support nvfp4 output keepsMmaAb generation kernels#2795
PerkzZheng wants to merge 1 commit intoflashinfer-ai:mainfrom
PerkzZheng:user/perkzz/add-e2m1-output

Conversation

@PerkzZheng
Copy link
Contributor

@PerkzZheng PerkzZheng commented Mar 16, 2026

  • Update cubin artifact path/checksum to new build with nvfp4 output support
  • Fix kernel selection: remove E2M1 output dtype condition from mixed-precision path, allowing nvfp4 output to use GQA generation kernel selection heuristics
  • Always invoke selectTileSizeQForGqaGeneration (not just for maxSeqLenQ > 1)
  • Add mUsesSharedPagedKvIdx field to KernelParams for vLLM/FlashInfer paged KV index
  • Remove speculative-decode skip for nvfp4 output in tests
  • Expand test coverage: head_dim [64, 128, 256], additional batch configs

AI-assisted

📌 Description

Qwen3-480B (num_qo_heads=96, num_kv_heads=8, head_dim_qk=128, head_dim_vo=128)

Speedup (baseline / opt)

s_qo bs=8 bs=16 bs=32 bs=40 bs=64
2 1.23x 1.32x 1.26x 1.15x 1.28x
4 2.21x 2.49x 2.16x 1.89x 1.93x
8 3.41x 3.32x 3.00x 2.81x 2.94x

GPT-OSS (num_qo_heads=64, num_kv_heads=8, head_dim_qk=64, head_dim_vo=64)

Speedup (baseline / opt)

s_qo bs=8 bs=16 bs=32 bs=40 bs=64
2 1.65x 1.78x 1.77x 1.52x 1.62x
4 2.50x 2.65x 2.41x 2.06x 2.41x
8 4.79x 5.05x 5.12x 4.45x 4.98x

Summary

Speedup scales strongly with s_qo (speculative decode query length):

  • At s_qo=2: 1.1–1.8x speedup across both models
  • At s_qo=4: 1.9–2.6x speedup
  • At s_qo=8: 2.8–5.1x speedup (peak 5.12x on GPT-OSS, bs=32)

🔍 Related Issues

#2632

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Simplified mixed-precision gating logic in attention kernels, improving type handling consistency.
  • Tests

    • Expanded test coverage for decode operations across multiple head dimensions (64, 128, 256).
    • Removed previously unsupported data type restrictions in speculative decoding scenarios, improving test completeness.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances FlashInfer's FMHA generation kernels by enabling and optimizing support for nvfp4 output, which is crucial for improving performance, particularly in speculative decoding scenarios. The changes involve updating core cubin artifacts, refining the kernel selection mechanism to better leverage GQA generation heuristics for nvfp4, and integrating support for shared paged KV indices. Furthermore, the pull request expands the testing suite to ensure robust validation of these new capabilities across a broader range of configurations.

Highlights

  • NVFP4 Output Support: Updated cubin artifact paths and checksums to include new builds that support nvfp4 output for FMHA generation kernels.
  • Kernel Selection Logic Improvement: Modified the kernel selection logic to remove the E2M1 output dtype condition from the mixed-precision path, allowing nvfp4 output to correctly utilize GQA generation kernel selection heuristics.
  • Consistent GQA Kernel Timing Model: Ensured that the selectTileSizeQForGqaGeneration function is always invoked, removing the previous condition that limited its use to maxSeqLenQ > 1, leading to more consistent kernel optimization.
  • Paged KV Index Support: Introduced a new mUsesSharedPagedKvIdx field in KernelParams to support unified paged KV indices, relevant for systems like vLLM and FlashInfer.
  • Expanded Test Coverage: Removed the speculative-decode skip for nvfp4 output in tests and significantly expanded test coverage by adding head_dim values (64, 128, 256) and additional batch configurations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/artifacts.py
    • Updated TRTLLM_GEN_FMHA artifact path to 13519b4f28fd79771189ef097ca14d6ee554306c/fmha/trtllm-gen/.
    • Updated TRTLLM_GEN_FMHA checksum to 64f1da78b3f41c02a106488e9dd85b36960e8370ac11597fd0a8fcf421bf649f.
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
    • Removed the mDtypeOut == DATA_TYPE_E2M1 condition from the mixed-precision kernel selection logic.
    • Removed the params.mMaxSeqLenQ > 1 condition, ensuring selectTileSizeQForGqaGeneration is always called.
  • include/flashinfer/trtllm/fmha/kernelParams.h
    • Added a new boolean field mUsesSharedPagedKvIdx to the KernelParams struct.
    • Initialized mUsesSharedPagedKvIdx to true in the create_kernel_params function.
  • tests/attention/test_trtllm_gen_attention.py
    • Removed the conditional pytest.skip for nvfp4 output when q_len_per_req > 1 or max_q_len > 1.
    • Expanded the head_dim parameter in _test_trtllm_batch_decode to include 64, 128, and 256.
    • Modified test_trtllm_batch_decode parameters to include head_dim and new batch configurations, and restricted kv_layout to HND.
    • Simplified test_trtllm_batch_decode_spec parameters by setting kv_layout to HND, enable_pdl to False, enable_sink to False, and skips_softmax to False.
Activity
  • The pull request description indicates that it was AI-assisted.
  • The author has included sections for pre-commit checks and tests, suggesting these steps are part of the development workflow.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 16, 2026

📝 Walkthrough

Walkthrough

This PR simplifies the GQA generation kernel selection logic by removing a special-case dtype check and making the tile-size heuristic selection unconditional, while also expanding test coverage for generative attention decoding across multiple head dimensions (64, 128, 256) and removing dtype-specific test skips.

Changes

Cohort / File(s) Summary
Kernel Selection Logic
include/flashinfer/trtllm/fmha/fmhaKernels.cuh
Simplified selectGqGenerationKernel by removing mDtypeOut == DATA_TYPE_E2M1 special-case check (early-return now triggers only when mDtypeQ != mDtypeKv). Changed kernel-timing heuristic invocation from conditional to unconditional, running selectTileSizeQForGqaGeneration regardless of sequence length constraints.
Test Coverage Expansion
tests/attention/test_trtllm_gen_attention.py
Removed conditional pytest skip for nvfp4 output dtype with speculative decoding. Expanded test_trtllm_batch_decode head_dim parametrization from [128] to [64, 128, 256]. Removed fixed head_dim constraint from test_trtllm_batch_decode_spec and extended backend-specific configuration matrix to include head_dim as additional dimension in test tuples.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

model: dsv3.2

Suggested reviewers

  • aleozlx
  • cyx-6
  • joker-eph
  • yzh119
  • nvmbreughe
  • saltyminty
  • jiahanc

Poem

🐰 A kernel now clearer, no dtype disguise,
Tests bloom with dimensions across all our tries—
From sixty-four onward to two-fifty-six wide,
Our attention hops faster with generative stride! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title '[Fmha] support nvfp4 output keepsMmaAb generation kernels' directly describes the main change: enabling nvfp4 output support for FMHA kernels.
Description check ✅ Passed The PR description includes a clear summary of changes, performance benchmarks, related issues, and completed pre-commit/test checklist items matching the template requirements.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can approve the review once all CodeRabbit's comments are resolved.

Enable the reviews.request_changes_workflow setting to automatically approve the review once all CodeRabbit's comments are resolved.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables support for nvfp4 output in generation kernels, which involves updating cubin artifacts, adjusting kernel selection logic, and expanding test coverage. The changes appear to be well-aligned with the PR's objectives. I've identified one area for improvement concerning a hardcoded value marked with a FIXME, which should be addressed to prevent future issues.

Comment on lines +818 to +819
// FIXME: set this with options.mUsesSharedPagedKvIdx.
params.mUsesSharedPagedKvIdx = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a FIXME here to set mUsesSharedPagedKvIdx from options, but it's currently hardcoded to true. This introduces technical debt and could lead to issues if this parameter needs to be configurable in the future. It would be best to plumb this option through from TllmGenFmhaRunnerParams and set it dynamically.

Suggested change
// FIXME: set this with options.mUsesSharedPagedKvIdx.
params.mUsesSharedPagedKvIdx = true;
// FIXME: set this with options.mUsesSharedPagedKvIdx.
params.mUsesSharedPagedKvIdx = options.mUsesSharedPagedKvIdx;

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 818-819: Replace the hardcoded true assignment for
params.mUsesSharedPagedKvIdx with the corresponding field from the caller
options object so the flag reflects the caller's intent; locate the assignment
to params.mUsesSharedPagedKvIdx and set it from options.mUsesSharedPagedKvIdx
(or the actual options struct in scope, e.g., kernelOptions or opts) and remove
the FIXME comment so non-shared layouts are handled correctly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f11a9187-4bec-48c6-8002-e55c771a0b61

📥 Commits

Reviewing files that changed from the base of the PR and between 043bc43 and d085be5.

📒 Files selected for processing (4)
  • flashinfer/artifacts.py
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • tests/attention/test_trtllm_gen_attention.py

Comment on lines +818 to +819
// FIXME: set this with options.mUsesSharedPagedKvIdx.
params.mUsesSharedPagedKvIdx = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Wire mUsesSharedPagedKvIdx from options instead of hardcoding true.

At Line 819, forcing this to true makes all calls behave as shared-index mode and ignores caller intent, which can produce incorrect paged-KV indexing for non-shared layouts.

Suggested fix
-    // FIXME: set this with options.mUsesSharedPagedKvIdx.
-    params.mUsesSharedPagedKvIdx = true;
+    params.mUsesSharedPagedKvIdx = options.mUsesSharedPagedKvIdx;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// FIXME: set this with options.mUsesSharedPagedKvIdx.
params.mUsesSharedPagedKvIdx = true;
params.mUsesSharedPagedKvIdx = options.mUsesSharedPagedKvIdx;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/kernelParams.h` around lines 818 - 819,
Replace the hardcoded true assignment for params.mUsesSharedPagedKvIdx with the
corresponding field from the caller options object so the flag reflects the
caller's intent; locate the assignment to params.mUsesSharedPagedKvIdx and set
it from options.mUsesSharedPagedKvIdx (or the actual options struct in scope,
e.g., kernelOptions or opts) and remove the FIXME comment so non-shared layouts
are handled correctly.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables support for nvfp4 output with keepsMmaAb generation kernels in FMHA. The changes include updating kernel selection logic, expanding test coverage by removing a pytest.skip and adding more configurations, and updating cubin artifacts. A new field mUsesSharedPagedKvIdx is introduced for vLLM/FlashInfer paged KV indices. My review includes a suggestion to address a FIXME related to the hardcoded initialization of this new field.

// TODO: Integrate trtllm block-sparse attention kernels when needed.
params.mUseBlockSparseAttention = false;
// FIXME: set this with options.mUsesSharedPagedKvIdx.
params.mUsesSharedPagedKvIdx = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

As the FIXME comment on the preceding line indicates, mUsesSharedPagedKvIdx is currently hardcoded to true. This should be properly configured from the options object instead of being hardcoded. To fix this, you'll likely need to add the mUsesSharedPagedKvIdx field to the TllmGenFmhaRunnerParams struct and ensure it's populated correctly from the calling code.

Suggested change
params.mUsesSharedPagedKvIdx = true;
params.mUsesSharedPagedKvIdx = options.mUsesSharedPagedKvIdx;

)
@pytest.mark.parametrize("enable_pdl", [True, False, None])
@pytest.mark.parametrize("enable_sink", [True, False])
@pytest.mark.parametrize("enable_pdl", [False])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are enable_pdl and enable_sink also not expected to work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, let me revert the changes. I used for debugging locally. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks!

@PerkzZheng
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@PerkzZheng is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/artifacts.py`:
- Line 138: The constant TRTLLM_GEN_FMHA currently points to an artifact path
that returns 404 and will break runtime kernel loading; verify or correct the
artifact reference by either uploading the missing files to the path referenced
by TRTLLM_GEN_FMHA or update TRTLLM_GEN_FMHA to the correct existing artifact
path and matching checksum values (ensure checksums.txt location and filenames
match the uploaded cubin assets); update only the TRTLLM_GEN_FMHA value and
associated checksum entries so runtime can fetch the CUDA kernels successfully.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9b5b9861-a478-4309-9a68-ac617e23ca8e

📥 Commits

Reviewing files that changed from the base of the PR and between 14f9989 and 86c1c32.

📒 Files selected for processing (1)
  • flashinfer/artifacts.py

@saltyminty
Copy link
Collaborator

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !429 has been created, and the CI pipeline #46475427 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46475427: 10/20 passed

@baonudesifeizhai
Copy link
Contributor

baonudesifeizhai commented Mar 20, 2026

script: https://paste.ubuntu.com/p/k2WYykRJ4Y/
llama31-8b-nvfp4-tp1

rate=1.0
  req_tp   fused=0.997  unfused=0.997  delta=+0.01%
  tok_tp   fused=1147.165  unfused=1147.090  delta=+0.01%
  med_ttft fused=20.757 unfused=20.576 delta=+0.88%
  p99_ttft fused=31.199 unfused=27.889 delta=+11.87%
  med_itl  fused=3.034 unfused=3.092 delta=-1.88%
  p99_itl  fused=3.284 unfused=3.322 delta=-1.14%
  p99_e2el fused=439.575 unfused=444.967 delta=-1.21%

rate=5.0
  req_tp   fused=4.983  unfused=4.983  delta=+0.01%
  tok_tp   fused=5735.756  unfused=5735.417  delta=+0.01%
  med_ttft fused=21.692 unfused=21.542 delta=+0.69%
  p99_ttft fused=35.131 unfused=34.616 delta=+1.49%
  med_itl  fused=3.160 unfused=3.230 delta=-2.19%
  p99_itl  fused=12.974 unfused=12.585 delta=+3.09%
  p99_e2el fused=517.024 unfused=520.552 delta=-0.68%

rate=10.0
  req_tp   fused=9.964  unfused=9.963  delta=+0.01%
  tok_tp   fused=11468.030  unfused=11467.388  delta=+0.01%
  med_ttft fused=22.395 unfused=21.902 delta=+2.25%
  p99_ttft fused=36.214 unfused=36.203 delta=+0.03%
  med_itl  fused=3.200 unfused=3.284 delta=-2.56%
  p99_itl  fused=13.122 unfused=12.698 delta=+3.34%
  p99_e2el fused=583.897 unfused=587.033 delta=-0.53%

rate=15.0
  req_tp   fused=14.947  unfused=14.946  delta=+0.01%
  tok_tp   fused=17204.079  unfused=17202.394  delta=+0.01%
  med_ttft fused=22.699 unfused=22.851 delta=-0.67%
  p99_ttft fused=42.928 unfused=43.033 delta=-0.24%
  med_itl  fused=3.620 unfused=3.734 delta=-3.07%
  p99_itl  fused=13.085 unfused=12.944 delta=+1.09%
  p99_e2el fused=626.241 unfused=630.919 delta=-0.74%

rate=20.0
  req_tp   fused=19.921  unfused=19.919  delta=+0.01%
  tok_tp   fused=22928.542  unfused=22926.463  delta=+0.01%
  med_ttft fused=23.667 unfused=23.471 delta=+0.84%
  p99_ttft fused=43.897 unfused=42.634 delta=+2.96%
  med_itl  fused=3.807 unfused=3.878 delta=-1.84%
  p99_itl  fused=13.316 unfused=12.926 delta=+3.01%
  p99_e2el fused=669.095 unfused=670.505 delta=-0.21%

rate=inf
  req_tp   fused=126.734  unfused=125.732  delta=+0.80%
  tok_tp   fused=145870.917  unfused=144717.584  delta=+0.80%
  med_ttft fused=2934.189 unfused=2955.540 delta=-0.72%
  p99_ttft fused=6006.058 unfused=6027.260 delta=-0.35%
  med_itl  fused=26.832 unfused=23.673 delta=+13.35%
  p99_itl  fused=82.033 unfused=84.370 delta=-2.77%
  p99_e2el fused=7680.289 unfused=7742.735 delta=-0.81%

llama33-70b-nvfp4-tp4-pr2795:

rate=1.0
  req_tp   fused=0.991  unfused=0.991  delta=+0.03%
  tok_tp   fused=1140.812  unfused=1140.501  delta=+0.03%
  med_ttft fused=48.950 unfused=48.932 delta=+0.04%
  p99_ttft fused=74.607 unfused=73.578 delta=+1.40%
  med_itl  fused=8.225 unfused=8.586 delta=-4.20%
  p99_itl  fused=8.618 unfused=8.947 delta=-3.68%
  p99_e2el fused=1197.490 unfused=1245.535 delta=-3.86%

rate=5.0
  req_tp   fused=4.955  unfused=4.953  delta=+0.04%
  tok_tp   fused=5703.549  unfused=5701.304  delta=+0.04%
  med_ttft fused=52.386 unfused=52.430 delta=-0.08%
  p99_ttft fused=90.788 unfused=92.978 delta=-2.36%
  med_itl  fused=8.348 unfused=8.732 delta=-4.39%
  p99_itl  fused=33.220 unfused=32.888 delta=+1.01%
  p99_e2el fused=1501.603 unfused=1551.192 delta=-3.20%

rate=10.0
  req_tp   fused=9.906  unfused=9.900  delta=+0.05%
  tok_tp   fused=11401.289  unfused=11395.071  delta=+0.05%
  med_ttft fused=55.904 unfused=56.949 delta=-1.83%
  p99_ttft fused=112.963 unfused=108.872 delta=+3.76%
  med_itl  fused=9.283 unfused=9.695 delta=-4.25%
  p99_itl  fused=34.775 unfused=35.334 delta=-1.58%
  p99_e2el fused=1999.183 unfused=2058.705 delta=-2.89%

rate=15.0
  req_tp   fused=14.849  unfused=14.845  delta=+0.02%
  tok_tp   fused=17090.686  unfused=17086.819  delta=+0.02%
  med_ttft fused=71.098 unfused=71.691 delta=-0.83%
  p99_ttft fused=148.361 unfused=147.833 delta=+0.36%
  med_itl  fused=10.745 unfused=11.104 delta=-3.23%
  p99_itl  fused=49.911 unfused=51.147 delta=-2.42%
  p99_e2el fused=2543.160 unfused=2615.250 delta=-2.76%

rate=20.0
  req_tp   fused=19.779  unfused=19.773  delta=+0.03%
  tok_tp   fused=22765.261  unfused=22758.874  delta=+0.03%
  med_ttft fused=79.304 unfused=79.785 delta=-0.60%
  p99_ttft fused=163.917 unfused=173.058 delta=-5.28%
  med_itl  fused=11.123 unfused=11.456 delta=-2.90%
  p99_itl  fused=65.689 unfused=65.716 delta=-0.04%
  p99_e2el fused=3062.634 unfused=3165.766 delta=-3.26%

rate=inf
  req_tp   fused=43.053  unfused=42.558  delta=+1.16%
  tok_tp   fused=49554.427  unfused=48983.921  delta=+1.16%
  med_ttft fused=9523.693 unfused=9640.706 delta=-1.21%
  p99_ttft fused=19246.435 unfused=19524.213 delta=-1.42%
  med_itl  fused=140.547 unfused=142.331 delta=-1.25%
  p99_itl  fused=150.834 unfused=156.341 delta=-3.52%
  p99_e2el fused=22742.839 unfused=22984.450 delta=-1.05%

results saved to: /root/zdj/vllm/results-34988/llama33-70b-nvfp4-tp4-pr2795

@ProExpertProg

@ProExpertProg
Copy link

@baonudesifeizhai thank you! If I'm reading this right fused latency goes from +2% - -2% to 0% - -4%, so overall an average improvement of 2%?

@baonudesifeizhai
Copy link
Contributor

nvidia/Llama-3.3-70B-Instruct-NVFP4 4 card tp

nvidia/Llama-3.1-8B-Instruct-NVFP4 1 card....
yep and at least no regression in inf

@baonudesifeizhai thank you! If I'm reading this right fused latency goes from +2% - -2% to 0% - -4%, so overall an average improvement of 2%?

@PerkzZheng
Copy link
Contributor Author

[FAILED] Pipeline #46475427: 10/20 passed

@saltyminty it seems that the failing ones are not related. Is it okay to merge it ? Thanks.

@yongwww yongwww added the run-ci label Mar 20, 2026
- Update cubin artifact path/checksum to new build with nvfp4 output support
- Fix kernel selection: remove E2M1 output dtype condition from mixed-precision path,
  allowing nvfp4 output to use GQA generation kernel selection heuristics
- Always invoke selectTileSizeQForGqaGeneration (not just for maxSeqLenQ > 1)
- Add mUsesSharedPagedKvIdx field to KernelParams for vLLM/FlashInfer paged KV index
- Remove speculative-decode skip for nvfp4 output in tests
- Expand test coverage: head_dim [64, 128, 256], additional batch configs

AI-assisted

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

revert

revert

revert
@PerkzZheng PerkzZheng force-pushed the user/perkzz/add-e2m1-output branch from 86c1c32 to 2632da4 Compare March 21, 2026 13:33
@PerkzZheng PerkzZheng requested a review from saltyminty as a code owner March 21, 2026 13:33
@PerkzZheng
Copy link
Contributor Author

[FAILED] Pipeline #46475427: 10/20 passed

@saltyminty it seems that the failing ones are not related. Is it okay to merge it ? Thanks.

Just rebased. Feel free to trigger CI again. Thanks!

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)

776-777: Reduce selector overhead now that heuristic selection is unconditional.

Line 777 now runs on every GQA-generation selection path. Consider removing per-call std::unordered_map construction in selectTileSizeQForGqaGeneration to avoid extra host overhead in decode-heavy workloads.

♻️ Suggested refactor
-    std::unordered_map<int, float> kernelMainloopCost = {
-        {128, 2.2}, {64, 1.68}, {32, 1.48}, {16, 1.2}, {8, 1.0}
-    };
-
-    std::unordered_map<int, float> kernelReductionCost = {
-        {128, 1.32}, {64, 1.2}, {32, 1.08}, {16, 1.03}, {8, 1.0}
-    };
+    auto kernelMainloopCost = [](int tileSizeQ) -> float {
+      switch (tileSizeQ) {
+        case 128: return 2.2f;
+        case 64: return 1.68f;
+        case 32: return 1.48f;
+        case 16: return 1.2f;
+        case 8: return 1.0f;
+        default: return FLT_MAX;
+      }
+    };
+    auto kernelReductionCost = [](int tileSizeQ) -> float {
+      switch (tileSizeQ) {
+        case 128: return 1.32f;
+        case 64: return 1.2f;
+        case 32: return 1.08f;
+        case 16: return 1.03f;
+        case 8: return 1.0f;
+        default: return FLT_MAX;
+      }
+    };
...
-      float modelingKernelTime = kernelMainloopCost.at(tileSizeQ) * seqLenPerCtaKv +
-                                 kernelReductionCost.at(tileSizeQ) * kernelReductionSeqLenFactor *
+      float modelingKernelTime = kernelMainloopCost(tileSizeQ) * seqLenPerCtaKv +
+                                 kernelReductionCost(tileSizeQ) * kernelReductionSeqLenFactor *
                                      ctaLaunchParams.mMaxNumCtasKv;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh` around lines 776 - 777,
selectTileSizeQForGqaGeneration currently builds a std::unordered_map on every
call which adds host-side overhead now that it's invoked unconditionally; change
it to use a persistent cache instead of per-call construction by moving the map
out of the function (e.g., a static or thread_local std::unordered_map or a
member cache in the owning class) and look up/insert entries rather than
recreating the container each time. Update selectTileSizeQForGqaGeneration to
accept/use the persistent cache (or reference it globally) and ensure
thread-safety (e.g., use a mutex or thread_local storage) when accessing the
cache so repeated calls during decoding reuse the prebuilt data rather than
reconstructing it.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 776-777: selectTileSizeQForGqaGeneration currently builds a
std::unordered_map on every call which adds host-side overhead now that it's
invoked unconditionally; change it to use a persistent cache instead of per-call
construction by moving the map out of the function (e.g., a static or
thread_local std::unordered_map or a member cache in the owning class) and look
up/insert entries rather than recreating the container each time. Update
selectTileSizeQForGqaGeneration to accept/use the persistent cache (or reference
it globally) and ensure thread-safety (e.g., use a mutex or thread_local
storage) when accessing the cache so repeated calls during decoding reuse the
prebuilt data rather than reconstructing it.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9ea0fe68-2306-4c7d-8170-3aa1aa442fe1

📥 Commits

Reviewing files that changed from the base of the PR and between 86c1c32 and 2632da4.

📒 Files selected for processing (2)
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • tests/attention/test_trtllm_gen_attention.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/attention/test_trtllm_gen_attention.py

@baonudesifeizhai
Copy link
Contributor

wait TRTLLM_GEN_FMHA: str = "3fec9f12548f83f44e4ca60394a2946238a677f1/fmha/trtllm-gen/" thats the whole points
...
the trtllm kernel link...

@PerkzZheng
Copy link
Contributor Author

wait TRTLLM_GEN_FMHA: str = "3fec9f12548f83f44e4ca60394a2946238a677f1/fmha/trtllm-gen/" thats the whole points ... the trtllm kernel link...

no worries. this is expected. we have another MR just merged (#2836) which includes the required cubins for this MR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants