From b92d438b706e0077cb4aa84a0f0b54371e9ffdd0 Mon Sep 17 00:00:00 2001 From: Banani Ghosh Date: Thu, 30 Apr 2026 11:40:09 -0700 Subject: [PATCH 01/24] [Mamba] selective_state_update auto-tuning framework with NVIDIA GB10 and B200 configs - Add benchmark and --validate script for generating tuned configs on any GPU - Add tuned dstate configs (16/32/64/128/256) for NVIDIA GB10 and B200 - Update mamba_ssm.py ops to support config-driven kernel selection - Add per-GPU subfolder config structure under vllm/model_executor/layers/mamba/configs/ - Add config loader unit tests with edge case coverage - Add type guard for non-dict JSON config files (prevents AttributeError) - Fix empty config dict crash path in _get_ssm_launch_config (prevents ValueError) Signed-off-by: Banani Ghosh --- .../benchmark_selective_state_update.py | 629 ++++++++++++++++++ tests/kernels/mamba/test_mamba_ssm_configs.py | 167 +++++ .../mamba/configs/NVIDIA_B200/dstate=128.json | 46 ++ .../mamba/configs/NVIDIA_B200/dstate=16.json | 46 ++ .../mamba/configs/NVIDIA_B200/dstate=256.json | 46 ++ .../mamba/configs/NVIDIA_B200/dstate=32.json | 46 ++ .../mamba/configs/NVIDIA_B200/dstate=64.json | 46 ++ .../mamba/configs/NVIDIA_GB10/dstate=128.json | 46 ++ .../mamba/configs/NVIDIA_GB10/dstate=16.json | 46 ++ .../mamba/configs/NVIDIA_GB10/dstate=256.json | 46 ++ .../mamba/configs/NVIDIA_GB10/dstate=32.json | 46 ++ .../mamba/configs/NVIDIA_GB10/dstate=64.json | 46 ++ .../layers/mamba/ops/mamba_ssm.py | 127 +++- 13 files changed, 1365 insertions(+), 18 deletions(-) create mode 100644 benchmarks/kernels/benchmark_selective_state_update.py create mode 100644 tests/kernels/mamba/test_mamba_ssm_configs.py create mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=128.json create mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=16.json create mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=256.json create mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=32.json create mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=64.json create mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=128.json create mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=16.json create mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=256.json create mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=32.json create mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=64.json diff --git a/benchmarks/kernels/benchmark_selective_state_update.py b/benchmarks/kernels/benchmark_selective_state_update.py new file mode 100644 index 000000000000..9858594fedc3 --- /dev/null +++ b/benchmarks/kernels/benchmark_selective_state_update.py @@ -0,0 +1,629 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark and tuning script for the Mamba selective_state_update kernel. + +This script mirrors the fused MoE tuning workflow in vLLM: + - Sweeps BLOCK_SIZE_M x num_warps across all batch sizes for a given dstate + - Finds the best launch config per (batch, dstate) combination + - Optionally saves configs to JSON in vllm/model_executor/layers/mamba/configs/ + - Optionally compares tuned configs against the existing heuristic baseline + - Always saves a human-readable results file alongside this script + +Usage (tune all dstates, save configs + compare vs heuristic): + python benchmarks/kernels/benchmark_selective_state_update.py \\ + --all-dstates --save-configs --compare + +Usage (single dstate, show results only): + python benchmarks/kernels/benchmark_selective_state_update.py --dstate 128 + +Generated JSON configs are loaded automatically by selective_state_update +at runtime when a matching device config file is found. +""" + +import argparse +import json +import os +import sys +from io import StringIO +from itertools import product +from unittest.mock import patch + +import torch + +import vllm.model_executor.layers.mamba.ops.mamba_ssm as mamba_ssm_module +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_state_update, +) +from vllm.platforms import current_platform + +_RESULTS_DIR = os.path.dirname(os.path.realpath(__file__)) + +# --------------------------------------------------------------------------- +# Tuning search space +# --------------------------------------------------------------------------- + +BLOCK_SIZE_M_CHOICES = [4, 8, 16, 32, 64] +NUM_WARPS_CHOICES = [1, 2, 4, 8] + +BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + +ALL_DSTATES = [16, 32, 64, 128, 256] + + +# --------------------------------------------------------------------------- +# Config file naming (mirrors fused_moe pattern) +# --------------------------------------------------------------------------- + + +def get_ssm_config_file_name(dstate: int) -> str: + return f"dstate={dstate}.json" + + +def get_device_name() -> str: + return current_platform.get_device_name().replace(" ", "_") + + +def get_ssm_configs_dir() -> str: + return os.path.normpath( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "../../vllm/model_executor/layers/mamba/configs", + ) + ) + + +# --------------------------------------------------------------------------- +# Benchmark helper +# --------------------------------------------------------------------------- + + +def _make_inputs( + batch: int, + nheads: int, + dim: int, + dstate: int, + ngroups: int, + dtype: torch.dtype, + device: str = "cuda", +): + state = torch.randn(batch, nheads, dim, dstate, dtype=dtype, device=device) + x = torch.randn(batch, nheads, dim, dtype=dtype, device=device) + dt = torch.randn(batch, nheads, dim, dtype=dtype, device=device) + A = -torch.rand(nheads, dim, dstate, dtype=torch.float32, device=device) + B = torch.randn(batch, ngroups, dstate, dtype=dtype, device=device) + C = torch.randn(batch, ngroups, dstate, dtype=dtype, device=device) + D = torch.randn(nheads, dim, dtype=dtype, device=device) + dt_bias = torch.randn(nheads, dim, dtype=dtype, device=device) + out = torch.zeros(batch, nheads, dim, dtype=dtype, device=device) + return state, x, dt, A, B, C, D, dt_bias, out + + +def benchmark_config( + batch: int, + nheads: int, + dim: int, + dstate: int, + ngroups: int, + block_size_m: int, + num_warps_val: int, + dtype: torch.dtype, + num_iters: int = 100, + num_warmup: int = 20, +) -> float | None: + """ + Time one (BLOCK_SIZE_M, num_warps) config for selective_state_update. + Returns elapsed time in microseconds, or None on error. + """ + state, x, dt, A, B, C, D, dt_bias, out = _make_inputs( + batch, nheads, dim, dstate, ngroups, dtype + ) + + # Monkeypatch _get_ssm_launch_config to return the specific config + # without affecting the lru_cache on get_ssm_configs. + def _fixed_launch_config(dstate_, batch_, is_blackwell_): + return block_size_m, num_warps_val + + try: + with patch.object( + mamba_ssm_module, "_get_ssm_launch_config", _fixed_launch_config + ): + # Warmup + for _ in range(num_warmup): + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + out=out, + ) + torch.accelerator.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(num_iters): + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + out=out, + ) + end.record() + torch.accelerator.synchronize() + return start.elapsed_time(end) / num_iters * 1000 # ms -> us + except Exception as e: + if "OutOfResources" not in str(e): + print( + f" Warning: config M={block_size_m},w={num_warps_val} " + f"raised {type(e).__name__}: {e}" + ) + return None + + +# --------------------------------------------------------------------------- +# Tuning loop +# --------------------------------------------------------------------------- + + +def tune_dstate( + dstate: int, + dtype: torch.dtype, + num_iters: int, + verbose: bool, + batch_sizes: list[int] | None = None, +) -> dict[int, dict]: + """ + For each batch size, sweep all (BLOCK_SIZE_M, num_warps) combos and + return a dict mapping batch_size -> best_config. + """ + # Use a representative shape for tuning (Mamba-2 style, common case). + nheads, dim, ngroups = 64, 64, 1 + active_batches = batch_sizes if batch_sizes is not None else BATCH_SIZES + + best_per_batch: dict[int, dict] = {} + + print(f"\n{'=' * 74}") + print(f"Tuning dstate={dstate} nheads={nheads} dim={dim} dtype={dtype}") + print(f"{'=' * 74}") + + hdr = f"{'Batch':>7} | {'BLOCK_M':>7} | {'warps':>5} | {'us':>10} | note" + print(hdr) + print("-" * 50) + + for batch in active_batches: + best_time = float("inf") + best_cfg: dict = {} + + for bsm, nw in product(BLOCK_SIZE_M_CHOICES, NUM_WARPS_CHOICES): + t = benchmark_config( + batch=batch, + nheads=nheads, + dim=dim, + dstate=dstate, + ngroups=ngroups, + block_size_m=bsm, + num_warps_val=nw, + dtype=dtype, + num_iters=num_iters, + ) + if t is None: + continue + is_best = t < best_time + if is_best: + best_time = t + best_cfg = {"BLOCK_SIZE_M": bsm, "num_warps": nw} + if verbose: + marker = " <-- best" if is_best else "" + print(f"{batch:>7} | {bsm:>7} | {nw:>5} | {t:>10.2f} |{marker}") + + if not verbose and best_cfg: + print( + f"{batch:>7} | {best_cfg['BLOCK_SIZE_M']:>7} | " + f"{best_cfg['num_warps']:>5} | {best_time:>10.2f} | best" + ) + + best_per_batch[batch] = best_cfg + + return best_per_batch + + +# --------------------------------------------------------------------------- +# Correctness validation +# --------------------------------------------------------------------------- + + +def _selective_state_update_ref( + state: torch.Tensor, + x: torch.Tensor, + dt: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D: torch.Tensor, + dt_bias: torch.Tensor, +) -> torch.Tensor: + """ + Pure-PyTorch CPU reference for selective_state_update (dt_softplus=True). + + Shapes (all moved to CPU float32 internally): + state : (batch, nheads, dim, dstate) + x : (batch, nheads, dim) + dt : (batch, nheads, dim) + A : (nheads, dim, dstate) + B : (batch, ngroups, dstate) + C : (batch, ngroups, dstate) + D : (nheads, dim) + dt_bias: (nheads, dim) + Returns: + out : (batch, nheads, dim) in the original dtype + """ + orig_dtype = x.dtype + state = state.clone().cpu().float() + x = x.cpu().float() + dt = dt.cpu().float() + A = A.cpu().float() + B = B.cpu().float() + C = C.cpu().float() + D = D.cpu().float() + dt = dt + dt_bias.cpu().float() + dt = torch.nn.functional.softplus(dt) # (batch, nheads, dim) + + nheads, _, _ = A.shape + ngroups = B.shape[1] + + dA = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0)) # (batch, nheads, dim, dstate) + B_exp = B.repeat_interleave(nheads // ngroups, dim=1) # (batch, nheads, dstate) + C_exp = C.repeat_interleave(nheads // ngroups, dim=1) + dB = dt.unsqueeze(-1) * B_exp.unsqueeze(2) # (batch, nheads, dim, dstate) + + state_new = state * dA + dB * x.unsqueeze(-1) + out = (state_new * C_exp.unsqueeze(2)).sum(-1) # (batch, nheads, dim) + out = out + x * D.unsqueeze(0) + return out.to(orig_dtype) + + +def validate_configs( + dstate: int, + tuned: dict[int, dict], + dtype: torch.dtype, + atol: float = 1e-2, + rtol: float = 1e-2, +) -> dict[int, bool]: + """ + For every batch size in *tuned*, run the kernel with the tuned config and + compare against the CPU reference. Returns {batch: passed}. + """ + nheads, dim, ngroups = 64, 64, 1 + + print(f"\n{'=' * 74}") + print(f"Validation dstate={dstate} dtype={dtype} atol={atol}") + print(f"{'=' * 74}") + print(f"{'Batch':>7} | {'MaxAbsErr':>12} | {'Status':>8}") + print("-" * 36) + + results: dict[int, bool] = {} + + for batch, cfg in sorted(tuned.items()): + state, x, dt, A, B, C, D, dt_bias, out = _make_inputs( + batch, nheads, dim, dstate, ngroups, dtype + ) + # Clone state before GPU kernel modifies it in-place + state_ref = state.clone() + + # GPU kernel output + def _fixed(dstate_, batch_, is_blackwell_, _cfg=cfg): + return _cfg["BLOCK_SIZE_M"], _cfg["num_warps"] + + with patch.object(mamba_ssm_module, "_get_ssm_launch_config", _fixed): + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + out=out, + ) + torch.accelerator.synchronize() + gpu_out = out.detach().cpu() + + # CPU reference uses the original (unmodified) state + ref_out = _selective_state_update_ref(state_ref, x, dt, A, B, C, D, dt_bias) + + passed = torch.allclose(gpu_out.float(), ref_out.float(), atol=atol, rtol=rtol) + max_err = (gpu_out.float() - ref_out.float()).abs().max().item() + status = "PASS" if passed else "FAIL" + results[batch] = passed + print(f"{batch:>7} | {max_err:>12.6f} | {status:>8}") + + n_pass = sum(results.values()) + n_total = len(results) + print(f"\n {n_pass}/{n_total} configs passed validation for dstate={dstate}") + return results + + +# --------------------------------------------------------------------------- +# Save configs +# --------------------------------------------------------------------------- + + +def save_configs( + dstate: int, configs: dict[int, dict], save_dir: str | None = None +) -> str: + base_dir = save_dir if save_dir else get_ssm_configs_dir() + # Place configs in a per-GPU subfolder for easy multi-GPU organisation. + configs_dir = os.path.join(base_dir, get_device_name()) + os.makedirs(configs_dir, exist_ok=True) + file_path = os.path.join(configs_dir, get_ssm_config_file_name(dstate)) + payload = {str(k): v for k, v in sorted(configs.items())} + with open(file_path, "w") as f: + json.dump(payload, f, indent=4) + return file_path + + +# --------------------------------------------------------------------------- +# Comparison table +# --------------------------------------------------------------------------- + + +def current_heuristic(dstate: int, is_blackwell: bool = False) -> dict: + """Return the current hard-coded BLOCK_SIZE_M / num_warps for dstate.""" + if dstate <= 16: + return {"BLOCK_SIZE_M": 32, "num_warps": 4} + elif dstate <= 32: + return {"BLOCK_SIZE_M": 16, "num_warps": 4} + elif dstate <= 64: + return {"BLOCK_SIZE_M": 8, "num_warps": 4} + else: + if is_blackwell: + return {"BLOCK_SIZE_M": 32, "num_warps": 8} + elif dstate <= 128: + return {"BLOCK_SIZE_M": 4, "num_warps": 4} + else: + return {"BLOCK_SIZE_M": 4, "num_warps": 8} + + +def compare_heuristic_vs_tuned( + dstate: int, + tuned: dict[int, dict], + dtype: torch.dtype, + num_iters: int, + is_blackwell: bool, +): + nheads, dim, ngroups = 64, 64, 1 + heur_cfg = current_heuristic(dstate, is_blackwell) + + print(f"\n{'=' * 74}") + print(f"Comparison dstate={dstate} — heuristic vs tuned") + print( + f"Heuristic: BLOCK_SIZE_M={heur_cfg['BLOCK_SIZE_M']}, " + f"num_warps={heur_cfg['num_warps']}" + ) + print(f"{'=' * 74}") + hdr = ( + f"{'Batch':>7} | {'Heur(us)':>10} | {'Tuned(us)':>10} | " + f"{'Speedup':>8} | Best config" + ) + print(hdr) + print("-" * len(hdr)) + + for batch in BATCH_SIZES: + t_h = benchmark_config( + batch, + nheads, + dim, + dstate, + ngroups, + heur_cfg["BLOCK_SIZE_M"], + heur_cfg["num_warps"], + dtype, + num_iters, + ) + best = tuned.get(batch, heur_cfg) + t_t = benchmark_config( + batch, + nheads, + dim, + dstate, + ngroups, + best["BLOCK_SIZE_M"], + best["num_warps"], + dtype, + num_iters, + ) + if t_h is None or t_t is None: + print(f"{batch:>7} | {'N/A':>10} | {'N/A':>10} | {'N/A':>8} |") + continue + speedup = t_h / t_t + marker = " <--" if speedup > 1.05 else "" + print( + f"{batch:>7} | {t_h:>10.2f} | {t_t:>10.2f} | " + f"{speedup:>7.2f}x | " + f"M={best['BLOCK_SIZE_M']},w={best['num_warps']}{marker}" + ) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def save_results(device_name: str, output: str, results_file: str | None = None) -> str: + """Save the full benchmark output to a results text file.""" + if results_file is None: + safe_name = device_name.replace(" ", "_") + results_file = os.path.join( + _RESULTS_DIR, f"ssm_benchmark_results_{safe_name}.txt" + ) + with open(results_file, "w") as f: + f.write(output) + return results_file + + +def main(): + parser = argparse.ArgumentParser( + description="Tune selective_state_update kernel for Mamba SSM" + ) + parser.add_argument( + "--dstate", + type=int, + default=128, + help="SSM state size to tune for (default: 128)", + ) + parser.add_argument( + "--all-dstates", + action="store_true", + help="Tune all common dstate values: " + str(ALL_DSTATES), + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["float16", "bfloat16"], + help="Data type (default: bfloat16)", + ) + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of timing iterations (default: 100)", + ) + parser.add_argument( + "--save-configs", + action="store_true", + help="Save best configs to JSON in mamba/configs/", + ) + parser.add_argument( + "--compare", + action="store_true", + help="Show comparison table: heuristic vs tuned", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print every (BLOCK_SIZE_M, num_warps) result, not just best", + ) + parser.add_argument( + "--results-file", + type=str, + default=None, + help="Path to save the benchmark results text file " + "(default: ssm_benchmark_results_.txt alongside this script)", + ) + parser.add_argument( + "--save-dir", + type=str, + default=None, + help="Base directory to save JSON configs. Configs are placed in a " + "per-GPU subfolder: //. " + "(default: vllm/model_executor/layers/mamba/configs/)", + ) + parser.add_argument( + "--batches", + type=int, + nargs="+", + default=None, + metavar="B", + help="Only tune these specific batch sizes, e.g. --batches 2 16 256. " + "Useful for stability re-checks on flagged configs.", + ) + parser.add_argument( + "--validate", + action="store_true", + help="After tuning, verify each best config against a CPU reference " + "implementation. Configs that fail are flagged in the output.", + ) + parser.add_argument( + "--atol", + type=float, + default=1e-2, + help="Absolute tolerance for --validate (default: 1e-2)", + ) + args = parser.parse_args() + + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + device_name = current_platform.get_device_name() + cap = torch.cuda.get_device_capability() + is_blackwell = cap[0] >= 10 + + # Mirror all output to a results file (like Unix tee). + buf = StringIO() + + class _Tee: + """Writes to both the original stdout and an in-memory buffer.""" + + def write(self, s): + buf.write(s) + sys.__stdout__.write(s) + + def flush(self): + sys.__stdout__.flush() + + sys.stdout = _Tee() # type: ignore[assignment] + + try: + print(f"Device : {device_name} (sm_{cap[0]}{cap[1]:02d})") + print(f"Blackwell: {is_blackwell}") + print(f"dtype : {args.dtype}") + + dstates = ALL_DSTATES if args.all_dstates else [args.dstate] + + for dstate in dstates: + tuned = tune_dstate( + dstate, dtype, args.num_iters, args.verbose, args.batches + ) + + if args.compare: + compare_heuristic_vs_tuned( + dstate, tuned, dtype, args.num_iters, is_blackwell + ) + + if args.validate: + validity = validate_configs(dstate, tuned, dtype, args.atol) + # Filter out any configs that failed correctness check + failed = [b for b, ok in validity.items() if not ok] + if failed: + print( + f"\n WARNING: {len(failed)} config(s) failed " + f"validation for dstate={dstate}: batches {failed}" + ) + print(" These will NOT be saved even with --save-configs.") + tuned = { + b: cfg for b, cfg in tuned.items() if validity.get(b, True) + } + + if args.save_configs: + path = save_configs(dstate, tuned, args.save_dir) + print(f"\nSaved: {path}") + else: + print(f"\nBest configs for dstate={dstate}:") + for batch, cfg in sorted(tuned.items()): + print(f" batch={batch:>5}: {cfg}") + print("\n(Re-run with --save-configs to persist to JSON)") + finally: + sys.stdout = sys.__stdout__ + results_path = save_results(device_name, buf.getvalue(), args.results_file) + print(f"\nResults saved to: {results_path}") + + +if __name__ == "__main__": + main() diff --git a/tests/kernels/mamba/test_mamba_ssm_configs.py b/tests/kernels/mamba/test_mamba_ssm_configs.py new file mode 100644 index 000000000000..66812831e846 --- /dev/null +++ b/tests/kernels/mamba/test_mamba_ssm_configs.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for the JSON-based config loader added to selective_state_update. + +Tests cover: + - Config filename generation + - VLLM_TUNED_CONFIG_FOLDER env-var override (per-GPU subfolder structure) + - Fallback to heuristic when no config file exists + - Nearest-batch interpolation +""" + +import json + +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + _get_ssm_launch_config, + get_ssm_config_file_name, + get_ssm_configs, + get_ssm_device_name, +) + +# --------------------------------------------------------------------------- +# Config filename generation +# --------------------------------------------------------------------------- + + +def test_config_file_name_format(): + name = get_ssm_config_file_name(128) + assert name == "dstate=128.json" + + +# --------------------------------------------------------------------------- +# VLLM_TUNED_CONFIG_FOLDER override (configs live in //dstate=N.json) +# --------------------------------------------------------------------------- + + +def test_env_override_loads_custom_config(monkeypatch, tmp_path): + """VLLM_TUNED_CONFIG_FOLDER should take precedence over the bundled dir.""" + device_name = get_ssm_device_name() + gpu_dir = tmp_path / device_name + gpu_dir.mkdir() + + config_path = gpu_dir / get_ssm_config_file_name(16) + payload = {"1": {"BLOCK_SIZE_M": 4, "num_warps": 1}} + with open(config_path, "w") as f: + json.dump(payload, f) + + monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path)) + get_ssm_configs.cache_clear() + + cfg = get_ssm_configs(16) + assert cfg is not None + assert cfg[1] == {"BLOCK_SIZE_M": 4, "num_warps": 1} + + get_ssm_configs.cache_clear() + + +# --------------------------------------------------------------------------- +# Fallback to heuristic when no config file exists +# --------------------------------------------------------------------------- + + +def test_fallback_when_no_config(monkeypatch, tmp_path): + """_get_ssm_launch_config must fall back to the hard-coded heuristic + when no JSON file is found for the current device.""" + monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path)) + monkeypatch.setattr( + "vllm.model_executor.layers.mamba.ops.mamba_ssm._CONFIGS_DIR", + str(tmp_path), + ) + get_ssm_configs.cache_clear() + + # dstate=64 heuristic: BLOCK_SIZE_M=8, num_warps=4 + block_m, warps = _get_ssm_launch_config(dstate=64, batch=1, is_blackwell=False) + assert block_m == 8 + assert warps == 4 + + # dstate=16 heuristic: BLOCK_SIZE_M=32, num_warps=4 + block_m, warps = _get_ssm_launch_config(dstate=16, batch=1, is_blackwell=False) + assert block_m == 32 + assert warps == 4 + + get_ssm_configs.cache_clear() + + +# --------------------------------------------------------------------------- +# Nearest-batch interpolation +# --------------------------------------------------------------------------- + + +def test_nearest_batch_interpolation(monkeypatch, tmp_path): + """When the exact batch size is not in the config, the closest key + should be selected.""" + device_name = get_ssm_device_name() + gpu_dir = tmp_path / device_name + gpu_dir.mkdir() + + config_path = gpu_dir / get_ssm_config_file_name(32) + payload = { + "1": {"BLOCK_SIZE_M": 8, "num_warps": 1}, + "64": {"BLOCK_SIZE_M": 32, "num_warps": 4}, + } + with open(config_path, "w") as f: + json.dump(payload, f) + + monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path)) + get_ssm_configs.cache_clear() + + # batch=5 is closer to 1 than to 64 — expects M=8, w=1 + block_m, warps = _get_ssm_launch_config(dstate=32, batch=5, is_blackwell=False) + assert block_m == 8 and warps == 1 + + # batch=40 is closer to 64 — expects M=32, w=4 + block_m, warps = _get_ssm_launch_config(dstate=32, batch=40, is_blackwell=False) + assert block_m == 32 and warps == 4 + + get_ssm_configs.cache_clear() + + +# --------------------------------------------------------------------------- +# Edge cases: malformed / empty config files +# --------------------------------------------------------------------------- + + +def test_non_dict_json_returns_none(monkeypatch, tmp_path): + """A valid JSON file that is not a dict (e.g. a list) must be ignored + and return None rather than raising AttributeError.""" + device_name = get_ssm_device_name() + gpu_dir = tmp_path / device_name + gpu_dir.mkdir() + + config_path = gpu_dir / get_ssm_config_file_name(16) + with open(config_path, "w") as f: + json.dump([1, 2, 3], f) + + monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path)) + monkeypatch.setattr( + "vllm.model_executor.layers.mamba.ops.mamba_ssm._CONFIGS_DIR", + str(tmp_path), + ) + get_ssm_configs.cache_clear() + + assert get_ssm_configs(16) is None + + get_ssm_configs.cache_clear() + + +def test_empty_config_falls_back_to_heuristic(monkeypatch, tmp_path): + """An empty JSON object {} must not crash min() — should fall back + to the hard-coded heuristic.""" + device_name = get_ssm_device_name() + gpu_dir = tmp_path / device_name + gpu_dir.mkdir() + + config_path = gpu_dir / get_ssm_config_file_name(64) + with open(config_path, "w") as f: + json.dump({}, f) + + monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path)) + get_ssm_configs.cache_clear() + + # dstate=64 heuristic: BLOCK_SIZE_M=8, num_warps=4 + block_m, warps = _get_ssm_launch_config(dstate=64, batch=1, is_blackwell=False) + assert block_m == 8 + assert warps == 4 + + get_ssm_configs.cache_clear() diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=128.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=128.json new file mode 100644 index 000000000000..6d571eb702ff --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=128.json @@ -0,0 +1,46 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "num_warps": 1 + }, + "2": { + "BLOCK_SIZE_M": 64, + "num_warps": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE_M": 8, + "num_warps": 1 + }, + "16": { + "BLOCK_SIZE_M": 64, + "num_warps": 2 + }, + "32": { + "BLOCK_SIZE_M": 8, + "num_warps": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=16.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=16.json new file mode 100644 index 000000000000..7db73c38803a --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=16.json @@ -0,0 +1,46 @@ +{ + "1": { + "BLOCK_SIZE_M": 4, + "num_warps": 8 + }, + "2": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "4": { + "BLOCK_SIZE_M": 64, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_M": 32, + "num_warps": 1 + }, + "16": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_SIZE_M": 64, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + }, + "512": { + "BLOCK_SIZE_M": 32, + "num_warps": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=256.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=256.json new file mode 100644 index 000000000000..61d58288b7ac --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=256.json @@ -0,0 +1,46 @@ +{ + "1": { + "BLOCK_SIZE_M": 8, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "num_warps": 2 + }, + "4": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "8": { + "BLOCK_SIZE_M": 4, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "64": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "1024": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=32.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=32.json new file mode 100644 index 000000000000..cc0585cf7eef --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=32.json @@ -0,0 +1,46 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "4": { + "BLOCK_SIZE_M": 4, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_SIZE_M": 8, + "num_warps": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "num_warps": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "num_warps": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "num_warps": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "num_warps": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=64.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=64.json new file mode 100644 index 000000000000..1274671669c6 --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=64.json @@ -0,0 +1,46 @@ +{ + "1": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "2": { + "BLOCK_SIZE_M": 4, + "num_warps": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE_M": 8, + "num_warps": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "num_warps": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "128": { + "BLOCK_SIZE_M": 8, + "num_warps": 1 + }, + "256": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "512": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=128.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=128.json new file mode 100644 index 000000000000..a35da36b3541 --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=128.json @@ -0,0 +1,46 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "num_warps": 2 + }, + "2": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "4": { + "BLOCK_SIZE_M": 64, + "num_warps": 2 + }, + "8": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + }, + "512": { + "BLOCK_SIZE_M": 4, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=16.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=16.json new file mode 100644 index 000000000000..136cbea31b6f --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=16.json @@ -0,0 +1,46 @@ +{ + "1": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "2": { + "BLOCK_SIZE_M": 32, + "num_warps": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "16": { + "BLOCK_SIZE_M": 64, + "num_warps": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "num_warps": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "num_warps": 1 + }, + "256": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "num_warps": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=256.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=256.json new file mode 100644 index 000000000000..d2141e287759 --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=256.json @@ -0,0 +1,46 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "num_warps": 2 + }, + "4": { + "BLOCK_SIZE_M": 4, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE_M": 4, + "num_warps": 8 + }, + "512": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=32.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=32.json new file mode 100644 index 000000000000..63406b0aa1fb --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=32.json @@ -0,0 +1,46 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "num_warps": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "512": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "1024": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=64.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=64.json new file mode 100644 index 000000000000..d983a23e44d2 --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=64.json @@ -0,0 +1,46 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "num_warps": 2 + }, + "2": { + "BLOCK_SIZE_M": 8, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "num_warps": 2 + }, + "64": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_M": 8, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_M": 8, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_M": 32, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE_M": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index e3c8ba8312f2..88a60a9553de 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -4,16 +4,123 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py +import functools +import json +import os +from typing import Any + import torch from packaging import version from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp +from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.v1.attention.backends.utils import NULL_BLOCK_ID +logger = init_logger(__name__) + TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0")) + +# --------------------------------------------------------------------------- +# JSON config loading (mirrors fused_moe pattern) +# --------------------------------------------------------------------------- + +_CONFIGS_DIR = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "..", "configs" +) + + +def get_ssm_config_file_name(dstate: int) -> str: + """Return the JSON filename for the given dstate. + + Config files are organised per GPU: + configs//dstate=.json + """ + return f"dstate={dstate}.json" + + +def get_ssm_device_name() -> str: + return current_platform.get_device_name().replace(" ", "_") + + +@functools.lru_cache +def get_ssm_configs(dstate: int) -> dict[int, Any] | None: + """ + Return tuned (BLOCK_SIZE_M, num_warps) configs for *selective_state_update* + keyed by batch size, or ``None`` if no config file is found. + + Config files live in a per-GPU subfolder: + vllm/model_executor/layers/mamba/configs//dstate=.json + + They can be generated with: + benchmarks/kernels/benchmark_selective_state_update.py --save-configs + """ + device_name = get_ssm_device_name() + json_file_name = get_ssm_config_file_name(dstate) + + config_file_paths: list[str] = [] + + # User-supplied override (same env-var as fused_moe) + user_dir = os.environ.get("VLLM_TUNED_CONFIG_FOLDER") + if user_dir is not None: + config_file_paths.append(os.path.join(user_dir, device_name, json_file_name)) + + # Bundled default + config_file_paths.append(os.path.join(_CONFIGS_DIR, device_name, json_file_name)) + + for path in config_file_paths: + if os.path.exists(path): + with open(path) as f: + logger.info_once( + "Using SSM config from %s for selective_state_update.", + path, + scope="global", + ) + raw = json.load(f) + if isinstance(raw, dict): + return {int(k): v for k, v in raw.items()} + + return None + + +def _get_ssm_launch_config( + dstate: int, + batch: int, + is_blackwell: bool, +) -> tuple[int, int]: + """ + Return (BLOCK_SIZE_M, num_warps) for a given dstate and batch size. + + Tries the JSON config first; falls back to the original hard-coded + heuristic so existing behaviour is fully preserved when no config file + is present. + """ + configs = get_ssm_configs(dstate) + if configs: + # Pick the closest batch size in the tuned grid (same strategy as MoE) + closest = min(configs.keys(), key=lambda x: abs(x - batch)) + cfg = configs[closest] + return cfg["BLOCK_SIZE_M"], cfg["num_warps"] + + # ---- original hard-coded heuristic (unchanged) ---- + BLOCK_SIZE_M, num_warps = 4, 8 + if dstate <= 16: + BLOCK_SIZE_M, num_warps = 32, 4 + elif dstate <= 32: + BLOCK_SIZE_M, num_warps = 16, 4 + elif dstate <= 64: + BLOCK_SIZE_M, num_warps = 8, 4 + else: + if is_blackwell: + BLOCK_SIZE_M, num_warps = 32, 8 + elif dstate <= 128: + BLOCK_SIZE_M, num_warps = 4, 4 + return BLOCK_SIZE_M, num_warps + + if TRITON3: @triton.jit @@ -440,24 +547,8 @@ def selective_state_update( else (0, 0) ) # We don't want autotune since it will overwrite the state. - # We instead tune by hand based on dstate. - - # Default - BLOCK_SIZE_M, num_warps = 4, 8 - - if dstate <= 16: - BLOCK_SIZE_M, num_warps = 32, 4 - elif dstate <= 32: - BLOCK_SIZE_M, num_warps = 16, 4 - elif dstate <= 64: - BLOCK_SIZE_M, num_warps = 8, 4 - else: - # dstate > 64 - if is_blackwell: - # Optimized for B200 with dstate>64 - BLOCK_SIZE_M, num_warps = 32, 8 - elif dstate <= 128: - BLOCK_SIZE_M, num_warps = 4, 4 + # Load from JSON config if available, otherwise fall back to heuristic. + BLOCK_SIZE_M, num_warps = _get_ssm_launch_config(dstate, N, is_blackwell) tie_hdim = ( A.stride(-1) == 0 From f7efdb80e5ae11ba44cd37cbb0a2fe48025af6ea Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Tue, 19 May 2026 11:33:46 +0300 Subject: [PATCH 02/24] Apply code review changes Includes: - triton_version - Mamba state dtype flag - Use effective batch for tuning Signed-off-by: Daniel Serebrenik --- .../benchmark_selective_state_update.py | 427 +++++++++++++----- tests/kernels/mamba/test_mamba_ssm_configs.py | 174 ++++--- .../mamba/configs/NVIDIA_B200/dstate=128.json | 46 -- .../mamba/configs/NVIDIA_B200/dstate=16.json | 46 -- .../mamba/configs/NVIDIA_B200/dstate=256.json | 46 -- .../mamba/configs/NVIDIA_B200/dstate=32.json | 46 -- .../mamba/configs/NVIDIA_B200/dstate=64.json | 46 -- .../mamba/configs/NVIDIA_GB10/dstate=128.json | 46 -- .../mamba/configs/NVIDIA_GB10/dstate=16.json | 46 -- .../mamba/configs/NVIDIA_GB10/dstate=256.json | 46 -- .../mamba/configs/NVIDIA_GB10/dstate=32.json | 46 -- .../mamba/configs/NVIDIA_GB10/dstate=64.json | 46 -- ..._name=NVIDIA_B200,cache_dtype=float16.json | 123 +++++ ..._name=NVIDIA_B200,cache_dtype=float32.json | 123 +++++ .../layers/mamba/ops/mamba_ssm.py | 94 ++-- 15 files changed, 740 insertions(+), 661 deletions(-) delete mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=128.json delete mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=16.json delete mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=256.json delete mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=32.json delete mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=64.json delete mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=128.json delete mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=16.json delete mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=256.json delete mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=32.json delete mode 100644 vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=64.json create mode 100644 vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json create mode 100644 vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json diff --git a/benchmarks/kernels/benchmark_selective_state_update.py b/benchmarks/kernels/benchmark_selective_state_update.py index 9858594fedc3..a33217237908 100644 --- a/benchmarks/kernels/benchmark_selective_state_update.py +++ b/benchmarks/kernels/benchmark_selective_state_update.py @@ -4,22 +4,14 @@ """ Benchmark and tuning script for the Mamba selective_state_update kernel. -This script mirrors the fused MoE tuning workflow in vLLM: - - Sweeps BLOCK_SIZE_M x num_warps across all batch sizes for a given dstate - - Finds the best launch config per (batch, dstate) combination - - Optionally saves configs to JSON in vllm/model_executor/layers/mamba/configs/ - - Optionally compares tuned configs against the existing heuristic baseline - - Always saves a human-readable results file alongside this script - -Usage (tune all dstates, save configs + compare vs heuristic): - python benchmarks/kernels/benchmark_selective_state_update.py \\ - --all-dstates --save-configs --compare - -Usage (single dstate, show results only): - python benchmarks/kernels/benchmark_selective_state_update.py --dstate 128 +Mirrors the fused MoE tuning workflow: sweeps (BLOCK_SIZE_M, num_warps) across +an effective_batch grid for a given (headdim, dstate, ngroups, cache_dtype) and +saves the best config per effective_batch to JSON. Generated configs are picked +up by selective_state_update at runtime. -Generated JSON configs are loaded automatically by selective_state_update -at runtime when a matching device config file is found. +Usage: + python benchmarks/kernels/benchmark_selective_state_update.py \ + --all-dstates --save-configs --compare """ import argparse @@ -28,6 +20,7 @@ import sys from io import StringIO from itertools import product +from typing import Any from unittest.mock import patch import torch @@ -37,6 +30,13 @@ selective_state_update, ) from vllm.platforms import current_platform +from vllm.triton_utils import triton + +# MambaDType subset: bf16 is excluded (not commonly used) +_SSM_CACHE_DTYPE_MAP: dict[str, torch.dtype] = { + "float32": torch.float32, + "float16": torch.float16, +} _RESULTS_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -44,23 +44,73 @@ # Tuning search space # --------------------------------------------------------------------------- -BLOCK_SIZE_M_CHOICES = [4, 8, 16, 32, 64] +_BSM_CHOICES_ALL = [4, 8, 16, 32, 64, 128, 256] + NUM_WARPS_CHOICES = [1, 2, 4, 8] -BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + +def _block_size_m_choices(headdim: int) -> list[int]: + """BLOCK_SIZE_M candidates worth sweeping for a given headdim. + + BLOCK_SIZE_M > next_pow2(headdim) wastes >=50% of each tile via masking + (offs_m >= dim rows are zeroed out), so we cap the sweep there. + """ + ceiling = 1 + while ceiling < headdim: + ceiling <<= 1 + return [b for b in _BSM_CHOICES_ALL if b <= ceiling] + + +# effective_batch = batch * nheads_per_rank — the kernel grid scales with +# the product, so configs transfer across (model, TP) combos sharing +# (headdim, dstate, cache_dtype). +# Ceiling 262144 covers 256-head at TP1, max BS=1024 (256 * 1024). +EFFECTIVE_BATCH_SIZES = [ + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 192, + 256, + 384, + 512, + 768, + 1024, + 1536, + 2048, + 3072, + 4096, + 6144, + 8192, + 12288, + 16384, + 24576, + 32768, + 49152, + 65536, + 98304, + 131072, + 196608, + 262144, +] ALL_DSTATES = [16, 32, 64, 128, 256] +# Default tuning shape — matches Nemotron-3-Super and Nemotron-3-Nano Mamba layers. +# Override with CLI flags for other architectures. +DEFAULT_HEADDIM = 64 +DEFAULT_NGROUPS = 8 + # --------------------------------------------------------------------------- # Config file naming (mirrors fused_moe pattern) # --------------------------------------------------------------------------- -def get_ssm_config_file_name(dstate: int) -> str: - return f"dstate={dstate}.json" - - def get_device_name() -> str: return current_platform.get_device_name().replace(" ", "_") @@ -86,9 +136,12 @@ def _make_inputs( dstate: int, ngroups: int, dtype: torch.dtype, + state_dtype: torch.dtype | None = None, device: str = "cuda", ): - state = torch.randn(batch, nheads, dim, dstate, dtype=dtype, device=device) + if state_dtype is None: + state_dtype = dtype + state = torch.randn(batch, nheads, dim, dstate, dtype=state_dtype, device=device) x = torch.randn(batch, nheads, dim, dtype=dtype, device=device) dt = torch.randn(batch, nheads, dim, dtype=dtype, device=device) A = -torch.rand(nheads, dim, dstate, dtype=torch.float32, device=device) @@ -109,6 +162,7 @@ def benchmark_config( block_size_m: int, num_warps_val: int, dtype: torch.dtype, + state_dtype: torch.dtype | None = None, num_iters: int = 100, num_warmup: int = 20, ) -> float | None: @@ -117,17 +171,17 @@ def benchmark_config( Returns elapsed time in microseconds, or None on error. """ state, x, dt, A, B, C, D, dt_bias, out = _make_inputs( - batch, nheads, dim, dstate, ngroups, dtype + batch, nheads, dim, dstate, ngroups, dtype, state_dtype=state_dtype ) - # Monkeypatch _get_ssm_launch_config to return the specific config + # Monkeypatch try_get_optimal_ssm_config to return the specific config # without affecting the lru_cache on get_ssm_configs. - def _fixed_launch_config(dstate_, batch_, is_blackwell_): + def _fixed_launch_config(*_args, **_kwargs): return block_size_m, num_warps_val try: with patch.object( - mamba_ssm_module, "_get_ssm_launch_config", _fixed_launch_config + mamba_ssm_module, "try_get_optimal_ssm_config", _fixed_launch_config ): # Warmup for _ in range(num_warmup): @@ -180,45 +234,111 @@ def _fixed_launch_config(dstate_, batch_, is_blackwell_): # --------------------------------------------------------------------------- +# CUDA grid Y/Z dim limit — both `batch` and `nheads` must fit individually, +# so effective_batch > 65535 has to be split across the two. +_CUDA_MAX_GRID_DIM = 65535 + + +def _factor_effective_batch( + effective_batch: int, ngroups: int +) -> tuple[int, int] | None: + """Return (batch, nheads) with batch*nheads == effective_batch such that + both fit the CUDA grid Y/Z dim limit and nheads is a positive multiple of + ngroups. Prefers batch=1 (the cheapest split) when it fits. + + Returns None if no valid factorization exists. + """ + for batch in range(1, _CUDA_MAX_GRID_DIM + 1): + if batch > effective_batch or effective_batch % batch != 0: + continue + nheads = effective_batch // batch + if nheads > _CUDA_MAX_GRID_DIM: + continue + if nheads % ngroups != 0: + continue + return batch, nheads + return None + + +def _resolve_effective_batches( + user_supplied: list[int] | None, + ngroups: int, +) -> list[tuple[int, int, int]]: + """Return [(effective_batch, batch, nheads)] for each valid sweep point. + + Drops any effective_batch with no valid (batch, nheads) factorization + that satisfies both the CUDA grid dim limit and nheads % ngroups == 0. + """ + candidates = user_supplied if user_supplied is not None else EFFECTIVE_BATCH_SIZES + valid: list[tuple[int, int, int]] = [] + skipped: list[int] = [] + for eb in candidates: + if eb <= 0: + skipped.append(eb) + continue + factored = _factor_effective_batch(eb, ngroups) + if factored is None: + skipped.append(eb) + continue + batch, nheads = factored + valid.append((eb, batch, nheads)) + if skipped: + print( + f" Note: skipping effective_batch values with no valid " + f"(batch, nheads) factorization for ngroups={ngroups} " + f"under CUDA grid dim {_CUDA_MAX_GRID_DIM}: {skipped}" + ) + return valid + + def tune_dstate( dstate: int, + headdim: int, + ngroups: int, dtype: torch.dtype, num_iters: int, verbose: bool, - batch_sizes: list[int] | None = None, + effective_batches: list[int] | None = None, + state_dtype: torch.dtype | None = None, ) -> dict[int, dict]: + """For each effective_batch, sweep (BLOCK_SIZE_M, num_warps) and return + {effective_batch: best_config}. effective_batch is factored into + (batch, nheads) by `_factor_effective_batch`. """ - For each batch size, sweep all (BLOCK_SIZE_M, num_warps) combos and - return a dict mapping batch_size -> best_config. - """ - # Use a representative shape for tuning (Mamba-2 style, common case). - nheads, dim, ngroups = 64, 64, 1 - active_batches = batch_sizes if batch_sizes is not None else BATCH_SIZES + active = _resolve_effective_batches(effective_batches, ngroups) - best_per_batch: dict[int, dict] = {} + best_per_eb: dict[int, dict] = {} print(f"\n{'=' * 74}") - print(f"Tuning dstate={dstate} nheads={nheads} dim={dim} dtype={dtype}") + effective_state_dtype = state_dtype if state_dtype is not None else dtype + print( + f"Tuning headdim={headdim} dstate={dstate} ngroups={ngroups} " + f"dtype={dtype} ssm_cache_dtype={effective_state_dtype}" + ) print(f"{'=' * 74}") - hdr = f"{'Batch':>7} | {'BLOCK_M':>7} | {'warps':>5} | {'us':>10} | note" + bsm_choices = _block_size_m_choices(headdim) + print(f"BSM candidates (capped at next_pow2(headdim={headdim})): {bsm_choices}") + + hdr = f"{'EffBatch':>8} | {'BLOCK_M':>7} | {'warps':>5} | {'us':>10} | note" print(hdr) - print("-" * 50) + print("-" * 52) - for batch in active_batches: + for eb, batch, nheads in active: best_time = float("inf") best_cfg: dict = {} - for bsm, nw in product(BLOCK_SIZE_M_CHOICES, NUM_WARPS_CHOICES): + for bsm, nw in product(bsm_choices, NUM_WARPS_CHOICES): t = benchmark_config( batch=batch, nheads=nheads, - dim=dim, + dim=headdim, dstate=dstate, ngroups=ngroups, block_size_m=bsm, num_warps_val=nw, dtype=dtype, + state_dtype=state_dtype, num_iters=num_iters, ) if t is None: @@ -229,17 +349,24 @@ def tune_dstate( best_cfg = {"BLOCK_SIZE_M": bsm, "num_warps": nw} if verbose: marker = " <-- best" if is_best else "" - print(f"{batch:>7} | {bsm:>7} | {nw:>5} | {t:>10.2f} |{marker}") + print(f"{eb:>8} | {bsm:>7} | {nw:>5} | {t:>10.2f} |{marker}") + + if not best_cfg: + print( + f"{eb:>8} | {'-':>7} | {'-':>5} | {'-':>10} | " + f"no working config (skipped)" + ) + continue - if not verbose and best_cfg: + if not verbose: print( - f"{batch:>7} | {best_cfg['BLOCK_SIZE_M']:>7} | " + f"{eb:>8} | {best_cfg['BLOCK_SIZE_M']:>7} | " f"{best_cfg['num_warps']:>5} | {best_time:>10.2f} | best" ) - best_per_batch[batch] = best_cfg + best_per_eb[eb] = best_cfg - return best_per_batch + return best_per_eb # --------------------------------------------------------------------------- @@ -299,37 +426,52 @@ def _selective_state_update_ref( def validate_configs( dstate: int, + headdim: int, + ngroups: int, tuned: dict[int, dict], dtype: torch.dtype, atol: float = 1e-2, rtol: float = 1e-2, + state_dtype: torch.dtype | None = None, ) -> dict[int, bool]: """ - For every batch size in *tuned*, run the kernel with the tuned config and - compare against the CPU reference. Returns {batch: passed}. + For every effective_batch in *tuned*, run the kernel with the tuned config + and compare against the CPU reference. Returns {effective_batch: passed}. """ - nheads, dim, ngroups = 64, 64, 1 - print(f"\n{'=' * 74}") - print(f"Validation dstate={dstate} dtype={dtype} atol={atol}") + effective_state_dtype = state_dtype if state_dtype is not None else dtype + print( + f"Validation headdim={headdim} dstate={dstate} ngroups={ngroups} " + f"dtype={dtype} ssm_cache_dtype={effective_state_dtype} atol={atol}" + ) print(f"{'=' * 74}") - print(f"{'Batch':>7} | {'MaxAbsErr':>12} | {'Status':>8}") + print(f"{'EffBatch':>8} | {'MaxAbsErr':>12} | {'Status':>8}") print("-" * 36) results: dict[int, bool] = {} - for batch, cfg in sorted(tuned.items()): + for eb, cfg in sorted(tuned.items()): + factored = _factor_effective_batch(eb, ngroups) + if factored is None: + continue + batch, nheads = factored state, x, dt, A, B, C, D, dt_bias, out = _make_inputs( - batch, nheads, dim, dstate, ngroups, dtype + batch=batch, + nheads=nheads, + dim=headdim, + dstate=dstate, + ngroups=ngroups, + dtype=dtype, + state_dtype=state_dtype, ) # Clone state before GPU kernel modifies it in-place state_ref = state.clone() # GPU kernel output - def _fixed(dstate_, batch_, is_blackwell_, _cfg=cfg): + def _fixed(*_args, _cfg=cfg, **_kwargs): return _cfg["BLOCK_SIZE_M"], _cfg["num_warps"] - with patch.object(mamba_ssm_module, "_get_ssm_launch_config", _fixed): + with patch.object(mamba_ssm_module, "try_get_optimal_ssm_config", _fixed): selective_state_update( state, x, @@ -352,8 +494,8 @@ def _fixed(dstate_, batch_, is_blackwell_, _cfg=cfg): passed = torch.allclose(gpu_out.float(), ref_out.float(), atol=atol, rtol=rtol) max_err = (gpu_out.float() - ref_out.float()).abs().max().item() status = "PASS" if passed else "FAIL" - results[batch] = passed - print(f"{batch:>7} | {max_err:>12.6f} | {status:>8}") + results[eb] = passed + print(f"{eb:>8} | {max_err:>12.6f} | {status:>8}") n_pass = sum(results.values()) n_total = len(results) @@ -367,14 +509,25 @@ def _fixed(dstate_, batch_, is_blackwell_, _cfg=cfg): def save_configs( - dstate: int, configs: dict[int, dict], save_dir: str | None = None + headdim: int, + dstate: int, + cache_dtype: str, + configs: dict[int, dict], + save_dir: str | None = None, ) -> str: base_dir = save_dir if save_dir else get_ssm_configs_dir() - # Place configs in a per-GPU subfolder for easy multi-GPU organisation. - configs_dir = os.path.join(base_dir, get_device_name()) - os.makedirs(configs_dir, exist_ok=True) - file_path = os.path.join(configs_dir, get_ssm_config_file_name(dstate)) - payload = {str(k): v for k, v in sorted(configs.items())} + os.makedirs(base_dir, exist_ok=True) + file_path = os.path.join( + base_dir, + mamba_ssm_module.get_ssm_config_file_name( + headdim, dstate, cache_dtype, get_device_name() + ), + ) + # triton_version is informational only, the loader ignores it + payload: dict[str, Any] = { + "triton_version": triton.__version__, + **{str(k): v for k, v in sorted(configs.items())}, + } with open(file_path, "w") as f: json.dump(payload, f, indent=4) return file_path @@ -404,59 +557,70 @@ def current_heuristic(dstate: int, is_blackwell: bool = False) -> dict: def compare_heuristic_vs_tuned( dstate: int, + headdim: int, + ngroups: int, tuned: dict[int, dict], dtype: torch.dtype, num_iters: int, is_blackwell: bool, + effective_batches: list[int] | None = None, + state_dtype: torch.dtype | None = None, ): - nheads, dim, ngroups = 64, 64, 1 + active = _resolve_effective_batches(effective_batches, ngroups) heur_cfg = current_heuristic(dstate, is_blackwell) print(f"\n{'=' * 74}") - print(f"Comparison dstate={dstate} — heuristic vs tuned") + print( + f"Comparison headdim={headdim} dstate={dstate} " + f"ngroups={ngroups} — heuristic vs tuned" + ) print( f"Heuristic: BLOCK_SIZE_M={heur_cfg['BLOCK_SIZE_M']}, " f"num_warps={heur_cfg['num_warps']}" ) print(f"{'=' * 74}") hdr = ( - f"{'Batch':>7} | {'Heur(us)':>10} | {'Tuned(us)':>10} | " + f"{'EffBatch':>8} | {'Heur(us)':>10} | {'Tuned(us)':>10} | " f"{'Speedup':>8} | Best config" ) print(hdr) print("-" * len(hdr)) - for batch in BATCH_SIZES: + for eb, batch, nheads in active: t_h = benchmark_config( - batch, - nheads, - dim, - dstate, - ngroups, - heur_cfg["BLOCK_SIZE_M"], - heur_cfg["num_warps"], - dtype, - num_iters, + batch=batch, + nheads=nheads, + dim=headdim, + dstate=dstate, + ngroups=ngroups, + block_size_m=heur_cfg["BLOCK_SIZE_M"], + num_warps_val=heur_cfg["num_warps"], + dtype=dtype, + state_dtype=state_dtype, + num_iters=num_iters, ) - best = tuned.get(batch, heur_cfg) + # `tuned[eb]` may be missing if all configs failed in tune_dstate; + # in that case fall back to the heuristic so the table still prints. + best = tuned.get(eb) or heur_cfg t_t = benchmark_config( - batch, - nheads, - dim, - dstate, - ngroups, - best["BLOCK_SIZE_M"], - best["num_warps"], - dtype, - num_iters, + batch=batch, + nheads=nheads, + dim=headdim, + dstate=dstate, + ngroups=ngroups, + block_size_m=best["BLOCK_SIZE_M"], + num_warps_val=best["num_warps"], + dtype=dtype, + state_dtype=state_dtype, + num_iters=num_iters, ) if t_h is None or t_t is None: - print(f"{batch:>7} | {'N/A':>10} | {'N/A':>10} | {'N/A':>8} |") + print(f"{eb:>8} | {'N/A':>10} | {'N/A':>10} | {'N/A':>8} |") continue speedup = t_h / t_t marker = " <--" if speedup > 1.05 else "" print( - f"{batch:>7} | {t_h:>10.2f} | {t_t:>10.2f} | " + f"{eb:>8} | {t_h:>10.2f} | {t_t:>10.2f} | " f"{speedup:>7.2f}x | " f"M={best['BLOCK_SIZE_M']},w={best['num_warps']}{marker}" ) @@ -499,7 +663,14 @@ def main(): type=str, default="bfloat16", choices=["float16", "bfloat16"], - help="Data type (default: bfloat16)", + help="Activation / input data type (default: bfloat16)", + ) + parser.add_argument( + "--mamba-ssm-cache-dtype", + type=str, + default="float32", + choices=list(_SSM_CACHE_DTYPE_MAP.keys()), + help="SSM state cache dtype (default: float32)", ) parser.add_argument( "--num-iters", @@ -533,18 +704,28 @@ def main(): "--save-dir", type=str, default=None, - help="Base directory to save JSON configs. Configs are placed in a " - "per-GPU subfolder: //. " + help="Directory to save JSON configs. " "(default: vllm/model_executor/layers/mamba/configs/)", ) parser.add_argument( - "--batches", + "--headdim", + type=int, + default=DEFAULT_HEADDIM, + help=f"Per-head feature dim (default: {DEFAULT_HEADDIM})", + ) + parser.add_argument( + "--ngroups", + type=int, + default=DEFAULT_NGROUPS, + help=f"Number of B/C groups (default: {DEFAULT_NGROUPS})", + ) + parser.add_argument( + "--effective-batches", type=int, nargs="+", default=None, - metavar="B", - help="Only tune these specific batch sizes, e.g. --batches 2 16 256. " - "Useful for stability re-checks on flagged configs.", + metavar="EB", + help="Tune only these effective_batch values (default: full sweep)", ) parser.add_argument( "--validate", @@ -561,6 +742,7 @@ def main(): args = parser.parse_args() dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + state_dtype = _SSM_CACHE_DTYPE_MAP[args.mamba_ssm_cache_dtype] device_name = current_platform.get_device_name() cap = torch.cuda.get_device_capability() is_blackwell = cap[0] >= 10 @@ -581,43 +763,76 @@ def flush(self): sys.stdout = _Tee() # type: ignore[assignment] try: - print(f"Device : {device_name} (sm_{cap[0]}{cap[1]:02d})") + print(f"Device : {device_name} (sm_{cap[0]}{cap[1]})") print(f"Blackwell: {is_blackwell}") print(f"dtype : {args.dtype}") + print(f"ssm_cache_dtype: {args.mamba_ssm_cache_dtype}") + print(f"headdim: {args.headdim}") + print(f"ngroups: {args.ngroups}") + print(f"triton : {triton.__version__}") dstates = ALL_DSTATES if args.all_dstates else [args.dstate] for dstate in dstates: tuned = tune_dstate( - dstate, dtype, args.num_iters, args.verbose, args.batches + dstate=dstate, + headdim=args.headdim, + ngroups=args.ngroups, + dtype=dtype, + num_iters=args.num_iters, + verbose=args.verbose, + effective_batches=args.effective_batches, + state_dtype=state_dtype, ) if args.compare: compare_heuristic_vs_tuned( - dstate, tuned, dtype, args.num_iters, is_blackwell + dstate=dstate, + headdim=args.headdim, + ngroups=args.ngroups, + tuned=tuned, + dtype=dtype, + num_iters=args.num_iters, + is_blackwell=is_blackwell, + effective_batches=args.effective_batches, + state_dtype=state_dtype, ) if args.validate: - validity = validate_configs(dstate, tuned, dtype, args.atol) + validity = validate_configs( + dstate=dstate, + headdim=args.headdim, + ngroups=args.ngroups, + tuned=tuned, + dtype=dtype, + atol=args.atol, + state_dtype=state_dtype, + ) # Filter out any configs that failed correctness check - failed = [b for b, ok in validity.items() if not ok] + failed = [eb for eb, ok in validity.items() if not ok] if failed: print( - f"\n WARNING: {len(failed)} config(s) failed " - f"validation for dstate={dstate}: batches {failed}" + f"\n WARNING: {len(failed)} config(s) failed validation " + f"for dstate={dstate}: effective_batches {failed}" ) print(" These will NOT be saved even with --save-configs.") tuned = { - b: cfg for b, cfg in tuned.items() if validity.get(b, True) + eb: cfg for eb, cfg in tuned.items() if validity.get(eb, True) } if args.save_configs: - path = save_configs(dstate, tuned, args.save_dir) + path = save_configs( + headdim=args.headdim, + dstate=dstate, + cache_dtype=args.mamba_ssm_cache_dtype, + configs=tuned, + save_dir=args.save_dir, + ) print(f"\nSaved: {path}") else: print(f"\nBest configs for dstate={dstate}:") - for batch, cfg in sorted(tuned.items()): - print(f" batch={batch:>5}: {cfg}") + for eb, cfg in sorted(tuned.items()): + print(f" effective_batch={eb:>6}: {cfg}") print("\n(Re-run with --save-configs to persist to JSON)") finally: sys.stdout = sys.__stdout__ diff --git a/tests/kernels/mamba/test_mamba_ssm_configs.py b/tests/kernels/mamba/test_mamba_ssm_configs.py index 66812831e846..3629c2119657 100644 --- a/tests/kernels/mamba/test_mamba_ssm_configs.py +++ b/tests/kernels/mamba/test_mamba_ssm_configs.py @@ -4,55 +4,80 @@ Unit tests for the JSON-based config loader added to selective_state_update. Tests cover: - - Config filename generation - - VLLM_TUNED_CONFIG_FOLDER env-var override (per-GPU subfolder structure) + - Flat MoE-style filename generation + - VLLM_TUNED_CONFIG_FOLDER env-var override - Fallback to heuristic when no config file exists - - Nearest-batch interpolation + - Nearest effective_batch interpolation + - Edge cases: non-dict JSON, empty config """ import json from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - _get_ssm_launch_config, get_ssm_config_file_name, get_ssm_configs, get_ssm_device_name, + try_get_optimal_ssm_config, ) +# Common kwargs for try_get_optimal_ssm_config. Tests pick (batch, nheads) so +# their product (effective_batch) matches the value being probed. +_HEADDIM = 64 +_CACHE_DTYPE = "float32" + + +def _clear_caches() -> None: + get_ssm_configs.cache_clear() + try_get_optimal_ssm_config.cache_clear() + + +def _write_config(tmp_path, dstate: int, payload: dict) -> None: + """Write payload as the bundled config for (headdim, dstate, cache_dtype).""" + device_name = get_ssm_device_name() + config_path = tmp_path / get_ssm_config_file_name( + _HEADDIM, dstate, _CACHE_DTYPE, device_name + ) + with open(config_path, "w") as f: + json.dump(payload, f) + + # --------------------------------------------------------------------------- # Config filename generation # --------------------------------------------------------------------------- def test_config_file_name_format(): - name = get_ssm_config_file_name(128) - assert name == "dstate=128.json" + name = get_ssm_config_file_name( + headdim=64, dstate=128, cache_dtype="float32", device_name="NVIDIA_B200" + ) + assert name == ( + "headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json" + ) # --------------------------------------------------------------------------- -# VLLM_TUNED_CONFIG_FOLDER override (configs live in //dstate=N.json) +# VLLM_TUNED_CONFIG_FOLDER override # --------------------------------------------------------------------------- def test_env_override_loads_custom_config(monkeypatch, tmp_path): """VLLM_TUNED_CONFIG_FOLDER should take precedence over the bundled dir.""" - device_name = get_ssm_device_name() - gpu_dir = tmp_path / device_name - gpu_dir.mkdir() - - config_path = gpu_dir / get_ssm_config_file_name(16) - payload = {"1": {"BLOCK_SIZE_M": 4, "num_warps": 1}} - with open(config_path, "w") as f: - json.dump(payload, f) + _write_config( + tmp_path, + dstate=16, + payload={ + "1": {"BLOCK_SIZE_M": 4, "num_warps": 1}, + }, + ) monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path)) - get_ssm_configs.cache_clear() + _clear_caches() - cfg = get_ssm_configs(16) + cfg = get_ssm_configs(_HEADDIM, 16, _CACHE_DTYPE) assert cfg is not None assert cfg[1] == {"BLOCK_SIZE_M": 4, "num_warps": 1} - get_ssm_configs.cache_clear() + _clear_caches() # --------------------------------------------------------------------------- @@ -61,60 +86,85 @@ def test_env_override_loads_custom_config(monkeypatch, tmp_path): def test_fallback_when_no_config(monkeypatch, tmp_path): - """_get_ssm_launch_config must fall back to the hard-coded heuristic + """try_get_optimal_ssm_config must fall back to the hard-coded heuristic when no JSON file is found for the current device.""" monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path)) monkeypatch.setattr( "vllm.model_executor.layers.mamba.ops.mamba_ssm._CONFIGS_DIR", str(tmp_path), ) - get_ssm_configs.cache_clear() + _clear_caches() # dstate=64 heuristic: BLOCK_SIZE_M=8, num_warps=4 - block_m, warps = _get_ssm_launch_config(dstate=64, batch=1, is_blackwell=False) + block_m, warps = try_get_optimal_ssm_config( + headdim=_HEADDIM, + dstate=64, + batch=1, + nheads=1, + cache_dtype=_CACHE_DTYPE, + is_blackwell=False, + ) assert block_m == 8 assert warps == 4 # dstate=16 heuristic: BLOCK_SIZE_M=32, num_warps=4 - block_m, warps = _get_ssm_launch_config(dstate=16, batch=1, is_blackwell=False) + block_m, warps = try_get_optimal_ssm_config( + headdim=_HEADDIM, + dstate=16, + batch=1, + nheads=1, + cache_dtype=_CACHE_DTYPE, + is_blackwell=False, + ) assert block_m == 32 assert warps == 4 - get_ssm_configs.cache_clear() + _clear_caches() # --------------------------------------------------------------------------- -# Nearest-batch interpolation +# Nearest effective_batch interpolation # --------------------------------------------------------------------------- -def test_nearest_batch_interpolation(monkeypatch, tmp_path): - """When the exact batch size is not in the config, the closest key - should be selected.""" - device_name = get_ssm_device_name() - gpu_dir = tmp_path / device_name - gpu_dir.mkdir() - - config_path = gpu_dir / get_ssm_config_file_name(32) - payload = { - "1": {"BLOCK_SIZE_M": 8, "num_warps": 1}, - "64": {"BLOCK_SIZE_M": 32, "num_warps": 4}, - } - with open(config_path, "w") as f: - json.dump(payload, f) +def test_nearest_effective_batch_interpolation(monkeypatch, tmp_path): + """When effective_batch = batch*nheads is not an exact key, the closest + key should be selected.""" + _write_config( + tmp_path, + dstate=32, + payload={ + "64": {"BLOCK_SIZE_M": 8, "num_warps": 1}, + "4096": {"BLOCK_SIZE_M": 32, "num_warps": 4}, + }, + ) monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path)) - get_ssm_configs.cache_clear() - - # batch=5 is closer to 1 than to 64 — expects M=8, w=1 - block_m, warps = _get_ssm_launch_config(dstate=32, batch=5, is_blackwell=False) + _clear_caches() + + # effective_batch = 1*128 = 128 -> closer to 64 than to 4096 + block_m, warps = try_get_optimal_ssm_config( + headdim=_HEADDIM, + dstate=32, + batch=1, + nheads=128, + cache_dtype=_CACHE_DTYPE, + is_blackwell=False, + ) assert block_m == 8 and warps == 1 - # batch=40 is closer to 64 — expects M=32, w=4 - block_m, warps = _get_ssm_launch_config(dstate=32, batch=40, is_blackwell=False) + # effective_batch = 4*1024 = 4096 -> exact match on 4096 + block_m, warps = try_get_optimal_ssm_config( + headdim=_HEADDIM, + dstate=32, + batch=4, + nheads=1024, + cache_dtype=_CACHE_DTYPE, + is_blackwell=False, + ) assert block_m == 32 and warps == 4 - get_ssm_configs.cache_clear() + _clear_caches() # --------------------------------------------------------------------------- @@ -126,10 +176,9 @@ def test_non_dict_json_returns_none(monkeypatch, tmp_path): """A valid JSON file that is not a dict (e.g. a list) must be ignored and return None rather than raising AttributeError.""" device_name = get_ssm_device_name() - gpu_dir = tmp_path / device_name - gpu_dir.mkdir() - - config_path = gpu_dir / get_ssm_config_file_name(16) + config_path = tmp_path / get_ssm_config_file_name( + _HEADDIM, 16, _CACHE_DTYPE, device_name + ) with open(config_path, "w") as f: json.dump([1, 2, 3], f) @@ -138,30 +187,31 @@ def test_non_dict_json_returns_none(monkeypatch, tmp_path): "vllm.model_executor.layers.mamba.ops.mamba_ssm._CONFIGS_DIR", str(tmp_path), ) - get_ssm_configs.cache_clear() + _clear_caches() - assert get_ssm_configs(16) is None + assert get_ssm_configs(_HEADDIM, 16, _CACHE_DTYPE) is None - get_ssm_configs.cache_clear() + _clear_caches() def test_empty_config_falls_back_to_heuristic(monkeypatch, tmp_path): """An empty JSON object {} must not crash min() — should fall back to the hard-coded heuristic.""" - device_name = get_ssm_device_name() - gpu_dir = tmp_path / device_name - gpu_dir.mkdir() - - config_path = gpu_dir / get_ssm_config_file_name(64) - with open(config_path, "w") as f: - json.dump({}, f) + _write_config(tmp_path, dstate=64, payload={}) monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path)) - get_ssm_configs.cache_clear() + _clear_caches() # dstate=64 heuristic: BLOCK_SIZE_M=8, num_warps=4 - block_m, warps = _get_ssm_launch_config(dstate=64, batch=1, is_blackwell=False) + block_m, warps = try_get_optimal_ssm_config( + headdim=_HEADDIM, + dstate=64, + batch=1, + nheads=64, + cache_dtype=_CACHE_DTYPE, + is_blackwell=False, + ) assert block_m == 8 assert warps == 4 - get_ssm_configs.cache_clear() + _clear_caches() diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=128.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=128.json deleted file mode 100644 index 6d571eb702ff..000000000000 --- a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=128.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 64, - "num_warps": 1 - }, - "2": { - "BLOCK_SIZE_M": 64, - "num_warps": 1 - }, - "4": { - "BLOCK_SIZE_M": 16, - "num_warps": 8 - }, - "8": { - "BLOCK_SIZE_M": 8, - "num_warps": 1 - }, - "16": { - "BLOCK_SIZE_M": 64, - "num_warps": 2 - }, - "32": { - "BLOCK_SIZE_M": 8, - "num_warps": 1 - }, - "64": { - "BLOCK_SIZE_M": 64, - "num_warps": 4 - }, - "128": { - "BLOCK_SIZE_M": 32, - "num_warps": 4 - }, - "256": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 - }, - "512": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 - }, - "1024": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 - } -} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=16.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=16.json deleted file mode 100644 index 7db73c38803a..000000000000 --- a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=16.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 4, - "num_warps": 8 - }, - "2": { - "BLOCK_SIZE_M": 8, - "num_warps": 2 - }, - "4": { - "BLOCK_SIZE_M": 64, - "num_warps": 4 - }, - "8": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 - }, - "16": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "32": { - "BLOCK_SIZE_M": 64, - "num_warps": 4 - }, - "64": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 - }, - "128": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - }, - "256": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - }, - "512": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 - }, - "1024": { - "BLOCK_SIZE_M": 64, - "num_warps": 1 - } -} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=256.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=256.json deleted file mode 100644 index 61d58288b7ac..000000000000 --- a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=256.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 8, - "num_warps": 4 - }, - "2": { - "BLOCK_SIZE_M": 32, - "num_warps": 2 - }, - "4": { - "BLOCK_SIZE_M": 8, - "num_warps": 2 - }, - "8": { - "BLOCK_SIZE_M": 4, - "num_warps": 4 - }, - "16": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "32": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 - }, - "64": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "128": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "256": { - "BLOCK_SIZE_M": 32, - "num_warps": 4 - }, - "512": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "1024": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - } -} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=32.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=32.json deleted file mode 100644 index cc0585cf7eef..000000000000 --- a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=32.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 32, - "num_warps": 4 - }, - "2": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "4": { - "BLOCK_SIZE_M": 4, - "num_warps": 8 - }, - "8": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - }, - "16": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "32": { - "BLOCK_SIZE_M": 8, - "num_warps": 1 - }, - "64": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 - }, - "128": { - "BLOCK_SIZE_M": 64, - "num_warps": 2 - }, - "256": { - "BLOCK_SIZE_M": 64, - "num_warps": 1 - }, - "512": { - "BLOCK_SIZE_M": 64, - "num_warps": 2 - }, - "1024": { - "BLOCK_SIZE_M": 64, - "num_warps": 2 - } -} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=64.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=64.json deleted file mode 100644 index 1274671669c6..000000000000 --- a/vllm/model_executor/layers/mamba/configs/NVIDIA_B200/dstate=64.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "2": { - "BLOCK_SIZE_M": 4, - "num_warps": 2 - }, - "4": { - "BLOCK_SIZE_M": 16, - "num_warps": 8 - }, - "8": { - "BLOCK_SIZE_M": 8, - "num_warps": 1 - }, - "16": { - "BLOCK_SIZE_M": 16, - "num_warps": 1 - }, - "32": { - "BLOCK_SIZE_M": 32, - "num_warps": 2 - }, - "64": { - "BLOCK_SIZE_M": 16, - "num_warps": 1 - }, - "128": { - "BLOCK_SIZE_M": 8, - "num_warps": 1 - }, - "256": { - "BLOCK_SIZE_M": 16, - "num_warps": 1 - }, - "512": { - "BLOCK_SIZE_M": 16, - "num_warps": 1 - }, - "1024": { - "BLOCK_SIZE_M": 16, - "num_warps": 1 - } -} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=128.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=128.json deleted file mode 100644 index a35da36b3541..000000000000 --- a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=128.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 32, - "num_warps": 2 - }, - "2": { - "BLOCK_SIZE_M": 8, - "num_warps": 2 - }, - "4": { - "BLOCK_SIZE_M": 64, - "num_warps": 2 - }, - "8": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "16": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "32": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - }, - "64": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - }, - "128": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - }, - "256": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - }, - "512": { - "BLOCK_SIZE_M": 4, - "num_warps": 8 - }, - "1024": { - "BLOCK_SIZE_M": 32, - "num_warps": 8 - } -} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=16.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=16.json deleted file mode 100644 index 136cbea31b6f..000000000000 --- a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=16.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "2": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 - }, - "4": { - "BLOCK_SIZE_M": 16, - "num_warps": 1 - }, - "8": { - "BLOCK_SIZE_M": 16, - "num_warps": 1 - }, - "16": { - "BLOCK_SIZE_M": 64, - "num_warps": 1 - }, - "32": { - "BLOCK_SIZE_M": 32, - "num_warps": 2 - }, - "64": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 - }, - "128": { - "BLOCK_SIZE_M": 64, - "num_warps": 1 - }, - "256": { - "BLOCK_SIZE_M": 16, - "num_warps": 4 - }, - "512": { - "BLOCK_SIZE_M": 64, - "num_warps": 2 - }, - "1024": { - "BLOCK_SIZE_M": 64, - "num_warps": 2 - } -} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=256.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=256.json deleted file mode 100644 index d2141e287759..000000000000 --- a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=256.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 64, - "num_warps": 4 - }, - "2": { - "BLOCK_SIZE_M": 32, - "num_warps": 2 - }, - "4": { - "BLOCK_SIZE_M": 4, - "num_warps": 8 - }, - "8": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "16": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "32": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - }, - "64": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - }, - "128": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - }, - "256": { - "BLOCK_SIZE_M": 4, - "num_warps": 8 - }, - "512": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - }, - "1024": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - } -} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=32.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=32.json deleted file mode 100644 index 63406b0aa1fb..000000000000 --- a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=32.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "num_warps": 4 - }, - "2": { - "BLOCK_SIZE_M": 64, - "num_warps": 2 - }, - "4": { - "BLOCK_SIZE_M": 32, - "num_warps": 4 - }, - "8": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "16": { - "BLOCK_SIZE_M": 16, - "num_warps": 1 - }, - "32": { - "BLOCK_SIZE_M": 16, - "num_warps": 4 - }, - "64": { - "BLOCK_SIZE_M": 16, - "num_warps": 1 - }, - "128": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - }, - "256": { - "BLOCK_SIZE_M": 8, - "num_warps": 2 - }, - "512": { - "BLOCK_SIZE_M": 8, - "num_warps": 2 - }, - "1024": { - "BLOCK_SIZE_M": 8, - "num_warps": 2 - } -} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=64.json b/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=64.json deleted file mode 100644 index d983a23e44d2..000000000000 --- a/vllm/model_executor/layers/mamba/configs/NVIDIA_GB10/dstate=64.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 32, - "num_warps": 2 - }, - "2": { - "BLOCK_SIZE_M": 8, - "num_warps": 4 - }, - "4": { - "BLOCK_SIZE_M": 32, - "num_warps": 4 - }, - "8": { - "BLOCK_SIZE_M": 16, - "num_warps": 4 - }, - "16": { - "BLOCK_SIZE_M": 16, - "num_warps": 1 - }, - "32": { - "BLOCK_SIZE_M": 32, - "num_warps": 2 - }, - "64": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "128": { - "BLOCK_SIZE_M": 8, - "num_warps": 4 - }, - "256": { - "BLOCK_SIZE_M": 8, - "num_warps": 4 - }, - "512": { - "BLOCK_SIZE_M": 32, - "num_warps": 8 - }, - "1024": { - "BLOCK_SIZE_M": 4, - "num_warps": 4 - } -} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json new file mode 100644 index 000000000000..2f15d30329a5 --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json @@ -0,0 +1,123 @@ +{ + "triton_version": "3.6.0", + "8": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "num_warps": 2 + }, + "48": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "192": { + "BLOCK_SIZE_M": 64, + "num_warps": 1 + }, + "256": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "384": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "512": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "768": { + "BLOCK_SIZE_M": 32, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_M": 4, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_SIZE_M": 8, + "num_warps": 1 + }, + "3072": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "4096": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "6144": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "8192": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "12288": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "16384": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "24576": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "32768": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "49152": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "65536": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "98304": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "131072": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "196608": { + "BLOCK_SIZE_M": 32, + "num_warps": 1 + }, + "262144": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + } +} diff --git a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json new file mode 100644 index 000000000000..7b98a6fd82c7 --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json @@ -0,0 +1,123 @@ +{ + "triton_version": "3.6.0", + "8": { + "BLOCK_SIZE_M": 4, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "24": { + "BLOCK_SIZE_M": 4, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE_M": 32, + "num_warps": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_M": 4, + "num_warps": 2 + }, + "96": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "num_warps": 8 + }, + "192": { + "BLOCK_SIZE_M": 4, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE_M": 4, + "num_warps": 4 + }, + "384": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "512": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "768": { + "BLOCK_SIZE_M": 4, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_M": 4, + "num_warps": 2 + }, + "1536": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "3072": { + "BLOCK_SIZE_M": 16, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "6144": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "8192": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "12288": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "24576": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "32768": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "49152": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "65536": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "98304": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "131072": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "196608": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "262144": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + } +} diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 88a60a9553de..2fa581835199 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -12,6 +12,7 @@ import torch from packaging import version +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp @@ -33,13 +34,17 @@ ) -def get_ssm_config_file_name(dstate: int) -> str: - """Return the JSON filename for the given dstate. +def get_ssm_config_file_name( + headdim: int, dstate: int, cache_dtype: str, device_name: str +) -> str: + """Return the JSON filename for the given kernel shape. - Config files are organised per GPU: - configs//dstate=.json + Layout: ``configs/headdim=,dstate=,device_name=,cache_dtype=
.json``. """ - return f"dstate={dstate}.json" + return ( + f"headdim={headdim},dstate={dstate}," + f"device_name={device_name},cache_dtype={cache_dtype}.json" + ) def get_ssm_device_name() -> str: @@ -47,29 +52,31 @@ def get_ssm_device_name() -> str: @functools.lru_cache -def get_ssm_configs(dstate: int) -> dict[int, Any] | None: +def get_ssm_configs( + headdim: int, dstate: int, cache_dtype: str +) -> dict[int, Any] | None: """ Return tuned (BLOCK_SIZE_M, num_warps) configs for *selective_state_update* - keyed by batch size, or ``None`` if no config file is found. - - Config files live in a per-GPU subfolder: - vllm/model_executor/layers/mamba/configs//dstate=.json + keyed by ``effective_batch = batch * nheads``, or ``None`` if no config + file is found for the (headdim, dstate, cache_dtype, device) combination. They can be generated with: benchmarks/kernels/benchmark_selective_state_update.py --save-configs """ device_name = get_ssm_device_name() - json_file_name = get_ssm_config_file_name(dstate) + json_file_name = get_ssm_config_file_name(headdim, dstate, cache_dtype, device_name) config_file_paths: list[str] = [] # User-supplied override (same env-var as fused_moe) - user_dir = os.environ.get("VLLM_TUNED_CONFIG_FOLDER") - if user_dir is not None: - config_file_paths.append(os.path.join(user_dir, device_name, json_file_name)) + user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER + if user_defined_config_folder is not None: + config_file_paths.append( + os.path.join(user_defined_config_folder, json_file_name) + ) # Bundled default - config_file_paths.append(os.path.join(_CONFIGS_DIR, device_name, json_file_name)) + config_file_paths.append(os.path.join(_CONFIGS_DIR, json_file_name)) for path in config_file_paths: if os.path.exists(path): @@ -81,31 +88,23 @@ def get_ssm_configs(dstate: int) -> dict[int, Any] | None: ) raw = json.load(f) if isinstance(raw, dict): + # triton_version included in the config file only for reference + raw.pop("triton_version", None) return {int(k): v for k, v in raw.items()} + logger.warning_once( + "Using default Mamba SSU config. Performance might be sub-optimal! " + "Config file not found at %s", + ", ".join(config_file_paths), + ) return None -def _get_ssm_launch_config( +def _get_default_ssm_launch_config( dstate: int, - batch: int, is_blackwell: bool, ) -> tuple[int, int]: - """ - Return (BLOCK_SIZE_M, num_warps) for a given dstate and batch size. - - Tries the JSON config first; falls back to the original hard-coded - heuristic so existing behaviour is fully preserved when no config file - is present. - """ - configs = get_ssm_configs(dstate) - if configs: - # Pick the closest batch size in the tuned grid (same strategy as MoE) - closest = min(configs.keys(), key=lambda x: abs(x - batch)) - cfg = configs[closest] - return cfg["BLOCK_SIZE_M"], cfg["num_warps"] - - # ---- original hard-coded heuristic (unchanged) ---- + """Hard-coded fallback heuristic used when no tuned config is available.""" BLOCK_SIZE_M, num_warps = 4, 8 if dstate <= 16: BLOCK_SIZE_M, num_warps = 32, 4 @@ -121,6 +120,32 @@ def _get_ssm_launch_config( return BLOCK_SIZE_M, num_warps +@functools.lru_cache +def try_get_optimal_ssm_config( + headdim: int, + dstate: int, + batch: int, + nheads: int, + cache_dtype: str, + is_blackwell: bool, +) -> tuple[int, int]: + """Return (BLOCK_SIZE_M, num_warps) for the given kernel shape. + + Tuning is keyed on ``effective_batch = batch * nheads`` (the kernel grid + scales with the product), so configs transfer across (model, TP) combos + sharing ``(headdim, dstate, cache_dtype)``. + """ + effective_batch = batch * nheads + configs = get_ssm_configs(headdim, dstate, cache_dtype) + if configs: + # Pick the closest effective_batch in the tuned grid (MoE strategy). + closest = min(configs.keys(), key=lambda x: abs(x - effective_batch)) + cfg = configs[closest] + return cfg["BLOCK_SIZE_M"], cfg["num_warps"] + + return _get_default_ssm_launch_config(dstate, is_blackwell) + + if TRITON3: @triton.jit @@ -548,7 +573,10 @@ def selective_state_update( ) # We don't want autotune since it will overwrite the state. # Load from JSON config if available, otherwise fall back to heuristic. - BLOCK_SIZE_M, num_warps = _get_ssm_launch_config(dstate, N, is_blackwell) + cache_dtype = str(state.dtype).removeprefix("torch.") + BLOCK_SIZE_M, num_warps = try_get_optimal_ssm_config( + dim, dstate, N, nheads, cache_dtype, is_blackwell + ) tie_hdim = ( A.stride(-1) == 0 From a5e44cb984b2c27e288d1c999b872030ac39457e Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Tue, 19 May 2026 13:23:23 +0300 Subject: [PATCH 03/24] Use cache (same as lru_cache with maxsize None) for try_get_optimal_ssm_config Signed-off-by: Daniel Serebrenik --- vllm/model_executor/layers/mamba/ops/mamba_ssm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 2fa581835199..3a4a0d4b650a 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -120,7 +120,7 @@ def _get_default_ssm_launch_config( return BLOCK_SIZE_M, num_warps -@functools.lru_cache +@functools.cache def try_get_optimal_ssm_config( headdim: int, dstate: int, From e1db1d7d4fba201927a35679c937caaa3ea20768 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Tue, 19 May 2026 23:03:57 +0300 Subject: [PATCH 04/24] Add tuned JSON files for H100 Signed-off-by: Daniel Serebrenik --- ...IA_H100_80GB_HBM3,cache_dtype=float16.json | 123 ++++++++++++++++++ ...IA_H100_80GB_HBM3,cache_dtype=float32.json | 123 ++++++++++++++++++ 2 files changed, 246 insertions(+) create mode 100644 vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json create mode 100644 vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json diff --git a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json new file mode 100644 index 000000000000..4a2ca1b7aa9d --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json @@ -0,0 +1,123 @@ +{ + "triton_version": "3.6.0", + "8": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_M": 8, + "num_warps": 1 + }, + "24": { + "BLOCK_SIZE_M": 4, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + }, + "48": { + "BLOCK_SIZE_M": 8, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "num_warps": 8 + }, + "96": { + "BLOCK_SIZE_M": 64, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_M": 4, + "num_warps": 4 + }, + "192": { + "BLOCK_SIZE_M": 4, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE_M": 32, + "num_warps": 4 + }, + "384": { + "BLOCK_SIZE_M": 4, + "num_warps": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "768": { + "BLOCK_SIZE_M": 16, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_M": 8, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "num_warps": 2 + }, + "3072": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "4096": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "6144": { + "BLOCK_SIZE_M": 4, + "num_warps": 2 + }, + "8192": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "12288": { + "BLOCK_SIZE_M": 4, + "num_warps": 2 + }, + "16384": { + "BLOCK_SIZE_M": 16, + "num_warps": 8 + }, + "24576": { + "BLOCK_SIZE_M": 16, + "num_warps": 8 + }, + "32768": { + "BLOCK_SIZE_M": 16, + "num_warps": 8 + }, + "49152": { + "BLOCK_SIZE_M": 16, + "num_warps": 8 + }, + "65536": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "98304": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "131072": { + "BLOCK_SIZE_M": 32, + "num_warps": 1 + }, + "196608": { + "BLOCK_SIZE_M": 32, + "num_warps": 1 + }, + "262144": { + "BLOCK_SIZE_M": 32, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json new file mode 100644 index 000000000000..e5ba7b1e2f28 --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json @@ -0,0 +1,123 @@ +{ + "triton_version": "3.6.0", + "8": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_M": 32, + "num_warps": 2 + }, + "24": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "32": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "96": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "num_warps": 8 + }, + "192": { + "BLOCK_SIZE_M": 64, + "num_warps": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "384": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "512": { + "BLOCK_SIZE_M": 32, + "num_warps": 2 + }, + "768": { + "BLOCK_SIZE_M": 32, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_M": 4, + "num_warps": 2 + }, + "1536": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "2048": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "3072": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "6144": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "8192": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "12288": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "16384": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "24576": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "32768": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "49152": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "65536": { + "BLOCK_SIZE_M": 8, + "num_warps": 4 + }, + "98304": { + "BLOCK_SIZE_M": 8, + "num_warps": 4 + }, + "131072": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "196608": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "262144": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + } +} \ No newline at end of file From 56a19280c3b3e80c09eaff6aa91b8ca584ff76f3 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Wed, 20 May 2026 11:19:11 +0300 Subject: [PATCH 05/24] Add CUDA graphs to script Signed-off-by: Daniel Serebrenik --- .../benchmark_selective_state_update.py | 76 ++++++----- ..._name=NVIDIA_B200,cache_dtype=float16.json | 86 ++++++------ ..._name=NVIDIA_B200,cache_dtype=float32.json | 50 +++---- ...IA_H100_80GB_HBM3,cache_dtype=float16.json | 123 ------------------ ...IA_H100_80GB_HBM3,cache_dtype=float32.json | 123 ------------------ 5 files changed, 113 insertions(+), 345 deletions(-) delete mode 100644 vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json delete mode 100644 vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json diff --git a/benchmarks/kernels/benchmark_selective_state_update.py b/benchmarks/kernels/benchmark_selective_state_update.py index a33217237908..e6e8f7d559a6 100644 --- a/benchmarks/kernels/benchmark_selective_state_update.py +++ b/benchmarks/kernels/benchmark_selective_state_update.py @@ -165,10 +165,15 @@ def benchmark_config( state_dtype: torch.dtype | None = None, num_iters: int = 100, num_warmup: int = 20, + graph_batch_size: int = 10, ) -> float | None: """ Time one (BLOCK_SIZE_M, num_warps) config for selective_state_update. Returns elapsed time in microseconds, or None on error. + + Uses CUDA graph capture-and-replay to isolate kernel time from Python + eager-mode dispatch / kwarg-resolution overhead, mirroring the timing + methodology in benchmarks/kernels/benchmark_moe.py. """ state, x, dt, A, B, C, D, dt_bias, out = _make_inputs( batch, nheads, dim, dstate, ngroups, dtype, state_dtype=state_dtype @@ -179,47 +184,56 @@ def benchmark_config( def _fixed_launch_config(*_args, **_kwargs): return block_size_m, num_warps_val + def _call_kernel() -> None: + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + out=out, + ) + try: with patch.object( mamba_ssm_module, "try_get_optimal_ssm_config", _fixed_launch_config ): - # Warmup + # Eager-mode warmup: triggers Triton autotune / JIT, primes caches. for _ in range(num_warmup): - selective_state_update( - state, - x, - dt, - A, - B, - C, - D=D, - z=None, - dt_bias=dt_bias, - dt_softplus=True, - out=out, - ) + _call_kernel() + torch.accelerator.synchronize() + + # Capture graph_batch_size invocations into a CUDA graph so the + # timed region runs without Python dispatch overhead per call. + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(graph_batch_size): + _call_kernel() + torch.accelerator.synchronize() + + # Warmup graph replays (let the runtime stabilize). + for _ in range(5): + graph.replay() torch.accelerator.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) - start.record() + latencies: list[float] = [] for _ in range(num_iters): - selective_state_update( - state, - x, - dt, - A, - B, - C, - D=D, - z=None, - dt_bias=dt_bias, - dt_softplus=True, - out=out, - ) - end.record() - torch.accelerator.synchronize() - return start.elapsed_time(end) / num_iters * 1000 # ms -> us + start.record() + graph.replay() + end.record() + end.synchronize() + latencies.append(start.elapsed_time(end)) + graph.reset() + # elapsed_time returns ms; each replay runs graph_batch_size kernels, + # so divide by (num_iters * graph_batch_size) and convert ms -> us. + return sum(latencies) / (num_iters * graph_batch_size) * 1000 except Exception as e: if "OutOfResources" not in str(e): print( diff --git a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json index 2f15d30329a5..40f061a2247e 100644 --- a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json +++ b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json @@ -1,87 +1,87 @@ { "triton_version": "3.6.0", "8": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 4, "num_warps": 4 }, "16": { - "BLOCK_SIZE_M": 64, - "num_warps": 4 + "BLOCK_SIZE_M": 4, + "num_warps": 1 }, "24": { - "BLOCK_SIZE_M": 64, - "num_warps": 4 + "BLOCK_SIZE_M": 4, + "num_warps": 1 }, "32": { - "BLOCK_SIZE_M": 32, - "num_warps": 2 + "BLOCK_SIZE_M": 16, + "num_warps": 4 }, "48": { "BLOCK_SIZE_M": 4, - "num_warps": 1 + "num_warps": 2 }, "64": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "num_warps": 4 }, "96": { - "BLOCK_SIZE_M": 32, - "num_warps": 8 + "BLOCK_SIZE_M": 16, + "num_warps": 4 }, "128": { - "BLOCK_SIZE_M": 8, - "num_warps": 2 + "BLOCK_SIZE_M": 32, + "num_warps": 8 }, "192": { - "BLOCK_SIZE_M": 64, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "256": { "BLOCK_SIZE_M": 16, - "num_warps": 4 + "num_warps": 2 }, "384": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 + "BLOCK_SIZE_M": 32, + "num_warps": 2 }, "512": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "num_warps": 1 }, "768": { - "BLOCK_SIZE_M": 32, - "num_warps": 4 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "1024": { - "BLOCK_SIZE_M": 4, - "num_warps": 4 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "1536": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "2048": { "BLOCK_SIZE_M": 8, - "num_warps": 1 + "num_warps": 2 }, "3072": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, "4096": { - "BLOCK_SIZE_M": 16, - "num_warps": 4 + "BLOCK_SIZE_M": 4, + "num_warps": 1 }, "6144": { - "BLOCK_SIZE_M": 8, - "num_warps": 2 + "BLOCK_SIZE_M": 4, + "num_warps": 1 }, "8192": { "BLOCK_SIZE_M": 8, "num_warps": 2 }, "12288": { - "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_M": 16, "num_warps": 2 }, "16384": { @@ -93,31 +93,31 @@ "num_warps": 2 }, "32768": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 + "BLOCK_SIZE_M": 32, + "num_warps": 4 }, "49152": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 + "BLOCK_SIZE_M": 32, + "num_warps": 1 }, "65536": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 + "BLOCK_SIZE_M": 32, + "num_warps": 1 }, "98304": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 + "BLOCK_SIZE_M": 32, + "num_warps": 1 }, "131072": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 + "BLOCK_SIZE_M": 32, + "num_warps": 1 }, "196608": { "BLOCK_SIZE_M": 32, "num_warps": 1 }, "262144": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "num_warps": 2 } -} +} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json index 7b98a6fd82c7..4e71c71dfdc0 100644 --- a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json +++ b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json @@ -2,71 +2,71 @@ "triton_version": "3.6.0", "8": { "BLOCK_SIZE_M": 4, - "num_warps": 4 + "num_warps": 1 }, "16": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 4, "num_warps": 1 }, "24": { "BLOCK_SIZE_M": 4, - "num_warps": 8 + "num_warps": 1 }, "32": { - "BLOCK_SIZE_M": 32, - "num_warps": 2 + "BLOCK_SIZE_M": 4, + "num_warps": 1 }, "48": { - "BLOCK_SIZE_M": 16, - "num_warps": 4 + "BLOCK_SIZE_M": 4, + "num_warps": 1 }, "64": { "BLOCK_SIZE_M": 4, - "num_warps": 2 + "num_warps": 1 }, "96": { - "BLOCK_SIZE_M": 8, - "num_warps": 2 + "BLOCK_SIZE_M": 4, + "num_warps": 1 }, "128": { - "BLOCK_SIZE_M": 16, - "num_warps": 8 + "BLOCK_SIZE_M": 4, + "num_warps": 1 }, "192": { "BLOCK_SIZE_M": 4, - "num_warps": 8 + "num_warps": 1 }, "256": { "BLOCK_SIZE_M": 4, - "num_warps": 4 + "num_warps": 1 }, "384": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 + "BLOCK_SIZE_M": 4, + "num_warps": 1 }, "512": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 + "BLOCK_SIZE_M": 4, + "num_warps": 1 }, "768": { "BLOCK_SIZE_M": 4, - "num_warps": 4 + "num_warps": 1 }, "1024": { "BLOCK_SIZE_M": 4, - "num_warps": 2 + "num_warps": 1 }, "1536": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, "2048": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 + "BLOCK_SIZE_M": 4, + "num_warps": 1 }, "3072": { - "BLOCK_SIZE_M": 16, - "num_warps": 8 + "BLOCK_SIZE_M": 4, + "num_warps": 1 }, "4096": { "BLOCK_SIZE_M": 4, @@ -120,4 +120,4 @@ "BLOCK_SIZE_M": 4, "num_warps": 1 } -} +} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json deleted file mode 100644 index 4a2ca1b7aa9d..000000000000 --- a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json +++ /dev/null @@ -1,123 +0,0 @@ -{ - "triton_version": "3.6.0", - "8": { - "BLOCK_SIZE_M": 16, - "num_warps": 4 - }, - "16": { - "BLOCK_SIZE_M": 8, - "num_warps": 1 - }, - "24": { - "BLOCK_SIZE_M": 4, - "num_warps": 8 - }, - "32": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - }, - "48": { - "BLOCK_SIZE_M": 8, - "num_warps": 4 - }, - "64": { - "BLOCK_SIZE_M": 32, - "num_warps": 8 - }, - "96": { - "BLOCK_SIZE_M": 64, - "num_warps": 8 - }, - "128": { - "BLOCK_SIZE_M": 4, - "num_warps": 4 - }, - "192": { - "BLOCK_SIZE_M": 4, - "num_warps": 8 - }, - "256": { - "BLOCK_SIZE_M": 32, - "num_warps": 4 - }, - "384": { - "BLOCK_SIZE_M": 4, - "num_warps": 2 - }, - "512": { - "BLOCK_SIZE_M": 16, - "num_warps": 4 - }, - "768": { - "BLOCK_SIZE_M": 16, - "num_warps": 8 - }, - "1024": { - "BLOCK_SIZE_M": 32, - "num_warps": 4 - }, - "1536": { - "BLOCK_SIZE_M": 8, - "num_warps": 4 - }, - "2048": { - "BLOCK_SIZE_M": 32, - "num_warps": 2 - }, - "3072": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "4096": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "6144": { - "BLOCK_SIZE_M": 4, - "num_warps": 2 - }, - "8192": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "12288": { - "BLOCK_SIZE_M": 4, - "num_warps": 2 - }, - "16384": { - "BLOCK_SIZE_M": 16, - "num_warps": 8 - }, - "24576": { - "BLOCK_SIZE_M": 16, - "num_warps": 8 - }, - "32768": { - "BLOCK_SIZE_M": 16, - "num_warps": 8 - }, - "49152": { - "BLOCK_SIZE_M": 16, - "num_warps": 8 - }, - "65536": { - "BLOCK_SIZE_M": 8, - "num_warps": 2 - }, - "98304": { - "BLOCK_SIZE_M": 8, - "num_warps": 2 - }, - "131072": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 - }, - "196608": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 - }, - "262144": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 - } -} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json deleted file mode 100644 index e5ba7b1e2f28..000000000000 --- a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json +++ /dev/null @@ -1,123 +0,0 @@ -{ - "triton_version": "3.6.0", - "8": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "16": { - "BLOCK_SIZE_M": 32, - "num_warps": 2 - }, - "24": { - "BLOCK_SIZE_M": 8, - "num_warps": 2 - }, - "32": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "48": { - "BLOCK_SIZE_M": 16, - "num_warps": 4 - }, - "64": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "96": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "128": { - "BLOCK_SIZE_M": 16, - "num_warps": 8 - }, - "192": { - "BLOCK_SIZE_M": 64, - "num_warps": 2 - }, - "256": { - "BLOCK_SIZE_M": 16, - "num_warps": 1 - }, - "384": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "512": { - "BLOCK_SIZE_M": 32, - "num_warps": 2 - }, - "768": { - "BLOCK_SIZE_M": 32, - "num_warps": 4 - }, - "1024": { - "BLOCK_SIZE_M": 4, - "num_warps": 2 - }, - "1536": { - "BLOCK_SIZE_M": 8, - "num_warps": 2 - }, - "2048": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "3072": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "4096": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "6144": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "8192": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "12288": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "16384": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "24576": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "32768": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "49152": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "65536": { - "BLOCK_SIZE_M": 8, - "num_warps": 4 - }, - "98304": { - "BLOCK_SIZE_M": 8, - "num_warps": 4 - }, - "131072": { - "BLOCK_SIZE_M": 16, - "num_warps": 1 - }, - "196608": { - "BLOCK_SIZE_M": 16, - "num_warps": 1 - }, - "262144": { - "BLOCK_SIZE_M": 16, - "num_warps": 1 - } -} \ No newline at end of file From 1dd1de1118e39b4aa01c2648f30e645e52524eef Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Wed, 20 May 2026 14:48:02 +0300 Subject: [PATCH 06/24] Add tuned JSONs for H100 Signed-off-by: Daniel Serebrenik --- ...IA_H100_80GB_HBM3,cache_dtype=float16.json | 123 ++++++++++++++++++ ...IA_H100_80GB_HBM3,cache_dtype=float32.json | 123 ++++++++++++++++++ 2 files changed, 246 insertions(+) create mode 100644 vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json create mode 100644 vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json diff --git a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json new file mode 100644 index 000000000000..a51944b59cfa --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json @@ -0,0 +1,123 @@ +{ + "triton_version": "3.6.0", + "8": { + "BLOCK_SIZE_M": 4, + "num_warps": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "48": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_M": 4, + "num_warps": 2 + }, + "128": { + "BLOCK_SIZE_M": 4, + "num_warps": 2 + }, + "192": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "256": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "384": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "512": { + "BLOCK_SIZE_M": 8, + "num_warps": 2 + }, + "768": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "1536": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "3072": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "4096": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "6144": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "8192": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "12288": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "24576": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "49152": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "65536": { + "BLOCK_SIZE_M": 32, + "num_warps": 1 + }, + "98304": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "131072": { + "BLOCK_SIZE_M": 32, + "num_warps": 1 + }, + "196608": { + "BLOCK_SIZE_M": 32, + "num_warps": 1 + }, + "262144": { + "BLOCK_SIZE_M": 32, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json new file mode 100644 index 000000000000..6ba079ba6fdc --- /dev/null +++ b/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json @@ -0,0 +1,123 @@ +{ + "triton_version": "3.6.0", + "8": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "24": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "48": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "96": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "192": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "384": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_SIZE_M": 8, + "num_warps": 4 + }, + "768": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "1024": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "1536": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "3072": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "6144": { + "BLOCK_SIZE_M": 4, + "num_warps": 2 + }, + "8192": { + "BLOCK_SIZE_M": 4, + "num_warps": 4 + }, + "12288": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_M": 4, + "num_warps": 2 + }, + "24576": { + "BLOCK_SIZE_M": 4, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_M": 4, + "num_warps": 4 + }, + "49152": { + "BLOCK_SIZE_M": 4, + "num_warps": 2 + }, + "65536": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "98304": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "131072": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "196608": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + }, + "262144": { + "BLOCK_SIZE_M": 16, + "num_warps": 1 + } +} \ No newline at end of file From aa4ad6dc2c10475b66a036f6f1c4dfd5cfd6d00e Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Wed, 20 May 2026 16:57:03 +0300 Subject: [PATCH 07/24] Remove ref to fused_moe, use cache instead of lru_cache Signed-off-by: Daniel Serebrenik --- vllm/model_executor/layers/mamba/ops/mamba_ssm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 3a4a0d4b650a..0f704345acd2 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -26,7 +26,7 @@ # --------------------------------------------------------------------------- -# JSON config loading (mirrors fused_moe pattern) +# JSON config loading # --------------------------------------------------------------------------- _CONFIGS_DIR = os.path.join( @@ -51,7 +51,7 @@ def get_ssm_device_name() -> str: return current_platform.get_device_name().replace(" ", "_") -@functools.lru_cache +@functools.cache def get_ssm_configs( headdim: int, dstate: int, cache_dtype: str ) -> dict[int, Any] | None: @@ -68,7 +68,7 @@ def get_ssm_configs( config_file_paths: list[str] = [] - # User-supplied override (same env-var as fused_moe) + # User-supplied override user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER if user_defined_config_folder is not None: config_file_paths.append( From 35103b8ce6f2314054afc7f599eb8cb1f6267fb4 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Wed, 20 May 2026 17:33:55 +0300 Subject: [PATCH 08/24] Remove duplicate functions from the tuning script Signed-off-by: Daniel Serebrenik --- .../benchmark_selective_state_update.py | 93 +++---------------- tests/kernels/mamba/__init__.py | 0 2 files changed, 15 insertions(+), 78 deletions(-) create mode 100644 tests/kernels/mamba/__init__.py diff --git a/benchmarks/kernels/benchmark_selective_state_update.py b/benchmarks/kernels/benchmark_selective_state_update.py index e6e8f7d559a6..8f2bae460063 100644 --- a/benchmarks/kernels/benchmark_selective_state_update.py +++ b/benchmarks/kernels/benchmark_selective_state_update.py @@ -10,7 +10,7 @@ up by selective_state_update at runtime. Usage: - python benchmarks/kernels/benchmark_selective_state_update.py \ + python -m benchmarks.kernels.benchmark_selective_state_update \ --all-dstates --save-configs --compare """ @@ -26,10 +26,13 @@ import torch import vllm.model_executor.layers.mamba.ops.mamba_ssm as mamba_ssm_module +from tests.kernels.mamba.test_mamba_ssm import selective_state_update_ref from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + _get_default_ssm_launch_config, + get_ssm_config_file_name, + get_ssm_device_name, selective_state_update, ) -from vllm.platforms import current_platform from vllm.triton_utils import triton # MambaDType subset: bf16 is excluded (not commonly used) @@ -107,14 +110,10 @@ def _block_size_m_choices(headdim: int) -> list[int]: # --------------------------------------------------------------------------- -# Config file naming (mirrors fused_moe pattern) +# Config file naming # --------------------------------------------------------------------------- -def get_device_name() -> str: - return current_platform.get_device_name().replace(" ", "_") - - def get_ssm_configs_dir() -> str: return os.path.normpath( os.path.join( @@ -388,56 +387,6 @@ def tune_dstate( # --------------------------------------------------------------------------- -def _selective_state_update_ref( - state: torch.Tensor, - x: torch.Tensor, - dt: torch.Tensor, - A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - D: torch.Tensor, - dt_bias: torch.Tensor, -) -> torch.Tensor: - """ - Pure-PyTorch CPU reference for selective_state_update (dt_softplus=True). - - Shapes (all moved to CPU float32 internally): - state : (batch, nheads, dim, dstate) - x : (batch, nheads, dim) - dt : (batch, nheads, dim) - A : (nheads, dim, dstate) - B : (batch, ngroups, dstate) - C : (batch, ngroups, dstate) - D : (nheads, dim) - dt_bias: (nheads, dim) - Returns: - out : (batch, nheads, dim) in the original dtype - """ - orig_dtype = x.dtype - state = state.clone().cpu().float() - x = x.cpu().float() - dt = dt.cpu().float() - A = A.cpu().float() - B = B.cpu().float() - C = C.cpu().float() - D = D.cpu().float() - dt = dt + dt_bias.cpu().float() - dt = torch.nn.functional.softplus(dt) # (batch, nheads, dim) - - nheads, _, _ = A.shape - ngroups = B.shape[1] - - dA = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0)) # (batch, nheads, dim, dstate) - B_exp = B.repeat_interleave(nheads // ngroups, dim=1) # (batch, nheads, dstate) - C_exp = C.repeat_interleave(nheads // ngroups, dim=1) - dB = dt.unsqueeze(-1) * B_exp.unsqueeze(2) # (batch, nheads, dim, dstate) - - state_new = state * dA + dB * x.unsqueeze(-1) - out = (state_new * C_exp.unsqueeze(2)).sum(-1) # (batch, nheads, dim) - out = out + x * D.unsqueeze(0) - return out.to(orig_dtype) - - def validate_configs( dstate: int, headdim: int, @@ -502,8 +451,10 @@ def _fixed(*_args, _cfg=cfg, **_kwargs): torch.accelerator.synchronize() gpu_out = out.detach().cpu() - # CPU reference uses the original (unmodified) state - ref_out = _selective_state_update_ref(state_ref, x, dt, A, B, C, D, dt_bias) + # Reference uses the original (unmodified) state + ref_out = selective_state_update_ref( + state_ref, x, dt, A, B, C, D=D, dt_bias=dt_bias, dt_softplus=True + ).cpu() passed = torch.allclose(gpu_out.float(), ref_out.float(), atol=atol, rtol=rtol) max_err = (gpu_out.float() - ref_out.float()).abs().max().item() @@ -533,9 +484,7 @@ def save_configs( os.makedirs(base_dir, exist_ok=True) file_path = os.path.join( base_dir, - mamba_ssm_module.get_ssm_config_file_name( - headdim, dstate, cache_dtype, get_device_name() - ), + get_ssm_config_file_name(headdim, dstate, cache_dtype, get_ssm_device_name()), ) # triton_version is informational only, the loader ignores it payload: dict[str, Any] = { @@ -554,19 +503,8 @@ def save_configs( def current_heuristic(dstate: int, is_blackwell: bool = False) -> dict: """Return the current hard-coded BLOCK_SIZE_M / num_warps for dstate.""" - if dstate <= 16: - return {"BLOCK_SIZE_M": 32, "num_warps": 4} - elif dstate <= 32: - return {"BLOCK_SIZE_M": 16, "num_warps": 4} - elif dstate <= 64: - return {"BLOCK_SIZE_M": 8, "num_warps": 4} - else: - if is_blackwell: - return {"BLOCK_SIZE_M": 32, "num_warps": 8} - elif dstate <= 128: - return {"BLOCK_SIZE_M": 4, "num_warps": 4} - else: - return {"BLOCK_SIZE_M": 4, "num_warps": 8} + bsm, nw = _get_default_ssm_launch_config(dstate, is_blackwell) + return {"BLOCK_SIZE_M": bsm, "num_warps": nw} def compare_heuristic_vs_tuned( @@ -648,9 +586,8 @@ def compare_heuristic_vs_tuned( def save_results(device_name: str, output: str, results_file: str | None = None) -> str: """Save the full benchmark output to a results text file.""" if results_file is None: - safe_name = device_name.replace(" ", "_") results_file = os.path.join( - _RESULTS_DIR, f"ssm_benchmark_results_{safe_name}.txt" + _RESULTS_DIR, f"ssm_benchmark_results_{device_name}.txt" ) with open(results_file, "w") as f: f.write(output) @@ -757,7 +694,7 @@ def main(): dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 state_dtype = _SSM_CACHE_DTYPE_MAP[args.mamba_ssm_cache_dtype] - device_name = current_platform.get_device_name() + device_name = get_ssm_device_name() cap = torch.cuda.get_device_capability() is_blackwell = cap[0] >= 10 diff --git a/tests/kernels/mamba/__init__.py b/tests/kernels/mamba/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From e77985f5593673b7888998127ecb8bd826f8be26 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Wed, 20 May 2026 18:16:59 +0300 Subject: [PATCH 09/24] Fix --validate failure Signed-off-by: Daniel Serebrenik --- .../benchmark_selective_state_update.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/benchmarks/kernels/benchmark_selective_state_update.py b/benchmarks/kernels/benchmark_selective_state_update.py index 8f2bae460063..a9c6c35738a7 100644 --- a/benchmarks/kernels/benchmark_selective_state_update.py +++ b/benchmarks/kernels/benchmark_selective_state_update.py @@ -399,8 +399,13 @@ def validate_configs( ) -> dict[int, bool]: """ For every effective_batch in *tuned*, run the kernel with the tuned config - and compare against the CPU reference. Returns {effective_batch: passed}. + and compare against the reference. Returns {effective_batch: passed}. """ + # Disable TF32 in the reference's matmul: at larger effective_batch the + # worst output value grows, so TF32 rounding shows up as a bf16-quantum + # mismatch (e.g. 1.0–4.0) versus the Triton kernel's true fp32 accumulation. + torch.set_float32_matmul_precision("highest") + print(f"\n{'=' * 74}") effective_state_dtype = state_dtype if state_dtype is not None else dtype print( @@ -452,9 +457,23 @@ def _fixed(*_args, _cfg=cfg, **_kwargs): gpu_out = out.detach().cpu() # Reference uses the original (unmodified) state - ref_out = selective_state_update_ref( - state_ref, x, dt, A, B, C, D=D, dt_bias=dt_bias, dt_softplus=True - ).cpu() + # Upcast to fp32 so the reference sums in fp32 (matches the Triton + # kernel); summing in bf16 over `dstate` blows up the error. + ref_out = ( + selective_state_update_ref( + state_ref.float(), + x.float(), + dt.float(), + A.float(), + B.float(), + C.float(), + D=D.float(), + dt_bias=dt_bias.float(), + dt_softplus=True, + ) + .to(out.dtype) + .cpu() + ) passed = torch.allclose(gpu_out.float(), ref_out.float(), atol=atol, rtol=rtol) max_err = (gpu_out.float() - ref_out.float()).abs().max().item() From af4769146399504f868e48f1a02eb8c2a26e616b Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Wed, 20 May 2026 19:26:29 +0300 Subject: [PATCH 10/24] Reuse tuned measurements for comparisons Signed-off-by: Daniel Serebrenik --- .../benchmark_selective_state_update.py | 65 ++++++++++--------- 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/benchmarks/kernels/benchmark_selective_state_update.py b/benchmarks/kernels/benchmark_selective_state_update.py index a9c6c35738a7..4dedfe75efa9 100644 --- a/benchmarks/kernels/benchmark_selective_state_update.py +++ b/benchmarks/kernels/benchmark_selective_state_update.py @@ -313,14 +313,16 @@ def tune_dstate( verbose: bool, effective_batches: list[int] | None = None, state_dtype: torch.dtype | None = None, -) -> dict[int, dict]: +) -> tuple[dict[int, dict], dict[int, dict[tuple[int, int], float]]]: """For each effective_batch, sweep (BLOCK_SIZE_M, num_warps) and return - {effective_batch: best_config}. effective_batch is factored into - (batch, nheads) by `_factor_effective_batch`. + ({effective_batch: best_config}, {effective_batch: {(bsm, nw): us}}). + The second map is the full timing grid, used downstream so we don't + re-measure the same config in the comparison phase. """ active = _resolve_effective_batches(effective_batches, ngroups) best_per_eb: dict[int, dict] = {} + timings: dict[int, dict[tuple[int, int], float]] = {} print(f"\n{'=' * 74}") effective_state_dtype = state_dtype if state_dtype is not None else dtype @@ -340,6 +342,7 @@ def tune_dstate( for eb, batch, nheads in active: best_time = float("inf") best_cfg: dict = {} + eb_timings: dict[tuple[int, int], float] = {} for bsm, nw in product(bsm_choices, NUM_WARPS_CHOICES): t = benchmark_config( @@ -356,6 +359,7 @@ def tune_dstate( ) if t is None: continue + eb_timings[(bsm, nw)] = t is_best = t < best_time if is_best: best_time = t @@ -364,6 +368,8 @@ def tune_dstate( marker = " <-- best" if is_best else "" print(f"{eb:>8} | {bsm:>7} | {nw:>5} | {t:>10.2f} |{marker}") + timings[eb] = eb_timings + if not best_cfg: print( f"{eb:>8} | {'-':>7} | {'-':>5} | {'-':>10} | " @@ -379,7 +385,7 @@ def tune_dstate( best_per_eb[eb] = best_cfg - return best_per_eb + return best_per_eb, timings # --------------------------------------------------------------------------- @@ -531,6 +537,7 @@ def compare_heuristic_vs_tuned( headdim: int, ngroups: int, tuned: dict[int, dict], + timings: dict[int, dict[tuple[int, int], float]], dtype: torch.dtype, num_iters: int, is_blackwell: bool, @@ -539,6 +546,7 @@ def compare_heuristic_vs_tuned( ): active = _resolve_effective_batches(effective_batches, ngroups) heur_cfg = current_heuristic(dstate, is_blackwell) + heur_key = (heur_cfg["BLOCK_SIZE_M"], heur_cfg["num_warps"]) print(f"\n{'=' * 74}") print( @@ -558,33 +566,30 @@ def compare_heuristic_vs_tuned( print("-" * len(hdr)) for eb, batch, nheads in active: - t_h = benchmark_config( - batch=batch, - nheads=nheads, - dim=headdim, - dstate=dstate, - ngroups=ngroups, - block_size_m=heur_cfg["BLOCK_SIZE_M"], - num_warps_val=heur_cfg["num_warps"], - dtype=dtype, - state_dtype=state_dtype, - num_iters=num_iters, - ) + eb_timings = timings.get(eb, {}) + + # Heuristic timing: reuse the tuning measurement if the heuristic + # config was in the swept grid; otherwise measure it once. + t_h = eb_timings.get(heur_key) + if t_h is None: + t_h = benchmark_config( + batch=batch, + nheads=nheads, + dim=headdim, + dstate=dstate, + ngroups=ngroups, + block_size_m=heur_cfg["BLOCK_SIZE_M"], + num_warps_val=heur_cfg["num_warps"], + dtype=dtype, + state_dtype=state_dtype, + num_iters=num_iters, + ) + # `tuned[eb]` may be missing if all configs failed in tune_dstate; # in that case fall back to the heuristic so the table still prints. best = tuned.get(eb) or heur_cfg - t_t = benchmark_config( - batch=batch, - nheads=nheads, - dim=headdim, - dstate=dstate, - ngroups=ngroups, - block_size_m=best["BLOCK_SIZE_M"], - num_warps_val=best["num_warps"], - dtype=dtype, - state_dtype=state_dtype, - num_iters=num_iters, - ) + t_t = eb_timings.get((best["BLOCK_SIZE_M"], best["num_warps"])) + if t_h is None or t_t is None: print(f"{eb:>8} | {'N/A':>10} | {'N/A':>10} | {'N/A':>8} |") continue @@ -744,7 +749,7 @@ def flush(self): dstates = ALL_DSTATES if args.all_dstates else [args.dstate] for dstate in dstates: - tuned = tune_dstate( + tuned, timings = tune_dstate( dstate=dstate, headdim=args.headdim, ngroups=args.ngroups, @@ -756,11 +761,13 @@ def flush(self): ) if args.compare: + # Use the measurements from tune_dstate compare_heuristic_vs_tuned( dstate=dstate, headdim=args.headdim, ngroups=args.ngroups, tuned=tuned, + timings=timings, dtype=dtype, num_iters=args.num_iters, is_blackwell=is_blackwell, From 702faaf65c750b0011cbf809129d8c97f827e2fb Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Wed, 20 May 2026 19:58:06 +0300 Subject: [PATCH 11/24] Add support for mamba cache bf16 (match to fp16 config) Signed-off-by: Daniel Serebrenik --- benchmarks/kernels/benchmark_selective_state_update.py | 8 +++++++- vllm/model_executor/layers/mamba/ops/mamba_ssm.py | 7 +++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_selective_state_update.py b/benchmarks/kernels/benchmark_selective_state_update.py index 4dedfe75efa9..4c56a4888773 100644 --- a/benchmarks/kernels/benchmark_selective_state_update.py +++ b/benchmarks/kernels/benchmark_selective_state_update.py @@ -28,6 +28,7 @@ import vllm.model_executor.layers.mamba.ops.mamba_ssm as mamba_ssm_module from tests.kernels.mamba.test_mamba_ssm import selective_state_update_ref from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + _canonical_cache_dtype, _get_default_ssm_launch_config, get_ssm_config_file_name, get_ssm_device_name, @@ -35,10 +36,12 @@ ) from vllm.triton_utils import triton -# MambaDType subset: bf16 is excluded (not commonly used) +# bf16 maps to float16 +# (same number of bits, same tuned config should work for both) _SSM_CACHE_DTYPE_MAP: dict[str, torch.dtype] = { "float32": torch.float32, "float16": torch.float16, + "bfloat16": torch.float16, } _RESULTS_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -505,6 +508,9 @@ def save_configs( configs: dict[int, dict], save_dir: str | None = None, ) -> str: + # bf16 shares configs with fp16, use common filename for both + cache_dtype = _canonical_cache_dtype(cache_dtype) + base_dir = save_dir if save_dir else get_ssm_configs_dir() os.makedirs(base_dir, exist_ok=True) file_path = os.path.join( diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 0f704345acd2..9793fb7ed1c1 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -51,6 +51,12 @@ def get_ssm_device_name() -> str: return current_platform.get_device_name().replace(" ", "_") +def _canonical_cache_dtype(cache_dtype: str) -> str: + """Canonical key for config lookup. bf16 and fp16 share the same tuned + configs because the kernel only sees bit width when accessing state.""" + return "float16" if cache_dtype == "bfloat16" else cache_dtype + + @functools.cache def get_ssm_configs( headdim: int, dstate: int, cache_dtype: str @@ -63,6 +69,7 @@ def get_ssm_configs( They can be generated with: benchmarks/kernels/benchmark_selective_state_update.py --save-configs """ + cache_dtype = _canonical_cache_dtype(cache_dtype) device_name = get_ssm_device_name() json_file_name = get_ssm_config_file_name(headdim, dstate, cache_dtype, device_name) From c856e723791bae8af8bc4f608ed70c42bdaccd49 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Wed, 20 May 2026 21:06:44 +0300 Subject: [PATCH 12/24] Add flags --batch-sizes and --nheads Signed-off-by: Daniel Serebrenik --- .../benchmark_selective_state_update.py | 161 +++++++----------- 1 file changed, 62 insertions(+), 99 deletions(-) diff --git a/benchmarks/kernels/benchmark_selective_state_update.py b/benchmarks/kernels/benchmark_selective_state_update.py index 4c56a4888773..2cb87483fee6 100644 --- a/benchmarks/kernels/benchmark_selective_state_update.py +++ b/benchmarks/kernels/benchmark_selective_state_update.py @@ -67,42 +67,12 @@ def _block_size_m_choices(headdim: int) -> list[int]: return [b for b in _BSM_CHOICES_ALL if b <= ceiling] -# effective_batch = batch * nheads_per_rank — the kernel grid scales with -# the product, so configs transfer across (model, TP) combos sharing -# (headdim, dstate, cache_dtype). -# Ceiling 262144 covers 256-head at TP1, max BS=1024 (256 * 1024). -EFFECTIVE_BATCH_SIZES = [ - 8, - 16, - 24, - 32, - 48, - 64, - 96, - 128, - 192, - 256, - 384, - 512, - 768, - 1024, - 1536, - 2048, - 3072, - 4096, - 6144, - 8192, - 12288, - 16384, - 24576, - 32768, - 49152, - 65536, - 98304, - 131072, - 196608, - 262144, -] +# Default deployment shapes. effective_batch = batch * nheads scales the +# kernel grid, so configs transfer across (model, TP) combos sharing +# (headdim, dstate, cache_dtype). nheads=128 and 256 cover common Mamba2 +# deployment shapes (Nemotron-class with/without TP). +DEFAULT_BATCH_SIZES = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 1536, 2048] +DEFAULT_NHEADS = [128, 256] ALL_DSTATES = [16, 32, 64, 128, 256] @@ -250,61 +220,44 @@ def _call_kernel() -> None: # --------------------------------------------------------------------------- -# CUDA grid Y/Z dim limit — both `batch` and `nheads` must fit individually, -# so effective_batch > 65535 has to be split across the two. +# CUDA grid Y/Z dim limit — both `batch` and `nheads` must fit individually. _CUDA_MAX_GRID_DIM = 65535 -def _factor_effective_batch( - effective_batch: int, ngroups: int -) -> tuple[int, int] | None: - """Return (batch, nheads) with batch*nheads == effective_batch such that - both fit the CUDA grid Y/Z dim limit and nheads is a positive multiple of - ngroups. Prefers batch=1 (the cheapest split) when it fits. - - Returns None if no valid factorization exists. - """ - for batch in range(1, _CUDA_MAX_GRID_DIM + 1): - if batch > effective_batch or effective_batch % batch != 0: - continue - nheads = effective_batch // batch - if nheads > _CUDA_MAX_GRID_DIM: - continue - if nheads % ngroups != 0: - continue - return batch, nheads - return None - - -def _resolve_effective_batches( - user_supplied: list[int] | None, +def expand_batch_x_nheads( + batch_sizes: list[int], + nheads_list: list[int], ngroups: int, ) -> list[tuple[int, int, int]]: - """Return [(effective_batch, batch, nheads)] for each valid sweep point. - - Drops any effective_batch with no valid (batch, nheads) factorization - that satisfies both the CUDA grid dim limit and nheads % ngroups == 0. + """Cross-product batch_sizes × nheads_list → sorted [(effective_batch, + batch, nheads)], deduped by effective_batch. Filters pairs that exceed + the CUDA grid dim limit or where nheads is not a positive multiple of + ngroups. """ - candidates = user_supplied if user_supplied is not None else EFFECTIVE_BATCH_SIZES - valid: list[tuple[int, int, int]] = [] - skipped: list[int] = [] - for eb in candidates: - if eb <= 0: - skipped.append(eb) + seen: dict[int, tuple[int, int]] = {} + skipped_grid: list[tuple[int, int]] = [] + skipped_ngroups: list[tuple[int, int]] = [] + for b, n in product(batch_sizes, nheads_list): + if b <= 0 or n <= 0: + continue + if b > _CUDA_MAX_GRID_DIM or n > _CUDA_MAX_GRID_DIM: + skipped_grid.append((b, n)) continue - factored = _factor_effective_batch(eb, ngroups) - if factored is None: - skipped.append(eb) + if n % ngroups != 0: + skipped_ngroups.append((b, n)) continue - batch, nheads = factored - valid.append((eb, batch, nheads)) - if skipped: + seen.setdefault(b * n, (b, n)) + if skipped_grid: print( - f" Note: skipping effective_batch values with no valid " - f"(batch, nheads) factorization for ngroups={ngroups} " - f"under CUDA grid dim {_CUDA_MAX_GRID_DIM}: {skipped}" + f" Note: skipping (batch, nheads) pairs exceeding CUDA grid dim " + f"{_CUDA_MAX_GRID_DIM}: {skipped_grid}" ) - return valid + if skipped_ngroups: + print( + f" Note: skipping (batch, nheads) pairs where nheads % ngroups != 0 " + f"for ngroups={ngroups}: {skipped_ngroups}" + ) + return sorted((eb, b, n) for eb, (b, n) in seen.items()) def tune_dstate( @@ -314,16 +267,15 @@ def tune_dstate( dtype: torch.dtype, num_iters: int, verbose: bool, - effective_batches: list[int] | None = None, + active: list[tuple[int, int, int]], state_dtype: torch.dtype | None = None, ) -> tuple[dict[int, dict], dict[int, dict[tuple[int, int], float]]]: - """For each effective_batch, sweep (BLOCK_SIZE_M, num_warps) and return + """For each (effective_batch, batch, nheads) in *active*, sweep + (BLOCK_SIZE_M, num_warps) and return ({effective_batch: best_config}, {effective_batch: {(bsm, nw): us}}). The second map is the full timing grid, used downstream so we don't re-measure the same config in the comparison phase. """ - active = _resolve_effective_batches(effective_batches, ngroups) - best_per_eb: dict[int, dict] = {} timings: dict[int, dict[tuple[int, int], float]] = {} @@ -401,14 +353,16 @@ def validate_configs( headdim: int, ngroups: int, tuned: dict[int, dict], + active: list[tuple[int, int, int]], dtype: torch.dtype, atol: float = 1e-2, rtol: float = 1e-2, state_dtype: torch.dtype | None = None, ) -> dict[int, bool]: """ - For every effective_batch in *tuned*, run the kernel with the tuned config - and compare against the reference. Returns {effective_batch: passed}. + For every (effective_batch, batch, nheads) in *active* that has a tuned + config, run the kernel with that config and compare against the reference. + Returns {effective_batch: passed}. """ # Disable TF32 in the reference's matmul: at larger effective_batch the # worst output value grows, so TF32 rounding shows up as a bf16-quantum @@ -427,11 +381,10 @@ def validate_configs( results: dict[int, bool] = {} - for eb, cfg in sorted(tuned.items()): - factored = _factor_effective_batch(eb, ngroups) - if factored is None: + for eb, batch, nheads in active: + cfg = tuned.get(eb) + if cfg is None: continue - batch, nheads = factored state, x, dt, A, B, C, D, dt_bias, out = _make_inputs( batch=batch, nheads=nheads, @@ -544,13 +497,12 @@ def compare_heuristic_vs_tuned( ngroups: int, tuned: dict[int, dict], timings: dict[int, dict[tuple[int, int], float]], + active: list[tuple[int, int, int]], dtype: torch.dtype, num_iters: int, is_blackwell: bool, - effective_batches: list[int] | None = None, state_dtype: torch.dtype | None = None, ): - active = _resolve_effective_batches(effective_batches, ngroups) heur_cfg = current_heuristic(dstate, is_blackwell) heur_key = (heur_cfg["BLOCK_SIZE_M"], heur_cfg["num_warps"]) @@ -701,12 +653,21 @@ def main(): help=f"Number of B/C groups (default: {DEFAULT_NGROUPS})", ) parser.add_argument( - "--effective-batches", + "--batch-sizes", type=int, nargs="+", - default=None, - metavar="EB", - help="Tune only these effective_batch values (default: full sweep)", + default=DEFAULT_BATCH_SIZES, + metavar="B", + help=f"Decoder batch sizes to sweep (default: {DEFAULT_BATCH_SIZES})", + ) + parser.add_argument( + "--nheads", + type=int, + nargs="+", + default=DEFAULT_NHEADS, + metavar="N", + help=f"Number of heads per rank to sweep (default: {DEFAULT_NHEADS}). " + "effective_batch = batch * nheads; cross-product is deduped by eb.", ) parser.add_argument( "--validate", @@ -753,6 +714,7 @@ def flush(self): print(f"triton : {triton.__version__}") dstates = ALL_DSTATES if args.all_dstates else [args.dstate] + active = expand_batch_x_nheads(args.batch_sizes, args.nheads, args.ngroups) for dstate in dstates: tuned, timings = tune_dstate( @@ -762,7 +724,7 @@ def flush(self): dtype=dtype, num_iters=args.num_iters, verbose=args.verbose, - effective_batches=args.effective_batches, + active=active, state_dtype=state_dtype, ) @@ -774,10 +736,10 @@ def flush(self): ngroups=args.ngroups, tuned=tuned, timings=timings, + active=active, dtype=dtype, num_iters=args.num_iters, is_blackwell=is_blackwell, - effective_batches=args.effective_batches, state_dtype=state_dtype, ) @@ -787,6 +749,7 @@ def flush(self): headdim=args.headdim, ngroups=args.ngroups, tuned=tuned, + active=active, dtype=dtype, atol=args.atol, state_dtype=state_dtype, From 621137a18d0e171aa6c4cec07af22ff5224a29c8 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Wed, 20 May 2026 21:56:04 +0300 Subject: [PATCH 13/24] Move location of config files Signed-off-by: Daniel Serebrenik --- .../benchmark_selective_state_update.py | 22 ++++--------------- ..._name=NVIDIA_B200,cache_dtype=float16.json | 0 ..._name=NVIDIA_B200,cache_dtype=float32.json | 0 ...IA_H100_80GB_HBM3,cache_dtype=float16.json | 0 ...IA_H100_80GB_HBM3,cache_dtype=float32.json | 0 .../layers/mamba/ops/mamba_ssm.py | 6 +++-- 6 files changed, 8 insertions(+), 20 deletions(-) rename vllm/model_executor/layers/mamba/{configs => ops/configs/selective_state_update}/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json (100%) rename vllm/model_executor/layers/mamba/{configs => ops/configs/selective_state_update}/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json (100%) rename vllm/model_executor/layers/mamba/{configs => ops/configs/selective_state_update}/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json (100%) rename vllm/model_executor/layers/mamba/{configs => ops/configs/selective_state_update}/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json (100%) diff --git a/benchmarks/kernels/benchmark_selective_state_update.py b/benchmarks/kernels/benchmark_selective_state_update.py index 2cb87483fee6..7e35367df6a3 100644 --- a/benchmarks/kernels/benchmark_selective_state_update.py +++ b/benchmarks/kernels/benchmark_selective_state_update.py @@ -28,6 +28,7 @@ import vllm.model_executor.layers.mamba.ops.mamba_ssm as mamba_ssm_module from tests.kernels.mamba.test_mamba_ssm import selective_state_update_ref from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + _CONFIGS_DIR, _canonical_cache_dtype, _get_default_ssm_launch_config, get_ssm_config_file_name, @@ -82,20 +83,6 @@ def _block_size_m_choices(headdim: int) -> list[int]: DEFAULT_NGROUPS = 8 -# --------------------------------------------------------------------------- -# Config file naming -# --------------------------------------------------------------------------- - - -def get_ssm_configs_dir() -> str: - return os.path.normpath( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "../../vllm/model_executor/layers/mamba/configs", - ) - ) - - # --------------------------------------------------------------------------- # Benchmark helper # --------------------------------------------------------------------------- @@ -464,7 +451,7 @@ def save_configs( # bf16 shares configs with fp16, use common filename for both cache_dtype = _canonical_cache_dtype(cache_dtype) - base_dir = save_dir if save_dir else get_ssm_configs_dir() + base_dir = save_dir if save_dir else _CONFIGS_DIR os.makedirs(base_dir, exist_ok=True) file_path = os.path.join( base_dir, @@ -614,7 +601,7 @@ def main(): parser.add_argument( "--save-configs", action="store_true", - help="Save best configs to JSON in mamba/configs/", + help=f"Save best configs to JSON in {_CONFIGS_DIR}", ) parser.add_argument( "--compare", @@ -637,8 +624,7 @@ def main(): "--save-dir", type=str, default=None, - help="Directory to save JSON configs. " - "(default: vllm/model_executor/layers/mamba/configs/)", + help=f"Directory to save JSON configs (default: {_CONFIGS_DIR})", ) parser.add_argument( "--headdim", diff --git a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json similarity index 100% rename from vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json rename to vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json diff --git a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json similarity index 100% rename from vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json rename to vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json diff --git a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json similarity index 100% rename from vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json rename to vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json diff --git a/vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json similarity index 100% rename from vllm/model_executor/layers/mamba/configs/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json rename to vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 9793fb7ed1c1..97573b70f43c 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -30,7 +30,7 @@ # --------------------------------------------------------------------------- _CONFIGS_DIR = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "..", "configs" + os.path.dirname(os.path.realpath(__file__)), "configs", "selective_state_update" ) @@ -97,7 +97,9 @@ def get_ssm_configs( if isinstance(raw, dict): # triton_version included in the config file only for reference raw.pop("triton_version", None) - return {int(k): v for k, v in raw.items()} + # Filter to integer-string keys to tolerate hand-edited + # configs with extra annotation fields. + return {int(k): v for k, v in raw.items() if k.isdigit()} logger.warning_once( "Using default Mamba SSU config. Performance might be sub-optimal! " From 878de248b8fb18c46dbc7e4a6e5e1ba6ecaffd26 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Wed, 20 May 2026 22:03:37 +0300 Subject: [PATCH 14/24] Use contextmanager for override_ssm_config Signed-off-by: Daniel Serebrenik --- .../benchmark_selective_state_update.py | 18 ++----- tests/kernels/mamba/test_mamba_ssm_configs.py | 3 +- .../layers/mamba/ops/mamba_ssm.py | 49 ++++++++++++++++--- 3 files changed, 47 insertions(+), 23 deletions(-) diff --git a/benchmarks/kernels/benchmark_selective_state_update.py b/benchmarks/kernels/benchmark_selective_state_update.py index 7e35367df6a3..6184026067df 100644 --- a/benchmarks/kernels/benchmark_selective_state_update.py +++ b/benchmarks/kernels/benchmark_selective_state_update.py @@ -21,11 +21,9 @@ from io import StringIO from itertools import product from typing import Any -from unittest.mock import patch import torch -import vllm.model_executor.layers.mamba.ops.mamba_ssm as mamba_ssm_module from tests.kernels.mamba.test_mamba_ssm import selective_state_update_ref from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( _CONFIGS_DIR, @@ -33,6 +31,7 @@ _get_default_ssm_launch_config, get_ssm_config_file_name, get_ssm_device_name, + override_ssm_config, selective_state_update, ) from vllm.triton_utils import triton @@ -138,11 +137,6 @@ def benchmark_config( batch, nheads, dim, dstate, ngroups, dtype, state_dtype=state_dtype ) - # Monkeypatch try_get_optimal_ssm_config to return the specific config - # without affecting the lru_cache on get_ssm_configs. - def _fixed_launch_config(*_args, **_kwargs): - return block_size_m, num_warps_val - def _call_kernel() -> None: selective_state_update( state, @@ -159,9 +153,7 @@ def _call_kernel() -> None: ) try: - with patch.object( - mamba_ssm_module, "try_get_optimal_ssm_config", _fixed_launch_config - ): + with override_ssm_config((block_size_m, num_warps_val)): # Eager-mode warmup: triggers Triton autotune / JIT, primes caches. for _ in range(num_warmup): _call_kernel() @@ -384,11 +376,7 @@ def validate_configs( # Clone state before GPU kernel modifies it in-place state_ref = state.clone() - # GPU kernel output - def _fixed(*_args, _cfg=cfg, **_kwargs): - return _cfg["BLOCK_SIZE_M"], _cfg["num_warps"] - - with patch.object(mamba_ssm_module, "try_get_optimal_ssm_config", _fixed): + with override_ssm_config((cfg["BLOCK_SIZE_M"], cfg["num_warps"])): selective_state_update( state, x, diff --git a/tests/kernels/mamba/test_mamba_ssm_configs.py b/tests/kernels/mamba/test_mamba_ssm_configs.py index 3629c2119657..97b89e9c8dd2 100644 --- a/tests/kernels/mamba/test_mamba_ssm_configs.py +++ b/tests/kernels/mamba/test_mamba_ssm_configs.py @@ -14,6 +14,7 @@ import json from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + _try_get_optimal_ssm_config_cached, get_ssm_config_file_name, get_ssm_configs, get_ssm_device_name, @@ -28,7 +29,7 @@ def _clear_caches() -> None: get_ssm_configs.cache_clear() - try_get_optimal_ssm_config.cache_clear() + _try_get_optimal_ssm_config_cached.cache_clear() def _write_config(tmp_path, dstate: int, payload: dict) -> None: diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 97573b70f43c..5d587ddfcbb9 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -7,6 +7,7 @@ import functools import json import os +from contextlib import contextmanager from typing import Any import torch @@ -130,7 +131,7 @@ def _get_default_ssm_launch_config( @functools.cache -def try_get_optimal_ssm_config( +def _try_get_optimal_ssm_config_cached( headdim: int, dstate: int, batch: int, @@ -138,12 +139,7 @@ def try_get_optimal_ssm_config( cache_dtype: str, is_blackwell: bool, ) -> tuple[int, int]: - """Return (BLOCK_SIZE_M, num_warps) for the given kernel shape. - - Tuning is keyed on ``effective_batch = batch * nheads`` (the kernel grid - scales with the product), so configs transfer across (model, TP) combos - sharing ``(headdim, dstate, cache_dtype)``. - """ + """Cached resolution. See :func:`try_get_optimal_ssm_config`.""" effective_batch = batch * nheads configs = get_ssm_configs(headdim, dstate, cache_dtype) if configs: @@ -155,6 +151,45 @@ def try_get_optimal_ssm_config( return _get_default_ssm_launch_config(dstate, is_blackwell) +# Override hook for benchmarks/tests, see `override_ssm_config`. +_ssm_config_override: tuple[int, int] | None = None + + +@contextmanager +def override_ssm_config(config: tuple[int, int]): + """Force ``try_get_optimal_ssm_config`` to return ``config`` for the + duration of the context. Used by the tuning benchmark to time specific + (BLOCK_SIZE_M, num_warps) pairs.""" + global _ssm_config_override + prev = _ssm_config_override + _ssm_config_override = config + try: + yield + finally: + _ssm_config_override = prev + + +def try_get_optimal_ssm_config( + headdim: int, + dstate: int, + batch: int, + nheads: int, + cache_dtype: str, + is_blackwell: bool, +) -> tuple[int, int]: + """Return (BLOCK_SIZE_M, num_warps) for the given kernel shape. + + Tuning is keyed on ``effective_batch = batch * nheads`` (the kernel grid + scales with the product), so configs transfer across (model, TP) combos + sharing ``(headdim, dstate, cache_dtype)``. + """ + if _ssm_config_override is not None: + return _ssm_config_override + return _try_get_optimal_ssm_config_cached( + headdim, dstate, batch, nheads, cache_dtype, is_blackwell + ) + + if TRITON3: @triton.jit From 24a2f614b7a5ffbf0edcd72d98cab777b5f93018 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Wed, 20 May 2026 22:11:29 +0300 Subject: [PATCH 15/24] Cleanup comments Signed-off-by: Daniel Serebrenik --- .../kernels/benchmark_selective_state_update.py | 12 ++++-------- vllm/model_executor/layers/mamba/ops/mamba_ssm.py | 7 ++----- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/benchmarks/kernels/benchmark_selective_state_update.py b/benchmarks/kernels/benchmark_selective_state_update.py index 6184026067df..6718b16f534d 100644 --- a/benchmarks/kernels/benchmark_selective_state_update.py +++ b/benchmarks/kernels/benchmark_selective_state_update.py @@ -36,8 +36,7 @@ ) from vllm.triton_utils import triton -# bf16 maps to float16 -# (same number of bits, same tuned config should work for both) +# bf16 shares configs with fp16 - same bit width. _SSM_CACHE_DTYPE_MAP: dict[str, torch.dtype] = { "float32": torch.float32, "float16": torch.float16, @@ -69,8 +68,7 @@ def _block_size_m_choices(headdim: int) -> list[int]: # Default deployment shapes. effective_batch = batch * nheads scales the # kernel grid, so configs transfer across (model, TP) combos sharing -# (headdim, dstate, cache_dtype). nheads=128 and 256 cover common Mamba2 -# deployment shapes (Nemotron-class with/without TP). +# (headdim, dstate, cache_dtype). DEFAULT_BATCH_SIZES = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 1536, 2048] DEFAULT_NHEADS = [128, 256] @@ -343,9 +341,8 @@ def validate_configs( config, run the kernel with that config and compare against the reference. Returns {effective_batch: passed}. """ - # Disable TF32 in the reference's matmul: at larger effective_batch the - # worst output value grows, so TF32 rounding shows up as a bf16-quantum - # mismatch (e.g. 1.0–4.0) versus the Triton kernel's true fp32 accumulation. + # Disable TF32 so the reference's matmul matches the Triton kernel's + # fp32 accumulation; otherwise large ebs show bf16 rounding mismatches. torch.set_float32_matmul_precision("highest") print(f"\n{'=' * 74}") @@ -703,7 +700,6 @@ def flush(self): ) if args.compare: - # Use the measurements from tune_dstate compare_heuristic_vs_tuned( dstate=dstate, headdim=args.headdim, diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 5d587ddfcbb9..185aae8a1f4b 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -98,8 +98,6 @@ def get_ssm_configs( if isinstance(raw, dict): # triton_version included in the config file only for reference raw.pop("triton_version", None) - # Filter to integer-string keys to tolerate hand-edited - # configs with extra annotation fields. return {int(k): v for k, v in raw.items() if k.isdigit()} logger.warning_once( @@ -157,9 +155,8 @@ def _try_get_optimal_ssm_config_cached( @contextmanager def override_ssm_config(config: tuple[int, int]): - """Force ``try_get_optimal_ssm_config`` to return ``config`` for the - duration of the context. Used by the tuning benchmark to time specific - (BLOCK_SIZE_M, num_warps) pairs.""" + """Pin ``try_get_optimal_ssm_config`` to ``config`` for the duration of + the context. Used by the tuning benchmark to time specific configs.""" global _ssm_config_override prev = _ssm_config_override _ssm_config_override = config From e4c31818e651dda75fd73bd1306fe7aebbc599ce Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Thu, 21 May 2026 10:19:05 +0300 Subject: [PATCH 16/24] Add upper limit to eff batch (to avoid cuda error) Signed-off-by: Daniel Serebrenik --- .../benchmark_selective_state_update.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/benchmarks/kernels/benchmark_selective_state_update.py b/benchmarks/kernels/benchmark_selective_state_update.py index 6718b16f534d..c893d1daf251 100644 --- a/benchmarks/kernels/benchmark_selective_state_update.py +++ b/benchmarks/kernels/benchmark_selective_state_update.py @@ -200,6 +200,11 @@ def _call_kernel() -> None: # CUDA grid Y/Z dim limit — both `batch` and `nheads` must fit individually. _CUDA_MAX_GRID_DIM = 65535 +# Above this, kernel state-offset arithmetic (batch * nheads * headdim * dstate) +# overflows int32 and the launch raises cudaErrorIllegalAddress. +# 262144 covers Nemotron Super TP1 BS=2048. +_MAX_EFFECTIVE_BATCH = 262144 + def expand_batch_x_nheads( batch_sizes: list[int], @@ -208,12 +213,13 @@ def expand_batch_x_nheads( ) -> list[tuple[int, int, int]]: """Cross-product batch_sizes × nheads_list → sorted [(effective_batch, batch, nheads)], deduped by effective_batch. Filters pairs that exceed - the CUDA grid dim limit or where nheads is not a positive multiple of - ngroups. + the CUDA grid dim limit, the effective_batch ceiling, or where nheads is + not a positive multiple of ngroups. """ seen: dict[int, tuple[int, int]] = {} skipped_grid: list[tuple[int, int]] = [] skipped_ngroups: list[tuple[int, int]] = [] + skipped_eb: list[tuple[int, int]] = [] for b, n in product(batch_sizes, nheads_list): if b <= 0 or n <= 0: continue @@ -223,6 +229,9 @@ def expand_batch_x_nheads( if n % ngroups != 0: skipped_ngroups.append((b, n)) continue + if b * n > _MAX_EFFECTIVE_BATCH: + skipped_eb.append((b, n)) + continue seen.setdefault(b * n, (b, n)) if skipped_grid: print( @@ -234,6 +243,11 @@ def expand_batch_x_nheads( f" Note: skipping (batch, nheads) pairs where nheads % ngroups != 0 " f"for ngroups={ngroups}: {skipped_ngroups}" ) + if skipped_eb: + print( + f" Note: skipping (batch, nheads) pairs whose effective_batch " + f"exceeds {_MAX_EFFECTIVE_BATCH}: {skipped_eb}" + ) return sorted((eb, b, n) for eb, (b, n) in seen.items()) From 73375622b793d7017e6b8c7ca10ac587982ab4b1 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Thu, 21 May 2026 10:39:34 +0300 Subject: [PATCH 17/24] Update JSON for B200 Signed-off-by: Daniel Serebrenik --- ..._name=NVIDIA_B200,cache_dtype=float16.json | 82 ++++++------------- ..._name=NVIDIA_B200,cache_dtype=float32.json | 40 +-------- 2 files changed, 25 insertions(+), 97 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json index 40f061a2247e..bbd24af8e330 100644 --- a/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json +++ b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float16.json @@ -5,79 +5,43 @@ "num_warps": 4 }, "16": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "24": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 + "BLOCK_SIZE_M": 8, + "num_warps": 8 }, "32": { - "BLOCK_SIZE_M": 16, - "num_warps": 4 - }, - "48": { "BLOCK_SIZE_M": 4, - "num_warps": 2 + "num_warps": 1 }, "64": { "BLOCK_SIZE_M": 16, "num_warps": 4 }, - "96": { - "BLOCK_SIZE_M": 16, - "num_warps": 4 - }, "128": { "BLOCK_SIZE_M": 32, - "num_warps": 8 - }, - "192": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 + "num_warps": 4 }, "256": { "BLOCK_SIZE_M": 16, "num_warps": 2 }, - "384": { - "BLOCK_SIZE_M": 32, - "num_warps": 2 - }, "512": { "BLOCK_SIZE_M": 32, - "num_warps": 1 - }, - "768": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 + "num_warps": 4 }, "1024": { "BLOCK_SIZE_M": 16, "num_warps": 2 }, - "1536": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 - }, "2048": { - "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_M": 32, "num_warps": 2 }, - "3072": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "4096": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "6144": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "8192": { - "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_M": 16, "num_warps": 2 }, "12288": { @@ -93,31 +57,31 @@ "num_warps": 2 }, "32768": { - "BLOCK_SIZE_M": 32, - "num_warps": 4 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "49152": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "65536": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "98304": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "131072": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "196608": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "262144": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 16, "num_warps": 2 } -} \ No newline at end of file +} diff --git a/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json index 4e71c71dfdc0..0fe496162805 100644 --- a/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json +++ b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_B200,cache_dtype=float32.json @@ -2,80 +2,44 @@ "triton_version": "3.6.0", "8": { "BLOCK_SIZE_M": 4, - "num_warps": 1 + "num_warps": 4 }, "16": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "24": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "32": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "48": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "64": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "96": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "128": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "192": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "256": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "384": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "512": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "768": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "1024": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "1536": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "2048": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "3072": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "4096": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "6144": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "8192": { "BLOCK_SIZE_M": 4, "num_warps": 1 @@ -120,4 +84,4 @@ "BLOCK_SIZE_M": 4, "num_warps": 1 } -} \ No newline at end of file +} From 4953de9d4a97aecee27fe5a8e873a8fd70d6ddfe Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Thu, 21 May 2026 13:10:38 +0300 Subject: [PATCH 18/24] Add tuned JSON files for GB200 Signed-off-by: Daniel Serebrenik --- ...name=NVIDIA_GB200,cache_dtype=float16.json | 87 +++++++++++++++++++ ...name=NVIDIA_GB200,cache_dtype=float32.json | 87 +++++++++++++++++++ 2 files changed, 174 insertions(+) create mode 100644 vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_GB200,cache_dtype=float16.json create mode 100644 vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_GB200,cache_dtype=float32.json diff --git a/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_GB200,cache_dtype=float16.json b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_GB200,cache_dtype=float16.json new file mode 100644 index 000000000000..b498fa3745d0 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_GB200,cache_dtype=float16.json @@ -0,0 +1,87 @@ +{ + "triton_version": "3.6.0", + "8": { + "BLOCK_SIZE_M": 4, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_M": 8, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "num_warps": 1 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "2048": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "4096": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "8192": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "12288": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "16384": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "24576": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "32768": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "49152": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "65536": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "98304": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "131072": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "196608": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + }, + "262144": { + "BLOCK_SIZE_M": 16, + "num_warps": 2 + } +} diff --git a/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_GB200,cache_dtype=float32.json b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_GB200,cache_dtype=float32.json new file mode 100644 index 000000000000..63fdcf5be246 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_GB200,cache_dtype=float32.json @@ -0,0 +1,87 @@ +{ + "triton_version": "3.6.0", + "8": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "1024": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "4096": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "8192": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "12288": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "24576": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "32768": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "49152": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "65536": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "98304": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "131072": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "196608": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + }, + "262144": { + "BLOCK_SIZE_M": 4, + "num_warps": 1 + } +} From 3c3f3aa833d5763b031de7febd4a0f6e9451f671 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Thu, 21 May 2026 13:21:08 +0300 Subject: [PATCH 19/24] Update JSON for H100 Signed-off-by: Daniel Serebrenik --- ...IA_H100_80GB_HBM3,cache_dtype=float16.json | 86 ++++++------------- ...IA_H100_80GB_HBM3,cache_dtype=float32.json | 70 ++++----------- 2 files changed, 42 insertions(+), 114 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json index a51944b59cfa..b479d7f69a43 100644 --- a/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json +++ b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float16.json @@ -2,37 +2,21 @@ "triton_version": "3.6.0", "8": { "BLOCK_SIZE_M": 4, - "num_warps": 2 + "num_warps": 1 }, "16": { "BLOCK_SIZE_M": 16, "num_warps": 4 }, - "24": { + "32": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "32": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "48": { - "BLOCK_SIZE_M": 16, - "num_warps": 4 - }, "64": { "BLOCK_SIZE_M": 16, "num_warps": 4 }, - "96": { - "BLOCK_SIZE_M": 4, - "num_warps": 2 - }, "128": { - "BLOCK_SIZE_M": 4, - "num_warps": 2 - }, - "192": { "BLOCK_SIZE_M": 8, "num_warps": 2 }, @@ -40,84 +24,64 @@ "BLOCK_SIZE_M": 8, "num_warps": 2 }, - "384": { - "BLOCK_SIZE_M": 8, - "num_warps": 2 - }, "512": { "BLOCK_SIZE_M": 8, "num_warps": 2 }, - "768": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "1024": { - "BLOCK_SIZE_M": 16, - "num_warps": 2 - }, - "1536": { - "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_M": 8, "num_warps": 1 }, "2048": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "3072": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "4096": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, - "6144": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "8192": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "12288": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 4 }, "16384": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "24576": { "BLOCK_SIZE_M": 16, - "num_warps": 4 + "num_warps": 2 }, "32768": { "BLOCK_SIZE_M": 16, - "num_warps": 4 + "num_warps": 2 }, "49152": { "BLOCK_SIZE_M": 16, - "num_warps": 4 + "num_warps": 2 }, "65536": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "98304": { "BLOCK_SIZE_M": 16, "num_warps": 2 }, "131072": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "196608": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "262144": { - "BLOCK_SIZE_M": 32, - "num_warps": 1 + "BLOCK_SIZE_M": 16, + "num_warps": 2 } -} \ No newline at end of file +} diff --git a/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json index 6ba079ba6fdc..57fe2996582e 100644 --- a/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json +++ b/vllm/model_executor/layers/mamba/ops/configs/selective_state_update/headdim=64,dstate=128,device_name=NVIDIA_H100_80GB_HBM3,cache_dtype=float32.json @@ -8,47 +8,23 @@ "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "24": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "32": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "48": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "64": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "96": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "128": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "192": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "256": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "384": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "512": { - "BLOCK_SIZE_M": 8, - "num_warps": 4 - }, - "768": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, @@ -56,49 +32,37 @@ "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "1536": { - "BLOCK_SIZE_M": 4, - "num_warps": 1 - }, "2048": { "BLOCK_SIZE_M": 4, "num_warps": 1 }, - "3072": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, "4096": { - "BLOCK_SIZE_M": 8, - "num_warps": 8 - }, - "6144": { "BLOCK_SIZE_M": 4, - "num_warps": 2 + "num_warps": 1 }, "8192": { - "BLOCK_SIZE_M": 4, - "num_warps": 4 + "BLOCK_SIZE_M": 16, + "num_warps": 2 }, "12288": { - "BLOCK_SIZE_M": 16, - "num_warps": 4 + "BLOCK_SIZE_M": 64, + "num_warps": 8 }, "16384": { - "BLOCK_SIZE_M": 4, - "num_warps": 2 + "BLOCK_SIZE_M": 16, + "num_warps": 1 }, "24576": { - "BLOCK_SIZE_M": 4, - "num_warps": 4 + "BLOCK_SIZE_M": 64, + "num_warps": 8 }, "32768": { - "BLOCK_SIZE_M": 4, - "num_warps": 4 + "BLOCK_SIZE_M": 16, + "num_warps": 1 }, "49152": { - "BLOCK_SIZE_M": 4, - "num_warps": 2 + "BLOCK_SIZE_M": 64, + "num_warps": 8 }, "65536": { "BLOCK_SIZE_M": 16, @@ -106,11 +70,11 @@ }, "98304": { "BLOCK_SIZE_M": 16, - "num_warps": 2 + "num_warps": 1 }, "131072": { "BLOCK_SIZE_M": 16, - "num_warps": 1 + "num_warps": 2 }, "196608": { "BLOCK_SIZE_M": 16, @@ -118,6 +82,6 @@ }, "262144": { "BLOCK_SIZE_M": 16, - "num_warps": 1 + "num_warps": 2 } -} \ No newline at end of file +} From 4bba278956aa03dfee79ff7ca8a90aad3ef5ec1a Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Sun, 24 May 2026 14:55:42 +0300 Subject: [PATCH 20/24] Fix stale comment in get_ssm_config_file_name Signed-off-by: Daniel Serebrenik --- vllm/model_executor/layers/mamba/ops/mamba_ssm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 185aae8a1f4b..2aef33375771 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -40,7 +40,8 @@ def get_ssm_config_file_name( ) -> str: """Return the JSON filename for the given kernel shape. - Layout: ``configs/headdim=,dstate=,device_name=,cache_dtype=
.json``. + Layout: ``configs/selective_state_update/ + headdim=,dstate=,device_name=,cache_dtype=
.json``. """ return ( f"headdim={headdim},dstate={dstate}," From e0152b2e0313104ec7c4cfe85ac5398f254d34ea Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Sun, 24 May 2026 15:06:17 +0300 Subject: [PATCH 21/24] Move selective_state_update_ref to new utils file Signed-off-by: Daniel Serebrenik --- .../benchmark_selective_state_update.py | 2 +- tests/kernels/mamba/test_mamba_ssm.py | 73 +---------------- tests/kernels/mamba/utils.py | 78 +++++++++++++++++++ 3 files changed, 80 insertions(+), 73 deletions(-) create mode 100644 tests/kernels/mamba/utils.py diff --git a/benchmarks/kernels/benchmark_selective_state_update.py b/benchmarks/kernels/benchmark_selective_state_update.py index c893d1daf251..a8b73da2aa9a 100644 --- a/benchmarks/kernels/benchmark_selective_state_update.py +++ b/benchmarks/kernels/benchmark_selective_state_update.py @@ -24,7 +24,7 @@ import torch -from tests.kernels.mamba.test_mamba_ssm import selective_state_update_ref +from tests.kernels.mamba.utils import selective_state_update_ref from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( _CONFIGS_DIR, _canonical_cache_dtype, diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index 81715be98941..d812242cba96 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from einops import rearrange, repeat +from tests.kernels.mamba.utils import selective_state_update_ref from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( @@ -17,78 +18,6 @@ from vllm.v1.attention.backends.utils import NULL_BLOCK_ID -def selective_state_update_ref( - state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False -): - """ - Argument: - state: (batch, dim, dstate) or (batch, nheads, dim, dstate) - x: (batch, dim) or (batch, nheads, dim) - dt: (batch, dim) or (batch, nheads, dim) - A: (dim, dstate) or (nheads, dim, dstate) - B: (batch, dstate) or (batch, ngroups, dstate) - C: (batch, dstate) or (batch, ngroups, dstate) - D: (dim,) or (nheads, dim) - z: (batch, dim) or (batch, nheads, dim) - dt_bias: (dim,) or (nheads, dim) - Return: - out: (batch, dim) or (batch, nheads, dim) - """ - has_heads = state.dim() > 3 - if state.dim() == 3: - state = state.unsqueeze(1) - if x.dim() == 2: - x = x.unsqueeze(1) - if dt.dim() == 2: - dt = dt.unsqueeze(1) - if A.dim() == 2: - A = A.unsqueeze(0) - if B.dim() == 2: - B = B.unsqueeze(1) - if C.dim() == 2: - C = C.unsqueeze(1) - if D is not None and D.dim() == 1: - D = D.unsqueeze(0) - if z is not None and z.dim() == 2: - z = z.unsqueeze(1) - if dt_bias is not None and dt_bias.dim() == 1: - dt_bias = dt_bias.unsqueeze(0) - batch, nheads, dim, dstate = state.shape - assert x.shape == (batch, nheads, dim) - assert dt.shape == x.shape - assert A.shape == (nheads, dim, dstate) - ngroups = B.shape[1] - assert nheads % ngroups == 0, "nheads must be divisible by ngroups" - assert B.shape == (batch, ngroups, dstate) - assert C.shape == B.shape - if D is not None: - assert D.shape == (nheads, dim) - if z is not None: - assert z.shape == x.shape - if dt_bias is not None: - assert dt_bias.shape == (nheads, dim) - dt = dt + dt_bias - dt = F.softplus(dt) if dt_softplus else dt - dA = torch.exp( - rearrange(dt, "b h d -> b h d 1") * A - ) # (batch, nheads, dim, dstate) - B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) - C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) - dB = rearrange(dt, "b h d -> b h d 1") * rearrange( - B, "b h n -> b h 1 n" - ) # (batch, nheads, dim, dstate) - state.copy_( - state * dA + dB * rearrange(x, "b h d -> b h d 1") - ) # (batch, dim, dstate - out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) - if D is not None: - out += (x * D).to(out.dtype) - out = (out if z is None else out * F.silu(z)).to(x.dtype) - if not has_heads: - out = out.squeeze(1) - return out - - def selective_scan_ref( u, delta, diff --git a/tests/kernels/mamba/utils.py b/tests/kernels/mamba/utils.py new file mode 100644 index 000000000000..fb8a4b0a28ec --- /dev/null +++ b/tests/kernels/mamba/utils.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + +def selective_state_update_ref( + state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False +): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + dt = dt + dt_bias + dt = F.softplus(dt) if dt_softplus else dt + dA = torch.exp( + rearrange(dt, "b h d -> b h d 1") * A + ) # (batch, nheads, dim, dstate) + B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + dB = rearrange(dt, "b h d -> b h d 1") * rearrange( + B, "b h n -> b h 1 n" + ) # (batch, nheads, dim, dstate) + state.copy_( + state * dA + dB * rearrange(x, "b h d -> b h d 1") + ) # (batch, dim, dstate + out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) + if D is not None: + out += (x * D).to(out.dtype) + out = (out if z is None else out * F.silu(z)).to(x.dtype) + if not has_heads: + out = out.squeeze(1) + return out From 7757ae5d8ccaaa8c643d46380ba1066f004ae9d0 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Sun, 24 May 2026 15:27:34 +0300 Subject: [PATCH 22/24] Fix hard coded block_m/warps in test_mamba_ssm_configs Signed-off-by: Daniel Serebrenik --- tests/kernels/mamba/test_mamba_ssm_configs.py | 45 ++++++++----------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_ssm_configs.py b/tests/kernels/mamba/test_mamba_ssm_configs.py index 97b89e9c8dd2..b5214fca3b94 100644 --- a/tests/kernels/mamba/test_mamba_ssm_configs.py +++ b/tests/kernels/mamba/test_mamba_ssm_configs.py @@ -14,6 +14,7 @@ import json from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + _get_default_ssm_launch_config, _try_get_optimal_ssm_config_cached, get_ssm_config_file_name, get_ssm_configs, @@ -96,29 +97,18 @@ def test_fallback_when_no_config(monkeypatch, tmp_path): ) _clear_caches() - # dstate=64 heuristic: BLOCK_SIZE_M=8, num_warps=4 - block_m, warps = try_get_optimal_ssm_config( - headdim=_HEADDIM, - dstate=64, - batch=1, - nheads=1, - cache_dtype=_CACHE_DTYPE, - is_blackwell=False, - ) - assert block_m == 8 - assert warps == 4 - - # dstate=16 heuristic: BLOCK_SIZE_M=32, num_warps=4 - block_m, warps = try_get_optimal_ssm_config( - headdim=_HEADDIM, - dstate=16, - batch=1, - nheads=1, - cache_dtype=_CACHE_DTYPE, - is_blackwell=False, - ) - assert block_m == 32 - assert warps == 4 + for dstate in (64, 16): + block_m, warps = try_get_optimal_ssm_config( + headdim=_HEADDIM, + dstate=dstate, + batch=1, + nheads=1, + cache_dtype=_CACHE_DTYPE, + is_blackwell=False, + ) + assert (block_m, warps) == _get_default_ssm_launch_config( + dstate, is_blackwell=False + ) _clear_caches() @@ -203,16 +193,17 @@ def test_empty_config_falls_back_to_heuristic(monkeypatch, tmp_path): monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path)) _clear_caches() - # dstate=64 heuristic: BLOCK_SIZE_M=8, num_warps=4 + dstate = 64 block_m, warps = try_get_optimal_ssm_config( headdim=_HEADDIM, - dstate=64, + dstate=dstate, batch=1, nheads=64, cache_dtype=_CACHE_DTYPE, is_blackwell=False, ) - assert block_m == 8 - assert warps == 4 + assert (block_m, warps) == _get_default_ssm_launch_config( + dstate=dstate, is_blackwell=False + ) _clear_caches() From f01f9f1e0193938a3a8d1ed0c391246946e1824f Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Sun, 24 May 2026 15:35:39 +0300 Subject: [PATCH 23/24] Add more coverage in test_mamba_ssm_configs Signed-off-by: Daniel Serebrenik --- tests/kernels/mamba/test_mamba_ssm_configs.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_ssm_configs.py b/tests/kernels/mamba/test_mamba_ssm_configs.py index b5214fca3b94..d35ce95746c7 100644 --- a/tests/kernels/mamba/test_mamba_ssm_configs.py +++ b/tests/kernels/mamba/test_mamba_ssm_configs.py @@ -88,27 +88,30 @@ def test_env_override_loads_custom_config(monkeypatch, tmp_path): def test_fallback_when_no_config(monkeypatch, tmp_path): - """try_get_optimal_ssm_config must fall back to the hard-coded heuristic - when no JSON file is found for the current device.""" + """try_get_optimal_ssm_config must fall back to _get_default_ssm_launch_config + when no JSON file is found for the current + (device, headdim, dstate, cache_dtype) combination. + """ monkeypatch.setenv("VLLM_TUNED_CONFIG_FOLDER", str(tmp_path)) monkeypatch.setattr( "vllm.model_executor.layers.mamba.ops.mamba_ssm._CONFIGS_DIR", str(tmp_path), ) - _clear_caches() - for dstate in (64, 16): - block_m, warps = try_get_optimal_ssm_config( - headdim=_HEADDIM, - dstate=dstate, - batch=1, - nheads=1, - cache_dtype=_CACHE_DTYPE, - is_blackwell=False, - ) - assert (block_m, warps) == _get_default_ssm_launch_config( - dstate, is_blackwell=False - ) + for dstate in (8, 16, 32, 64, 128, 256): + for is_blackwell in (False, True): + _clear_caches() + block_m, warps = try_get_optimal_ssm_config( + headdim=_HEADDIM, + dstate=dstate, + batch=1, + nheads=1, + cache_dtype=_CACHE_DTYPE, + is_blackwell=is_blackwell, + ) + assert (block_m, warps) == _get_default_ssm_launch_config( + dstate, is_blackwell=is_blackwell + ) _clear_caches() From f1b0f51b9f7415e80a4fb64c30b099201bf3bc3d Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Sun, 24 May 2026 16:24:55 +0300 Subject: [PATCH 24/24] Update comment for VLLM_TUNED_CONFIG_FOLDER Signed-off-by: Daniel Serebrenik --- vllm/envs.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index 929feab3ef71..f5b2759e9934 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1744,7 +1744,10 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT": lambda: bool( int(os.getenv("VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", "0")) ), - # Allows vllm to find tuned config under customized folder + # User override folder for tuned Triton-kernel configs. Shared by MoE, + # Mamba SSU, and LoRA. Filenames are distinct so one folder can hold all. + # Each component first checks this folder, then the configs shipped with + # vLLM (if any). If no JSON matches, it uses a hard-coded heuristic. "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), # Valid values are container,code_interpreter,web_search_preview # ex VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter