Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 106 additions & 1 deletion vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
tensor_model_parallel_all_reduce,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp, PluggableLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand Down Expand Up @@ -53,6 +54,8 @@
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata

logger = init_logger(__name__)

# Added by the IBM Team, 2024


Expand Down Expand Up @@ -556,6 +559,106 @@ def forward(

return output

def _warmup_ssd_kernels(self, projected_states: torch.Tensor) -> None:
"""Run a minimal SSD forward pass to trigger Triton autotuning
while GPU memory is still plentiful (before SSM cache allocation).
"""
if hasattr(self, "_ssd_kernels_warmed_up"):
Comment thread
tdoublep marked this conversation as resolved.
Outdated
return
self._ssd_kernels_warmed_up = True
logger.info("Starting Mamba2 SSD autotuning warmup for layer %s", self.prefix)
Comment thread
tdoublep marked this conversation as resolved.
Outdated

device = projected_states.device
dtype = projected_states.dtype

nheads = self.num_heads // self.tp_size
ngroups = self.n_groups // self.tp_size
headdim = self.head_dim
dstate = self.ssm_state_size

chunk_size = (
self.model_config.get_mamba_chunk_size()
if self.model_config is not None
else None
)
if chunk_size is None:
chunk_size = 64
Comment thread
tdoublep marked this conversation as resolved.
Outdated

# Triton's autotuner includes tensor dtypes in its cache key,
# so state_dtype must match what real inference uses.
_, ssm_state_dtype = self.get_state_dtype()

seqlen = chunk_size
batch = 1
nchunks = seqlen // chunk_size # = 1

x = torch.zeros(seqlen, nheads, headdim, device=device, dtype=dtype)
Comment thread
tdoublep marked this conversation as resolved.
Outdated
dt = torch.zeros(seqlen, nheads, device=device, dtype=dtype)
B = torch.zeros(seqlen, ngroups, dstate, device=device, dtype=dtype)
C = torch.zeros(seqlen, ngroups, dstate, device=device, dtype=dtype)
cu_seqlens = torch.tensor([0, seqlen], device=device, dtype=torch.int32)
cu_chunk_seqlens = torch.tensor(
[i * chunk_size for i in range(nchunks + 1)],
device=device,
dtype=torch.int32,
)
last_chunk_indices = torch.tensor(
[nchunks - 1], device=device, dtype=torch.int32
)
seq_idx = torch.zeros(nchunks, device=device, dtype=torch.int32)
out = torch.empty(seqlen, nheads, headdim, device=device, dtype=dtype)

# Two kernels (_state_passing_fwd, _chunk_scan_fwd) use
Comment thread
tdoublep marked this conversation as resolved.
# HAS_INITSTATES as a constexpr, so we must warm up both
# code paths: without initial_states (first request) and
# with initial_states (subsequent requests).
for use_initial_states in (False, True):
initial_states = (
torch.zeros(
batch,
nheads,
headdim,
dstate,
device=device,
dtype=ssm_state_dtype,
)
if use_initial_states
else None
)
try:
mamba_chunk_scan_combined_varlen(
x=x,
dt=dt,
A=self.A,
B=B,
C=C,
chunk_size=chunk_size,
cu_seqlens=cu_seqlens,
cu_chunk_seqlens=cu_chunk_seqlens,
last_chunk_indices=last_chunk_indices,
seq_idx=seq_idx,
out=out,
D=self.D,
z=None,
dt_bias=self.dt_bias,
initial_states=initial_states,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
state_dtype=ssm_state_dtype,
)
except Exception:
logger.warning(
"Mamba2 SSD kernel warmup failed for layer %s "
"(initial_states=%s). First inference may experience "
"latency spike or OOM due to autotuner.",
self.prefix,
use_initial_states,
exc_info=True,
)

logger.info("Mamba2 SSD kernel warmup completed for layer %s", self.prefix)
torch.accelerator.empty_cache()
Comment thread
tdoublep marked this conversation as resolved.

def conv_ssm_forward(
self,
projected_states: torch.Tensor,
Expand Down Expand Up @@ -605,7 +708,9 @@ def conv_ssm_forward(
num_decode_tokens = attn_metadata.num_decode_tokens

if attn_metadata is None:
# profile run
# V1 profile run -- warm up SSD kernels so that autotuning
# completes before SSM cache allocation.
self._warmup_ssd_kernels(projected_states)
hidden_states_B_C = (
hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
).contiguous()
Expand Down
Loading