A Blackwell-optimized version of selective_state_update (mamba)#2387
A Blackwell-optimized version of selective_state_update (mamba)#2387yzh119 merged 11 commits intoflashinfer-ai:mainfrom
Conversation
Split the Hopper SM90 module from supporting all future architectures to only SM90. Add a new SM100+ module that will use horizontal producer-consumer kernel optimized for Blackwell and newer GPUs.
Summary of ChangesHello @ishovkun, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughThis PR introduces SM100 (Blackwell) and newer architecture support for the mamba selective state update operation. It adds SM100-specific horizontal producer-consumer kernels, JIT compilation specs, and runtime dispatch logic that routes based on compute capability. A CUDA error-checking utility macro is also added. Changes
Sequence DiagramsequenceDiagram
participant Host
participant AOT as AOT Compiler
participant JITCompiler as JIT Compiler
participant Runtime as Runtime Dispatch
participant Device as GPU Device
Host->>AOT: Trigger AOT compilation
AOT->>JITCompiler: Generate SM90 JIT spec (NVCC 9)
JITCompiler-->>AOT: SM90 JIT module
AOT->>JITCompiler: Generate SM100 JIT spec (NVCC 10-12)
JITCompiler-->>AOT: SM100 JIT module
AOT-->>Host: Compiled modules available
Host->>Runtime: Call selective_state_update()
Runtime->>Runtime: Query GPU compute capability
alt Compute Capability >= 10
Runtime->>Runtime: Load SM100 module
Runtime->>Device: Invoke horizontal producer-consumer kernel
else Compute Capability == 9
Runtime->>Runtime: Load SM90 module
Runtime->>Device: Invoke vertical kernel
else Older GPU
Runtime->>Runtime: Load base module
Runtime->>Device: Invoke base kernel
end
Device-->>Runtime: Kernel execution complete
Runtime-->>Host: Result ready
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a Blackwell-optimized version of the selective_state_update kernel, specifically targeting SM100+ architectures. The changes include adding new JIT compilation modules, updating the dispatch logic in Python to select the appropriate kernel based on compute capability, and implementing a new horizontal producer-consumer kernel in CUDA. The new kernel incorporates advanced features like TMA for state cache management and bank-conflict-free indexing for shared memory access. Overall, the changes are well-structured and aim to improve performance on newer NVIDIA architectures. However, there are a few critical and high-severity issues related to TMA configuration and error handling that need to be addressed.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@include/flashinfer/mamba/selective_state_update.cuh`:
- Around line 831-833: The z indexing in the computation of z_value incorrectly
assumes a contiguous stride and should use params.z_stride_batch; update the z
lookup in the expression that defines z_value (the line creating z_value
alongside x_value) to index z with batch * params.z_stride_batch + head * DIM +
d (preserving the z null check and toFloat conversion) so it matches other
kernels that use params.z_stride_batch.
- Around line 1083-1090: Remove the unconditional cudaDeviceSynchronize() call
after launching scan_func to preserve stream-asynchronous behavior, and add
error checking around the cudaFuncSetAttribute call by wrapping it with
FLASHINFER_CUDA_CHECK; specifically, replace the raw
cudaFuncSetAttribute(scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size) with a FLASHINFER_CUDA_CHECK-wrapped call and delete the
cudaDeviceSynchronize() line so the scan_func<<<grid, block, smem_size,
stream>>>(params, tensorState) launch remains asynchronous.
🧹 Nitpick comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)
693-699: Add rationale for the alternate column permutation.
Thesizeof(state_t) != sizeof(input_t)branch uses a custom permutation ((group / 4) * 2) without explaining the bank-conflict or throughput rationale. A short comment on why this layout is chosen (and any alternatives tried) would help future performance tuning. As per coding guidelines, add rationale for special algorithmic choices in hot paths.
The cudaDeviceSynchronize call was forcing synchronous execution which hurts performance. Also adds proper error checking for the cudaFuncSetAttribute call.
There is no Blackwell-specific instructions here; there is just no point in compiling this specific kernel on Hopper as it is substantially slower than the vertical kernel on H200. Given the amount of template instantiations, this might help reduce the compile times. |
Clarify why SM90 and SM100+ modules are restricted to specific architectures despite being technically compatible with other GPUs - the kernels perform significantly worse on non-target architectures. Also fix z tensor stride bug using z_stride_batch instead of hardcoded value.
|
/bot run |
|
@ishovkun is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In @.zed/tasks.json:
- Around line 12-19: The args array currently passes the combined string "-v -s
-x" as one element which pytest will not parse; update the args in the
.zed/tasks.json task so each pytest flag is its own array element (replace the
single "-v -s -x" entry with three entries "-v", "-s", and "-x") so Zed forwards
them correctly to pytest when running the task.
- Around line 31-38: The JSON task entry with label "print beer" contains a
trailing comma after the "shell" object which makes the file invalid JSON;
remove the comma immediately following the closing brace of the "shell" object
in the task entry (the object that contains "program": "bash") so the task is
valid JSON and the "print beer" command entry remains unchanged.
In `@0`:
- Around line 1-67: This JSON snapshot contains sensitive environment fields
(e.g., "SSH_CONNECTION", "SSH_CLIENT", "HOST_OS", "P4PORT", "P4CLIENT",
"USER"/"LOGNAME", "CDS_LIC_FILE", "SNPSLMD_LICENSE_FILE", "DISPLAY", "HOME",
"PATH") and must be removed from the repo history and prevented from being
re-committed: delete the committed file from the repo, replace it with a
sanitized template (e.g., ENV_TEMPLATE.json with placeholder values for the keys
above) or nothing, add the filename to .gitignore (or add a rule to ignore
similar snapshots), and if the file has already been pushed rotate any exposed
credentials/contacts as required; ensure the commit that removes the file does
not reintroduce sensitive values in other files.
In `@flashinfer/jit/mamba/selective_state_update.py`:
- Around line 33-39: Fix the typos and improve clarity in the module comment in
selective_state_update.py: change "then an alternatice" to "than an
alternative", correct "alternatice" -> "alternative", ensure "This supports SM90
(Hopper) only." reads clearly (e.g., "This supports SM90 (Hopper) only.") and
reword the final sentence to something like "Therefore, this is excluded to
reduce compilation overhead." Keep the mention of "TMA device functions
(vertical producer-consumer kernel)" intact for context.
🧹 Nitpick comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)
688-724: HoistdAcomputation out of the inner loop.
A_valueanddt_valueare constant per thread; computing__expfonce should reduce redundant math. Please verify the impact with profiling on SM100.♻️ Proposed refactor
__device__ __forceinline__ void consumer_func_horizontal( int d, int member, float A_value, float dt_value, float x_value, SharedStorageHorizontal<input_t, weight_t, matrixA_t, state_t, DIM, DSTATE, colsPerStage, numStages>& sram, float& out_value) { namespace cde = cuda::device::experimental; constexpr auto lanesPerRow = (consumerWarps * warpSize) / DIM; constexpr auto itemsPerThread = colsPerStage / lanesPerRow; auto const group = d % (warpSize / lanesPerRow); + auto const dA = __expf(A_value * dt_value); // `#pragma` unroll 1 for (int iBegin = 0, stage = 0; iBegin < DSTATE; iBegin += colsPerStage, stage = (stage + 1) % numStages) { ... - auto const dA = __expf(A_value * dt_value); auto const dB = B_value * dt_value; auto const new_state = state_value * dA + dB * x_value; ... - auto const dA = __expf(A_value * dt_value); auto const dB = B_value * dt_value; auto const new_state = state_value * dA + dB * x_value;
|
∏
Done, please see the PR description at the very top. |
|
/bot run |
yzh119
left a comment
There was a problem hiding this comment.
LGTM, the performance of horizontal kernel looks impressive!
<!-- .github/pull_request_template.md --> ## 📌 Description `uv.lock` should not be part of the package (it was introduced by accident in #2387, this PR removes `uv.lock` and add it to .gitignore to prevent accidental commits. Lock files are not recommended for libraries as they can conflict with downstream users' environments. ## 🔍 Related Issues #2387 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes This PR itself should not involve any changes related to kernels and we can bypass kernel CI checks.
📌 Description
This contributes a
selective_state_updatekernel that has been manually optimized for sm100.Technically, the code in the sm100 module can be compiled and executed on sm90, but the performance of this kernel on Hopper is not great (see the image below).
Algorithm
The algorithmic differences between the new
selective_state_update_kernel_producer_consumer_horizontaland the older Hopper-specificselective_state_update_kernel_producer_consumer_verticalare summarized below. Each CTA processes a single head. In theverticalcase, each warp processes a row, and then performs a reduction usingshfl_syncintrinsics. In case of thehorizontalkernel, each thread consecutively processes multiple row elements while keeping track of the sum. This new treatment allows to lift the bottleneck off of the warp-level reduction on Blackwell.🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Could you please check consumer_func_horizontal and let me know whether the comments about the bank-conflict-free indexing suffices.
Summary by CodeRabbit
New Features
Refactor
Bug Fixes
✏️ Tip: You can customize this high-level summary in your review settings.