diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index fcde986f48d4..b2bbbe1c5b14 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -142,8 +142,7 @@ def build( # type: ignore[override] m = common_attn_metadata query_start_loc = m.query_start_loc - context_lens = m.num_computed_tokens_cpu - context_lens_tensor = context_lens.to(query_start_loc.device, non_blocking=True) + context_lens_tensor = m.compute_num_computed_tokens() nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None if ( @@ -370,6 +369,5 @@ def build_for_cudagraph_capture( num_accepted_tokens = torch.diff(m.query_start_loc) num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu() - m._num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu() return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index b526f0a32997..4600911614b4 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -215,7 +215,10 @@ def build( num_prefills = common.num_prefills num_decode_tokens = common.num_decode_tokens - num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[ + num_computed_tokens_cpu = ( + common_attn_metadata.compute_num_computed_tokens().cpu() + ) + num_computed_tokens_p_cpu = num_computed_tokens_cpu[ num_reqs - num_prefills : num_reqs ] query_start_loc_p_cpu = ( diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 4f876d66da14..dd7b96e9824a 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -138,9 +138,7 @@ def _compute_prefix_caching_block_indices( common_attn_metadata: CommonAttentionMetadata, mamba_block_size: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( - self.device - ) + num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() # Block index of the last computed token block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1 # which is <= block index for the first scheduled token @@ -193,13 +191,12 @@ def _compute_common_metadata( nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None if self.vllm_config.cache_config.enable_prefix_caching: + num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() + # Return a tensor of shape (#requests, #max blocks) state_indices_tensor = common_attn_metadata.block_table_tensor # Additional cache-related varaiables: mamba_block_size = self.kv_cache_spec.block_size - num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( - self.device - ) ( block_idx_last_computed_token, block_idx_first_scheduled_token, @@ -212,15 +209,16 @@ def _compute_common_metadata( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] if num_prefills > 0: + if num_computed_tokens is None: + num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() + num_computed_tokens_cpu = num_computed_tokens.cpu() + query_start_loc_p = ( common_attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decode_tokens ) has_initial_states_cpu = ( - common_attn_metadata.num_computed_tokens_cpu[ - num_reqs - num_prefills : num_reqs - ] - > 0 + num_computed_tokens_cpu[num_reqs - num_prefills : num_reqs] > 0 ) has_initial_states_p = has_initial_states_cpu.to( common_attn_metadata.query_start_loc.device