Skip to content

A Blackwell-optimized version of selective_state_update (mamba)#2387

Merged
yzh119 merged 11 commits intoflashinfer-ai:mainfrom
ishovkun:main
Jan 22, 2026
Merged

A Blackwell-optimized version of selective_state_update (mamba)#2387
yzh119 merged 11 commits intoflashinfer-ai:mainfrom
ishovkun:main

Conversation

@ishovkun
Copy link
Copy Markdown
Contributor

@ishovkun ishovkun commented Jan 20, 2026

📌 Description

This contributes a selective_state_update kernel that has been manually optimized for sm100.

runtime_vs_batch_size_NVIDIA_B200

Technically, the code in the sm100 module can be compiled and executed on sm90, but the performance of this kernel on Hopper is not great (see the image below).

image

Algorithm

The algorithmic differences between the new selective_state_update_kernel_producer_consumer_horizontal and the older Hopper-specific selective_state_update_kernel_producer_consumer_vertical are summarized below. Each CTA processes a single head. In the vertical case, each warp processes a row, and then performs a reduction using shfl_sync intrinsics. In case of the horizontal kernel, each thread consecutively processes multiple row elements while keeping track of the sum. This new treatment allows to lift the bottleneck off of the warp-level reduction on Blackwell.

Screenshot 2026-01-20 at 5 53 48 PM

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Could you please check consumer_func_horizontal and let me know whether the comments about the bank-conflict-free indexing suffices.

Summary by CodeRabbit

  • New Features

    • Added support for SM100 (Blackwell) and newer GPU architectures in Mamba selective state update operations
    • Automatic kernel selection based on device compute capability for optimal performance
  • Refactor

    • Improved runtime dispatch logic to route operations to architecture-specific optimized kernels
  • Bug Fixes

    • Enhanced CUDA error checking mechanism for better diagnostics

✏️ Tip: You can customize this high-level summary in your review settings.

yzh119 and others added 3 commits January 20, 2026 03:28
Split the Hopper SM90 module from supporting all future architectures
to only SM90. Add a new SM100+ module that will use horizontal
producer-consumer kernel optimized for Blackwell and newer GPUs.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @ishovkun, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the selective_state_update kernel for Mamba models by introducing a highly optimized version tailored for NVIDIA's latest Blackwell (SM100) architecture. This new kernel employs a horizontal producer-consumer pattern, complementing existing kernels for older GPU generations and ensuring optimal performance across different hardware. The changes involve integrating the new kernel into the JIT compilation system and implementing dynamic dispatch based on compute capability.

Highlights

  • Blackwell (SM100) Optimization: Introduced a new selective_state_update kernel specifically optimized for NVIDIA Blackwell (SM100) GPUs and newer architectures. This kernel utilizes a 'horizontal producer-consumer' design for enhanced performance.
  • Hopper (SM90) Kernel Refinement: The existing selective_state_update_sm90_module has been refined to explicitly target Hopper (SM90) GPUs only, clarifying its scope and preventing its use on future architectures where the new Blackwell kernel is more appropriate.
  • Dynamic Kernel Dispatch: The system now dynamically dispatches the appropriate selective_state_update kernel based on the detected CUDA compute capability: the new SM100+ kernel for Blackwell and newer, the SM90 kernel for Hopper, and a base kernel for older architectures.
  • Bank Conflict Avoidance: The new consumer_func_horizontal in the Blackwell kernel includes sophisticated indexing logic to avoid shared memory bank conflicts, crucial for maximizing performance on NVIDIA GPUs.
  • CUDA Error Handling: A new FLASHINFER_CUDA_CHECK macro has been added to improve CUDA error handling and debugging within the C++ codebase.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 20, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

This PR introduces SM100 (Blackwell) and newer architecture support for the mamba selective state update operation. It adds SM100-specific horizontal producer-consumer kernels, JIT compilation specs, and runtime dispatch logic that routes based on compute capability. A CUDA error-checking utility macro is also added.

Changes

Cohort / File(s) Summary
JIT Compilation Setup
flashinfer/aot.py, flashinfer/jit/mamba/__init__.py, flashinfer/jit/mamba/selective_state_update.py
Introduced SM100-specific JIT module generation with NVCC version constraints [10, 11, 12]; refined SM90 path to explicitly target NVCC version 9; integrated new module into AOT compilation pipeline via gen_selective_state_update_sm100_module()
Runtime Dispatch
flashinfer/mamba/selective_state_update.py
Added get_selective_state_update_module_sm100() accessor and updated dispatch routing logic to select SM100+ kernels for compute capability ≥ 10, SM90 for capability == 9, and base kernel for older GPUs
CUDA Kernels
include/flashinfer/mamba/selective_state_update.cuh
Implemented horizontal producer-consumer kernel path for SM100 with new SharedStorageHorizontal struct, multi-stage buffering, per-stage barriers, TMA alignment checks, and device functions producer_func_horizontal() and consumer_func_horizontal(); added global warpSize constant
Utilities
include/flashinfer/utils.cuh
Added FLASHINFER_CUDA_CHECK macro for centralized CUDA error validation with detailed error reporting

