Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions vllm/model_executor/layers/mamba/gdn_linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,13 @@ def _forward_core(
num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens

spec_decode_src_indices = attn_metadata.spec_decode_src_indices
if spec_decode_src_indices is not None:
assert non_spec_state_indices_tensor is not None
n_correct = spec_decode_src_indices.shape[0]
dst_indices = non_spec_state_indices_tensor[:n_correct]
ssm_state[dst_indices] = ssm_state[spec_decode_src_indices]

mixed_qkv = mixed_qkv[:num_actual_tokens]
b = b[:num_actual_tokens]
a = a[:num_actual_tokens]
Expand Down Expand Up @@ -880,6 +887,9 @@ def _forward_core(
if attn_metadata.num_prefills > 0:
assert mixed_qkv_non_spec is not None
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
conv_num_accepted = (
num_accepted_tokens if spec_decode_src_indices is not None else None
)
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor"
mixed_qkv_non_spec = causal_conv1d_fn(
Expand All @@ -891,10 +901,14 @@ def _forward_core(
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
num_accepted_tokens=conv_num_accepted,
metadata=attn_metadata,
).transpose(0, 1)
elif attn_metadata.num_decodes > 0:
assert mixed_qkv_non_spec is not None
conv_num_accepted = (
num_accepted_tokens if spec_decode_src_indices is not None else None
)
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv_non_spec,
conv_state,
Expand All @@ -904,6 +918,7 @@ def _forward_core(
conv_state_indices=non_spec_state_indices_tensor[ # type: ignore[index]
: attn_metadata.num_actual_tokens # type: ignore[attr-defined]
],
num_accepted_tokens=conv_num_accepted,
validate_data=True,
)
else:
Expand Down Expand Up @@ -1065,20 +1080,32 @@ def _forward_core_decode_non_spec(
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens

spec_decode_src_indices = attn_metadata.spec_decode_src_indices
num_accepted_tokens = attn_metadata.num_accepted_tokens
if spec_decode_src_indices is not None:
assert non_spec_state_indices_tensor is not None
n_correct = spec_decode_src_indices.shape[0]
dst_indices = non_spec_state_indices_tensor[:n_correct]
ssm_state[dst_indices] = ssm_state[spec_decode_src_indices]

mixed_qkv = mixed_qkv[:num_actual_tokens]
b = b[:num_actual_tokens]
a = a[:num_actual_tokens]

conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
conv_num_accepted = (
num_accepted_tokens if spec_decode_src_indices is not None else None
)
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], # type: ignore[index]
num_accepted_tokens=conv_num_accepted,
validate_data=False,
)
out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
Expand Down
22 changes: 20 additions & 2 deletions vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
block_idx_last_scheduled_token, # (batch,)
initial_state_idx, # (batch,)
num_computed_tokens, # (batch,)
num_accepted_tokens_ptr, # (batch,) or None
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
# Matrix dimensions
dim: tl.constexpr,
Expand All @@ -55,6 +56,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
IS_APC_ENABLED: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
HAS_NULL_BLOCK: tl.constexpr,
NP2_STATELEN: tl.constexpr,
BLOCK_M: tl.constexpr,
Expand All @@ -74,6 +76,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching

# single-sequence id
idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64)

if IS_SPEC_DECODING:
conv_state_token_offset = (
tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1
)
else:
conv_state_token_offset = 0
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))

