-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
[Kernel] Mamba2 SSD add fused kernel for 1.5-2.5x SSD (prefill) speedup #27299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
I collected some supporting performance results on single H100.
microbenchmarkThe script is provided by @RishiAstra default (CHUNK_SIZE_FUSED=128)matching mamba config (CHUNK_SIZE_FUSED=256)Latency benchmarkAs the changes mainly involved mamba2 SSD used in prefill, I collected the latency measurements with fused ssd on and off.
Test command examples ibm-granite/granite-4.0-h-tiny
ibm-granite/granite-4.0-h-small
nvidia/NVIDIA-Nemotron-Nano-12B-v2
|
|
lm_eval Results
ibm-granite/granite-4.0-h-tinyfused off
fused on
ibm-granite/granite-4.0-h-smallfused off
fused on
nvidia/NVIDIA-Nemotron-Nano-12B-v2fused off
fused on
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
Signed-off-by: Rishi Astra <[email protected]>
Signed-off-by: Rishi Astra <[email protected]>
Signed-off-by: Rishi Astra <[email protected]>
Signed-off-by: Rishi Astra <[email protected]>
Signed-off-by: Rishi Astra <[email protected]>
Signed-off-by: Rishi Astra <[email protected]>
Signed-off-by: Rishi Astra <[email protected]>
02f5e5a to
9c0d274
Compare
|
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
Signed-off-by: Rishi Astra <[email protected]>
Signed-off-by: Rishi Astra <[email protected]>
Signed-off-by: Rishi Astra <[email protected]>
|
@codex review |
|
Codex Review: Didn't find any major issues. Hooray! ℹ️ About Codex in GitHubCodex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback". |
|
Amazing work!
I haven't encountered a feature flag being implemented using A more pressing concern though the maintenance cost of introducing so much new code, for relatively small latency improvements. Do we understand why the bigger gains seen in the microbenchmarks (up to 2x if I'm reading correctly?) don't translate into E2E speedups? |
This kernel speeds up the Mamba2 SSD layer, and the microbenchmarks show that the speedup is ~2-3x. However, most models that contain Mamba2 SSD layers also contain many other layers. For example, even Mamba2-2.7b contains linear projections, diluting the 2-3x speedup down to ~15%. Some other models like granite-4.0-h-tiny, granite-4.0-h-small, and NVIDIA-Nemotron-Nano-12B-v2 contain even more (or heavier) other layers. As a concrete example using Amdahl's law, imagine that the Mamba2 SSD layers take up 1/4th of the total runtime in Mamba2-2.7b, and we speed them up by 2x. The Amdahl predicted speedup is: Although the kernel is a lot of code, it's mostly contained in 1 file. There is a chance that future variants of Mamba would require modifying this kernel, adding maintenance cost, but there is also a chance that future models will use more or larger Mamba2 layers, causing larger speedups and more benefit. |
Purpose
This PR speeds up the Mamba2 SSD prefill by about 1.5-2.5x (depending on state dtype) by adding a fully fused Triton kernel. This fused kernel can be used instead of the original Chunk Cumsum, BMM, Chunk State, State Passing, and Chunk Scan kernels. This fusion reorders work and uses synchronization to eliminate some intermediate VRAM writes/reads and increase cache locality. For Mamba2-2.7B with 64k context (tested in state-spaces/mamba), this results in an end-to-end speedup of ~15-17% for A100 and H100 GPUs.
Test Plan
More tests by @cyang49 below.
Run all SSD tests from
tests/kernels/mamba/test_mamba_ssm_ssd.pyon the fused kernel.This effectively doubles the tests in that file, so it might reduce the CI speedup benefit from #26538
16 extra tests are added as
test_mamba_chunk_scan_cont_batch_z_dintests/kernels/mamba/test_mamba_ssm_ssd.pyto test that the kernels work with optional argszandd.I also tested an end-to-end model:
Test Result
See more in-depth benchmark and accuracy results from @cyang49 below.
Tests pass locally on a RTX4090.
For the
vllm servetest, I get:Questions
--hf-overrides '{"mamba2_fast_kernel": true}'the most appropriate way to give users the option to use the fused kernel?' If so, how can we document it?Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.