Sequence Diagram

sequenceDiagram
    participant Host
    participant AOT as AOT Compiler
    participant JITCompiler as JIT Compiler
    participant Runtime as Runtime Dispatch
    participant Device as GPU Device

    Host->>AOT: Trigger AOT compilation
    AOT->>JITCompiler: Generate SM90 JIT spec (NVCC 9)
    JITCompiler-->>AOT: SM90 JIT module
    AOT->>JITCompiler: Generate SM100 JIT spec (NVCC 10-12)
    JITCompiler-->>AOT: SM100 JIT module
    AOT-->>Host: Compiled modules available

    Host->>Runtime: Call selective_state_update()
    Runtime->>Runtime: Query GPU compute capability
    alt Compute Capability >= 10
        Runtime->>Runtime: Load SM100 module
        Runtime->>Device: Invoke horizontal producer-consumer kernel
    else Compute Capability == 9
        Runtime->>Runtime: Load SM90 module
        Runtime->>Device: Invoke vertical kernel
    else Older GPU
        Runtime->>Runtime: Load base module
        Runtime->>Device: Invoke base kernel
    end
    Device-->>Runtime: Kernel execution complete
    Runtime-->>Host: Result ready
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • IwakuraRein
  • kahyunnam
  • jiahanc
  • nvmbreughe
  • jimmyzho
  • cyx-6
  • aleozlx
  • yzh119

Poem

