Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b92d438
[Mamba] selective_state_update auto-tuning framework with NVIDIA GB10…
bananighosh Apr 30, 2026
f7efdb8
Apply code review changes
danisereb May 19, 2026
a5e44cb
Use cache (same as lru_cache with maxsize None) for try_get_optimal_s…
danisereb May 19, 2026
e1db1d7
Add tuned JSON files for H100
danisereb May 19, 2026
56a1928
Add CUDA graphs to script
danisereb May 20, 2026
1dd1de1
Add tuned JSONs for H100
danisereb May 20, 2026
aa4ad6d
Remove ref to fused_moe, use cache instead of lru_cache
danisereb May 20, 2026
35103b8
Remove duplicate functions from the tuning script
danisereb May 20, 2026
e77985f
Fix --validate failure
danisereb May 20, 2026
af47691
Reuse tuned measurements for comparisons
danisereb May 20, 2026
702faaf
Add support for mamba cache bf16 (match to fp16 config)
danisereb May 20, 2026
c856e72
Add flags --batch-sizes and --nheads
danisereb May 20, 2026
621137a
Move location of config files
danisereb May 20, 2026
878de24
Use contextmanager for override_ssm_config
danisereb May 20, 2026
24a2f61
Cleanup comments
danisereb May 20, 2026
e4c3181
Add upper limit to eff batch (to avoid cuda error)
danisereb May 21, 2026
7337562
Update JSON for B200
danisereb May 21, 2026
4953de9
Add tuned JSON files for GB200
danisereb May 21, 2026
3c3f3aa
Update JSON for H100
danisereb May 21, 2026
4bba278
Fix stale comment in get_ssm_config_file_name
danisereb May 24, 2026
e0152b2
Move selective_state_update_ref to new utils file
danisereb May 24, 2026
7757ae5
Fix hard coded block_m/warps in test_mamba_ssm_configs
danisereb May 24, 2026
f01f9f1
Add more coverage in test_mamba_ssm_configs
danisereb May 24, 2026
f1b0f51
Update comment for VLLM_TUNED_CONFIG_FOLDER
danisereb May 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
774 changes: 774 additions & 0 deletions benchmarks/kernels/benchmark_selective_state_update.py

Large diffs are not rendered by default.

Empty file.
73 changes: 1 addition & 72 deletions tests/kernels/mamba/test_mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional as F
from einops import rearrange, repeat

from tests.kernels.mamba.utils import selective_state_update_ref
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
Expand All @@ -17,78 +18,6 @@
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID


def selective_state_update_ref(
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
x: (batch, dim) or (batch, nheads, dim)
dt: (batch, dim) or (batch, nheads, dim)
A: (dim, dstate) or (nheads, dim, dstate)
B: (batch, dstate) or (batch, ngroups, dstate)
C: (batch, dstate) or (batch, ngroups, dstate)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
Return:
out: (batch, dim) or (batch, nheads, dim)
"""
has_heads = state.dim() > 3
if state.dim() == 3:
state = state.unsqueeze(1)
if x.dim() == 2:
x = x.unsqueeze(1)
if dt.dim() == 2:
dt = dt.unsqueeze(1)
if A.dim() == 2:
A = A.unsqueeze(0)
if B.dim() == 2:
B = B.unsqueeze(1)
if C.dim() == 2:
C = C.unsqueeze(1)
if D is not None and D.dim() == 1:
D = D.unsqueeze(0)
if z is not None and z.dim() == 2:
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
batch, nheads, dim, dstate = state.shape
assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
assert A.shape == (nheads, dim, dstate)
ngroups = B.shape[1]
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
assert B.shape == (batch, ngroups, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (nheads, dim)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
dt = dt + dt_bias
dt = F.softplus(dt) if dt_softplus else dt
dA = torch.exp(
rearrange(dt, "b h d -> b h d 1") * A
) # (batch, nheads, dim, dstate)
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
B, "b h n -> b h 1 n"
) # (batch, nheads, dim, dstate)
state.copy_(
state * dA + dB * rearrange(x, "b h d -> b h d 1")
) # (batch, dim, dstate
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
if D is not None:
out += (x * D).to(out.dtype)
out = (out if z is None else out * F.silu(z)).to(x.dtype)
if not has_heads:
out = out.squeeze(1)
return out


def selective_scan_ref(
u,
delta,
Expand Down
212 changes: 212 additions & 0 deletions tests/kernels/mamba/test_mamba_ssm_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for the JSON-based config loader added to selective_state_update.

Tests cover:
- Flat MoE-style filename generation
- VLLM_TUNED_CONFIG_FOLDER env-var override
- Fallback to heuristic when no config file exists
- Nearest effective_batch interpolation
- Edge cases: non-dict JSON, empty config
"""

import json

from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
_get_default_ssm_launch_config,
_try_get_optimal_ssm_config_cached,
get_ssm_config_file_name,
get_ssm_configs,
get_ssm_device_name,
try_get_optimal_ssm_config,
)

