diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 5495b4fc1890..2b9b0dc1e6e4 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -126,12 +126,12 @@ def create_and_prepopulate_kv_cache( Tuple of (kv_cache, updated_block_table) """ batch_size = len(k_contexts) - seq_lens = common_attn_metadata.seq_lens_cpu + seq_lens = common_attn_metadata.seq_lens.cpu() query_lens = ( common_attn_metadata.query_start_loc_cpu[1:] - common_attn_metadata.query_start_loc_cpu[:-1] ) - context_lens = common_attn_metadata.num_computed_tokens_cpu + context_lens = seq_lens - query_lens block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 783e02ce89bd..bd2feac41100 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -154,12 +154,12 @@ def create_and_prepopulate_kv_cache( MLA KV cache tensor """ batch_size = len(kv_c_contexts) - seq_lens = common_attn_metadata.seq_lens_cpu + seq_lens = common_attn_metadata.seq_lens.cpu() query_lens = ( common_attn_metadata.query_start_loc_cpu[1:] - common_attn_metadata.query_start_loc_cpu[:-1] ) - context_lens = common_attn_metadata.num_computed_tokens_cpu + context_lens = seq_lens - query_lens block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 9b7c5822db98..f4ca3dccfb5e 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -297,7 +297,7 @@ def test_sparse_backend_decode_correctness( positions = np.arange(starts[-1], dtype=np.int32) - np.repeat( starts[:-1], seg_lengths ) - seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32) + seq_lengths = np.asarray(common_attn_metadata.seq_lens.cpu(), dtype=np.int32) prefix_lengths = seq_lengths - seg_lengths positions += np.repeat(prefix_lengths, seg_lengths) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 623ae892ecda..7ef157384be1 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -870,7 +870,9 @@ def build( # Guard access to seq_lens_cpu, which may not always be needed # and can be expensive to retrieve in async mode. needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode - seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None + seq_lens_cpu = ( + common_attn_metadata.seq_lens.cpu() if needs_seq_lens_cpu else None + ) seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None num_blocks_np = ( (seq_lens_np + (page_size - 1)) // page_size diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 8193c05c2b1a..a151a437a76a 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -727,9 +727,7 @@ def build( block_table_tensor, seq_lens, block_size, num_gpu_blocks ) - offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to( - self.device, non_blocking=True - ) + offset_tensor = common_attn_metadata.compute_num_computed_tokens() out = FlexAttentionMetadata( causal=common_attn_metadata.causal, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index e9ec96835f27..a47a2282fe49 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -791,7 +791,9 @@ def build( prefill_metadata = None if num_prefills > 0: - num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu + num_computed_tokens_cpu = ( + common_attn_metadata.compute_num_computed_tokens().cpu() + ) reqs_start = num_decodes # prefill_start diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 9dbb17b78a53..1122538969d6 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -511,7 +511,7 @@ def _build_fp8_separate_prefill_decode( # For pure decode batches, prefill_request_id will be None # For mixed batches, it will have -1 for decode and request_id for prefill if num_prefills > 0: - seq_lens_cpu = common_attn_metadata.seq_lens_cpu + seq_lens_cpu = common_attn_metadata.seq_lens.cpu() seq_lens = common_attn_metadata.seq_lens query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index b3dfc55cd059..9bf440a04d06 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -221,7 +221,7 @@ def build( prefix_kv_lens = torch.tensor( [common_prefix_len], dtype=torch.int32, device=self.device ) - suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len + suffix_kv_lens = common_attn_metadata.seq_lens.cpu() - common_prefix_len suffix_kv_lens = suffix_kv_lens.to(self.device) else: cu_prefix_query_lens = None diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 6b94f786a26b..3cbdafe14da6 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -100,6 +100,8 @@ class CommonAttentionMetadata: _seq_lens_cpu: torch.Tensor | None = None _num_computed_tokens_cpu: torch.Tensor | None = None + _num_computed_tokens_cache: torch.Tensor | None = None + @property @deprecated( """ @@ -130,6 +132,13 @@ def num_computed_tokens_cpu(self) -> torch.Tensor: self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens return self._num_computed_tokens_cpu + def compute_num_computed_tokens(self) -> torch.Tensor: + """Compute num_computed_tokens on device (seq_lens - query_lens).""" + if self._num_computed_tokens_cache is None: + query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1] + self._num_computed_tokens_cache = self.seq_lens - query_lens + return self._num_computed_tokens_cache + # TODO(lucas): remove once we have FULL-CG spec-decode support def unpadded( self, num_actual_tokens: int, num_actual_reqs: int