Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions vllm/v1/attention/backends/gdn_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ def build( # type: ignore[override]
m = common_attn_metadata

query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device, non_blocking=True)
context_lens_tensor = m.compute_num_computed_tokens()
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None

if (
Expand Down Expand Up @@ -370,6 +369,5 @@ def build_for_cudagraph_capture(

num_accepted_tokens = torch.diff(m.query_start_loc)
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
m._num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()

return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/mamba2_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,10 @@ def build(
num_prefills = common.num_prefills
num_decode_tokens = common.num_decode_tokens

num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
query_start_loc_p_cpu = (
Expand Down
18 changes: 8 additions & 10 deletions vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,7 @@ def _compute_prefix_caching_block_indices(
common_attn_metadata: CommonAttentionMetadata,
mamba_block_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
# Block index of the last computed token
block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
# which is <= block index for the first scheduled token
Expand Down Expand Up @@ -193,13 +191,12 @@ def _compute_common_metadata(
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None

if self.vllm_config.cache_config.enable_prefix_caching:
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()

# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
# Additional cache-related varaiables:
mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
Expand All @@ -212,15 +209,16 @@ def _compute_common_metadata(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]

if num_prefills > 0:
if num_computed_tokens is None:
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
num_computed_tokens_cpu = num_computed_tokens.cpu()

query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
num_computed_tokens_cpu[num_reqs - num_prefills : num_reqs] > 0
)
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device
Expand Down