feat: add SM120 fmha_v2 kernels to AOT pip wheel builds#2885
feat: add SM120 fmha_v2 kernels to AOT pip wheel builds#2885blake-snc wants to merge 2 commits intoflashinfer-ai:mainfrom
Conversation
`gen_trtllm_fmha_v2_sm120_module()` was already callable via JIT (generate_kernels.py dispatches to it at runtime), but was never registered in gen_all_modules() in aot.py. SM120/SM121 devices getting flashinfer from a pip wheel would skip the fmha_v2 SM120 kernels entirely during the AOT build step, falling back to slower paths or missing support. Add it to the `has_sm120 or has_sm121` section alongside the other SM120 modules (fused MOE, GEMM, FP4 quantization). Contributed by Second Nature Computing (https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughThis change imports the SM120 FMHA_V2 attention module generator and appends its JIT spec into Changes
Sequence Diagram(s)(omitted — change is a small wiring update, not a multi-component flow) Estimated code review effort🎯 1 (Trivial) | ⏱️ ~3 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the AOT compilation process for FlashInfer's pip wheel builds by integrating the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request integrates a new TensorRT-LLM Flash Attention v2 module for SM120 architectures into the AOT compilation process. A review comment suggests updating an existing code comment to accurately reflect the inclusion of attention kernels alongside fused MOE and GEMM, improving clarity and maintainability.
| jit_specs.append(gen_cutlass_fused_moe_sm120_module()) | ||
| jit_specs.append(gen_gemm_sm120_module()) | ||
| jit_specs.append(gen_gemm_sm120_module_cutlass_fp4()) | ||
| jit_specs.append(gen_trtllm_fmha_v2_sm120_module()) |
There was a problem hiding this comment.
With the addition of this fmha_v2 module, the comment on lines 525-527 is now slightly outdated as it only mentions 'fused MOE and GEMM'. For better maintainability, please consider updating it to include attention kernels for clarity.
For example:
- # SM120 and SM121 share the same CUTLASS kernels for fused MOE and GEMM.
+ # SM120 and SM121 share the same kernels for fused MOE, GEMM, and attention.There was a problem hiding this comment.
🧹 Nitpick comments (1)
flashinfer/aot.py (1)
524-531: Consider decoupling FMHA v2 from theadd_moegate.At Line 531, this is an attention kernel but it is only emitted when
add_moeisTrue. For custom AOT configs (--add-moe false), that can unexpectedly drop FMHA v2.♻️ Suggested placement change
@@ - if add_moe: + if has_sm120 or has_sm121: + jit_specs.append(gen_trtllm_fmha_v2_sm120_module()) + + if add_moe: @@ if has_sm120 or has_sm121: # SM120 and SM121 share the same CUTLASS kernels for fused MOE and GEMM. # The SM120 module generators use supported_major_versions=[12] which # compiles for all SM12x targets. jit_specs.append(gen_cutlass_fused_moe_sm120_module()) jit_specs.append(gen_gemm_sm120_module()) jit_specs.append(gen_gemm_sm120_module_cutlass_fp4()) - jit_specs.append(gen_trtllm_fmha_v2_sm120_module())🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/aot.py` around lines 524 - 531, The FMHA v2 module gen_trtllm_fmha_v2_sm120_module() is currently gated by the add_moe flag (inside the has_sm120/has_sm121 block) which causes FMHA v2 to be omitted when --add-moe false; update the logic so that gen_trtllm_fmha_v2_sm120_module() is appended to jit_specs independently of add_moe (i.e., move or duplicate the call out of the add_moe-specific branch in the SM120/SM121 handling code), or replace the add_moe check with a dedicated attention-kernel condition so FMHA v2 is always emitted for SM12x targets regardless of the MOE flag.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/aot.py`:
- Around line 524-531: The FMHA v2 module gen_trtllm_fmha_v2_sm120_module() is
currently gated by the add_moe flag (inside the has_sm120/has_sm121 block) which
causes FMHA v2 to be omitted when --add-moe false; update the logic so that
gen_trtllm_fmha_v2_sm120_module() is appended to jit_specs independently of
add_moe (i.e., move or duplicate the call out of the add_moe-specific branch in
the SM120/SM121 handling code), or replace the add_moe check with a dedicated
attention-kernel condition so FMHA v2 is always emitted for SM12x targets
regardless of the MOE flag.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
gen_trtllm_fmha_v2_sm120_module()exists injit/attention/modules.pyand the JIT runtime path (generate_kernels.py) already dispatches to it correctly. However,aot.py'sgen_all_modules()— which drives the pip wheel AOT build — was missing it from thehas_sm120 or has_sm121section.This means SM120/SM121 devices using a pip wheel would never get the fmha_v2 SM120 kernels compiled into the wheel, and would have to fall back to slower paths.
Fix: Add
gen_trtllm_fmha_v2_sm120_module()to thehas_sm120 or has_sm121block inaot.py, alongside the other SM120 modules (fused MOE, GEMM, FP4 quantization).No behavior change for JIT users; only affects AOT pip wheel builds.
Addresses the AOT gap noted in #2555.
Contributed by Second Nature Computing (https://joinsecondnature.com)
Summary by CodeRabbit