🐰 Blackwell hops with newer might,
Horizontal kernels blazing bright,
From Hopper's vertical tower tall,
SM100 answers the producer's call!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: adding a Blackwell-optimized version of the selective_state_update kernel for mamba.
Description check ✅ Passed The PR description includes a detailed explanation of the changes, algorithmic differences, performance benefits with images, and pre-commit/test checklist completion, though reviewer notes are present.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Blackwell-optimized version of the selective_state_update kernel, specifically targeting SM100+ architectures. The changes include adding new JIT compilation modules, updating the dispatch logic in Python to select the appropriate kernel based on compute capability, and implementing a new horizontal producer-consumer kernel in CUDA. The new kernel incorporates advanced features like TMA for state cache management and bank-conflict-free indexing for shared memory access. Overall, the changes are well-structured and aim to improve performance on newer NVIDIA architectures. However, there are a few critical and high-severity issues related to TMA configuration and error handling that need to be addressed.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@include/flashinfer/mamba/selective_state_update.cuh`:
- Around line 831-833: The z indexing in the computation of z_value incorrectly
assumes a contiguous stride and should use params.z_stride_batch; update the z
lookup in the expression that defines z_value (the line creating z_value
alongside x_value) to index z with batch * params.z_stride_batch + head * DIM +
d (preserving the z null check and toFloat conversion) so it matches other
kernels that use params.z_stride_batch.
- Around line 1083-1090: Remove the unconditional cudaDeviceSynchronize() call
after launching scan_func to preserve stream-asynchronous behavior, and add
error checking around the cudaFuncSetAttribute call by wrapping it with
FLASHINFER_CUDA_CHECK; specifically, replace the raw
cudaFuncSetAttribute(scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size) with a FLASHINFER_CUDA_CHECK-wrapped call and delete the
cudaDeviceSynchronize() line so the scan_func<<<grid, block, smem_size,
stream>>>(params, tensorState) launch remains asynchronous.
🧹 Nitpick comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)

693-699: Add rationale for the alternate column permutation.
The sizeof(state_t) != sizeof(input_t) branch uses a custom permutation ((group / 4) * 2) without explaining the bank-conflict or throughput rationale. A short comment on why this layout is chosen (and any alternatives tried) would help future performance tuning. As per coding guidelines, add rationale for special algorithmic choices in hot paths.

The cudaDeviceSynchronize call was forcing synchronous execution which
hurts performance. Also adds proper error checking for the
cudaFuncSetAttribute call.
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ishovkun can you explain a little bit about the optimization? What specific feature did you use for sm_100/110/120 acceleration compared to sm_90?

@ishovkun
Copy link
Copy Markdown
Contributor Author

Hi @ishovkun can you explain a little bit about the optimization? What specific feature did you use for sm_100/110/120 acceleration compared to sm_90?

There is no Blackwell-specific instructions here; there is just no point in compiling this specific kernel on Hopper as it is substantially slower than the vertical kernel on H200. Given the amount of template instantiations, this might help reduce the compile times.

Clarify why SM90 and SM100+ modules are restricted to specific
architectures
despite being technically compatible with other GPUs - the kernels
perform
significantly worse on non-target architectures.

Also fix z tensor stride bug using z_stride_batch instead of hardcoded
value.
@ishovkun
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

@ishovkun is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In @.zed/tasks.json:
- Around line 12-19: The args array currently passes the combined string "-v -s
-x" as one element which pytest will not parse; update the args in the
.zed/tasks.json task so each pytest flag is its own array element (replace the
single "-v -s -x" entry with three entries "-v", "-s", and "-x") so Zed forwards
them correctly to pytest when running the task.
- Around line 31-38: The JSON task entry with label "print beer" contains a
trailing comma after the "shell" object which makes the file invalid JSON;
remove the comma immediately following the closing brace of the "shell" object
in the task entry (the object that contains "program": "bash") so the task is
valid JSON and the "print beer" command entry remains unchanged.

In `@0`:
- Around line 1-67: This JSON snapshot contains sensitive environment fields
(e.g., "SSH_CONNECTION", "SSH_CLIENT", "HOST_OS", "P4PORT", "P4CLIENT",
"USER"/"LOGNAME", "CDS_LIC_FILE", "SNPSLMD_LICENSE_FILE", "DISPLAY", "HOME",
"PATH") and must be removed from the repo history and prevented from being
re-committed: delete the committed file from the repo, replace it with a
sanitized template (e.g., ENV_TEMPLATE.json with placeholder values for the keys
above) or nothing, add the filename to .gitignore (or add a rule to ignore
similar snapshots), and if the file has already been pushed rotate any exposed
credentials/contacts as required; ensure the commit that removes the file does
not reintroduce sensitive values in other files.

In `@flashinfer/jit/mamba/selective_state_update.py`:
- Around line 33-39: Fix the typos and improve clarity in the module comment in
selective_state_update.py: change "then an alternatice" to "than an
alternative", correct "alternatice" -> "alternative", ensure "This supports SM90
(Hopper) only." reads clearly (e.g., "This supports SM90 (Hopper) only.") and
reword the final sentence to something like "Therefore, this is excluded to
reduce compilation overhead." Keep the mention of "TMA device functions
(vertical producer-consumer kernel)" intact for context.
🧹 Nitpick comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)

688-724: Hoist dA computation out of the inner loop.

A_value and dt_value are constant per thread; computing __expf once should reduce redundant math. Please verify the impact with profiling on SM100.

♻️ Proposed refactor
 __device__ __forceinline__ void consumer_func_horizontal(
     int d, int member, float A_value, float dt_value, float x_value,
     SharedStorageHorizontal<input_t, weight_t, matrixA_t, state_t, DIM, DSTATE, colsPerStage,
                             numStages>& sram,
     float& out_value) {
   namespace cde = cuda::device::experimental;
   constexpr auto lanesPerRow = (consumerWarps * warpSize) / DIM;
   constexpr auto itemsPerThread = colsPerStage / lanesPerRow;
   auto const group = d % (warpSize / lanesPerRow);
+  auto const dA = __expf(A_value * dt_value);

   // `#pragma` unroll 1
   for (int iBegin = 0, stage = 0; iBegin < DSTATE;
        iBegin += colsPerStage, stage = (stage + 1) % numStages) {
...
-          auto const dA = __expf(A_value * dt_value);
           auto const dB = B_value * dt_value;
           auto const new_state = state_value * dA + dB * x_value;
...
-          auto const dA = __expf(A_value * dt_value);
           auto const dB = B_value * dt_value;
           auto const new_state = state_value * dA + dB * x_value;

@ishovkun
Copy link
Copy Markdown
Contributor Author

Hi @ishovkun can you explain a little bit about the optimization? What specific feature did you use for sm_100/110/120 acceleration compared to sm_90?

Done, please see the PR description at the very top.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 21, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !254 has been created, and the CI pipeline #42206712 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, the performance of horizontal kernel looks impressive!

@yzh119 yzh119 merged commit 18804cd into flashinfer-ai:main Jan 22, 2026
21 checks passed
yzh119 added a commit that referenced this pull request Jan 22, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

`uv.lock` should not be part of the package (it was introduced by
accident in #2387, this PR removes `uv.lock` and add it to .gitignore to
prevent accidental commits.

Lock files are not recommended for libraries as they can conflict with
downstream users' environments.

## 🔍 Related Issues

#2387 

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

This PR itself should not involve any changes related to kernels and we
can bypass kernel CI checks.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants