Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions vllm/model_executor/layers/mamba/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
),
)

@staticmethod
def clear_stale_decode_states(
has_initial_states_d: torch.Tensor | None,
indices: torch.Tensor,
ssm_state: torch.Tensor,
conv_state: torch.Tensor,
) -> None:
if has_initial_states_d is None:
return

for state in (ssm_state, conv_state):
gathered = state[indices]
keep_state = has_initial_states_d.to(gathered.dtype)
keep_state = keep_state.view(-1, *([1] * (gathered.dim() - 1)))
state[indices] = gathered * keep_state

def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this Mamba layer."""
return get_mamba_attn_backend(self.mamba_type)
9 changes: 9 additions & 0 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor):
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
has_initial_states_d = attn_metadata.has_initial_states_d
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p

Expand Down Expand Up @@ -394,6 +395,14 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor):
else:
state_indices_tensor_d_input = state_indices_tensor_d
state_indices_tensor_d_output = state_indices_tensor_d

self.clear_stale_decode_states(
has_initial_states_d,
state_indices_tensor_d_input,
ssm_state,
conv_state,
)

# 2. Convolution sequence transformation
conv_out_d = causal_conv1d_update(
hidden_states_BC_d.transpose(0, 1),
Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def conv_ssm_forward(
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
has_initial_states_d = attn_metadata.has_initial_states_d
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
Expand Down Expand Up @@ -831,6 +832,13 @@ def conv_ssm_forward(
state_indices_tensor_d_input = state_indices_tensor_d
state_indices_tensor_d_output = state_indices_tensor_d

self.clear_stale_decode_states(
has_initial_states_d,
state_indices_tensor_d_input,
ssm_state,
conv_state,
)

# 2. Convolution sequence transformation
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
Expand Down
31 changes: 31 additions & 0 deletions vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class BaseMambaAttentionMetadata:
# The following tensors are used for decode requests and
# speculative decoding compatibility, and will be None if the batch
# has no decode requests.
has_initial_states_d: torch.Tensor | None
state_indices_tensor_d: torch.Tensor | None
query_start_loc_d: torch.Tensor | None # shape: [num_decodes + 1,]

Expand Down Expand Up @@ -148,6 +149,12 @@ def __init__(
device=device,
)

self.has_initial_states_d: torch.Tensor = torch.ones(
(self.decode_cudagraph_max_bs,),
dtype=torch.bool,
device=device,
)

self._init_reorder_batch_threshold(1, self.use_spec_decode)
if self.use_spec_decode:
self.supports_update_block_table = False
Expand Down Expand Up @@ -364,6 +371,7 @@ def _compute_common_metadata(

# Need flags to indicate if there are initial states
has_initial_states_p = None
has_initial_states_d = None
query_start_loc_p = None
query_start_loc_d = None
num_computed_tokens = None
Expand Down Expand Up @@ -414,6 +422,19 @@ def _compute_common_metadata(
]
state_indices_tensor_p = state_indices_tensor_p[:, 0]

# Only set when there are genuinely new decode requests
# (num_computed_tokens == 0, seq_lens > 0). Padded CG slots
# (seq_lens == 0) are excluded so their PAD_SLOT_ID doesn't
# cause zeroing of the last real cache slot.
if num_decodes > 0:
if num_computed_tokens is None:
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
has_initial_states_d = (num_computed_tokens[:num_decodes] > 0) | (
common_attn_metadata.seq_lens[:num_decodes] == 0
)
if has_initial_states_d.all():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might cause a cpu/gpu sync because has_initial_states_d is a GPU tensor?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benchislett You're right but in order to avoid the cpu<>gpu sync I would have to always run this code or run this code (if we decide to keep this version) on every decode step, regardless if has_initial_states_d has any false values. (this could be a cheap "no-op" because we're multiplying by 1 but still or looping over the indices in this PR version. depends what we decide to go with.

Do you maybe have another suggestion? I am seeing other functions using .cpu() which I can also use instead of .all() if better but not sure if its any different

has_initial_states_d = None

if num_decodes > 0 and self.use_spec_decode:
assert num_accepted_tokens is not None
query_start_loc_d = common_attn_metadata.query_start_loc[: num_decodes + 1]
Expand Down Expand Up @@ -459,6 +480,7 @@ def _compute_common_metadata(
num_decode_tokens=num_decode_tokens,
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
has_initial_states_d=has_initial_states_d,
state_indices_tensor_p=state_indices_tensor_p,
state_indices_tensor_d=state_indices_tensor_d,
num_accepted_tokens=num_accepted_tokens,
Expand All @@ -485,6 +507,7 @@ def _update_metadata_for_cudagraph_capture(
Currently, only decode is supported for full cudagraphs with Mamba.
"""
state_indices_tensor_d = metadata.state_indices_tensor_d
has_initial_states_d = metadata.has_initial_states_d
query_start_loc_d = metadata.query_start_loc_d
num_accepted_tokens = metadata.num_accepted_tokens
block_idx_last_scheduled_token = metadata.block_idx_last_scheduled_token
Expand All @@ -501,6 +524,13 @@ def _update_metadata_for_cudagraph_capture(
state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs]
state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID

if has_initial_states_d is not None:
self.has_initial_states_d[: metadata.num_decodes].copy_(
has_initial_states_d, non_blocking=True
)
self.has_initial_states_d[metadata.num_decodes :] = True
has_initial_states_d = self.has_initial_states_d[:padded_bs]

if self.use_spec_decode:
assert query_start_loc_d is not None
assert num_accepted_tokens is not None
Expand Down Expand Up @@ -535,6 +565,7 @@ def _update_metadata_for_cudagraph_capture(
return replace(
metadata,
state_indices_tensor_d=state_indices_tensor_d,
has_initial_states_d=has_initial_states_d,
query_start_loc_d=query_start_loc_d,
num_accepted_tokens=num_accepted_tokens,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
Expand Down
Loading