Skip to content

[Mamba] Flashinfer selective_state_update#36162

Merged
mgoin merged 13 commits into
vllm-project:mainfrom
roikoren755:feat/flashinfer-selective-state-update
Apr 14, 2026
Merged

[Mamba] Flashinfer selective_state_update#36162
mgoin merged 13 commits into
vllm-project:mainfrom
roikoren755:feat/flashinfer-selective-state-update

Conversation

@roikoren755
Copy link
Copy Markdown
Contributor

@roikoren755 roikoren755 commented Mar 5, 2026

Purpose

Add wrapper for FI's selective_state_update kernel, with a runtime dispatcher, connected to a config field, to select between it and the existing triton implementation.

As suggested by @tdoublep in #35753, I've introduced MambaConfig in this PR, and a followup (or this, if you'd prefer) could move config fields relevant to Mamba to it.

Test Plan

New test file for the dispatcher's functionality.
Add tests e2e for Nemotron 3 Nano that use FI's kernel.
Benchmark the two options.

Test Result

Tests pass, e2e showing the same GSM8K score for both backends for nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.
Running nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 with the following command (replacing the --mamba-backend argument for the flashinfer measurement):

vllm serve nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \
    --tensor-parallel-size 4 --max-model-len 8192 --mamba-backend triton \
    --trust-remote-code

And benchmarking with:

vllm bench serve \
    --model nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \
    --num-prompts 500 --request-rate inf \
    --input-len 256 --output-len 256

Got the following results:
Triton:

============ Serving Benchmark Result ============
Successful requests:                     500
Failed requests:                         0
Benchmark duration (s):                  5.88
Total input tokens:                      128002
Total generated tokens:                  128000
Request throughput (req/s):              85.02
Output token throughput (tok/s):         21765.03
Peak output token throughput (tok/s):    31129.00
Peak concurrent requests:                500.00
Total token throughput (tok/s):          43530.41
---------------Time to First Token----------------
Mean TTFT (ms):                          1256.51
Median TTFT (ms):                        1233.83
P99 TTFT (ms):                           1770.09
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          17.77
Median TPOT (ms):                        17.90
P99 TPOT (ms):                           19.41
---------------Inter-token Latency----------------
Mean ITL (ms):                           17.82
Median ITL (ms):                         15.93
P99 ITL (ms):                            69.00
==================================================

Flashinfer:

============ Serving Benchmark Result ============
Successful requests:                     500
Failed requests:                         0
Benchmark duration (s):                  5.64
Total input tokens:                      128002
Total generated tokens:                  128000
Request throughput (req/s):              88.60
Output token throughput (tok/s):         22682.25
Peak output token throughput (tok/s):    32938.00
Peak concurrent requests:                500.00
Total token throughput (tok/s):          45364.86
---------------Time to First Token----------------
Mean TTFT (ms):                          1183.95
Median TTFT (ms):                        1161.50
P99 TTFT (ms):                           1656.52
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          17.13
Median TPOT (ms):                        17.26
P99 TPOT (ms):                           18.60
---------------Inter-token Latency----------------
Mean ITL (ms):                           17.18
Median ITL (ms):                         15.53
P99 ITL (ms):                            63.94
==================================================

