diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index 7a0b54335baa..d1e62f549e30 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -837,6 +837,13 @@ def _forward_core( num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens + spec_decode_src_indices = attn_metadata.spec_decode_src_indices + if spec_decode_src_indices is not None: + assert non_spec_state_indices_tensor is not None + n_correct = spec_decode_src_indices.shape[0] + dst_indices = non_spec_state_indices_tensor[:n_correct] + ssm_state[dst_indices] = ssm_state[spec_decode_src_indices] + mixed_qkv = mixed_qkv[:num_actual_tokens] b = b[:num_actual_tokens] a = a[:num_actual_tokens] @@ -880,6 +887,9 @@ def _forward_core( if attn_metadata.num_prefills > 0: assert mixed_qkv_non_spec is not None mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) + conv_num_accepted = ( + num_accepted_tokens if spec_decode_src_indices is not None else None + ) # - "cache_indices" updates the conv_state cache in positions # pointed to by "state_indices_tensor" mixed_qkv_non_spec = causal_conv1d_fn( @@ -891,10 +901,14 @@ def _forward_core( has_initial_state=has_initial_state, cache_indices=non_spec_state_indices_tensor, query_start_loc=non_spec_query_start_loc, + num_accepted_tokens=conv_num_accepted, metadata=attn_metadata, ).transpose(0, 1) elif attn_metadata.num_decodes > 0: assert mixed_qkv_non_spec is not None + conv_num_accepted = ( + num_accepted_tokens if spec_decode_src_indices is not None else None + ) mixed_qkv_non_spec = causal_conv1d_update( mixed_qkv_non_spec, conv_state, @@ -904,6 +918,7 @@ def _forward_core( conv_state_indices=non_spec_state_indices_tensor[ # type: ignore[index] : attn_metadata.num_actual_tokens # type: ignore[attr-defined] ], + num_accepted_tokens=conv_num_accepted, validate_data=True, ) else: @@ -1065,6 +1080,14 @@ def _forward_core_decode_non_spec( ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens + spec_decode_src_indices = attn_metadata.spec_decode_src_indices + num_accepted_tokens = attn_metadata.num_accepted_tokens + if spec_decode_src_indices is not None: + assert non_spec_state_indices_tensor is not None + n_correct = spec_decode_src_indices.shape[0] + dst_indices = non_spec_state_indices_tensor[:n_correct] + ssm_state[dst_indices] = ssm_state[spec_decode_src_indices] + mixed_qkv = mixed_qkv[:num_actual_tokens] b = b[:num_actual_tokens] a = a[:num_actual_tokens] @@ -1072,6 +1095,9 @@ def _forward_core_decode_non_spec( conv_weights = self.conv1d.weight.view( self.conv1d.weight.size(0), self.conv1d.weight.size(2) ) + conv_num_accepted = ( + num_accepted_tokens if spec_decode_src_indices is not None else None + ) mixed_qkv_non_spec = causal_conv1d_update( mixed_qkv, conv_state, @@ -1079,6 +1105,7 @@ def _forward_core_decode_non_spec( self.conv1d.bias, self.activation, conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], # type: ignore[index] + num_accepted_tokens=conv_num_accepted, validate_data=False, ) out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 1160105ad101..72b2af01bfbc 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -30,6 +30,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching block_idx_last_scheduled_token, # (batch,) initial_state_idx, # (batch,) num_computed_tokens, # (batch,) + num_accepted_tokens_ptr, # (batch,) or None o_ptr, # (dim, seqlen) - actually pointing to x_ptr # Matrix dimensions dim: tl.constexpr, @@ -55,6 +56,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, IS_APC_ENABLED: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, HAS_NULL_BLOCK: tl.constexpr, NP2_STATELEN: tl.constexpr, BLOCK_M: tl.constexpr, @@ -74,6 +76,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching # single-sequence id idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64) + + if IS_SPEC_DECODING: + conv_state_token_offset = ( + tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1 + ) + else: + conv_state_token_offset = 0 chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0)) # BLOCK_N elements along the feature-dimension (channel) @@ -154,7 +163,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) if load_init_state: # load from conv_states - prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok + prior_tokens = ( + conv_states_base + + (state_len - 1 + conv_state_token_offset) * stride_conv_state_tok + ) mask_w = idx_feats < dim if KERNEL_WIDTH == 2: conv_states_ptrs = prior_tokens # [BLOCK_N] @@ -244,7 +256,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching conv_states_ptr + (conv_states_input_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim)[None, :] - + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] + + ( + (idx_tokens_conv + seqlen + conv_state_token_offset) + * stride_conv_state_tok + )[:, None] ) # [BLOCK_M, BLOCK_N] mask = ( (conv_states_input_coord < num_cache_lines) @@ -477,6 +492,7 @@ def causal_conv1d_fn( activation: str | None = "silu", pad_slot_id: int = PAD_SLOT_ID, null_block_id: int = NULL_BLOCK_ID, + num_accepted_tokens: torch.Tensor | None = None, block_idx_first_scheduled_token: torch.Tensor | None = None, block_idx_last_scheduled_token: torch.Tensor | None = None, initial_state_idx: torch.Tensor | None = None, @@ -712,6 +728,7 @@ def grid(META): block_idx_last_scheduled_token, initial_state_idx, num_computed_tokens, + num_accepted_tokens, out, # Matrix dimensions dim, @@ -737,6 +754,7 @@ def grid(META): KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_APC_ENABLED=block_idx_last_scheduled_token is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, HAS_NULL_BLOCK=null_block_id is not None, NP2_STATELEN=np2_statelen, # launch_cooperative_grid=True diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 85715e91ab40..06777d74805a 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -63,6 +63,11 @@ class GDNAttentionMetadata: num_accepted_tokens: torch.Tensor | None = None # shape: [batch,] + # 1D source block indices for state recovery after spec decode. + # When set, conv/ssm state must be copied from these blocks to the + # blocks in non_spec_state_indices_tensor before the decode kernel. + spec_decode_src_indices: torch.Tensor | None = None + # Pre-computed FLA chunk metadata (avoids GPU->CPU sync in prepare_chunk_indices) chunk_indices: torch.Tensor | None = None chunk_offsets: torch.Tensor | None = None @@ -196,6 +201,7 @@ def build( # type: ignore[override] query_start_loc.device, non_blocking=True ) + spec_decode_src_indices = None if spec_sequence_masks is None: num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills(m, decode_threshold=1) @@ -204,11 +210,34 @@ def build( # type: ignore[override] spec_token_indx = None non_spec_token_indx = None spec_state_indices_tensor = None - non_spec_state_indices_tensor = block_table_tensor[:, 0] spec_query_start_loc = None non_spec_query_start_loc = query_start_loc non_spec_query_start_loc_cpu = query_start_loc_cpu - num_accepted_tokens = None + non_spec_state_indices_tensor = block_table_tensor[:, 0] + if ( + self.use_spec_decode + and num_accepted_tokens is not None + and num_decodes > 0 + ): + col_indices = (num_accepted_tokens[:num_decodes] - 1).clamp(min=0) + spec_decode_src_indices = block_table_tensor[ + torch.arange(num_decodes, device=block_table_tensor.device), + col_indices, + ] + num_accepted_tokens = num_accepted_tokens[:num_decodes] + if num_prefills > 0: + num_accepted_tokens = torch.cat( + [ + num_accepted_tokens, + torch.ones( + num_prefills, + dtype=num_accepted_tokens.dtype, + device=num_accepted_tokens.device, + ), + ] + ) + else: + num_accepted_tokens = None else: query_lens = query_start_loc[1:] - query_start_loc[:-1] assert spec_sequence_masks_cpu is not None @@ -443,6 +472,7 @@ def build( # type: ignore[override] spec_token_indx=spec_token_indx, non_spec_token_indx=non_spec_token_indx, num_accepted_tokens=num_accepted_tokens, + spec_decode_src_indices=spec_decode_src_indices, nums_dict=nums_dict, batch_ptr=batch_ptr, token_chunk_offset_ptr=token_chunk_offset_ptr, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 386db4fecd4b..b6d6565819b4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2242,6 +2242,17 @@ def _build_attn_group_metadata( :num_reqs_padded ], ) + elif ( + not use_spec_decode + and self.speculative_config is not None + and isinstance( + builder, + (Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder), + ) + ): + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded], + ) if for_cudagraph_capture: attn_metadata_i = builder.build_for_cudagraph_capture(