# Common kwargs for try_get_optimal_ssm_config. Tests pick (batch, nheads) so
# their product (effective_batch) matches the value being probed.
_HEADDIM = 64
_CACHE_DTYPE = "float32"


def _clear_caches() -> None:
get_ssm_configs.cache_clear()
_try_get_optimal_ssm_config_cached.cache_clear()


def _write_config(tmp_path, dstate: int, payload: dict) -> None:
"""Write payload as the bundled config for (headdim, dstate, cache_dtype)."""
device_name = get_ssm_device_name()
config_path = tmp_path / get_ssm_config_file_name(
_HEADDIM, dstate, _CACHE_DTYPE, device_name
)
with open(config_path, "w") as f:
json.dump(payload, f)


# ---------------------------------------------------------------------------
# Config filename generation
# ---------------------------------------------------------------------------


def test_config_file_name_format():
name = get_ssm_config_file_name(
headdim=64, dstate=128, cache_dtype="float32", device_name="NVIDIA_B200"
)
assert name == (
"headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json"
)


# ---------------------------------------------------------------------------
# VLLM_TUNED_CONFIG_FOLDER override
# ---------------------------------------------------------------------------


def test_env_override_loads_custom_config(monkeypatch, tmp_path):
"""VLLM_TUNED_CONFIG_FOLDER should take precedence over the bundled dir."""
_write_config(
tmp_path,
dstate=16,
payload={
"1": {"BLOCK_SIZE_M": 4, "num_warps": 1},
},
)

monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path))
_clear_caches()

cfg = get_ssm_configs(_HEADDIM, 16, _CACHE_DTYPE)
assert cfg is not None
assert cfg[1] == {"BLOCK_SIZE_M": 4, "num_warps": 1}

_clear_caches()


# ---------------------------------------------------------------------------
# Fallback to heuristic when no config file exists
# ---------------------------------------------------------------------------


def test_fallback_when_no_config(monkeypatch, tmp_path):
"""try_get_optimal_ssm_config must fall back to _get_default_ssm_launch_config
when no JSON file is found for the current
(device, headdim, dstate, cache_dtype) combination.
"""
monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path))
monkeypatch.setattr(
"vllm.model_executor.layers.mamba.ops.mamba_ssm._CONFIGS_DIR",
str(tmp_path),
)

for dstate in (8, 16, 32, 64, 128, 256):
for is_blackwell in (False, True):
_clear_caches()
block_m, warps = try_get_optimal_ssm_config(
headdim=_HEADDIM,
dstate=dstate,
batch=1,
nheads=1,
cache_dtype=_CACHE_DTYPE,
is_blackwell=is_blackwell,
)
assert (block_m, warps) == _get_default_ssm_launch_config(
dstate, is_blackwell=is_blackwell
)

_clear_caches()


# ---------------------------------------------------------------------------
# Nearest effective_batch interpolation
# ---------------------------------------------------------------------------


def test_nearest_effective_batch_interpolation(monkeypatch, tmp_path):
"""When effective_batch = batch*nheads is not an exact key, the closest
key should be selected."""
_write_config(
tmp_path,
dstate=32,
payload={
"64": {"BLOCK_SIZE_M": 8, "num_warps": 1},
"4096": {"BLOCK_SIZE_M": 32, "num_warps": 4},
},
)

monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path))
_clear_caches()

# effective_batch = 1*128 = 128 -> closer to 64 than to 4096
block_m, warps = try_get_optimal_ssm_config(
headdim=_HEADDIM,
dstate=32,
batch=1,
nheads=128,
cache_dtype=_CACHE_DTYPE,
is_blackwell=False,
)
assert block_m == 8 and warps == 1

