[Mamba] Flashinfer selective_state_update#36162
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Flashinfer selective_state_update kernel as an alternative to the existing Triton implementation. A dispatcher has been added to select between the two backends at runtime, controlled by a new MambaConfig configuration object. The changes are well-structured, with the new backend logic cleanly integrated into the existing architecture. The pull request also includes comprehensive tests for the new dispatcher and backend, including checks for unsupported features in the Flashinfer kernel, ensuring robustness. The argument parsing and engine configuration have been updated appropriately to expose the new backend option. Overall, this is a high-quality contribution that significantly improves performance for Mamba models.
c8b1288 to
c2e89c3
Compare
10c42a1 to
2761ad6
Compare
hmellor
left a comment
There was a problem hiding this comment.
Could we not introduce an enum to vllm.config, we we have been reserving this namespace for config classes only
| @@ -16,6 +16,7 @@ | |||
| from vllm.config.kv_transfer import KVTransferConfig | |||
| from vllm.config.load import LoadConfig | |||
| from vllm.config.lora import LoRAConfig | |||
| from vllm.config.mamba import MambaBackendEnum, MambaConfig | |||
There was a problem hiding this comment.
| from vllm.config.mamba import MambaBackendEnum, MambaConfig | |
| from vllm.config.mamba import MambaConfig |
| @@ -82,6 +83,9 @@ | |||
| "LoadConfig", | |||
| # From vllm.config.lora | |||
| "LoRAConfig", | |||
| # From vllm.config.mamba | |||
| "MambaBackendEnum", | |||
There was a problem hiding this comment.
| "MambaBackendEnum", |
|
This pull request has merge conflicts that must be resolved before it can be |
2761ad6 to
5bdac84
Compare
tomeras91
left a comment
There was a problem hiding this comment.
Overall looks good!
I do have some comments though:
- I agree with @amirkl94 we should have the triton backend as default so we don't need to deal with
None - I'm not sure the current dispatch mechanism to choose between triton/FlashInfer is the most fitting one. I understand the pros and cons, but worth thinking if we should use a pattern that is already used elsewhere for backend selection instead of introducing something new. Maybe the recently added "Custom Ops IR" is more fitting?
- [most important] You mention in the description that the current FlashInfer impl lacks support for some features (specifically SpecDec). Yet, these are not reflected in the code. So if the user selects the flashinfer backend and runs with SpecDec, they will get ab opaque error from flashinfer (or worse - silent failures resulting in bad outputs). Maybe we should add some defensive runtime checks? Something like this:
if dst_state_batch_indices is not None and \
dst_state_batch_indices is not state_batch_indices:
raise NotImplementedError(
"FlashInfer SSU backend does not yet support separate "
"dst_state_batch_indices. Use --mamba-backend triton.")
if num_accepted_tokens is not None:
raise NotImplementedError(
"FlashInfer SSU backend does not yet support "
"num_accepted_tokens (cache rewind). Use --mamba-backend triton.")
|
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
539b43e to
a6ced90
Compare
Signed-off-by: Roi Koren <roik@nvidia.com>
tdoublep
left a comment
There was a problem hiding this comment.
LGTM
We might want to revisit whether the dispatch could be better handled by vLLM IR at a later stage.
That was @tomeras91 's suggestion as well, but from what I saw there are a couple of issues there:
|
Signed-off-by: Roi Koren <roik@nvidia.com> Signed-off-by: zengxian <xiangdong.zeng@intel.com>
I think in #39262 IR is used also as a priority mechanism. Not exactly what we want here (control from cli arg), but also not only for different arg support for different implementations. |
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com> Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Purpose
Add wrapper for FI's selective_state_update kernel, with a runtime dispatcher, connected to a config field, to select between it and the existing triton implementation.
As suggested by @tdoublep in #35753, I've introduced
MambaConfigin this PR, and a followup (or this, if you'd prefer) could move config fields relevant to Mamba to it.Test Plan
New test file for the dispatcher's functionality.
Add tests e2e for Nemotron 3 Nano that use FI's kernel.
Benchmark the two options.
Test Result
Tests pass, e2e showing the same GSM8K score for both backends for
nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.Running
nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16with the following command (replacing the--mamba-backendargument for theflashinfermeasurement):And benchmarking with:
Got the following results:
Triton:
Flashinfer:
Which comes up as ~4-5% speedup in most all metrics:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.