[Hybrid] Warmup Mamba2 SSD kernel#39822
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a warmup mechanism for Mamba2 SSD kernels in vllm/model_executor/layers/mamba/mamba_mixer2.py to trigger Triton autotuning during the initial profile run. This ensures that autotuning completes before SSM cache allocation, helping to prevent latency spikes or OOM errors during inference. Feedback suggests replacing torch.accelerator.empty_cache() with torch.cuda.empty_cache() to avoid potential AttributeError and ensure better compatibility across different execution environments.
2fa9500 to
3c83e5f
Compare
Run a minimal SSD forward pass during vLLM profile phase to trigger Triton autotuning before SSM cache allocation. This shifts the ~31s first-request latency spike into server startup, reducing it to ~2.9s. Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
3c83e5f to
fb8da1e
Compare
tomeras91
left a comment
There was a problem hiding this comment.
Thanks @tdoublep for taking this! This is very helpful!
Added a few comments
A general comment - Can we add a sentence somewhere saying SSD kernels don't have seqlen/batch-size dependent autotune keys? To preempt the obvious question "why didn't you autotune for different seqlens and batch sizes"..
- Use randn instead of zeros for warmup tensors to avoid kernel fast-paths - Skip warmup when model_config is None instead of defaulting chunk_size - Fix hasattr warmup guard to use __init__ flag (Mamba2 and GDN) - Use logger.info_once for model-level log, logger.debug for per-layer - Fix HAS_INITSTATES comment (JIT compilation, not autotuning) - Add comment explaining autotune keys are shape-independent - Fix get_mamba_chunk_size docstring (1024 -> 2048) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
|
@tomeras91 Thanks for the review - I think I have addressed all feedback, please TAL. |
The method always returns an int (defaults to 2048), so the signature should be `-> int` not `-> int | None`. This fixes mypy errors in CI where chunk_size was used in arithmetic without a None guard. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Summary
Triton's auto-tuner for the Mamba2 SSD kernels currently runs lazily on the first inference request, causing a large latency spike. This PR adds a
_warmup_ssd_kernels()method toMambaMixer2that triggers auto-tuning during vLLM's profile phase (before SSM cache allocation), shifting the cost into server startup.mamba_chunk_scan_combined_varlenforward pass with dummy tensors during the V1 profile runHAS_INITSTATESconstexpr code paths (with and withoutinitial_states)ssm_state_dtype) to match Triton's cache keysBenchmark Results
Measured on a single H100 80GB with
nvidia/NVIDIA-Nemotron-3-Nano-4B-BF16(Mamba2 hybrid model),max_model_len=512, cold Triton cache (TRITON_PRINT_AUTOTUNING=1confirmed autotuning location):With
TRITON_PRINT_AUTOTUNING=1, 58 kernel autotuning events occurred after model load on main (during the first request). With the warmup branch, zero autotuning events occurred after model load -- all SSD kernel autotuning completed during initialization.The remaining ~2.9s first-request overhead (vs 0.08s subsequent) is clearly not from auto-tuning of any kernels, since
TRITON_PRINT_AUTOTUNING=1produces no output during the first request. The current suspicion is that this residual cost comes from Triton JIT compilation (as opposed to auto-tuning), but this requires further investigation as a follow-up.Benchmark script
Test plan
TRITON_PRINT_AUTOTUNING=1