Which comes up as ~4-5% speedup in most all metrics:

  ┌────────────────────────┬──────────────┬──────────────┬───────┐
  │         Metric         │    Triton    │  FlashInfer  │ Delta │
  ├────────────────────────┼──────────────┼──────────────┼───────┤
  │ Duration               │ 5.88s        │ 5.64s        │ -4.1% │
  ├────────────────────────┼──────────────┼──────────────┼───────┤
  │ Request throughput     │ 85.02 req/s  │ 88.60 req/s  │ +4.2% │
  ├────────────────────────┼──────────────┼──────────────┼───────┤
  │ Output throughput      │ 21,765 tok/s │ 22,682 tok/s │ +4.2% │
  ├────────────────────────┼──────────────┼──────────────┼───────┤
  │ Peak output throughput │ 31,129 tok/s │ 32,938 tok/s │ +5.8% │
  ├────────────────────────┼──────────────┼──────────────┼───────┤
  │ TPOT median            │ 17.90 ms     │ 17.26 ms     │ -3.6% │
  ├────────────────────────┼──────────────┼──────────────┼───────┤
  │ TPOT P99               │ 19.41 ms     │ 18.60 ms     │ -4.2% │
  ├────────────────────────┼──────────────┼──────────────┼───────┤
  │ ITL P99                │ 69.00 ms     │ 63.94 ms     │ -7.3% │
  └────────────────────────┴──────────────┴──────────────┴───────┘

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Flashinfer selective_state_update kernel as an alternative to the existing Triton implementation. A dispatcher has been added to select between the two backends at runtime, controlled by a new MambaConfig configuration object. The changes are well-structured, with the new backend logic cleanly integrated into the existing architecture. The pull request also includes comprehensive tests for the new dispatcher and backend, including checks for unsupported features in the Flashinfer kernel, ensuring robustness. The argument parsing and engine configuration have been updated appropriately to expose the new backend option. Overall, this is a high-quality contribution that significantly improves performance for Mamba models.

@roikoren755 roikoren755 force-pushed the feat/flashinfer-selective-state-update branch from c8b1288 to c2e89c3 Compare March 5, 2026 17:49
@mergify mergify Bot added the ci/build label Mar 5, 2026
@roikoren755 roikoren755 force-pushed the feat/flashinfer-selective-state-update branch from 10c42a1 to 2761ad6 Compare March 15, 2026 13:53
Copy link
Copy Markdown
Member

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we not introduce an enum to vllm.config, we we have been reserving this namespace for config classes only

Comment thread vllm/config/__init__.py Outdated
@@ -16,6 +16,7 @@
from vllm.config.kv_transfer import KVTransferConfig
from vllm.config.load import LoadConfig
from vllm.config.lora import LoRAConfig
from vllm.config.mamba import MambaBackendEnum, MambaConfig
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from vllm.config.mamba import MambaBackendEnum, MambaConfig
from vllm.config.mamba import MambaConfig

Comment thread vllm/config/__init__.py Outdated
@@ -82,6 +83,9 @@
"LoadConfig",
# From vllm.config.lora
"LoRAConfig",
# From vllm.config.mamba
"MambaBackendEnum",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"MambaBackendEnum",

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 28, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 30, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @roikoren755.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 30, 2026
@roikoren755 roikoren755 force-pushed the feat/flashinfer-selective-state-update branch from 2761ad6 to 5bdac84 Compare April 5, 2026 18:45
@mergify mergify Bot removed the needs-rebase label Apr 5, 2026
@roikoren755 roikoren755 requested a review from vadiklyutiy as a code owner April 5, 2026 19:02
Comment thread tests/kernels/mamba/test_ssu_dispatch.py
Comment thread tests/kernels/mamba/test_ssu_dispatch.py
Comment thread vllm/config/mamba.py
Comment thread vllm/config/mamba.py
Comment thread vllm/model_executor/layers/mamba/ops/ssu_dispatch.py Outdated
Copy link
Copy Markdown
Member

@tomeras91 tomeras91 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good!

I do have some comments though:

  1. I agree with @amirkl94 we should have the triton backend as default so we don't need to deal with None
  2. I'm not sure the current dispatch mechanism to choose between triton/FlashInfer is the most fitting one. I understand the pros and cons, but worth thinking if we should use a pattern that is already used elsewhere for backend selection instead of introducing something new. Maybe the recently added "Custom Ops IR" is more fitting?
  3. [most important] You mention in the description that the current FlashInfer impl lacks support for some features (specifically SpecDec). Yet, these are not reflected in the code. So if the user selects the flashinfer backend and runs with SpecDec, they will get ab opaque error from flashinfer (or worse - silent failures resulting in bad outputs). Maybe we should add some defensive runtime checks? Something like this:
  if dst_state_batch_indices is not None and \
     dst_state_batch_indices is not state_batch_indices:
      raise NotImplementedError(
          "FlashInfer SSU backend does not yet support separate "
          "dst_state_batch_indices. Use --mamba-backend triton.")

  if num_accepted_tokens is not None:
      raise NotImplementedError(
          "FlashInfer SSU backend does not yet support "
          "num_accepted_tokens (cache rewind). Use --mamba-backend triton.")

