Skip to content

[Hybrid] Warmup Mamba2 SSD kernel#39822

Merged
tomeras91 merged 4 commits into
vllm-project:mainfrom
tdoublep:mamba2-ssd-kernel-warmup
May 12, 2026
Merged

[Hybrid] Warmup Mamba2 SSD kernel#39822
tomeras91 merged 4 commits into
vllm-project:mainfrom
tdoublep:mamba2-ssd-kernel-warmup

Conversation

@tdoublep
Copy link
Copy Markdown
Member

@tdoublep tdoublep commented Apr 14, 2026

Summary

Triton's auto-tuner for the Mamba2 SSD kernels currently runs lazily on the first inference request, causing a large latency spike. This PR adds a _warmup_ssd_kernels() method to MambaMixer2 that triggers auto-tuning during vLLM's profile phase (before SSM cache allocation), shifting the cost into server startup.

  • Runs a minimal mamba_chunk_scan_combined_varlen forward pass with dummy tensors during the V1 profile run
  • Covers both HAS_INITSTATES constexpr code paths (with and without initial_states)
  • Uses correct dtypes (including ssm_state_dtype) to match Triton's cache keys
  • Each layer warms up once, and subsequent layers hit the Triton cache

Benchmark Results

Measured on a single H100 80GB with nvidia/NVIDIA-Nemotron-3-Nano-4B-BF16 (Mamba2 hybrid model), max_model_len=512, cold Triton cache (TRITON_PRINT_AUTOTUNING=1 confirmed autotuning location):

Metric Baseline (main) With Warmup Change
Model load time 30.0s 76.9s +46.9s (autotuning shifted here)
First request latency 31.343s 2.890s -28.5s (91% reduction)
Subsequent request avg 0.083s 0.083s No change
First / subsequent ratio 378x 35x 10.8x improvement

With TRITON_PRINT_AUTOTUNING=1, 58 kernel autotuning events occurred after model load on main (during the first request). With the warmup branch, zero autotuning events occurred after model load -- all SSD kernel autotuning completed during initialization.

The remaining ~2.9s first-request overhead (vs 0.08s subsequent) is clearly not from auto-tuning of any kernels, since TRITON_PRINT_AUTOTUNING=1 produces no output during the first request. The current suspicion is that this residual cost comes from Triton JIT compilation (as opposed to auto-tuning), but this requires further investigation as a follow-up.

Benchmark script
#!/usr/bin/env python3
"""Measure first-request vs subsequent-request latency for a Mamba2 model."""
import os
import shutil
import time

for cache_dir in [
    os.path.expanduser("~/.triton/cache"),
    os.path.expanduser("~/.cache/vllm/torch_compile_cache"),
]:
    if os.path.exists(cache_dir):
        shutil.rmtree(cache_dir)
        print(f"Cleared {cache_dir}")

from vllm import LLM, SamplingParams

NUM_REQUESTS = 5
PROMPT = "The capital of France is"

print("Loading model...")
t0 = time.perf_counter()
llm = LLM(model="nvidia/NVIDIA-Nemotron-3-Nano-4B-BF16",
           max_model_len=512, trust_remote_code=True)
load_time = time.perf_counter() - t0
print(f"Model loaded in {load_time:.1f}s\n")

sampling_params = SamplingParams(temperature=0.0, max_tokens=20)

results = []
for i in range(NUM_REQUESTS):
    t_start = time.perf_counter()
    llm.generate([PROMPT], sampling_params)
    latency = time.perf_counter() - t_start
    results.append(latency)
    print(f"  Request {i+1}: {latency:.3f}s")

first = results[0]
rest_avg = sum(results[1:]) / len(results[1:])
print(f"\n  Model load time:    {load_time:.1f}s")
print(f"  First request:      {first:.3f}s")
print(f"  Subsequent avg:     {rest_avg:.3f}s")
print(f"  Ratio (1st / avg):  {first / rest_avg:.2f}x")

Test plan

  • Verified autotuning moves from first request to model init via TRITON_PRINT_AUTOTUNING=1
  • Confirmed first-request latency drops from 31.3s to 2.9s
  • Confirmed subsequent request latency is unchanged
  • Run existing Mamba2 model tests

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a warmup mechanism for Mamba2 SSD kernels in vllm/model_executor/layers/mamba/mamba_mixer2.py to trigger Triton autotuning during the initial profile run. This ensures that autotuning completes before SSM cache allocation, helping to prevent latency spikes or OOM errors during inference. Feedback suggests replacing torch.accelerator.empty_cache() with torch.cuda.empty_cache() to avoid potential AttributeError and ensure better compatibility across different execution environments.

Comment thread vllm/model_executor/layers/mamba/mamba_mixer2.py
@tdoublep tdoublep force-pushed the mamba2-ssd-kernel-warmup branch from 2fa9500 to 3c83e5f Compare April 21, 2026 22:10
Run a minimal SSD forward pass during vLLM profile phase to trigger
Triton autotuning before SSM cache allocation. This shifts the ~31s
first-request latency spike into server startup, reducing it to ~2.9s.

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@tdoublep tdoublep force-pushed the mamba2-ssd-kernel-warmup branch from 3c83e5f to fb8da1e Compare April 21, 2026 22:11
@tdoublep tdoublep marked this pull request as ready for review April 21, 2026 22:11
@tdoublep tdoublep requested a review from tomeras91 as a code owner April 21, 2026 22:11
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Copy link
Copy Markdown
Member

@tomeras91 tomeras91 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tdoublep for taking this! This is very helpful!

Added a few comments

A general comment - Can we add a sentence somewhere saying SSD kernels don't have seqlen/batch-size dependent autotune keys? To preempt the obvious question "why didn't you autotune for different seqlens and batch sizes"..

Comment thread vllm/model_executor/layers/mamba/mamba_mixer2.py
Comment thread vllm/model_executor/layers/mamba/mamba_mixer2.py Outdated
Comment thread vllm/model_executor/layers/mamba/mamba_mixer2.py Outdated
Comment thread vllm/model_executor/layers/mamba/mamba_mixer2.py Outdated
Comment thread vllm/model_executor/layers/mamba/mamba_mixer2.py Outdated
Comment thread vllm/model_executor/layers/mamba/mamba_mixer2.py
tdoublep added 2 commits May 11, 2026 19:59
- Use randn instead of zeros for warmup tensors to avoid kernel fast-paths
- Skip warmup when model_config is None instead of defaulting chunk_size
- Fix hasattr warmup guard to use __init__ flag (Mamba2 and GDN)
- Use logger.info_once for model-level log, logger.debug for per-layer
- Fix HAS_INITSTATES comment (JIT compilation, not autotuning)
- Add comment explaining autotune keys are shape-independent
- Fix get_mamba_chunk_size docstring (1024 -> 2048)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@tdoublep
Copy link
Copy Markdown
Member Author

@tomeras91 Thanks for the review - I think I have addressed all feedback, please TAL.

The method always returns an int (defaults to 2048), so the signature
should be `-> int` not `-> int | None`. This fixes mypy errors in CI
where chunk_size was used in arithmetic without a None guard.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Copy link
Copy Markdown
Member

@tomeras91 tomeras91 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tdoublep
LGTM!

@tomeras91 tomeras91 enabled auto-merge (squash) May 12, 2026 06:39
@tomeras91 tomeras91 added the ready ONLY add when PR is ready to merge/full CI is needed label May 12, 2026
@tomeras91 tomeras91 merged commit 8f89381 into vllm-project:main May 12, 2026
73 checks passed
@tdoublep tdoublep deleted the mamba2-ssd-kernel-warmup branch May 12, 2026 13:02
weifang231 pushed a commit to weifang231/eb-vllm that referenced this pull request May 13, 2026
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
h1t35h pushed a commit to h1t35h/vllm that referenced this pull request May 21, 2026
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants