Skip to content

Horizontal MTP kernel + DSTATE=96 support for selective_state_update#2845

Closed
ishovkun wants to merge 2 commits intoflashinfer-ai:mainfrom
ishovkun:main
Closed

Horizontal MTP kernel + DSTATE=96 support for selective_state_update#2845
ishovkun wants to merge 2 commits intoflashinfer-ai:mainfrom
ishovkun:main

Conversation

@ishovkun
Copy link
Copy Markdown
Contributor

@ishovkun ishovkun commented Mar 20, 2026

📌 Description

This PR adds a high-performance horizontal MTP kernel for selective_state_update and extends DSTATE support to non-power-of-2 values (e.g. DSTATE=96). It also includes the full history of the vertical MTP kernel, int16 block-scaled state, stochastic rounding, and test infrastructure improvements accumulated on this branch.

Key changes

Horizontal MTP kernel (new)

  • 5-warp design: 4 compute warps + 1 TMA warp per CTA, processing 1 head per CTA.
  • TMA-level pipelining: state chunks are double-buffered via NUM_IN_STAGES=2 pipeline stages. While compute processes chunk N, TMA preloads chunk N+1.
  • Tight-spin parity barriers: replaces cuda::barrier wait with a custom arrive_and_wait_parity pattern for lower-latency synchronization.
  • f32x2 packed SIMD: state recurrence uses float2 operations to reduce instruction count and register pressure.
  • HEADS_PER_CTA=1: benchmarking on B200 showed that single-head-per-CTA gives better occupancy at all practical batch sizes (the 2-head variant only wins at batch ≥ 1024, and by < 2%).
  • TMA_STATE_ROWS is a template parameter (set in the launcher), allowing future tuning per dtype/dim/dstate without kernel code changes.

Non-power-of-2 DSTATE support (DSTATE=96)

  • Tile math pads DSTATE to next power of 2 (DSTATE_PADDED = nextPow2(DSTATE)) for tiling only.
  • Threads whose columns land beyond the real DSTATE zero-fill their registers, skip B/C loads, and skip state writes.
  • No extra shared memory allocated — TMA still loads exactly DSTATE columns.
  • PackedAligned alignment fixed to use largestPow2Divisor instead of raw sizeof, so alignas(N) is always a valid power of 2.

Dispatch & test coverage

  • dispatchRatio supports nheads/ngroups ratios: 1, 2, 4, 8, 16, 32, 64 (for both STP and MTP).
  • MTP tests parametrized over all three algorithms: simple, vertical, horizontal.
  • Vertical/horizontal tests require SM100+ and are skipped on older GPUs.
  • Int16 block-scaled state tests assert rejection for vertical/horizontal (unsupported).

Benchmark (B200)

DIM=64, DSTATE=128

ssu_mtp_hor_32

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • New Features

    • Added a GPU capability detection utility function to identify support for stochastic rounding operations on compatible hardware.
  • Tests

    • Updated test infrastructure to utilize the new capability detection utility, improving GPU support validation across Mamba test suite fixtures.

ishovkun and others added 2 commits March 5, 2026 12:07
Do not assume the forward compatibility of cv.rs

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the selective_state_update functionality by integrating a new, highly optimized horizontal MTP kernel and enabling flexible DSTATE configurations beyond power-of-2 values. These changes aim to boost performance, especially on newer GPU architectures, and improve the robustness of the system by refining stochastic rounding logic and expanding test coverage.

Highlights

  • Horizontal MTP Kernel: Introduced a new high-performance horizontal MTP kernel for selective_state_update with a 5-warp design, TMA-level pipelining, tight-spin parity barriers, and f32x2 packed SIMD.
  • Non-Power-of-2 DSTATE Support: Extended DSTATE support to non-power-of-2 values (e.g., DSTATE=96) by padding for tiling, zero-filling registers, and skipping loads/writes for threads beyond the real DSTATE.
  • Dispatch & Test Coverage: Enhanced dispatchRatio to support various nheads/ngroups ratios and expanded MTP tests to cover simple, vertical, and horizontal algorithms, with specific GPU requirements.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new is_cvt_rs_supported utility function to more accurately check for GPU support of the cvt.rs.f16x2.f32 PTX instruction, and refactors several tests to use this new function. The changes improve the precision of hardware feature detection. My main feedback is to enhance the new utility function to also check for the required CUDA version, making it more robust and consistent with other similar checks in the codebase.

Comment on lines +597 to +599
major, _ = get_compute_capability(device)
# SM100a and SM110a support cvt.rs; SM120 does not.
return major in (10, 11)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of is_cvt_rs_supported only checks the major compute capability. However, support for specific architecture features like sm_100a also depends on the CUDA toolkit version. For consistency with other support-check functions in this file (e.g., is_sm100a_supported), this function should also verify the minimum required CUDA version. This ensures that the check is robust and prevents runtime errors if an older CUDA toolkit is used with a compatible GPU.

Suggested change
major, _ = get_compute_capability(device)
# SM100a and SM110a support cvt.rs; SM120 does not.
return major in (10, 11)
major, _ = get_compute_capability(device)
# SM100a and SM110a support cvt.rs; SM120 does not.
if major == 10:
return version_at_least(torch.version.cuda, "12.8")
if major == 11:
return version_at_least(torch.version.cuda, "13.0")
return False

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 21, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d844dbd0-8e93-48fc-9f25-8769e9bebe2d

📥 Commits

Reviewing files that changed from the base of the PR and between af1e02d and 5e7ae55.

📒 Files selected for processing (4)
  • flashinfer/utils.py
  • tests/mamba/test_philox_rounding.py
  • tests/mamba/test_selective_state_update_mtp.py
  • tests/mamba/test_selective_state_update_stp.py

📝 Walkthrough

Walkthrough

A new GPU capability helper is_cvt_rs_supported() was added to flashinfer/utils.py to detect compute capability support for CVT.RS stochastic rounding operations. Three Mamba test files were updated to use this helper instead of directly checking compute capability versions.

Changes

Cohort / File(s) Summary
GPU Capability Helper
flashinfer/utils.py
Added is_cvt_rs_supported(device: torch.device = None) -> bool function that defaults to CUDA device and returns True only for compute capability major versions 10 and 11, enabling centralized CVT.RS support detection.
Stochastic Rounding Tests
tests/mamba/test_philox_rounding.py, tests/mamba/test_selective_state_update_mtp.py, tests/mamba/test_selective_state_update_stp.py
Updated three test files to import and use is_cvt_rs_supported() for gating CVT.RS stochastic rounding behavior, replacing direct compute capability checks and improving code maintainability.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Possibly related PRs

  • #2699: Adds the same is_cvt_rs_supported() helper function and updates identical Mamba test files with the same refactoring pattern.

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • cyx-6
  • aleozlx
  • kahyunnam
  • jimmyzho
  • bkryu
  • nvmbreughe

Poem

🐰 A helper function hops into sight,
Checking if compute can stochastically round just right,
Three tests now gather around its call,
No more hardcoded checks—much cleaner for all! 🎉

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 inconclusive)

Check name Status Explanation Resolution
Title check ❓ Inconclusive The title describes a horizontal MTP kernel and DSTATE=96 support, but the actual changes focus on adding an is_cvt_rs_supported utility function and updating tests to use it for stochastic rounding detection. Clarify whether this PR's primary change is the is_cvt_rs_supported utility function or the horizontal MTP kernel, as the title does not reflect the code changes shown in the summary.
Description check ❓ Inconclusive The description provides detailed information about the horizontal MTP kernel and DSTATE support, but the raw_summary shows only changes to GPU capability detection functions and test updates for stochastic rounding. Verify that the description accurately reflects all code changes included in this PR, particularly regarding the horizontal MTP kernel implementation versus utility function updates.
✅ Passed checks (1 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 85.71% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

Migrating from UI to YAML configuration.

Use the @coderabbitai configuration command in a PR comment to get a dump of all your UI settings in YAML format. You can then edit this YAML file and upload it to the root of your repository to configure CodeRabbit programmatically.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Mar 23, 2026

/bot run

@yzh119 yzh119 added the run-ci label Mar 23, 2026
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !446 has been created, and the CI pipeline #46793108 is currently running. I'll report back once the pipeline job completes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants