diff --git a/tests/evals/gsm8k/configs/Nemotron-3-Super-120B-A12B-BF16.yaml b/tests/evals/gsm8k/configs/Nemotron-3-Super-120B-A12B-BF16.yaml index d9110efaaad0..b0f886a86ad0 100644 --- a/tests/evals/gsm8k/configs/Nemotron-3-Super-120B-A12B-BF16.yaml +++ b/tests/evals/gsm8k/configs/Nemotron-3-Super-120B-A12B-BF16.yaml @@ -8,4 +8,5 @@ server_args: >- --max-model-len 4096 --tensor-parallel-size 8 --enable-expert-parallel + --mamba-backend flashinfer --speculative-config '{"method":"mtp","num_speculative_tokens":5}' diff --git a/tests/evals/gsm8k/configs/Nemotron-3-Super-120B-A12B-NVFP4.yaml b/tests/evals/gsm8k/configs/Nemotron-3-Super-120B-A12B-NVFP4.yaml index 50f097319462..71ba7d52f144 100644 --- a/tests/evals/gsm8k/configs/Nemotron-3-Super-120B-A12B-NVFP4.yaml +++ b/tests/evals/gsm8k/configs/Nemotron-3-Super-120B-A12B-NVFP4.yaml @@ -8,4 +8,5 @@ server_args: >- --max-model-len 4096 --tensor-parallel-size 2 --enable-expert-parallel + --mamba-backend flashinfer --speculative-config '{"method":"mtp","num_speculative_tokens":5}' diff --git a/tests/kernels/mamba/test_ssu_dispatch.py b/tests/kernels/mamba/test_ssu_dispatch.py new file mode 100644 index 000000000000..96b04f44d220 --- /dev/null +++ b/tests/kernels/mamba/test_ssu_dispatch.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.config.mamba import MambaBackendEnum, MambaConfig +from vllm.model_executor.layers.mamba.ops.ssu_dispatch import ( + FlashInferSSUBackend, + TritonSSUBackend, + get_mamba_ssu_backend, + initialize_mamba_ssu_backend, + selective_state_update, +) +from vllm.utils.torch_utils import set_random_seed + +try: + import flashinfer.mamba # noqa: F401 + + HAS_FLASHINFER = True +except ImportError: + HAS_FLASHINFER = False + + +def test_default_backend_is_triton(): + initialize_mamba_ssu_backend(MambaConfig()) + backend = get_mamba_ssu_backend() + assert isinstance(backend, TritonSSUBackend) + assert backend.name == "triton" + + +def test_explicit_triton_backend(): + initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.TRITON)) + backend = get_mamba_ssu_backend() + assert isinstance(backend, TritonSSUBackend) + + +@pytest.mark.skipif(not HAS_FLASHINFER, reason="flashinfer not installed") +def test_flashinfer_backend_init(): + initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.FLASHINFER)) + backend = get_mamba_ssu_backend() + assert isinstance(backend, FlashInferSSUBackend) + assert backend.name == "flashinfer" + + +def test_uninitialized_backend_raises(): + import vllm.model_executor.layers.mamba.ops.ssu_dispatch as mod + + old = mod._mamba_ssu_backend + mod._mamba_ssu_backend = None + with pytest.raises(RuntimeError, match="not been initialized"): + get_mamba_ssu_backend() + mod._mamba_ssu_backend = old + + +@pytest.mark.skipif(HAS_FLASHINFER, reason="flashinfer is installed") +def test_flashinfer_import_error(): + with pytest.raises(ImportError, match="FlashInfer is required"): + FlashInferSSUBackend(MambaConfig()) + + +def test_triton_basic_call(): + set_random_seed(0) + initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.TRITON)) + device = "cuda" + batch_size = 2 + dim = 64 + dstate = 16 + + state = torch.randn(batch_size, dim, dstate, device=device) + x = torch.randn(batch_size, dim, device=device) + out = torch.empty_like(x) + dt = torch.randn(batch_size, dim, device=device) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) + B = torch.randn(batch_size, dstate, device=device) + C = torch.randn(batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + dt_bias=dt_bias, + dt_softplus=True, + out=out, + ) + assert not torch.isnan(out).any() diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index d5a3e9bfd960..758605d25c60 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -16,6 +16,7 @@ from vllm.config.kv_transfer import KVTransferConfig from vllm.config.load import LoadConfig from vllm.config.lora import LoRAConfig +from vllm.config.mamba import MambaConfig from vllm.config.model import ( ModelConfig, iter_architecture_defaults, @@ -83,6 +84,8 @@ "LoadConfig", # From vllm.config.lora "LoRAConfig", + # From vllm.config.mamba + "MambaConfig", # From vllm.config.model "ModelConfig", "iter_architecture_defaults", diff --git a/vllm/config/cache.py b/vllm/config/cache.py index cd1554590ea3..20721cc80923 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -123,14 +123,6 @@ class CacheConfig: - "align": only cache the mamba state of the last token of each scheduler step and when the token is at position i * block_size. """ - enable_mamba_cache_stochastic_rounding: bool = False - """Enable stochastic rounding when writing SSM state to fp16 cache. - Uses random bits to unbias the rounding error, which can improve - numerical stability for long sequences.""" - mamba_cache_philox_rounds: int = 0 - """Number of Philox PRNG rounds for stochastic rounding random number - generation. 0 uses the Triton default. Higher values improve randomness - quality at the cost of compute.""" # Will be set after profiling. num_gpu_blocks: int | None = field(default=None, init=False) @@ -258,29 +250,3 @@ def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: str(cache_dtype), ) return cache_dtype - - def __post_init__(self): - if self.enable_mamba_cache_stochastic_rounding: - from vllm.platforms import current_platform - - if not current_platform.is_cuda(): - raise ValueError( - "Stochastic rounding for Mamba cache is only supported " - "on NVIDIA CUDA platforms. Please do not specify " - "`--enable-mamba-cache-stochastic-rounding`." - ) - if not current_platform.is_device_capability_family(100): - raise ValueError( - "Stochastic rounding for Mamba cache requires compute " - "capability 10.0 (data center Blackwell). The `cvt.rs` PTX " - "instruction is not supported on your GPU. Please do not specify " - "`--enable-mamba-cache-stochastic-rounding`." - ) - if self.mamba_ssm_cache_dtype != "float16": - raise ValueError( - "Stochastic rounding for Mamba cache requires " - "the SSM cache to be float16. Please set it explicitly, " - "by specifying `--mamba-ssm-cache-dtype float16`, or disable " - "stochastic rounding by not specifying " - "`--enable-mamba-cache-stochastic-rounding`." - ) diff --git a/vllm/config/mamba.py b/vllm/config/mamba.py new file mode 100644 index 000000000000..996478c36760 --- /dev/null +++ b/vllm/config/mamba.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from enum import Enum, EnumMeta +from typing import Any + +from pydantic import field_validator + +from vllm.config.utils import config + + +class _MambaBackendEnumMeta(EnumMeta): + """Metaclass for MambaBackendEnum to provide better error messages.""" + + def __getitem__(cls, name: str): + try: + return super().__getitem__(name) + except KeyError: + valid = ", ".join(cls.__members__.keys()) + raise ValueError( + f"Unknown Mamba SSU backend: '{name}'. Valid options are: {valid}" + ) from None + + +class MambaBackendEnum(Enum, metaclass=_MambaBackendEnumMeta): + """Enumeration of supported Mamba SSU (selective state update) backends.""" + + TRITON = "triton" + FLASHINFER = "flashinfer" + + +@config +class MambaConfig: + """Configuration for Mamba SSM backends.""" + + backend: MambaBackendEnum = MambaBackendEnum.TRITON + """Mamba SSU backend to use.""" + + enable_stochastic_rounding: bool = False + """Enable stochastic rounding when writing SSM state to fp16 cache. + Uses random bits to unbias the rounding error, which can improve + numerical stability for long sequences.""" + stochastic_rounding_philox_rounds: int = 0 + """Number of Philox PRNG rounds for stochastic rounding random number + generation. 0 uses the Triton default. Higher values improve randomness + quality at the cost of compute.""" + + @field_validator("backend", mode="before") + @classmethod + def validate_backend_before(cls, value: Any) -> Any: + """Enable parsing of the `backend` enum type from string.""" + if isinstance(value, str): + return MambaBackendEnum[value.upper()] + return value + + def __post_init__(self): + if self.enable_stochastic_rounding: + from vllm.platforms import current_platform + + if not current_platform.is_cuda(): + raise ValueError( + "Stochastic rounding for Mamba cache is only supported " + "on NVIDIA CUDA platforms. Please do not specify " + "`--enable-mamba-cache-stochastic-rounding`." + ) + if ( + self.backend == MambaBackendEnum.TRITON + and not current_platform.is_device_capability_family(100) + ): + raise ValueError( + "Stochastic rounding for Mamba cache with triton backend requires " + "compute capability 10.0 (data center Blackwell). The `cvt.rs` " + "PTX instruction is not supported on your GPU. Please do not " + "specify `--enable-mamba-cache-stochastic-rounding`, " + "or set `--mamba-backend flashinfer`." + ) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 7bf75a67b2ce..84f63351d516 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -37,6 +37,7 @@ from .kv_transfer import KVTransferConfig from .load import LoadConfig from .lora import LoRAConfig +from .mamba import MambaConfig from .model import ModelConfig from .observability import ObservabilityConfig from .offload import OffloadConfig @@ -275,6 +276,8 @@ class VllmConfig: """Model weight offloading configuration.""" attention_config: AttentionConfig = Field(default_factory=AttentionConfig) """Attention configuration.""" + mamba_config: MambaConfig = Field(default_factory=MambaConfig) + """Mamba configuration.""" kernel_config: KernelConfig = Field(default_factory=KernelConfig) """Kernel configuration.""" lora_config: LoRAConfig | None = None @@ -707,6 +710,18 @@ def __post_init__(self): if self.lora_config is not None: self.lora_config.verify_with_model_config(self.model_config) + if ( + self.mamba_config.enable_stochastic_rounding + and self.cache_config.mamba_ssm_cache_dtype != "float16" + ): + raise ValueError( + "Stochastic rounding for Mamba cache requires " + "the SSM cache to be float16. Please set it explicitly, " + "by specifying `--mamba-ssm-cache-dtype float16`, or disable " + "stochastic rounding by not specifying " + "`--enable-mamba-cache-stochastic-rounding`." + ) + if self.quant_config is None and self.model_config is not None: self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c9b90848ff04..9d4a12c343e9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -45,6 +45,7 @@ KVTransferConfig, LoadConfig, LoRAConfig, + MambaConfig, ModelConfig, MultiModalConfig, ObservabilityConfig, @@ -72,6 +73,7 @@ from vllm.config.device import Device from vllm.config.kernel import IrOpPriorityConfig, MoEBackend from vllm.config.lora import MaxLoRARanks +from vllm.config.mamba import MambaBackendEnum from vllm.config.model import ( ConvertOption, HfOverrides, @@ -578,6 +580,7 @@ class EngineArgs: pooler_config: PoolerConfig | None = ModelConfig.pooler_config compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config") attention_config: AttentionConfig = get_field(VllmConfig, "attention_config") + mamba_config: MambaConfig = get_field(VllmConfig, "mamba_config") kernel_config: KernelConfig = get_field(VllmConfig, "kernel_config") enable_flashinfer_autotune: bool = get_field( KernelConfig, "enable_flashinfer_autotune" @@ -610,10 +613,12 @@ class EngineArgs: mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size") mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode + + mamba_backend: MambaBackendEnum = MambaBackendEnum.TRITON enable_mamba_cache_stochastic_rounding: bool = ( - CacheConfig.enable_mamba_cache_stochastic_rounding + MambaConfig.enable_stochastic_rounding ) - mamba_cache_philox_rounds: int = CacheConfig.mamba_cache_philox_rounds + mamba_cache_philox_rounds: int = MambaConfig.stochastic_rounding_philox_rounds additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") @@ -655,6 +660,8 @@ def __post_init__(self): self.compilation_config = CompilationConfig(**self.compilation_config) if isinstance(self.attention_config, dict): self.attention_config = AttentionConfig(**self.attention_config) + if isinstance(self.mamba_config, dict): + self.mamba_config = MambaConfig(**self.mamba_config) if isinstance(self.kernel_config, dict): self.kernel_config = KernelConfig(**self.kernel_config) if isinstance(self.eplb_config, dict): @@ -825,6 +832,22 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--attention-backend", **attention_kwargs["backend"] ) + # Mamba arguments + mamba_kwargs = get_kwargs(MambaConfig) + mamba_group = parser.add_argument_group( + title="MambaConfig", + description=MambaConfig.__doc__, + ) + mamba_group.add_argument("--mamba-backend", **mamba_kwargs["backend"]) + mamba_group.add_argument( + "--enable-mamba-cache-stochastic-rounding", + **mamba_kwargs["enable_stochastic_rounding"], + ) + mamba_group.add_argument( + "--mamba-cache-philox-rounds", + **mamba_kwargs["stochastic_rounding_philox_rounds"], + ) + # Structured outputs arguments structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig) structured_outputs_group = parser.add_argument_group( @@ -1050,13 +1073,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: cache_group.add_argument( "--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"] ) - cache_group.add_argument( - "--enable-mamba-cache-stochastic-rounding", - **cache_kwargs["enable_mamba_cache_stochastic_rounding"], - ) - cache_group.add_argument( - "--mamba-cache-philox-rounds", **cache_kwargs["mamba_cache_philox_rounds"] - ) cache_group.add_argument( "--kv-offloading-size", **cache_kwargs["kv_offloading_size"] ) @@ -1625,8 +1641,6 @@ def create_engine_config( mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, mamba_block_size=self.mamba_block_size, mamba_cache_mode=self.mamba_cache_mode, - enable_mamba_cache_stochastic_rounding=self.enable_mamba_cache_stochastic_rounding, - mamba_cache_philox_rounds=self.mamba_cache_philox_rounds, kv_offloading_size=self.kv_offloading_size, kv_offloading_backend=self.kv_offloading_backend, ) @@ -1931,6 +1945,22 @@ def create_engine_config( self.attention_backend ) + # Mamba config overrides + mamba_config = copy.deepcopy(self.mamba_config) + # Convert string to enum if needed (CLI parsing returns a string) + if isinstance(self.mamba_backend, str): + mamba_config.backend = MambaBackendEnum[self.mamba_backend.upper()] + else: + mamba_config.backend = self.mamba_backend + if self.enable_mamba_cache_stochastic_rounding: + mamba_config.enable_stochastic_rounding = ( + self.enable_mamba_cache_stochastic_rounding + ) + if self.mamba_cache_philox_rounds: + mamba_config.stochastic_rounding_philox_rounds = ( + self.mamba_cache_philox_rounds + ) + # Kernel config overrides kernel_config = copy.deepcopy(self.kernel_config) if self.enable_flashinfer_autotune is not None: @@ -2029,6 +2059,7 @@ def create_engine_config( load_config=load_config, offload_config=offload_config, attention_config=attention_config, + mamba_config=mamba_config, kernel_config=kernel_config, lora_config=lora_config, speculative_config=speculative_config, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 6f8b19c3880f..4509a0956280 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -30,10 +30,8 @@ causal_conv1d_fn, causal_conv1d_update, ) -from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, - selective_state_update, -) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn +from vllm.model_executor.layers.mamba.ops.ssu_dispatch import selective_state_update from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils.torch_utils import ( @@ -431,14 +429,12 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): B_d, C_d, self.D, - gate_d.transpose(0, 1), time_proj_bias, + z=gate_d.transpose(0, 1), dt_softplus=True, state_batch_indices=state_indices_tensor_d_input, dst_state_batch_indices=state_indices_tensor_d_output, out=scan_outputs_d, - enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding, - cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds, ) scan_outputs_d = scan_outputs_d.transpose(0, 1) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 17fcc5609c09..0518bde2f427 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -31,10 +31,10 @@ causal_conv1d_update, ) from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated -from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined_varlen, ) +from vllm.model_executor.layers.mamba.ops.ssu_dispatch import selective_state_update from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import ( LoaderFunction, @@ -890,8 +890,7 @@ def conv_ssm_forward( B_d, C_d, D_d, - z=None, - dt_bias=dt_bias, + dt_bias, dt_softplus=True, state_batch_indices=state_indices_tensor_d_input, dst_state_batch_indices=state_indices_tensor_d_output, @@ -899,8 +898,6 @@ def conv_ssm_forward( num_accepted_tokens=num_accepted_tokens, cu_seqlens=query_start_loc_d, is_blackwell=self.is_blackwell, - enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding, - cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds, ) def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index c4a0ef385d1c..e3c8ba8312f2 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -323,9 +323,9 @@ def selective_state_update( A, B, C, - D=None, + D, + dt_bias, z=None, - dt_bias=None, dt_softplus=False, state_batch_indices=None, dst_state_batch_indices=None, @@ -374,11 +374,11 @@ def selective_state_update( B = B.unsqueeze(1) if C.dim() == 2: C = C.unsqueeze(1) - if D is not None and D.dim() == 1: + if D.dim() == 1: D = D.unsqueeze(0) if z is not None and z.dim() == 2: z = z.unsqueeze(1) - if dt_bias is not None and dt_bias.dim() == 1: + if dt_bias.dim() == 1: dt_bias = dt_bias.unsqueeze(0) if out.dim() == 2: out = out.unsqueeze(1) @@ -410,12 +410,10 @@ def selective_state_update( assert nheads % ngroups == 0, "nheads must be divisible by ngroups" assert B.shape == (batch, ngroups, dstate) assert C.shape == B.shape - if D is not None: - assert D.shape == (nheads, dim) + assert D.shape == (nheads, dim) if z is not None: assert z.shape == x.shape - if dt_bias is not None: - assert dt_bias.shape == (nheads, dim) + assert dt_bias.shape == (nheads, dim) if state_batch_indices is not None: assert state_batch_indices.shape[0] >= N assert state_batch_indices.shape[1] >= max_seqlen @@ -506,7 +504,8 @@ def selective_state_update( dt.stride(0), dt.stride(1), dt.stride(2), - *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, + dt_bias.stride(0), + dt_bias.stride(1), A.stride(0), A.stride(1), A.stride(2), @@ -516,7 +515,8 @@ def selective_state_update( C.stride(0), C.stride(1), C.stride(2), - *(D.stride(0), D.stride(1)) if D is not None else 0, + D.stride(0), + D.stride(1), z_strides[0], z_strides[1], z_strides[2], diff --git a/vllm/model_executor/layers/mamba/ops/ssu_dispatch.py b/vllm/model_executor/layers/mamba/ops/ssu_dispatch.py new file mode 100644 index 000000000000..8a86b1a068bb --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssu_dispatch.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Dispatch module for Mamba selective state update (SSU) backends. + +Provides a unified `selective_state_update` function that dispatches to +either the Triton or FlashInfer backend based on the configured +`MambaBackendEnum`. Follows SGLang's dispatch pattern adapted for vLLM. +""" + +from abc import ABC, abstractmethod + +import torch + +from vllm.config.mamba import MambaBackendEnum, MambaConfig +from vllm.logger import init_logger +from vllm.v1.attention.backends.utils import NULL_BLOCK_ID + +logger = init_logger(__name__) + + +class MambaSSUBackend(ABC): + """Abstract base class for Mamba SSU backends.""" + + def __init__(self, mamba_config: MambaConfig): + self._mamba_config = mamba_config + + @property + @abstractmethod + def name(self) -> str: ... + + @abstractmethod + def __call__( + self, + state: torch.Tensor, + x: torch.Tensor, + dt: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D: torch.Tensor, + dt_bias: torch.Tensor, + z: torch.Tensor | None = None, + dt_softplus: bool = False, + state_batch_indices: torch.Tensor | None = None, + dst_state_batch_indices: torch.Tensor | None = None, + null_block_id: int = NULL_BLOCK_ID, + out: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + is_blackwell: bool = False, + ) -> None: ... + + +class TritonSSUBackend(MambaSSUBackend): + """Triton-based SSU backend (vLLM's default).""" + + def __init__(self, mamba_config: MambaConfig): + super().__init__(mamba_config) + from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_state_update as _triton_selective_state_update, + ) + + self._kernel = _triton_selective_state_update + + @property + def name(self) -> str: + return "triton" + + def __call__( + self, + state: torch.Tensor, + x: torch.Tensor, + dt: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D: torch.Tensor, + dt_bias: torch.Tensor, + z: torch.Tensor | None = None, + dt_softplus: bool = False, + state_batch_indices: torch.Tensor | None = None, + dst_state_batch_indices: torch.Tensor | None = None, + null_block_id: int = NULL_BLOCK_ID, + out: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + is_blackwell: bool = False, + ) -> None: + self._kernel( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + state_batch_indices=state_batch_indices, + dst_state_batch_indices=dst_state_batch_indices, + null_block_id=null_block_id, + out=out, + num_accepted_tokens=num_accepted_tokens, + cu_seqlens=cu_seqlens, + is_blackwell=is_blackwell, + enable_stochastic_rounding=self._mamba_config.enable_stochastic_rounding, + cache_philox_rounds=self._mamba_config.stochastic_rounding_philox_rounds, + ) + + +class FlashInferSSUBackend(MambaSSUBackend): + """FlashInfer-based SSU backend.""" + + def __init__(self, mamba_config: MambaConfig): + super().__init__(mamba_config) + try: + from flashinfer.mamba import selective_state_update as _fi_ssu + except ImportError as e: + raise ImportError( + "FlashInfer is required for the flashinfer Mamba SSU backend. " + "Please install flashinfer (>= 0.6.4): " + "pip install flashinfer-python" + ) from e + self._kernel = _fi_ssu + + @property + def name(self) -> str: + return "flashinfer" + + def __call__( + self, + state: torch.Tensor, + x: torch.Tensor, + dt: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D: torch.Tensor, + dt_bias: torch.Tensor, + z: torch.Tensor | None = None, + dt_softplus: bool = False, + state_batch_indices: torch.Tensor | None = None, + dst_state_batch_indices: torch.Tensor | None = None, + null_block_id: int = NULL_BLOCK_ID, + out: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + is_blackwell: bool = False, + ) -> None: + rand_seed = ( + torch.randint(0, 2**32, (1,), device=state.device) + if self._mamba_config.enable_stochastic_rounding + else None + ) + + self._kernel( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + state_batch_indices=state_batch_indices, + dst_state_batch_indices=dst_state_batch_indices, + cu_seqlens=cu_seqlens, + num_accepted_tokens=num_accepted_tokens, + cache_steps=state_batch_indices.size(-1) + if cu_seqlens is not None and state_batch_indices is not None + else 0, + pad_slot_id=null_block_id, + out=out, + rand_seed=rand_seed, + philox_rounds=self._mamba_config.stochastic_rounding_philox_rounds or 10, + ) + + +_BACKEND_REGISTRY: dict[MambaBackendEnum, type[MambaSSUBackend]] = { + MambaBackendEnum.TRITON: TritonSSUBackend, + MambaBackendEnum.FLASHINFER: FlashInferSSUBackend, +} + +_mamba_ssu_backend: MambaSSUBackend | None = None + + +def initialize_mamba_ssu_backend(mamba_config: MambaConfig) -> None: + """Initialize the global Mamba SSU backend. + + Args: + mamba_config: Mamba configuration. + """ + global _mamba_ssu_backend + + backend = mamba_config.backend + if backend not in _BACKEND_REGISTRY: + raise ValueError( + f"Unknown Mamba SSU backend: {backend}. " + f"Valid options: {list(_BACKEND_REGISTRY.keys())}" + ) + + _mamba_ssu_backend = _BACKEND_REGISTRY[backend](mamba_config) + logger.info("Using %s Mamba SSU backend.", _mamba_ssu_backend.name) + + +def get_mamba_ssu_backend() -> MambaSSUBackend: + """Get the current Mamba SSU backend. Raises if not initialized.""" + if _mamba_ssu_backend is None: + raise RuntimeError( + "Mamba SSU backend has not been initialized. " + "Call initialize_mamba_ssu_backend() first." + ) + return _mamba_ssu_backend + + +def selective_state_update( + state: torch.Tensor, + x: torch.Tensor, + dt: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D: torch.Tensor, + dt_bias: torch.Tensor, + z: torch.Tensor | None = None, + dt_softplus: bool = False, + state_batch_indices: torch.Tensor | None = None, + dst_state_batch_indices: torch.Tensor | None = None, + null_block_id: int = NULL_BLOCK_ID, + out: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + is_blackwell: bool = False, +) -> None: + """Unified dispatch for Mamba selective state update. + + Delegates to the initialized backend (Triton or FlashInfer). + """ + get_mamba_ssu_backend()( + state, + x, + dt, + A, + B, + C, + D, + dt_bias, + z=z, + dt_softplus=dt_softplus, + state_batch_indices=state_batch_indices, + dst_state_batch_indices=dst_state_batch_indices, + null_block_id=null_block_id, + out=out, + num_accepted_tokens=num_accepted_tokens, + cu_seqlens=cu_seqlens, + is_blackwell=is_blackwell, + ) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index ce7acc1cb19f..e81541b29aec 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -38,10 +38,10 @@ causal_conv1d_fn, causal_conv1d_update, ) -from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined_varlen, ) +from vllm.model_executor.layers.mamba.ops.ssu_dispatch import selective_state_update from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -447,13 +447,11 @@ def forward_impl( B, C, D, + dt_bias, z=gate_d.reshape(num_decodes, -1, self.head_dim), - dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices_tensor_d, out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), - enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding, - cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds, ) # 4. Final linear projection diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index ca52b8b66d0f..aea08115e40e 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -36,6 +36,9 @@ ) from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger +from vllm.model_executor.layers.mamba.ops.ssu_dispatch import ( + initialize_mamba_ssu_backend, +) from vllm.model_executor.model_loader import get_model_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors @@ -360,6 +363,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.attn_backends, self.attn_groups, attn_cg_support = init_attn_backend( self.kv_cache_config, self.vllm_config, self.device ) + initialize_mamba_ssu_backend(self.vllm_config.mamba_config) cudagraph_mode = self.compilation_config.resolve_cudagraph_mode_and_sizes( attn_cg_support.min_cg_support, attn_cg_support.min_cg_attn_backend, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4f3e192772ff..e8df720e78c0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -56,6 +56,9 @@ from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( RoutedExpertsCapturer, ) +from vllm.model_executor.layers.mamba.ops.ssu_dispatch import ( + initialize_mamba_ssu_backend, +) from vllm.model_executor.layers.rotary_embedding import ( MRotaryEmbedding, XDRotaryEmbedding, @@ -6750,6 +6753,7 @@ def initialize_kv_cache( self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config, is_profiling=is_profiling) + initialize_mamba_ssu_backend(self.vllm_config.mamba_config) # The kernel block size for all KV cache groups. For example, if # kv_cache_manager uses block_size 256 for a given group, but the attention # backends for that group only supports block_size 64, we will return