From 603f8a9c2644959cf0f307ac42bf1ae41490431e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 6 Jan 2026 00:09:58 +0000 Subject: [PATCH 1/5] deprecation: replace num_computed_tokens_cpu/seq_lens_cpu in SSM backends Replace deprecated CommonAttentionMetadata.num_computed_tokens_cpu and seq_lens_cpu properties with explicit computation: num_computed_tokens = seq_lens - query_lens Changes: - mamba_attn.py: Add helper method _get_num_computed_tokens_cpu() and use it - mamba2_attn.py: Use inherited helper from base class - gdn_attn.py: Inline the computation This is part of the deprecation effort for seq_lens_cpu and num_computed_tokens_cpu properties (to be removed in v0.14.0). Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/gdn_attn.py | 7 ++++-- vllm/v1/attention/backends/mamba2_attn.py | 5 +++- vllm/v1/attention/backends/mamba_attn.py | 28 +++++++++++++++-------- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index fcde986f48d4..e47e9879ad70 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -142,7 +142,9 @@ def build( # type: ignore[override] m = common_attn_metadata query_start_loc = m.query_start_loc - context_lens = m.num_computed_tokens_cpu + # Compute num_computed_tokens from query_start_loc and seq_lens + query_lens = m.query_start_loc_cpu[1:] - m.query_start_loc_cpu[:-1] + context_lens = m.seq_lens.cpu() - query_lens context_lens_tensor = context_lens.to(query_start_loc.device, non_blocking=True) nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None @@ -370,6 +372,7 @@ def build_for_cudagraph_capture( num_accepted_tokens = torch.diff(m.query_start_loc) num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu() - m._num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu() + # Note: Setting _num_computed_tokens_cpu directly for cudagraph capture + m._num_computed_tokens_cpu = m.seq_lens.cpu() - num_accepted_tokens.cpu() return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index b526f0a32997..48b98389f20b 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -215,7 +215,10 @@ def build( num_prefills = common.num_prefills num_decode_tokens = common.num_decode_tokens - num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[ + num_computed_tokens_cpu = self._get_num_computed_tokens_cpu( + common_attn_metadata + ) + num_computed_tokens_p_cpu = num_computed_tokens_cpu[ num_reqs - num_prefills : num_reqs ] query_start_loc_p_cpu = ( diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 4f876d66da14..6b99e14782cb 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -103,6 +103,16 @@ def __init__( device=device, ) + def _get_num_computed_tokens_cpu( + self, common_attn_metadata: CommonAttentionMetadata + ) -> torch.Tensor: + """Compute num_computed_tokens_cpu from query_start_loc and seq_lens.""" + query_lens = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) + return common_attn_metadata.seq_lens.cpu() - query_lens + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> M: @@ -138,9 +148,10 @@ def _compute_prefix_caching_block_indices( common_attn_metadata: CommonAttentionMetadata, mamba_block_size: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( - self.device + num_computed_tokens_cpu = self._get_num_computed_tokens_cpu( + common_attn_metadata ) + num_computed_tokens = num_computed_tokens_cpu.to(self.device) # Block index of the last computed token block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1 # which is <= block index for the first scheduled token @@ -192,14 +203,16 @@ def _compute_common_metadata( # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + num_computed_tokens_cpu = self._get_num_computed_tokens_cpu( + common_attn_metadata + ) + if self.vllm_config.cache_config.enable_prefix_caching: # Return a tensor of shape (#requests, #max blocks) state_indices_tensor = common_attn_metadata.block_table_tensor # Additional cache-related varaiables: mamba_block_size = self.kv_cache_spec.block_size - num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( - self.device - ) + num_computed_tokens = num_computed_tokens_cpu.to(self.device) ( block_idx_last_computed_token, block_idx_first_scheduled_token, @@ -217,10 +230,7 @@ def _compute_common_metadata( - num_decode_tokens ) has_initial_states_cpu = ( - common_attn_metadata.num_computed_tokens_cpu[ - num_reqs - num_prefills : num_reqs - ] - > 0 + num_computed_tokens_cpu[num_reqs - num_prefills : num_reqs] > 0 ) has_initial_states_p = has_initial_states_cpu.to( common_attn_metadata.query_start_loc.device From 12c8b9c796bde04282ef0a7bad8e8f3d4d3524d3 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 6 Jan 2026 05:12:10 +0000 Subject: [PATCH 2/5] cleanup Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/gdn_attn.py | 3 --- vllm/v1/attention/backends/mamba2_attn.py | 4 ++-- vllm/v1/attention/backends/mamba_attn.py | 27 +++++++---------------- 3 files changed, 10 insertions(+), 24 deletions(-) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index e47e9879ad70..8df9b058b050 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -142,7 +142,6 @@ def build( # type: ignore[override] m = common_attn_metadata query_start_loc = m.query_start_loc - # Compute num_computed_tokens from query_start_loc and seq_lens query_lens = m.query_start_loc_cpu[1:] - m.query_start_loc_cpu[:-1] context_lens = m.seq_lens.cpu() - query_lens context_lens_tensor = context_lens.to(query_start_loc.device, non_blocking=True) @@ -372,7 +371,5 @@ def build_for_cudagraph_capture( num_accepted_tokens = torch.diff(m.query_start_loc) num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu() - # Note: Setting _num_computed_tokens_cpu directly for cudagraph capture - m._num_computed_tokens_cpu = m.seq_lens.cpu() - num_accepted_tokens.cpu() return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 48b98389f20b..4600911614b4 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -215,8 +215,8 @@ def build( num_prefills = common.num_prefills num_decode_tokens = common.num_decode_tokens - num_computed_tokens_cpu = self._get_num_computed_tokens_cpu( - common_attn_metadata + num_computed_tokens_cpu = ( + common_attn_metadata.compute_num_computed_tokens().cpu() ) num_computed_tokens_p_cpu = num_computed_tokens_cpu[ num_reqs - num_prefills : num_reqs diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 6b99e14782cb..9589d3e2d5b0 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -103,16 +103,6 @@ def __init__( device=device, ) - def _get_num_computed_tokens_cpu( - self, common_attn_metadata: CommonAttentionMetadata - ) -> torch.Tensor: - """Compute num_computed_tokens_cpu from query_start_loc and seq_lens.""" - query_lens = ( - common_attn_metadata.query_start_loc_cpu[1:] - - common_attn_metadata.query_start_loc_cpu[:-1] - ) - return common_attn_metadata.seq_lens.cpu() - query_lens - def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> M: @@ -148,10 +138,7 @@ def _compute_prefix_caching_block_indices( common_attn_metadata: CommonAttentionMetadata, mamba_block_size: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - num_computed_tokens_cpu = self._get_num_computed_tokens_cpu( - common_attn_metadata - ) - num_computed_tokens = num_computed_tokens_cpu.to(self.device) + num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() # Block index of the last computed token block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1 # which is <= block index for the first scheduled token @@ -203,16 +190,13 @@ def _compute_common_metadata( # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None - num_computed_tokens_cpu = self._get_num_computed_tokens_cpu( - common_attn_metadata - ) - if self.vllm_config.cache_config.enable_prefix_caching: + num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() + # Return a tensor of shape (#requests, #max blocks) state_indices_tensor = common_attn_metadata.block_table_tensor # Additional cache-related varaiables: mamba_block_size = self.kv_cache_spec.block_size - num_computed_tokens = num_computed_tokens_cpu.to(self.device) ( block_idx_last_computed_token, block_idx_first_scheduled_token, @@ -225,6 +209,11 @@ def _compute_common_metadata( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] if num_prefills > 0: + num_computed_tokens_cpu = ( + num_computed_tokens + or common_attn_metadata.compute_num_computed_tokens() + ).cpu() + query_start_loc_p = ( common_attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decode_tokens From ef81b2a5d09f5e9c37d5b2ad543277bc81ffceb2 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 6 Jan 2026 05:20:15 +0000 Subject: [PATCH 3/5] cleanup Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/gdn_attn.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 8df9b058b050..b2bbbe1c5b14 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -142,9 +142,7 @@ def build( # type: ignore[override] m = common_attn_metadata query_start_loc = m.query_start_loc - query_lens = m.query_start_loc_cpu[1:] - m.query_start_loc_cpu[:-1] - context_lens = m.seq_lens.cpu() - query_lens - context_lens_tensor = context_lens.to(query_start_loc.device, non_blocking=True) + context_lens_tensor = m.compute_num_computed_tokens() nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None if ( From d603fe74d13376829c938a0932a60648da9bce22 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 6 Jan 2026 10:45:10 -0500 Subject: [PATCH 4/5] Apply suggestion from @DarkLight1337 Co-authored-by: Cyrus Leung Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mamba_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 9589d3e2d5b0..708451349cba 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -209,10 +209,10 @@ def _compute_common_metadata( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] if num_prefills > 0: - num_computed_tokens_cpu = ( - num_computed_tokens - or common_attn_metadata.compute_num_computed_tokens() - ).cpu() + if num_computed_tokens is None: + num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() + + num_computed_tokens_cpu = num_computed_tokens.cpu() query_start_loc_p = ( common_attn_metadata.query_start_loc[-num_prefills - 1 :] From 85d2931147ccb9cf6208bbf60871215e38326f7c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 6 Jan 2026 15:45:43 +0000 Subject: [PATCH 5/5] cleanup Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mamba_attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 708451349cba..dd7b96e9824a 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -211,7 +211,6 @@ def _compute_common_metadata( if num_prefills > 0: if num_computed_tokens is None: num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() - num_computed_tokens_cpu = num_computed_tokens.cpu() query_start_loc_p = (