Skip to content
310 changes: 220 additions & 90 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional
from typing import NamedTuple, Optional

import torch
from torch import nn
from torch.nn.parameter import Parameter

from vllm import envs
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -154,13 +156,38 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):

self.prefix = prefix

def _ssm_transform(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.is_lora_enabled:
# Lora kernel requires contiguous tensor.
ssm_params = self.x_proj(x.contiguous())[0]
else:
ssm_params = self.x_proj(x)[0]
time_step, B, C = torch.split(
ssm_params,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1)
if self.use_rms_norm:
assert self.dt_layernorm is not None
assert self.b_layernorm is not None
assert self.c_layernorm is not None
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
return discrete_time_step, B, C

def forward(self,
hidden_states: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
if not envs.VLLM_USE_V1:
return CustomOp.forward(self, hidden_states, mamba_cache_params)
else:
return self.forward_cuda(hidden_states, mamba_cache_params)
return self.forward_cuda(
hidden_states,
mamba_cache_params,
)

def forward_native(self,
hidden_states: torch.Tensor,
Expand All @@ -170,6 +197,27 @@ def forward_native(self,
def forward_cuda(self,
hidden_states: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
"""
Run the Mamba-1 SSM pipeline.

Steps
-----
1. Apply the gated-MLP linear projection to the raw input.
2. Pass the projected sequence through the convolutional mixing layer.
3. Feed the result into the State-Space Model (SSM) blocks.
4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
to produce contextual representations.
5. Project the contextualised sequence back
to the output embedding dimension.

Batch handling
--------------
Prefill and decode tokens are processed by dedicated CUDA
kernels for both the convolutional (conv1d) and SSM stages.
In the case of a mixed batch (containing both prefill and
decode tokens), both sets of kernels are executed independently
and their outputs are concatenated before the final output projection.
"""

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
Expand All @@ -185,126 +233,142 @@ def forward_cuda(self,
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_state = mamba1_metadata.has_initial_states
context_lens_tensor = mamba1_metadata.context_lens_tensor
has_initial_states = mamba1_metadata.has_initial_states
else:
assert isinstance(attn_metadata, PlaceholderAttentionMetadata)
assert mamba_cache_params is not None
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
query_start_loc = attn_metadata.query_start_loc
context_lens_tensor = attn_metadata.context_lens_tensor

has_initial_states = None
if context_lens_tensor is not None:
has_initial_state = context_lens_tensor > 0
has_initial_states = context_lens_tensor > 0

# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)
hidden_states_BC, gate = projected_states.chunk(2, dim=-2)

# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))

if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
hidden_states = hidden_states.contiguous()
return self.out_proj(hidden_states.transpose(-2, -1))[0]

if query_start_loc is not None and context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
hidden_states_BC = hidden_states_BC.contiguous()
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]

num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
has_prefill = num_prefill_tokens > 0
has_decode = num_decode_tokens > 0

prefill_decode_split = split_batch_to_prefill_and_decode(
hidden_states_BC,
gate,
state_indices_tensor,
query_start_loc,
has_initial_states,
num_prefill_tokens,
num_decode_tokens,
num_prefills,
num_decodes,
)
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
gate_p = prefill_decode_split.gate_p
gate_d = prefill_decode_split.gate_d
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
query_start_loc_p = prefill_decode_split.query_start_loc_p
has_initial_states_p = prefill_decode_split.has_initial_states_p

ssm_outputs = []

if has_prefill:
# 2. Convolution sequence transformation
conv_out_p = causal_conv1d_fn(
hidden_states_BC_p,
conv_weights,
bias=self.conv1d.bias,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=has_initial_state,
cache_indices=state_indices_tensor,
query_start_loc=query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p)
# 3. State Space Model sequence transformations.
discrete_time_step_p, B_p, C_p = self._ssm_transform(
conv_out_p.transpose(-2, -1))
time_proj_bias = self._time_proj_bias()

# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
scan_out_p = selective_scan_fn(
conv_out_p,
ssm_state,
discrete_time_step_p,
self.A,
B_p.transpose(-2, -1),
C_p.transpose(-2, -1),
self.D.float(),
gate_p,
time_proj_bias,
delta_softplus=True,
cache_indices=state_indices_tensor_p,
has_initial_state=has_initial_states_p,
query_start_loc=query_start_loc_p)
ssm_outputs.append(scan_out_p)

if has_decode:
# 2. Convolution sequence transformation
conv_out_d = causal_conv1d_update(
hidden_states_BC_d.transpose(0, 1),
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
conv_state_indices=state_indices_tensor_d).transpose(0, 1)

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
# 3. State Space Model sequence transformation.
discrete_time_step_d, B_d, C_d = self._ssm_transform(
conv_out_d.transpose(-2, -1))
time_proj_bias = self._time_proj_bias()

if self.is_lora_enabled:
# lora kernel requires contiguous tensor
ssm_parameters = self.x_proj(
hidden_states.transpose(-2, -1).contiguous())[0]
else:
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]

time_step, B, C = torch.split(
ssm_parameters,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1,
)
if self.use_rms_norm:
assert self.dt_layernorm is not None
assert self.b_layernorm is not None
assert self.c_layernorm is not None
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())

discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)

if query_start_loc is not None and context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
cache_indices=state_indices_tensor,
has_initial_state=has_initial_state,
query_start_loc=query_start_loc)
else:
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
scan_outputs_d = torch.empty_like(
hidden_states_BC_d.transpose(0, 1))
selective_state_update(ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
conv_out_d.transpose(0, 1),
discrete_time_step_d.transpose(0, 1),
self.A,
B,
C,
B_d,
C_d,
self.D,
gate.transpose(0, 1),
gate_d.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor,
out=scan_outputs)
scan_outputs = scan_outputs.transpose(0, 1)

# 4. Final linear projection
if self.is_lora_enabled:
# lora kernel requires contiguous tensor
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1).contiguous())[0]
state_batch_indices=state_indices_tensor_d,
out=scan_outputs_d)
scan_outputs_d = scan_outputs_d.transpose(0, 1)

