Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ server_args: >-
--max-model-len 4096
--tensor-parallel-size 8
--enable-expert-parallel
--mamba-backend flashinfer
--speculative-config '{"method":"mtp","num_speculative_tokens":5}'
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ server_args: >-
--max-model-len 4096
--tensor-parallel-size 2
--enable-expert-parallel
--mamba-backend flashinfer
--speculative-config '{"method":"mtp","num_speculative_tokens":5}'
92 changes: 92 additions & 0 deletions tests/kernels/mamba/test_ssu_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from vllm.config.mamba import MambaBackendEnum, MambaConfig
from vllm.model_executor.layers.mamba.ops.ssu_dispatch import (
FlashInferSSUBackend,
TritonSSUBackend,
get_mamba_ssu_backend,
initialize_mamba_ssu_backend,
selective_state_update,
)
from vllm.utils.torch_utils import set_random_seed

try:
import flashinfer.mamba # noqa: F401

HAS_FLASHINFER = True
except ImportError:
HAS_FLASHINFER = False


def test_default_backend_is_triton():
initialize_mamba_ssu_backend(MambaConfig())
backend = get_mamba_ssu_backend()
assert isinstance(backend, TritonSSUBackend)
assert backend.name == "triton"


def test_explicit_triton_backend():
initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.TRITON))
backend = get_mamba_ssu_backend()
assert isinstance(backend, TritonSSUBackend)


@pytest.mark.skipif(not HAS_FLASHINFER, reason="flashinfer not installed")
def test_flashinfer_backend_init():
initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.FLASHINFER))
backend = get_mamba_ssu_backend()
assert isinstance(backend, FlashInferSSUBackend)
assert backend.name == "flashinfer"


def test_uninitialized_backend_raises():
import vllm.model_executor.layers.mamba.ops.ssu_dispatch as mod

old = mod._mamba_ssu_backend
mod._mamba_ssu_backend = None
with pytest.raises(RuntimeError, match="not been initialized"):
get_mamba_ssu_backend()
mod._mamba_ssu_backend = old


@pytest.mark.skipif(HAS_FLASHINFER, reason="flashinfer is installed")
def test_flashinfer_import_error():
with pytest.raises(ImportError, match="FlashInfer is required"):
FlashInferSSUBackend(MambaConfig())


def test_triton_basic_call():
Comment thread
roikoren755 marked this conversation as resolved.
set_random_seed(0)
initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.TRITON))
device = "cuda"
batch_size = 2
dim = 64
dstate = 16

state = torch.randn(batch_size, dim, dstate, device=device)
Comment thread
roikoren755 marked this conversation as resolved.
x = torch.randn(batch_size, dim, device=device)
out = torch.empty_like(x)
dt = torch.randn(batch_size, dim, device=device)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device)
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
D = torch.randn(dim, device=device)

selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
dt_bias=dt_bias,
dt_softplus=True,
out=out,
)
assert not torch.isnan(out).any()
3 changes: 3 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm.config.kv_transfer import KVTransferConfig
from vllm.config.load import LoadConfig
from vllm.config.lora import LoRAConfig
from vllm.config.mamba import MambaConfig
from vllm.config.model import (
ModelConfig,
iter_architecture_defaults,
Expand Down Expand Up @@ -83,6 +84,8 @@
"LoadConfig",
# From vllm.config.lora
"LoRAConfig",
# From vllm.config.mamba
"MambaConfig",
# From vllm.config.model
"ModelConfig",
"iter_architecture_defaults",
Expand Down
34 changes: 0 additions & 34 deletions vllm/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,6 @@ class CacheConfig:
- "align": only cache the mamba state of the last token of each scheduler step and
when the token is at position i * block_size.
"""
enable_mamba_cache_stochastic_rounding: bool = False
"""Enable stochastic rounding when writing SSM state to fp16 cache.
Uses random bits to unbias the rounding error, which can improve
numerical stability for long sequences."""
mamba_cache_philox_rounds: int = 0
"""Number of Philox PRNG rounds for stochastic rounding random number
generation. 0 uses the Triton default. Higher values improve randomness
quality at the cost of compute."""

# Will be set after profiling.
num_gpu_blocks: int | None = field(default=None, init=False)
Expand Down Expand Up @@ -258,29 +250,3 @@ def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
str(cache_dtype),
)
return cache_dtype

def __post_init__(self):
if self.enable_mamba_cache_stochastic_rounding:
from vllm.platforms import current_platform

if not current_platform.is_cuda():
raise ValueError(
"Stochastic rounding for Mamba cache is only supported "
"on NVIDIA CUDA platforms. Please do not specify "
"`--enable-mamba-cache-stochastic-rounding`."
)
if not current_platform.is_device_capability_family(100):
raise ValueError(
"Stochastic rounding for Mamba cache requires compute "
"capability 10.0 (data center Blackwell). The `cvt.rs` PTX "
"instruction is not supported on your GPU. Please do not specify "
"`--enable-mamba-cache-stochastic-rounding`."
)
if self.mamba_ssm_cache_dtype != "float16":
raise ValueError(
"Stochastic rounding for Mamba cache requires "
"the SSM cache to be float16. Please set it explicitly, "
"by specifying `--mamba-ssm-cache-dtype float16`, or disable "
"stochastic rounding by not specifying "
"`--enable-mamba-cache-stochastic-rounding`."
)
76 changes: 76 additions & 0 deletions vllm/config/mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from enum import Enum, EnumMeta
from typing import Any

from pydantic import field_validator

from vllm.config.utils import config


class _MambaBackendEnumMeta(EnumMeta):
"""Metaclass for MambaBackendEnum to provide better error messages."""

def __getitem__(cls, name: str):
try:
return super().__getitem__(name)
except KeyError:
valid = ", ".join(cls.__members__.keys())
raise ValueError(
f"Unknown Mamba SSU backend: '{name}'. Valid options are: {valid}"
) from None


class MambaBackendEnum(Enum, metaclass=_MambaBackendEnumMeta):
"""Enumeration of supported Mamba SSU (selective state update) backends."""

TRITON = "triton"
FLASHINFER = "flashinfer"


@config
class MambaConfig:
Comment thread
roikoren755 marked this conversation as resolved.
"""Configuration for Mamba SSM backends."""

backend: MambaBackendEnum = MambaBackendEnum.TRITON
"""Mamba SSU backend to use."""

enable_stochastic_rounding: bool = False
"""Enable stochastic rounding when writing SSM state to fp16 cache.
Uses random bits to unbias the rounding error, which can improve
numerical stability for long sequences."""
stochastic_rounding_philox_rounds: int = 0
"""Number of Philox PRNG rounds for stochastic rounding random number
generation. 0 uses the Triton default. Higher values improve randomness
quality at the cost of compute."""

@field_validator("backend", mode="before")
@classmethod
def validate_backend_before(cls, value: Any) -> Any:
"""Enable parsing of the `backend` enum type from string."""
if isinstance(value, str):
return MambaBackendEnum[value.upper()]
return value

def __post_init__(self):
if self.enable_stochastic_rounding:
from vllm.platforms import current_platform

if not current_platform.is_cuda():
raise ValueError(
"Stochastic rounding for Mamba cache is only supported "
"on NVIDIA CUDA platforms. Please do not specify "
"`--enable-mamba-cache-stochastic-rounding`."
)
if (
self.backend == MambaBackendEnum.TRITON
and not current_platform.is_device_capability_family(100)
):
raise ValueError(
"Stochastic rounding for Mamba cache with triton backend requires "
"compute capability 10.0 (data center Blackwell). The `cvt.rs` "
"PTX instruction is not supported on your GPU. Please do not "
"specify `--enable-mamba-cache-stochastic-rounding`, "
"or set `--mamba-backend flashinfer`."
)
15 changes: 15 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .kv_transfer import KVTransferConfig
from .load import LoadConfig
from .lora import LoRAConfig
from .mamba import MambaConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
from .offload import OffloadConfig
Expand Down Expand Up @@ -275,6 +276,8 @@ class VllmConfig:
"""Model weight offloading configuration."""
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
"""Attention configuration."""
mamba_config: MambaConfig = Field(default_factory=MambaConfig)
"""Mamba configuration."""
kernel_config: KernelConfig = Field(default_factory=KernelConfig)
"""Kernel configuration."""
lora_config: LoRAConfig | None = None
Expand Down Expand Up @@ -707,6 +710,18 @@ def __post_init__(self):
if self.lora_config is not None:
self.lora_config.verify_with_model_config(self.model_config)

if (
self.mamba_config.enable_stochastic_rounding
and self.cache_config.mamba_ssm_cache_dtype != "float16"
):
raise ValueError(
"Stochastic rounding for Mamba cache requires "
"the SSM cache to be float16. Please set it explicitly, "
"by specifying `--mamba-ssm-cache-dtype float16`, or disable "
"stochastic rounding by not specifying "
"`--enable-mamba-cache-stochastic-rounding`."
)

if self.quant_config is None and self.model_config is not None:
self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config
Expand Down
Loading
Loading