Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions vllm/model_executor/layers/mamba/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
),
)

@staticmethod
def clear_stale_decode_states(
has_initial_states_d: torch.Tensor | None,
indices: torch.Tensor,
ssm_state: torch.Tensor,
conv_state: torch.Tensor,
) -> None:
if has_initial_states_d is None:
return

for state in (ssm_state, conv_state):
gathered = state[indices]
keep_state = has_initial_states_d.to(gathered.dtype)
keep_state = keep_state.view(-1, *([1] * (gathered.dim() - 1)))
state[indices] = gathered * keep_state

def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this Mamba layer."""
return get_mamba_attn_backend(self.mamba_type)
9 changes: 9 additions & 0 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor):
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
has_initial_states_d = attn_metadata.has_initial_states_d
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p

Expand Down Expand Up @@ -394,6 +395,14 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor):
else:
state_indices_tensor_d_input = state_indices_tensor_d
state_indices_tensor_d_output = state_indices_tensor_d

self.clear_stale_decode_states(
has_initial_states_d,
state_indices_tensor_d_input,
ssm_state,
conv_state,
)

# 2. Convolution sequence transformation
conv_out_d = causal_conv1d_update(
hidden_states_BC_d.transpose(0, 1),
Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def conv_ssm_forward(
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
has_initial_states_d = attn_metadata.has_initial_states_d
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
Expand Down Expand Up @@ -831,6 +832,13 @@ def conv_ssm_forward(
state_indices_tensor_d_input = state_indices_tensor_d
state_indices_tensor_d_output = state_indices_tensor_d

self.clear_stale_decode_states(
has_initial_states_d,
state_indices_tensor_d_input,
ssm_state,
conv_state,
)

# 2. Convolution sequence transformation
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
Expand Down
15 changes: 15 additions & 0 deletions vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ def __init__(
device=device,
)

self.has_initial_states_d: torch.Tensor = torch.ones(
(self.decode_cudagraph_max_bs,),
dtype=torch.bool,
device=device,
)

self._init_reorder_batch_threshold(1, self.use_spec_decode)
if self.use_spec_decode:
self.supports_update_block_table = False
Expand Down Expand Up @@ -501,6 +507,7 @@ def _update_metadata_for_cudagraph_capture(
Currently, only decode is supported for full cudagraphs with Mamba.
"""
state_indices_tensor_d = metadata.state_indices_tensor_d
has_initial_states_d = metadata.has_initial_states_d
query_start_loc_d = metadata.query_start_loc_d
num_accepted_tokens = metadata.num_accepted_tokens
block_idx_last_scheduled_token = metadata.block_idx_last_scheduled_token
Expand All @@ -517,6 +524,13 @@ def _update_metadata_for_cudagraph_capture(
state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs]
state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID

if has_initial_states_d is not None:
self.has_initial_states_d[: metadata.num_decodes].copy_(
has_initial_states_d, non_blocking=True
)
self.has_initial_states_d[metadata.num_decodes :] = True
has_initial_states_d = self.has_initial_states_d[:padded_bs]

if self.use_spec_decode:
assert query_start_loc_d is not None
assert num_accepted_tokens is not None
Expand Down Expand Up @@ -551,6 +565,7 @@ def _update_metadata_for_cudagraph_capture(
return replace(
metadata,
state_indices_tensor_d=state_indices_tensor_d,
has_initial_states_d=has_initial_states_d,
query_start_loc_d=query_start_loc_d,
num_accepted_tokens=num_accepted_tokens,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
Expand Down
6 changes: 0 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3683,12 +3683,6 @@ def execute_model(
)
)

mamba_utils.clear_stale_mamba_states(
attn_metadata,
self.attn_groups,
self.compilation_config.static_forward_context,
)

(
input_ids,
inputs_embeds,
Expand Down
38 changes: 0 additions & 38 deletions vllm/v1/worker/mamba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
)
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
from vllm.v1.worker.utils import AttentionGroup


@triton.jit
Expand Down Expand Up @@ -216,42 +214,6 @@ def preprocess_mamba(
do_mamba_copy_block(copy_bufs)


def clear_stale_mamba_states(
attn_metadata: list[dict[str, Any]] | dict[str, Any],
attn_groups: list[list["AttentionGroup"]],
forward_context: dict[str, Any],
) -> None:
"""Clear Mamba states for new requests in the decode batch.

New requests (has_initial_states=False) would otherwise read stale
state from a recycled cache slot.
"""
if not isinstance(attn_metadata, dict):
return

for kv_cache_groups in attn_groups:
for attn_group in kv_cache_groups:
if not isinstance(attn_group.kv_cache_spec, MambaSpec):
continue

metadata = attn_metadata.get(attn_group.layer_names[0])
if not isinstance(metadata, BaseMambaAttentionMetadata):
continue

has_initial_states_d = metadata.has_initial_states_d
if has_initial_states_d is None:
continue

assert metadata.state_indices_tensor_d is not None
indices = metadata.state_indices_tensor_d[: metadata.num_decodes]
new_indices = indices[~has_initial_states_d[: metadata.num_decodes]]

for layer_name in attn_group.layer_names:
layer = forward_context[layer_name]
for state in layer.kv_cache[0]:
state[new_indices] = 0


def postprocess_mamba(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
Expand Down
Loading