diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index fd2df329f17f..e616442a1440 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -27,6 +27,8 @@ # "yujiepan/mamba2-codestral-v0.1-tiny-random", ] +GDN_MODELS = ["tiny-random/qwen3-next-moe"] + HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", "pfnet/plamo-2-1b", @@ -35,8 +37,7 @@ "ibm-granite/granite-4.0-tiny-preview", "tiiuae/Falcon-H1-0.5B-Base", "LiquidAI/LFM2-1.2B", - "tiny-random/qwen3-next-moe", -] +] + GDN_MODELS FULL_CUDA_GRAPH_MODELS = [ "ai21labs/Jamba-tiny-dev", @@ -380,7 +381,7 @@ def _get_vLLM_output( return outs, vllm_model -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -446,7 +447,7 @@ def test_apc_single_prompt( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -528,7 +529,7 @@ def test_apc_single_prompt_block_align_alignment( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -595,7 +596,7 @@ def test_apc_multiple_prompts_all_cached_outputs( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -679,7 +680,7 @@ def test_apc_multiple_prompts_block_align_alignment( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index 4c8bf9f43997..65d41bb64ff6 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -32,6 +32,7 @@ def chunk_gated_delta_rule_fwd( initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.LongTensor | None = None, + return_intermediate_states: bool = False, ): g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) # obtain WY representation. u is actually the new v. @@ -66,7 +67,15 @@ def chunk_gated_delta_rule_fwd( cu_seqlens=cu_seqlens, ) if SUPPRESS_LEVEL < 3: - return g, o, A, final_state, None, None, None + return ( + g, + o, + A, + final_state, + None, + h if return_intermediate_states else None, + None, + ) elif SUPPRESS_LEVEL >= 3: return g, o, A, final_state, w, h, v_new @@ -87,6 +96,7 @@ def forward( output_final_state: bool, cu_seqlens: torch.LongTensor | None = None, use_qk_l2norm_in_kernel: bool = False, + return_intermediate_states: bool = False, ): if use_qk_l2norm_in_kernel: q = l2norm_fwd(q) @@ -102,10 +112,22 @@ def forward( initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, + return_intermediate_states=return_intermediate_states, ) ctx.scale = scale ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel - return o.to(q.dtype), final_state + intermediate_states = None + if return_intermediate_states: + assert h is not None + # Convert intermediate states into "chunk-major" form + # Equal-length batches keep their batch dimension; flatten it together + # with the chunk axis so callers receive a contiguous chunk stream. + # Variable-length inputs collapse the batch dimension during preprocessing, + # so the kernel already emits a linearised chunk stream in ``states[:, i]``. + # Flattening mirrors the metadata builder's chunk enumeration order. + # Last three axes of h are [H, K, V], producing [num_chunks_total, H, K, V] + intermediate_states = h.reshape(-1, *h.shape[-3:]) + return o.to(q.dtype), final_state, intermediate_states @torch.compiler.disable @@ -121,6 +143,7 @@ def chunk_gated_delta_rule( cu_seqlens: torch.LongTensor | None = None, head_first: bool = False, use_qk_l2norm_in_kernel: bool = False, + return_intermediate_states: bool = False, ): r""" Args: @@ -155,6 +178,10 @@ def chunk_gated_delta_rule( Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + intermediate_states (Optional[torch.Tensor]): + When ``return_intermediate_states`` is ``True`` a tensor containing + the per-chunk state snapshots shaped ``[num_chunks_total, H, K, V]``. + Otherwise ``None``. Examples:: >>> import torch @@ -169,7 +196,7 @@ def chunk_gated_delta_rule( >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') - >>> o, ht = chunk_gated_delta_rule( + >>> o, ht, _ = chunk_gated_delta_rule( q, k, v, g, beta, initial_state=h0, output_final_state=True @@ -178,7 +205,7 @@ def chunk_gated_delta_rule( >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o_var, ht_var = chunk_gated_delta_rule( + >>> o_var, ht_var, _ = chunk_gated_delta_rule( q, k, v, g, beta, initial_state=h0, output_final_state=True, @@ -223,7 +250,7 @@ def chunk_gated_delta_rule( ) if scale is None: scale = k.shape[-1] ** -0.5 - o, final_state = ChunkGatedDeltaRuleFunction.apply( + o, final_state, intermediate_states = ChunkGatedDeltaRuleFunction.apply( q, k, v, @@ -234,7 +261,8 @@ def chunk_gated_delta_rule( output_final_state, cu_seqlens, use_qk_l2norm_in_kernel, + return_intermediate_states, ) if head_first: o = rearrange(o, "b t h ... -> b h t ...") - return o, final_state + return o, final_state, intermediate_states diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index f452ba871582..bbf45637c65d 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -79,6 +79,7 @@ IsHybrid, MixtureOfExperts, SupportsLoRA, + SupportsMambaPrefixCaching, SupportsPP, ) from .utils import ( @@ -465,12 +466,206 @@ def _forward( non_spec_token_indx = attn_metadata.non_spec_token_indx spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + state_indices_tensor_d = attn_metadata.state_indices_tensor_d + state_indices_tensor_p = attn_metadata.state_indices_tensor_p + block_idx_last_computed_token_d = attn_metadata.block_idx_last_computed_token_d + block_idx_last_scheduled_token_d = ( + attn_metadata.block_idx_last_scheduled_token_d + ) + block_idx_first_scheduled_token_p = ( + attn_metadata.block_idx_first_scheduled_token_p + ) + block_idx_last_computed_token_p = attn_metadata.block_idx_last_computed_token_p + block_idx_last_scheduled_token_p = ( + attn_metadata.block_idx_last_scheduled_token_p + ) self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens + prefix_caching_enabled = bool( + ( + state_indices_tensor_d is not None + and block_idx_last_scheduled_token_d is not None + ) + or ( + state_indices_tensor_p is not None + and block_idx_last_scheduled_token_p is not None + ) + ) + non_spec_state_indices_runtime = non_spec_state_indices_tensor + state_indices_decode: torch.Tensor | None = None + state_indices_prefill: torch.Tensor | None = None + + start_non_spec_prefill = attn_metadata.num_decodes + end_non_spec_prefill = start_non_spec_prefill + attn_metadata.num_prefills + + if ( + prefix_caching_enabled + and non_spec_state_indices_tensor is not None + and non_spec_state_indices_tensor.numel() > 0 + ): + # Work on a copy so that updates to the runtime view don't leak back + # into the attention metadata shared across microbatches. + non_spec_state_indices_runtime = non_spec_state_indices_tensor.clone() + + num_decodes = attn_metadata.num_decodes + if ( + num_decodes > 0 + and state_indices_tensor_d is not None + and block_idx_last_computed_token_d is not None + and block_idx_last_scheduled_token_d is not None + ): + decode_slice = slice(0, num_decodes) + base_decode_slots = non_spec_state_indices_tensor[decode_slice] + gathered_last_computed = ( + block_idx_last_computed_token_d[:num_decodes] + .clamp(min=0) + .to(torch.long) + ) + gathered_last_scheduled = ( + block_idx_last_scheduled_token_d[:num_decodes] + .clamp(min=0) + .to(torch.long) + ) + slot_in = state_indices_tensor_d.gather( + 1, gathered_last_computed.unsqueeze(1) + ).squeeze(1) + slot_out = state_indices_tensor_d.gather( + 1, gathered_last_scheduled.unsqueeze(1) + ).squeeze(1) + valid_in = (block_idx_last_computed_token_d[:num_decodes] >= 0) & ( + slot_in >= 0 + ) + valid_out = (block_idx_last_scheduled_token_d[:num_decodes] >= 0) & ( + slot_out >= 0 + ) + slot_out_safe = torch.where( + valid_out, + slot_out, + base_decode_slots, + ) + slot_in_safe = torch.where( + valid_in, + slot_in, + slot_out_safe, + ) + slot_out_copy = slot_out_safe.clamp(min=0).to( + device=conv_state.device, dtype=torch.long + ) + slot_in_copy = slot_in_safe.clamp(min=0).to( + device=conv_state.device, dtype=torch.long + ) + if slot_out_copy.numel() > 0: + # Recycle the previously computed state into the newly + # scheduled slot so we can skip recomputing the prefix. + conv_state.index_copy_( + 0, + slot_out_copy, + conv_state.index_select(0, slot_in_copy), + ) + ssm_state.index_copy_( + 0, + slot_out_copy, + ssm_state.index_select(0, slot_in_copy), + ) + updated_decode_slots = torch.where( + valid_out, + slot_out, + base_decode_slots, + ) + non_spec_state_indices_runtime[decode_slice] = updated_decode_slots + state_indices_decode = updated_decode_slots + + num_prefills = attn_metadata.num_prefills + if ( + num_prefills > 0 + and state_indices_tensor_p is not None + and block_idx_last_computed_token_p is not None + and block_idx_last_scheduled_token_p is not None + ): + start = attn_metadata.num_decodes + end = start + num_prefills + base_prefill_slots = non_spec_state_indices_tensor[start:end] + gathered_last_computed = ( + block_idx_last_computed_token_p[:num_prefills] + .clamp(min=0) + .to(torch.long) + ) + gathered_last_scheduled = ( + block_idx_last_scheduled_token_p[:num_prefills] + .clamp(min=0) + .to(torch.long) + ) + slot_in = state_indices_tensor_p.gather( + 1, gathered_last_computed.unsqueeze(1) + ).squeeze(1) + slot_out = state_indices_tensor_p.gather( + 1, gathered_last_scheduled.unsqueeze(1) + ).squeeze(1) + valid_in = (block_idx_last_computed_token_p[:num_prefills] >= 0) & ( + slot_in >= 0 + ) + valid_out = (block_idx_last_scheduled_token_p[:num_prefills] >= 0) & ( + slot_out >= 0 + ) + slot_out_safe = torch.where( + valid_out, + slot_out, + base_prefill_slots, + ) + slot_in_safe = torch.where( + valid_in, + slot_in, + slot_out_safe, + ) + slot_out_copy = slot_out_safe.clamp(min=0).to( + device=conv_state.device, dtype=torch.long + ) + slot_in_copy = slot_in_safe.clamp(min=0).to( + device=conv_state.device, dtype=torch.long + ) + if slot_out_copy.numel() > 0: + # Mirror the decode path: move cached prefix states into + # the slots assigned to this prefill chunk. + conv_state.index_copy_( + 0, + slot_out_copy, + conv_state.index_select(0, slot_in_copy), + ) + ssm_state.index_copy_( + 0, + slot_out_copy, + ssm_state.index_select(0, slot_in_copy), + ) + + updated_prefill_slots = torch.where( + valid_out, + slot_out, + base_prefill_slots, + ) + non_spec_state_indices_runtime[start:end] = updated_prefill_slots + state_indices_prefill = updated_prefill_slots + + if state_indices_decode is None and non_spec_state_indices_tensor is not None: + state_indices_decode = non_spec_state_indices_tensor[ + : attn_metadata.num_decodes + ] + + if state_indices_prefill is None and non_spec_state_indices_tensor is not None: + state_indices_prefill = non_spec_state_indices_tensor[ + start_non_spec_prefill:end_non_spec_prefill + ] + + if attn_metadata.num_decodes > 0: + assert state_indices_decode is not None + + if attn_metadata.num_prefills > 0: + assert state_indices_prefill is not None + assert non_spec_state_indices_runtime is not None + # 1. Set up dimensions for reshapes later projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens]) projected_states_ba, _ = self.in_proj_ba(hidden_states[:num_actual_tokens]) @@ -527,20 +722,19 @@ def _forward( activation=self.activation, conv_states=conv_state, has_initial_state=has_initial_state, - cache_indices=non_spec_state_indices_tensor, + cache_indices=non_spec_state_indices_runtime, query_start_loc=non_spec_query_start_loc, metadata=attn_metadata, ).transpose(0, 1) elif attn_metadata.num_decodes > 0: + assert state_indices_decode is not None mixed_qkv_non_spec = causal_conv1d_update( mixed_qkv_non_spec, conv_state, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=non_spec_state_indices_tensor[ - : attn_metadata.num_decodes - ], + conv_state_indices=state_indices_decode, validate_data=True, ) else: @@ -575,7 +769,7 @@ def _forward( # 3. Recurrent attention - # 3.1: process the mutlti-query part + # 3.1: process the multi-query part if spec_sequence_masks is not None: core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( q=query_spec, @@ -595,11 +789,41 @@ def _forward( # 3.2: process the remaining part if attn_metadata.num_prefills > 0: - initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() - initial_state[~has_initial_state, ...] = 0 + chunk_state_indices = non_spec_state_indices_runtime[:end_non_spec_prefill] + initial_state = ssm_state.new_zeros( + (chunk_state_indices.shape[0], *ssm_state.shape[1:]) + ) + if chunk_state_indices.numel() > 0: + valid_chunk_slots = chunk_state_indices >= 0 + valid_chunk_positions = torch.nonzero( + valid_chunk_slots, as_tuple=False + ).squeeze(-1) + if valid_chunk_positions.numel() > 0: + initial_state.index_copy_( + 0, + valid_chunk_positions, + ssm_state.index_select( + 0, + chunk_state_indices.index_select( + 0, valid_chunk_positions + ).to(device=ssm_state.device, dtype=torch.long), + ), + ) + + if has_initial_state is not None: + chunk_has_initial_state = has_initial_state[:end_non_spec_prefill] + initial_state[~chunk_has_initial_state, ...] = 0 + + assert query_non_spec is not None + assert key_non_spec is not None + assert value_non_spec is not None + assert g_non_spec is not None + assert beta_non_spec is not None + cu_seqlens = non_spec_query_start_loc[: end_non_spec_prefill + 1] ( core_attn_out_non_spec, last_recurrent_state, + block_state_history, ) = chunk_gated_delta_rule( q=query_non_spec, k=key_non_spec, @@ -608,15 +832,116 @@ def _forward( beta=beta_non_spec, initial_state=initial_state, output_final_state=True, - cu_seqlens=non_spec_query_start_loc, + cu_seqlens=cu_seqlens, head_first=False, use_qk_l2norm_in_kernel=True, + return_intermediate_states=prefix_caching_enabled, ) - # Init cache - ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( - ssm_state.dtype - ) + if chunk_state_indices.numel() > 0: + valid_chunk_slots = chunk_state_indices >= 0 + valid_chunk_positions = torch.nonzero( + valid_chunk_slots, as_tuple=False + ).squeeze(-1) + if valid_chunk_positions.numel() > 0: + dest_slots = chunk_state_indices.index_select( + 0, valid_chunk_positions + ).to(device=ssm_state.device, dtype=torch.long) + ssm_state.index_copy_( + 0, + dest_slots, + last_recurrent_state.index_select(0, valid_chunk_positions).to( + ssm_state.dtype + ), + ) + if prefix_caching_enabled and ( + block_state_history is not None + and block_state_history.numel() > 0 + and block_idx_first_scheduled_token_p is not None + and block_idx_last_scheduled_token_p is not None + and state_indices_tensor_p is not None + and attn_metadata.last_chunk_indices_p is not None + and attn_metadata.num_computed_tokens_p is not None + and attn_metadata.chunk_size is not None + and attn_metadata.block_size is not None + ): + block_history = block_state_history.to(ssm_state.dtype) + total_chunks = block_history.shape[0] + last_chunk_indices = attn_metadata.last_chunk_indices_p + prefill_chunk_count = ( + int(last_chunk_indices[-1].item()) + 1 + if last_chunk_indices is not None and last_chunk_indices.numel() > 0 + else 0 + ) + decode_chunk_count = max(total_chunks - prefill_chunk_count, 0) + # Prefill chunks trail the decode chunks; skip the actual number of + # decode chunk completions so partial decodes (no chunk output) do + # not offset the history. + block_history_prefill = block_history[decode_chunk_count:] + if block_history_prefill.shape[0] > 0: + # The block history contains recurrent states per chunk; we + # replay it into the persistent cache blocks owned by each + # sequence so future steps can hit the prefix cache. + chunk_size = attn_metadata.chunk_size + block_size = attn_metadata.block_size + chunk_stride = block_size // chunk_size + last_chunk_indices = attn_metadata.last_chunk_indices_p + last_chunk_indices_long = last_chunk_indices.to(torch.long) + num_computed_tokens_p = attn_metadata.num_computed_tokens_p + + for seq_idx in range(attn_metadata.num_prefills): + block_first = int( + block_idx_first_scheduled_token_p[seq_idx].item() + ) + block_last = int( + block_idx_last_scheduled_token_p[seq_idx].item() + ) + n_blocks_to_fill = block_last - block_first + if n_blocks_to_fill <= 0: + continue + + cache_blocks = state_indices_tensor_p[ + seq_idx, block_first:block_last + ].to(torch.long) + + first_chunk = ( + 0 + if seq_idx == 0 + else int(last_chunk_indices[seq_idx - 1].item()) + 1 + ) + first_aligned_chunk = first_chunk + chunk_stride - 1 + num_unaligned_tokens = int( + num_computed_tokens_p[seq_idx].item() % block_size + ) + if num_unaligned_tokens > 0: + first_aligned_chunk -= num_unaligned_tokens // chunk_size + chunk_stop = ( + first_aligned_chunk + n_blocks_to_fill * chunk_stride + ) + cached_states = block_history_prefill[ + first_aligned_chunk:chunk_stop:chunk_stride + ] + ssm_state[cache_blocks] = cached_states + + final_slots = state_indices_tensor_p.gather( + 1, block_idx_last_scheduled_token_p.unsqueeze(1) + ).squeeze(1) + valid_final = final_slots >= 0 + valid_final_positions = torch.nonzero( + valid_final, as_tuple=False + ).squeeze(-1) + if valid_final_positions.numel() > 0: + final_slot_ids = final_slots.index_select( + 0, valid_final_positions + ).to(device=ssm_state.device, dtype=torch.long) + final_states = block_history_prefill.index_select( + 0, + last_chunk_indices_long.index_select( + 0, valid_final_positions + ), + ) + ssm_state.index_copy_(0, final_slot_ids, final_states) elif attn_metadata.num_decodes > 0: + assert state_indices_decode is not None core_attn_out_non_spec, last_recurrent_state = ( fused_recurrent_gated_delta_rule( q=query_non_spec, @@ -629,7 +954,7 @@ def _forward( cu_seqlens=non_spec_query_start_loc[ : attn_metadata.num_decodes + 1 ], - ssm_state_indices=non_spec_state_indices_tensor, + ssm_state_indices=state_indices_decode, use_qk_l2norm_in_kernel=True, ) ) @@ -1096,7 +1421,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: class Qwen3NextForCausalLM( - nn.Module, HasInnerState, SupportsLoRA, SupportsPP, MixtureOfExperts, IsHybrid + nn.Module, + HasInnerState, + SupportsLoRA, + SupportsPP, + MixtureOfExperts, + IsHybrid, + SupportsMambaPrefixCaching, ): packed_modules_mapping = { "qkv_proj": [ @@ -1111,12 +1442,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, ( - "Qwen3Next currently does not support prefix caching" - ) self.quant_config = vllm_config.quant_config super().__init__() diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 2ca19646911e..754ef3828c80 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig +from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -36,6 +37,8 @@ class GDNAttentionMetadata: num_actual_tokens: int has_initial_state: torch.Tensor | None = None + block_size: int | None = None + chunk_size: int | None = None spec_query_start_loc: torch.Tensor | None = None # shape: [num_spec_decodes + 1,] non_spec_query_start_loc: torch.Tensor | None = ( @@ -52,6 +55,21 @@ class GDNAttentionMetadata: num_accepted_tokens: torch.Tensor | None = None # shape: [batch,] + # Decode-side APC metadata + state_indices_tensor_d: torch.Tensor | None = None + state_indices_tensor_p: torch.Tensor | None = None + block_idx_last_computed_token_d: torch.Tensor | None = None + block_idx_last_scheduled_token_d: torch.Tensor | None = None + + # Prefill-side APC metadata + block_idx_first_scheduled_token_p: torch.Tensor | None = None + block_idx_last_computed_token_p: torch.Tensor | None = None + block_idx_last_scheduled_token_p: torch.Tensor | None = None + seq_idx_p: torch.Tensor | None = None + cu_chunk_seqlen_p: torch.Tensor | None = None + last_chunk_indices_p: torch.Tensor | None = None + num_computed_tokens_p: torch.Tensor | None = None + # The following attributes are for triton implementation of causal_conv1d nums_dict: dict | None = None batch_ptr: torch.Tensor | None = None @@ -75,6 +93,7 @@ def __init__( self.compilation_config = vllm_config.compilation_config self.speculative_config = vllm_config.speculative_config self.kv_cache_spec = kv_cache_spec + self.device = device if self.speculative_config: self.num_spec = self.speculative_config.num_speculative_tokens else: @@ -82,6 +101,17 @@ def __init__( self.use_spec_decode = self.num_spec > 0 self._init_reorder_batch_threshold(1, self.use_spec_decode) + # 64 is a hardcoded value in the FLA GDN kernel. + # https://github.com/fla-org/flash-linear-attention/blob/2e7336262c11f8bc6cd6a94b1eb5ee353ae8b4cd/fla/ops/common/chunk_delta_h.py#L439 + self.chunk_size = 64 + if self.vllm_config.cache_config.enable_prefix_caching and ( + kv_cache_spec.block_size % self.chunk_size != 0 + ): + raise ValueError( + "GDN prefix caching requires the mamba block size to be a " + "multiple of the kernel chunk size." + ) + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) @@ -90,6 +120,10 @@ def __init__( self.compilation_config.max_cudagraph_capture_size, ) + self._max_cached_blocks = cdiv( + vllm_config.model_config.max_model_len, kv_cache_spec.block_size + ) + self.spec_state_indices_tensor = torch.empty( (self.decode_cudagraph_max_bs, self.num_spec + 1), dtype=torch.int32, @@ -131,6 +165,80 @@ def __init__( device=device, ) + if self.vllm_config.cache_config.enable_prefix_caching: + self.state_indices_tensor_d_buf = torch.empty( + (self.decode_cudagraph_max_bs, self._max_cached_blocks), + dtype=torch.int32, + device=device, + ) + self.state_indices_tensor_p_buf = torch.empty( + (self.decode_cudagraph_max_bs, self._max_cached_blocks), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_computed_token_d_buf = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_scheduled_token_d_buf = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + + max_num_prefill_chunks = ( + cdiv(vllm_config.model_config.max_model_len, self.chunk_size) + * self.decode_cudagraph_max_bs + ) + self.seq_idx_p_buf = torch.empty( + (max_num_prefill_chunks,), + dtype=torch.int32, + device=device, + ) + self.cu_chunk_seqlen_p_buf = torch.empty( + (max_num_prefill_chunks + 1,), + dtype=torch.int32, + device=device, + ) + self.last_chunk_indices_p_buf = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.num_computed_tokens_p_buf = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.block_idx_first_scheduled_token_p_buf = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_computed_token_p_buf = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_scheduled_token_p_buf = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + else: + self.state_indices_tensor_d_buf = None + self.block_idx_last_computed_token_d_buf = None + self.block_idx_last_scheduled_token_d_buf = None + self.state_indices_tensor_p_buf = None + self.seq_idx_p_buf = None + self.cu_chunk_seqlen_p_buf = None + self.last_chunk_indices_p_buf = None + self.num_computed_tokens_p_buf = None + self.block_idx_first_scheduled_token_p_buf = None + self.block_idx_last_computed_token_p_buf = None + self.block_idx_last_scheduled_token_p_buf = None + def build( # type: ignore[override] self, common_prefix_len: int, @@ -146,6 +254,25 @@ def build( # type: ignore[override] context_lens_tensor = context_lens.to(query_start_loc.device) nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + enable_apc = self.vllm_config.cache_config.enable_prefix_caching + block_size_value: int | None = None + chunk_size_value: int | None = None + if enable_apc: + block_size_value = self.kv_cache_spec.block_size + chunk_size_value = self.chunk_size + state_indices_tensor_d: torch.Tensor | None = None + state_indices_tensor_p: torch.Tensor | None = None + block_idx_last_computed_token_d: torch.Tensor | None = None + block_idx_last_scheduled_token_d: torch.Tensor | None = None + block_idx_first_scheduled_token_p: torch.Tensor | None = None + block_idx_last_computed_token_p: torch.Tensor | None = None + block_idx_last_scheduled_token_p: torch.Tensor | None = None + num_computed_tokens_p: torch.Tensor | None = None + seq_idx_p: torch.Tensor | None = None + cu_chunk_seqlen_p: torch.Tensor | None = None + last_chunk_indices_p: torch.Tensor | None = None + non_spec_query_start_loc_cpu: torch.Tensor | None = None + if ( not self.use_spec_decode or num_decode_draft_tokens_cpu is None @@ -158,7 +285,7 @@ def build( # type: ignore[override] num_spec_decodes = 0 else: spec_sequence_masks = num_decode_draft_tokens_cpu >= 0 - num_spec_decodes = spec_sequence_masks.sum().item() + num_spec_decodes = int(spec_sequence_masks.sum().item()) if num_spec_decodes == 0: spec_sequence_masks = None else: @@ -177,6 +304,7 @@ def build( # type: ignore[override] non_spec_state_indices_tensor = m.block_table_tensor[:, 0] spec_query_start_loc = None non_spec_query_start_loc = query_start_loc + non_spec_query_start_loc_cpu = m.query_start_loc_cpu num_accepted_tokens = None else: query_lens = query_start_loc[1:] - query_start_loc[:-1] @@ -241,10 +369,156 @@ def build( # type: ignore[override] dim=0, out=non_spec_query_start_loc[1:], ) + query_lens_cpu = m.query_start_loc_cpu[1:] - m.query_start_loc_cpu[:-1] + non_spec_query_start_loc_cpu = torch.zeros( + query_lens_cpu.size(0) - num_spec_decodes + 1, + dtype=torch.int32, + ) + torch.cumsum( + query_lens_cpu[~spec_sequence_masks.cpu()], + dim=0, + out=non_spec_query_start_loc_cpu[1:], + ) assert num_accepted_tokens is not None num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] + if enable_apc: + block_table_tensor_full = m.block_table_tensor + block_size = self.kv_cache_spec.block_size + num_computed_tokens_device = m.num_computed_tokens_cpu.to( + self.device, dtype=torch.int32 + ) + seq_lens_device = m.seq_lens.to(self.device, dtype=torch.int32) + + block_idx_last_computed_all = ( + (cdiv(num_computed_tokens_device, block_size) - 1) + .clamp(min=0) + .to(torch.int32) + ) + block_idx_first_scheduled_all = ( + cdiv(num_computed_tokens_device + 1, block_size) - 1 + ).to(torch.int32) + block_idx_last_scheduled_all = (cdiv(seq_lens_device, block_size) - 1).to( + torch.int32 + ) + + if spec_sequence_masks is not None: + non_spec_mask = ~spec_sequence_masks + non_spec_block_table = block_table_tensor_full[non_spec_mask] + block_idx_last_computed_non_spec = block_idx_last_computed_all[ + non_spec_mask + ] + block_idx_last_scheduled_non_spec = block_idx_last_scheduled_all[ + non_spec_mask + ] + block_idx_first_scheduled_non_spec = block_idx_first_scheduled_all[ + non_spec_mask + ] + num_computed_tokens_non_spec = num_computed_tokens_device[non_spec_mask] + spec_sequence_masks_cpu = spec_sequence_masks.cpu() + non_spec_mask_cpu = ~spec_sequence_masks_cpu + num_computed_tokens_cpu_non_spec = m.num_computed_tokens_cpu[ + non_spec_mask_cpu + ] + else: + non_spec_block_table = block_table_tensor_full + block_idx_last_computed_non_spec = block_idx_last_computed_all + block_idx_last_scheduled_non_spec = block_idx_last_scheduled_all + block_idx_first_scheduled_non_spec = block_idx_first_scheduled_all + num_computed_tokens_non_spec = num_computed_tokens_device + num_computed_tokens_cpu_non_spec = m.num_computed_tokens_cpu + + if num_decodes > 0: + state_indices_tensor_d = non_spec_block_table[:num_decodes].contiguous() + block_idx_last_computed_token_d = block_idx_last_computed_non_spec[ + :num_decodes + ].contiguous() + block_idx_last_scheduled_token_d = block_idx_last_scheduled_non_spec[ + :num_decodes + ].contiguous() + + if num_prefills > 0: + start = num_decodes + end = start + num_prefills + state_indices_tensor_p = non_spec_block_table[start:end].contiguous() + block_idx_first_scheduled_token_p = block_idx_first_scheduled_non_spec[ + start:end + ].contiguous() + block_idx_last_computed_token_p = block_idx_last_computed_non_spec[ + start:end + ].contiguous() + block_idx_last_scheduled_token_p = block_idx_last_scheduled_non_spec[ + start:end + ].contiguous() + num_computed_tokens_p = num_computed_tokens_non_spec[ + start:end + ].contiguous() + + if spec_sequence_masks is None: + num_computed_tokens_p_cpu = m.num_computed_tokens_cpu[ + m.num_reqs - num_prefills : + ] + query_start_loc_p_cpu = ( + m.query_start_loc_cpu[-num_prefills - 1 :] - num_decode_tokens + ) + else: + num_computed_tokens_p_cpu = num_computed_tokens_cpu_non_spec[ + num_decodes: + ] + assert non_spec_query_start_loc_cpu is not None + query_start_loc_p_cpu = ( + non_spec_query_start_loc_cpu[-num_prefills - 1 :] + - num_decode_tokens + ) + + cu_chunk_seqlen: list[int] = [] + seq_idx_list: list[int] = [] + last_chunk_indices_list: list[int] = [] + seqlen_pos = 0 + + for req_idx in range(num_prefills): + this_num_computed = int(num_computed_tokens_p_cpu[req_idx].item()) + this_new_tokens = int( + query_start_loc_p_cpu[req_idx + 1].item() + - query_start_loc_p_cpu[req_idx].item() + ) + + if this_num_computed % self.chunk_size != 0: + seq_idx_list.append(req_idx) + cu_chunk_seqlen.append(seqlen_pos) + chunk_len = ( + cdiv(this_num_computed, self.chunk_size) * self.chunk_size + - this_num_computed + ) + chunk_len = min(chunk_len, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + n_chunks = cdiv(this_new_tokens, self.chunk_size) + for _ in range(n_chunks): + seq_idx_list.append(req_idx) + cu_chunk_seqlen.append(seqlen_pos) + chunk_len = min(self.chunk_size, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + assert this_new_tokens == 0 + last_chunk_indices_list.append(len(cu_chunk_seqlen) - 1) + + cu_chunk_seqlen.append(seqlen_pos) + + device = query_start_loc.device + seq_idx_p = torch.as_tensor( + seq_idx_list, device=device, dtype=torch.int32 + ) + cu_chunk_seqlen_p = torch.as_tensor( + cu_chunk_seqlen, device=device, dtype=torch.int32 + ) + last_chunk_indices_p = torch.as_tensor( + last_chunk_indices_list, device=device, dtype=torch.int32 + ) + if num_prefills > 0: has_initial_state = context_lens_tensor > 0 if spec_sequence_masks is not None: @@ -337,6 +611,35 @@ def build( # type: ignore[override] non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1] non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens) + if enable_apc and num_decodes > 0: + assert state_indices_tensor_d is not None + num_blocks = state_indices_tensor_d.shape[1] + self.state_indices_tensor_d_buf[:num_decodes, :num_blocks].copy_( + state_indices_tensor_d, non_blocking=True + ) + state_indices_tensor_d = self.state_indices_tensor_d_buf[ + :batch_size, :num_blocks + ] + state_indices_tensor_d[num_decodes:, :].fill_(PAD_SLOT_ID) + + assert block_idx_last_scheduled_token_d is not None + self.block_idx_last_scheduled_token_d_buf[:num_decodes].copy_( + block_idx_last_scheduled_token_d, non_blocking=True + ) + block_idx_last_scheduled_token_d = ( + self.block_idx_last_scheduled_token_d_buf[:batch_size] + ) + block_idx_last_scheduled_token_d[num_decodes:] = 0 + + assert block_idx_last_computed_token_d is not None + self.block_idx_last_computed_token_d_buf[:num_decodes].copy_( + block_idx_last_computed_token_d, non_blocking=True + ) + block_idx_last_computed_token_d = ( + self.block_idx_last_computed_token_d_buf[:batch_size] + ) + block_idx_last_computed_token_d[num_decodes:] = 0 + attn_metadata = GDNAttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, @@ -346,6 +649,8 @@ def build( # type: ignore[override] num_spec_decode_tokens=num_spec_decode_tokens, num_actual_tokens=num_actual_tokens, has_initial_state=has_initial_state, + block_size=block_size_value, + chunk_size=chunk_size_value, spec_query_start_loc=spec_query_start_loc, non_spec_query_start_loc=non_spec_query_start_loc, spec_state_indices_tensor=spec_state_indices_tensor, @@ -354,6 +659,17 @@ def build( # type: ignore[override] spec_token_indx=spec_token_indx, non_spec_token_indx=non_spec_token_indx, num_accepted_tokens=num_accepted_tokens, + state_indices_tensor_d=state_indices_tensor_d, + state_indices_tensor_p=state_indices_tensor_p, + block_idx_last_computed_token_d=block_idx_last_computed_token_d, + block_idx_last_scheduled_token_d=block_idx_last_scheduled_token_d, + block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, + block_idx_last_computed_token_p=block_idx_last_computed_token_p, + block_idx_last_scheduled_token_p=block_idx_last_scheduled_token_p, + seq_idx_p=seq_idx_p, + cu_chunk_seqlen_p=cu_chunk_seqlen_p, + last_chunk_indices_p=last_chunk_indices_p, + num_computed_tokens_p=num_computed_tokens_p, nums_dict=nums_dict, batch_ptr=batch_ptr, token_chunk_offset_ptr=token_chunk_offset_ptr,