@roikoren755
Copy link
Copy Markdown
Contributor Author

Overall looks good!

I do have some comments though:

  1. I agree with @amirkl94 we should have the triton backend as default so we don't need to deal with None
  2. I'm not sure the current dispatch mechanism to choose between triton/FlashInfer is the most fitting one. I understand the pros and cons, but worth thinking if we should use a pattern that is already used elsewhere for backend selection instead of introducing something new. Maybe the recently added "Custom Ops IR" is more fitting?
  3. [most important] You mention in the description that the current FlashInfer impl lacks support for some features (specifically SpecDec). Yet, these are not reflected in the code. So if the user selects the flashinfer backend and runs with SpecDec, they will get ab opaque error from flashinfer (or worse - silent failures resulting in bad outputs). Maybe we should add some defensive runtime checks? Something like this:
  if dst_state_batch_indices is not None and \
     dst_state_batch_indices is not state_batch_indices:
      raise NotImplementedError(
          "FlashInfer SSU backend does not yet support separate "
          "dst_state_batch_indices. Use --mamba-backend triton.")

  if num_accepted_tokens is not None:
      raise NotImplementedError(
          "FlashInfer SSU backend does not yet support "
          "num_accepted_tokens (cache rewind). Use --mamba-backend triton.")
  1. Dealing with None was actually simpler, but I'll fix that now.
  2. I may be mistaken, but it looks like this doesn't quite fit, as there isn't (currently) a PyTorch native implementation, and the triton and flashinfer kernels support the same features, so it boils down to an implementation priority list, which I don't think is what we want to aim for...
  3. That's a mistake, and I've deleted that part of the description. SpecDec (and all other features) is fully supported, and is covered in tests after this PR, so there's no need for the NotImplementedErrors.

Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
@roikoren755 roikoren755 force-pushed the feat/flashinfer-selective-state-update branch from 539b43e to a6ced90 Compare April 13, 2026 11:02
Signed-off-by: Roi Koren <roik@nvidia.com>
Copy link
Copy Markdown
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

We might want to revisit whether the dispatch could be better handled by vLLM IR at a later stage.

@mgoin mgoin merged commit ecf5ff7 into vllm-project:main Apr 14, 2026
80 checks passed
@roikoren755 roikoren755 deleted the feat/flashinfer-selective-state-update branch April 15, 2026 08:21
@roikoren755
Copy link
Copy Markdown
Contributor Author

LGTM

We might want to revisit whether the dispatch could be better handled by vLLM IR at a later stage.

That was @tomeras91 's suggestion as well, but from what I saw there are a couple of issues there:

  • We need a native PyTorch base implementation ([CPU] Enable Granite 4 / Mamba models on CPU backend  #39157 will add it)
  • The main differentiator in IR seems to be which args the different implementations support. Since the flashinfer and triton kernels have the same support, I'm not sure this is quite the right fit...

zxd1997066 pushed a commit to zxd1997066/vllm that referenced this pull request Apr 15, 2026
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: zengxian <xiangdong.zeng@intel.com>
@tomeras91
Copy link
Copy Markdown
Member

  • The main differentiator in IR seems to be which args the different implementations support. Since the flashinfer and triton kernels have the same support, I'm not sure this is quite the right fit...

I think in #39262 IR is used also as a priority mechanism. Not exactly what we want here (control from cli arg), but also not only for different arg support for different implementations.

whk-lab pushed a commit to whk-lab/vllm that referenced this pull request Apr 23, 2026
Signed-off-by: Roi Koren <roik@nvidia.com>
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
Signed-off-by: Roi Koren <roik@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants