Horizontal MTP kernel + DSTATE=96 support for selective_state_update#2845
Horizontal MTP kernel + DSTATE=96 support for selective_state_update#2845ishovkun wants to merge 2 commits intoflashinfer-ai:mainfrom
Conversation
Do not assume the forward compatibility of cv.rs Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new is_cvt_rs_supported utility function to more accurately check for GPU support of the cvt.rs.f16x2.f32 PTX instruction, and refactors several tests to use this new function. The changes improve the precision of hardware feature detection. My main feedback is to enhance the new utility function to also check for the required CUDA version, making it more robust and consistent with other similar checks in the codebase.
| major, _ = get_compute_capability(device) | ||
| # SM100a and SM110a support cvt.rs; SM120 does not. | ||
| return major in (10, 11) |
There was a problem hiding this comment.
The current implementation of is_cvt_rs_supported only checks the major compute capability. However, support for specific architecture features like sm_100a also depends on the CUDA toolkit version. For consistency with other support-check functions in this file (e.g., is_sm100a_supported), this function should also verify the minimum required CUDA version. This ensures that the check is robust and prevents runtime errors if an older CUDA toolkit is used with a compatible GPU.
| major, _ = get_compute_capability(device) | |
| # SM100a and SM110a support cvt.rs; SM120 does not. | |
| return major in (10, 11) | |
| major, _ = get_compute_capability(device) | |
| # SM100a and SM110a support cvt.rs; SM120 does not. | |
| if major == 10: | |
| return version_at_least(torch.version.cuda, "12.8") | |
| if major == 11: | |
| return version_at_least(torch.version.cuda, "13.0") | |
| return False |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (4)
📝 WalkthroughWalkthroughA new GPU capability helper Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip Migrating from UI to YAML configuration.Use the |
|
/bot run |
📌 Description
This PR adds a high-performance horizontal MTP kernel for
selective_state_updateand extends DSTATE support to non-power-of-2 values (e.g. DSTATE=96). It also includes the full history of the vertical MTP kernel, int16 block-scaled state, stochastic rounding, and test infrastructure improvements accumulated on this branch.Key changes
Horizontal MTP kernel (new)
NUM_IN_STAGES=2pipeline stages. While compute processes chunk N, TMA preloads chunk N+1.cuda::barrierwait with a customarrive_and_wait_paritypattern for lower-latency synchronization.float2operations to reduce instruction count and register pressure.HEADS_PER_CTA=1: benchmarking on B200 showed that single-head-per-CTA gives better occupancy at all practical batch sizes (the 2-head variant only wins at batch ≥ 1024, and by < 2%).TMA_STATE_ROWSis a template parameter (set in the launcher), allowing future tuning per dtype/dim/dstate without kernel code changes.Non-power-of-2 DSTATE support (DSTATE=96)
DSTATE_PADDED = nextPow2(DSTATE)) for tiling only.PackedAlignedalignment fixed to uselargestPow2Divisorinstead of rawsizeof, soalignas(N)is always a valid power of 2.Dispatch & test coverage
dispatchRatiosupports nheads/ngroups ratios: 1, 2, 4, 8, 16, 32, 64 (for both STP and MTP).simple,vertical,horizontal.Benchmark (B200)
DIM=64, DSTATE=128
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
New Features
Tests