# effective_batch = 4*1024 = 4096 -> exact match on 4096
block_m, warps = try_get_optimal_ssm_config(
headdim=_HEADDIM,
dstate=32,
batch=4,
nheads=1024,
cache_dtype=_CACHE_DTYPE,
is_blackwell=False,
)
assert block_m == 32 and warps == 4

_clear_caches()


# ---------------------------------------------------------------------------
# Edge cases: malformed / empty config files
# ---------------------------------------------------------------------------


def test_non_dict_json_returns_none(monkeypatch, tmp_path):
"""A valid JSON file that is not a dict (e.g. a list) must be ignored
and return None rather than raising AttributeError."""
device_name = get_ssm_device_name()
config_path = tmp_path / get_ssm_config_file_name(
_HEADDIM, 16, _CACHE_DTYPE, device_name
)
with open(config_path, "w") as f:
json.dump([1, 2, 3], f)

monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path))
monkeypatch.setattr(
"vllm.model_executor.layers.mamba.ops.mamba_ssm._CONFIGS_DIR",
str(tmp_path),
)
_clear_caches()

assert get_ssm_configs(_HEADDIM, 16, _CACHE_DTYPE) is None

_clear_caches()


def test_empty_config_falls_back_to_heuristic(monkeypatch, tmp_path):
"""An empty JSON object {} must not crash min() — should fall back
to the hard-coded heuristic."""
_write_config(tmp_path, dstate=64, payload={})

monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path))
_clear_caches()

dstate = 64
block_m, warps = try_get_optimal_ssm_config(
headdim=_HEADDIM,
dstate=dstate,
batch=1,
nheads=64,
cache_dtype=_CACHE_DTYPE,
is_blackwell=False,
)
assert (block_m, warps) == _get_default_ssm_launch_config(
dstate=dstate, is_blackwell=False
)

_clear_caches()
78 changes: 78 additions & 0 deletions tests/kernels/mamba/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
import torch.nn.functional as F
from einops import rearrange, repeat


def selective_state_update_ref(
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
x: (batch, dim) or (batch, nheads, dim)
dt: (batch, dim) or (batch, nheads, dim)
A: (dim, dstate) or (nheads, dim, dstate)
B: (batch, dstate) or (batch, ngroups, dstate)
C: (batch, dstate) or (batch, ngroups, dstate)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
Return:
out: (batch, dim) or (batch, nheads, dim)
"""
has_heads = state.dim() > 3
if state.dim() == 3:
state = state.unsqueeze(1)
if x.dim() == 2:
x = x.unsqueeze(1)
if dt.dim() == 2:
dt = dt.unsqueeze(1)
if A.dim() == 2:
A = A.unsqueeze(0)
if B.dim() == 2:
B = B.unsqueeze(1)
if C.dim() == 2:
C = C.unsqueeze(1)
if D is not None and D.dim() == 1:
D = D.unsqueeze(0)
if z is not None and z.dim() == 2:
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
batch, nheads, dim, dstate = state.shape
assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
assert A.shape == (nheads, dim, dstate)
ngroups = B.shape[1]
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
assert B.shape == (batch, ngroups, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (nheads, dim)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
dt = dt + dt_bias
dt = F.softplus(dt) if dt_softplus else dt
dA = torch.exp(
rearrange(dt, "b h d -> b h d 1") * A
) # (batch, nheads, dim, dstate)
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
B, "b h n -> b h 1 n"
) # (batch, nheads, dim, dstate)
state.copy_(
state * dA + dB * rearrange(x, "b h d -> b h d 1")
) # (batch, dim, dstate
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
if D is not None:
out += (x * D).to(out.dtype)
out = (out if z is None else out * F.silu(z)).to(x.dtype)
if not has_heads:
out = out.squeeze(1)
return out
5 changes: 4 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1744,7 +1744,10 @@ def _resolve_rust_frontend_path() -> str | None:
"VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT": lambda: bool(
int(os.getenv("VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", "0"))
),
# Allows vllm to find tuned config under customized folder
# User override folder for tuned Triton-kernel configs. Shared by MoE,
# Mamba SSU, and LoRA. Filenames are distinct so one folder can hold all.
# Each component first checks this folder, then the configs shipped with
# vLLM (if any). If no JSON matches, it uses a hard-coded heuristic.
"VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
# Valid values are container,code_interpreter,web_search_preview
# ex VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
Expand Down
Loading
Loading