Skip to content
4 changes: 2 additions & 2 deletions tests/v1/attention/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ def create_and_prepopulate_kv_cache(
Tuple of (kv_cache, updated_block_table)
"""
batch_size = len(k_contexts)
seq_lens = common_attn_metadata.seq_lens_cpu
seq_lens = common_attn_metadata.seq_lens.cpu()
query_lens = (
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
context_lens = common_attn_metadata.num_computed_tokens_cpu
context_lens = seq_lens - query_lens
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping

Expand Down
4 changes: 2 additions & 2 deletions tests/v1/attention/test_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,12 @@ def create_and_prepopulate_kv_cache(
MLA KV cache tensor
"""
batch_size = len(kv_c_contexts)
seq_lens = common_attn_metadata.seq_lens_cpu
seq_lens = common_attn_metadata.seq_lens.cpu()
query_lens = (
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
context_lens = common_attn_metadata.num_computed_tokens_cpu
context_lens = seq_lens - query_lens
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping

Expand Down
2 changes: 1 addition & 1 deletion tests/v1/attention/test_sparse_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_sparse_backend_decode_correctness(
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
starts[:-1], seg_lengths
)
seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32)
seq_lengths = np.asarray(common_attn_metadata.seq_lens.cpu(), dtype=np.int32)
prefix_lengths = seq_lengths - seg_lengths
positions += np.repeat(prefix_lengths, seg_lengths)

Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,9 @@ def build(
# Guard access to seq_lens_cpu, which may not always be needed
# and can be expensive to retrieve in async mode.
needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode
seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None
seq_lens_cpu = (
common_attn_metadata.seq_lens.cpu() if needs_seq_lens_cpu else None
)
seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None
num_blocks_np = (
(seq_lens_np + (page_size - 1)) // page_size
Expand Down
4 changes: 1 addition & 3 deletions vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,9 +727,7 @@ def build(
block_table_tensor, seq_lens, block_size, num_gpu_blocks
)

offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
self.device, non_blocking=True
)
offset_tensor = common_attn_metadata.compute_num_computed_tokens()

out = FlexAttentionMetadata(
causal=common_attn_metadata.causal,
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,9 @@ def build(

prefill_metadata = None
if num_prefills > 0:
num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu
num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)

reqs_start = num_decodes # prefill_start

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _build_fp8_separate_prefill_decode(
# For pure decode batches, prefill_request_id will be None
# For mixed batches, it will have -1 for decode and request_id for prefill
if num_prefills > 0:
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens_cpu = common_attn_metadata.seq_lens.cpu()
seq_lens = common_attn_metadata.seq_lens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def build(
prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device
)
suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len
suffix_kv_lens = common_attn_metadata.seq_lens.cpu() - common_prefix_len
suffix_kv_lens = suffix_kv_lens.to(self.device)
else:
cu_prefix_query_lens = None
Expand Down
9 changes: 9 additions & 0 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class CommonAttentionMetadata:
_seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None

_num_computed_tokens_cache: torch.Tensor | None = None

@property
@deprecated(
"""
Expand Down Expand Up @@ -130,6 +132,13 @@ def num_computed_tokens_cpu(self) -> torch.Tensor:
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
return self._num_computed_tokens_cpu

def compute_num_computed_tokens(self) -> torch.Tensor:
"""Compute num_computed_tokens on device (seq_lens - query_lens)."""
if self._num_computed_tokens_cache is None:
query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
self._num_computed_tokens_cache = self.seq_lens - query_lens
return self._num_computed_tokens_cache

# TODO(lucas): remove once we have FULL-CG spec-decode support
def unpadded(
self, num_actual_tokens: int, num_actual_reqs: int
Expand Down
Loading