Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor):
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
has_initial_states_d = attn_metadata.has_initial_states_d
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p

Expand Down Expand Up @@ -394,6 +395,13 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor):
else:
state_indices_tensor_d_input = state_indices_tensor_d
state_indices_tensor_d_output = state_indices_tensor_d

# Clear stale state for new requests classified as decodes
if has_initial_states_d is not None:
new_indices = state_indices_tensor_d_input[~has_initial_states_d]
conv_state[:, new_indices] = 0
ssm_state[new_indices] = 0

# 2. Convolution sequence transformation
conv_out_d = causal_conv1d_update(
hidden_states_BC_d.transpose(0, 1),
Expand Down
7 changes: 7 additions & 0 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def conv_ssm_forward(
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
has_initial_states_d = attn_metadata.has_initial_states_d
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
Expand Down Expand Up @@ -831,6 +832,12 @@ def conv_ssm_forward(
state_indices_tensor_d_input = state_indices_tensor_d
state_indices_tensor_d_output = state_indices_tensor_d

# Clear stale state for new requests classified as decodes
if has_initial_states_d is not None:
new_indices = state_indices_tensor_d_input[~has_initial_states_d]
conv_state[:, new_indices] = 0
ssm_state[new_indices] = 0

# 2. Convolution sequence transformation
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
Expand Down
7 changes: 7 additions & 0 deletions vllm/model_executor/models/plamo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def forward_impl(
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
has_initial_states_p = attn_metadata.has_initial_states_p
has_initial_states_d = attn_metadata.has_initial_states_d
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
Expand Down Expand Up @@ -393,6 +394,12 @@ def forward_impl(

# Process decode requests
if has_decode:
# Clear stale state for new requests classified as decodes
if has_initial_states_d is not None:
new_indices = state_indices_tensor_d[~has_initial_states_d]
conv_state[:, new_indices] = 0
ssm_state[new_indices] = 0

# 2. Convolution sequence transformation
hidden_states_d = causal_conv1d_update(
hidden_states_d,
Expand Down
6 changes: 0 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3683,12 +3683,6 @@ def execute_model(
)
)

mamba_utils.clear_stale_mamba_states(
attn_metadata,
self.attn_groups,
self.compilation_config.static_forward_context,
)

(
input_ids,
inputs_embeds,
Expand Down
38 changes: 0 additions & 38 deletions vllm/v1/worker/mamba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
)
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
from vllm.v1.worker.utils import AttentionGroup


@triton.jit
Expand Down Expand Up @@ -216,42 +214,6 @@ def preprocess_mamba(
do_mamba_copy_block(copy_bufs)


def clear_stale_mamba_states(
attn_metadata: list[dict[str, Any]] | dict[str, Any],
attn_groups: list[list["AttentionGroup"]],
forward_context: dict[str, Any],
) -> None:
"""Clear Mamba states for new requests in the decode batch.

New requests (has_initial_states=False) would otherwise read stale
state from a recycled cache slot.
"""
if not isinstance(attn_metadata, dict):
return

for kv_cache_groups in attn_groups:
for attn_group in kv_cache_groups:
if not isinstance(attn_group.kv_cache_spec, MambaSpec):
continue

metadata = attn_metadata.get(attn_group.layer_names[0])
if not isinstance(metadata, BaseMambaAttentionMetadata):
continue

has_initial_states_d = metadata.has_initial_states_d
if has_initial_states_d is None:
continue

assert metadata.state_indices_tensor_d is not None
indices = metadata.state_indices_tensor_d[: metadata.num_decodes]
new_indices = indices[~has_initial_states_d[: metadata.num_decodes]]

for layer_name in attn_group.layer_names:
layer = forward_context[layer_name]
for state in layer.kv_cache[0]:
state[new_indices] = 0


def postprocess_mamba(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
Expand Down