if envs.VLLM_USE_V1:
ssm_outputs.insert(0, scan_outputs_d)
else:
ssm_outputs.append(scan_outputs_d)

scan_outputs_combined = ssm_outputs[0] if len(
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)

# 5. Final output projection
if self.is_lora_enabled: # Lora kernel requires contiguous tensor.
scan_outputs_combined = scan_outputs_combined.transpose(
-2, -1).contiguous()
out = self.out_proj(scan_outputs_combined)[0]
else:
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1))[0]
return contextualized_states
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]

return out

def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba1_state_shape(
Expand All @@ -317,3 +381,69 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
@property
def mamba_type(self) -> str:
return "mamba1"

def _time_proj_bias(self) -> Optional[torch.Tensor]:
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
return self.dt_proj.bias.float()
return None


class PrefillDecodeSplit(NamedTuple):
hidden_states_BC_p: torch.Tensor
hidden_states_BC_d: torch.Tensor
gate_p: torch.Tensor
gate_d: torch.Tensor
state_indices_tensor_p: torch.Tensor
state_indices_tensor_d: torch.Tensor
query_start_loc_p: Optional[torch.Tensor]
has_initial_states_p: Optional[torch.Tensor]


def split_batch_to_prefill_and_decode(
hidden_states_BC: torch.Tensor,
gate: torch.Tensor,
state_indices_tensor: torch.Tensor,
query_start_loc: torch.Tensor,
has_initial_states: Optional[torch.Tensor],
num_prefill_tokens: int,
num_decode_tokens: int,
num_prefills: int,
num_decodes: int,
) -> PrefillDecodeSplit:
if envs.VLLM_USE_V1:
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1)
gate_d, gate_p = torch.split(gate,
[num_decode_tokens, num_prefill_tokens],
dim=-1)
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor, [num_decodes, num_prefills], dim=0)
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
num_decodes if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[-num_prefills:] if (
has_initial_states is not None and num_prefills > 0) else None
else:
# In v0, prefill tokens come first, then decode tokens.
hidden_states_BC_p, hidden_states_BC_d = torch.split(
hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1)
gate_p, gate_d = torch.split(gate,
[num_prefill_tokens, num_decode_tokens],
dim=-1)
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor, [num_prefills, num_decodes], dim=0)
query_start_loc_p = (query_start_loc[:num_prefills +
1] if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[:num_prefills] if (
has_initial_states is not None and num_prefills > 0) else None

return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p,
hidden_states_BC_d=hidden_states_BC_d,
gate_p=gate_p,
gate_d=gate_d,
state_indices_tensor_p=state_indices_tensor_p,
state_indices_tensor_d=state_indices_tensor_d,
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
)
Loading