# BLOCK_N elements along the feature-dimension (channel)
Expand Down Expand Up @@ -154,7 +163,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1)
if load_init_state:
# load from conv_states
prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok
prior_tokens = (
conv_states_base
+ (state_len - 1 + conv_state_token_offset) * stride_conv_state_tok
)
mask_w = idx_feats < dim
if KERNEL_WIDTH == 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
Expand Down Expand Up @@ -244,7 +256,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching
conv_states_ptr
+ (conv_states_input_coord * stride_conv_state_seq)
+ (idx_feats * stride_conv_state_dim)[None, :]
+ ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None]
+ (
(idx_tokens_conv + seqlen + conv_state_token_offset)
* stride_conv_state_tok
)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = (
(conv_states_input_coord < num_cache_lines)
Expand Down Expand Up @@ -477,6 +492,7 @@ def causal_conv1d_fn(
activation: str | None = "silu",
pad_slot_id: int = PAD_SLOT_ID,
null_block_id: int = NULL_BLOCK_ID,
num_accepted_tokens: torch.Tensor | None = None,
block_idx_first_scheduled_token: torch.Tensor | None = None,
block_idx_last_scheduled_token: torch.Tensor | None = None,
initial_state_idx: torch.Tensor | None = None,
Expand Down Expand Up @@ -712,6 +728,7 @@ def grid(META):
block_idx_last_scheduled_token,
initial_state_idx,
num_computed_tokens,
num_accepted_tokens,
out,
# Matrix dimensions
dim,
Expand All @@ -737,6 +754,7 @@ def grid(META):
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
HAS_NULL_BLOCK=null_block_id is not None,
NP2_STATELEN=np2_statelen,
# launch_cooperative_grid=True
Expand Down
34 changes: 32 additions & 2 deletions vllm/v1/attention/backends/gdn_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ class GDNAttentionMetadata:

num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]

# 1D source block indices for state recovery after spec decode.
# When set, conv/ssm state must be copied from these blocks to the
# blocks in non_spec_state_indices_tensor before the decode kernel.
spec_decode_src_indices: torch.Tensor | None = None

# Pre-computed FLA chunk metadata (avoids GPU->CPU sync in prepare_chunk_indices)
chunk_indices: torch.Tensor | None = None
chunk_offsets: torch.Tensor | None = None
Expand Down Expand Up @@ -196,6 +201,7 @@ def build( # type: ignore[override]
query_start_loc.device, non_blocking=True
)

spec_decode_src_indices = None
if spec_sequence_masks is None:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(m, decode_threshold=1)
Expand All @@ -204,11 +210,34 @@ def build( # type: ignore[override]
spec_token_indx = None
non_spec_token_indx = None
spec_state_indices_tensor = None
non_spec_state_indices_tensor = block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
non_spec_query_start_loc_cpu = query_start_loc_cpu
num_accepted_tokens = None
non_spec_state_indices_tensor = block_table_tensor[:, 0]
if (
self.use_spec_decode
and num_accepted_tokens is not None
and num_decodes > 0
):
col_indices = (num_accepted_tokens[:num_decodes] - 1).clamp(min=0)
spec_decode_src_indices = block_table_tensor[
torch.arange(num_decodes, device=block_table_tensor.device),
col_indices,
]
num_accepted_tokens = num_accepted_tokens[:num_decodes]
if num_prefills > 0:
num_accepted_tokens = torch.cat(
[
num_accepted_tokens,
torch.ones(
num_prefills,
dtype=num_accepted_tokens.dtype,
device=num_accepted_tokens.device,
),
]
)
else:
num_accepted_tokens = None
else:
query_lens = query_start_loc[1:] - query_start_loc[:-1]
assert spec_sequence_masks_cpu is not None
Expand Down Expand Up @@ -443,6 +472,7 @@ def build( # type: ignore[override]
spec_token_indx=spec_token_indx,
non_spec_token_indx=non_spec_token_indx,
num_accepted_tokens=num_accepted_tokens,
spec_decode_src_indices=spec_decode_src_indices,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
Expand Down
11 changes: 11 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2242,6 +2242,17 @@ def _build_attn_group_metadata(
:num_reqs_padded
],
)
elif (
not use_spec_decode
and self.speculative_config is not None
and isinstance(
builder,
(Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder),
)
):
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded],
)

if for_cudagraph_capture:
attn_metadata_i = builder.build_for_cudagraph_capture(
Expand Down
Loading