Update trtllm FMHA cubins#3317
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (3)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughThis PR: (1) uses an ChangesTRTLLM FMHA MLA Decode and TMA Reshape Updates
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request updates the TRT-LLM FMHA kernel launcher and parameters to align with newer cubin ABIs, including support for output scaling factors and refined TMA box reshaping logic. Key changes involve adding tmaOSf_ and mReshapeFactorKv to the KernelParams struct and implementing conditional mask type selection for MLA decode. Review feedback identifies a critical omission where the new tmaOSf_ descriptor is not initialized in setKernelParams, which could lead to undefined behavior. Additionally, it is recommended to refactor duplicated MLA detection logic for better maintainability and to use specific K/V data types when checking for TMA reshape compatibility to improve robustness.
| // TMA descriptor for V. | ||
| CUtensorMap tmaV_; | ||
| // TMA descriptor for output scaling factor. | ||
| CUtensorMap tmaOSf_; |
There was a problem hiding this comment.
The new tmaOSf_ member is added to the KernelParams struct to align with the newer cubin ABI, but it is not initialized in the setKernelParams function. If the newer cubins expect a valid TMA descriptor for output scaling factors (e.g., when performing FP4 quantization on output), this will lead to undefined behavior or crashes as the descriptor will be all zeros. Please add the necessary logic in setKernelParams to build the TMA descriptor for tmaOSf_ when options.oSfPtr is provided, similar to how tmaKSf_ and tmaVSf_ are handled.
There was a problem hiding this comment.
@djmmoss I am not quite sure why those changes are made. please add me to review next time. Thanks!
And no worries about that. I will revert them in my MR.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 703-706: The code block defining the canReshapeTmaKv boolean is
misformatted; run clang-format (e.g., pre-commit run clang-format --all-files)
to reformat this declaration so spacing and line breaks follow project style for
the canReshapeTmaKv initializer that references isPagedKv(options.mQkvLayout),
options.mHeadDimQk, swizzleKv, canUseTmaKvReshape(options,
kernelMeta.mDataTypeKv, /*isK*/ true/false), and ensure the file
include/flashinfer/trtllm/fmha/kernelParams.h is committed after formatting.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 12a74a50-4c9b-43b3-90f4-2ca34f9b744f
📒 Files selected for processing (3)
csrc/trtllm_fmha_kernel_launcher.cuflashinfer/artifacts.pyinclude/flashinfer/trtllm/fmha/kernelParams.h
ac48b6e to
9113c91
Compare
Point FMHA at the newer public trtllm cubin publish, align the FMHA parameter ABI, and use dense mask selection for MLA decode kernels.
9113c91 to
a6b9087
Compare
|
/bot run |
📌 Description
Updates the trtllm FMHA artifact path and checksum to the newer cubins. Aligns the FMHA parameter ABI expected by those cubins and uses dense mask selection for MLA decode generation kernels.
🔍 Related Issues
None.
🧪 Tests
pre-commit run --all-filespytest tests/attention/test_cute_dsl_mla_decode.py::test_cute_dsl_vs_trtllm_gen[True-128-1] tests/attention/test_trtllm_ragged_kv_stride.py -q -ra --tb=shorttests/attention/*.pyitem after-k trtllm:18314 passed, 22979 skipped, 342803 deselected, 686 warnings7/7 PASS, all mismatches0Reviewer Notes
None.