[Fmha] support nvfp4 output keepsMmaAb generation kernels#2988
[Fmha] support nvfp4 output keepsMmaAb generation kernels#2988bkryu merged 3 commits intoflashinfer-ai:mainfrom
Conversation
- Update cubin artifact path/checksum to new build with nvfp4 output support - Fix kernel selection: remove E2M1 output dtype condition from mixed-precision path, allowing nvfp4 output to use GQA generation kernel selection heuristics - Always invoke selectTileSizeQForGqaGeneration (not just for maxSeqLenQ > 1) - Add mUsesSharedPagedKvIdx field to KernelParams for vLLM/FlashInfer paged KV index - Remove speculative-decode skip for nvfp4 output in tests - Expand test coverage: head_dim [64, 128, 256], additional batch configs AI-assisted Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> revert revert revert
- Add o_data_type to plan_params in _test_trtllm_batch_prefill and _test_trtllm_batch_decode to properly test output dtype selection - Remove head_dim=64 from test_trtllm_batch_decode and test_trtllm_batch_decode_spec parametrize lists due to buggy Sm100f SwapsAbForGen cubins in the current artifact store - Fix std::max(1, ...) guard in fmhaKernels.cuh to avoid numTokensPerCtaQ=0 when mStepQ < mNumHeadsQPerKv
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughClamps CTA token grouping to avoid zero-valued division and relaxes a dtype gate during GQA kernel selection; expands trtllm generation tests to explicitly set output dtype and add head_dim=256 coverage for decode/speculative cases. Changes
Sequence Diagram(s)Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
/bot run |
There was a problem hiding this comment.
Code Review
This pull request updates the FMHA kernels to ensure at least one token per CTA during calculation and relaxes constraints on mixed precision kernels by removing the specific check for E2M1 output data types. The test suite is updated to include output data type configuration in planning parameters, expand coverage for head dimensions of 256, and enable nvfp4 support for speculative decoding by removing previous skips. I have no further feedback to provide as the review comments were purely evaluative.
|
[FAILED] Pipeline #47805888: 9/20 passed |
|
@saltyminty sorry for asking you again. it seems that all related tests have been passed (https://gitlab-master.nvidia.com/dl/flashinfer/flashinfer-ci/-/jobs/292875212). Feel free to merge if everything looks good. Thank you! |
|
@aleozlx take a look, ask from redhat team (vllm) |
|
/bot run |
|
[FAILED] Pipeline #47955317: 8/20 passed |
bkryu
left a comment
There was a problem hiding this comment.
@PerkzZheng, please check the internal CI failures on SM120 cards on tests/attention/test_trtllm_gen_attention.py
I suspect it has to do with XQA on head size 256. If XQA is not supported, we need to fix on SM120
Thanks. let me skip XQA for headDim 256. @qsang-nv for vis. |
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
7fa1088 to
e82d9ad
Compare
|
/bot run |
Same as #2795. It is recreated because the original branch was force-pushed.
AI-assisted
📌 Description
Qwen3-480B (num_qo_heads=96, num_kv_heads=8, head_dim_qk=128, head_dim_vo=128)
Speedup (baseline / opt)
GPT-OSS (num_qo_heads=64, num_kv_heads=8, head_dim_qk=64, head_dim_vo=64)
Speedup (baseline / opt)
Summary
Speedup scales strongly with
s_qo(speculative decode query length):s_qo=2: 1.1–1.8x speedup across both modelss_qo=4: 1.9–2.6x speedups_qo=8: 2.8–5.1x speedup (peak 5.12x on GPT-OSS, bs=32)🔍 Related Issues
#2632
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Tests