From 774b43ca7f2dc90fdea0a90450974dce5339408e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sat, 10 Jan 2026 04:42:13 +0000 Subject: [PATCH 01/13] wip Signed-off-by: Lucas Wilkinson --- tests/v1/spec_decode/test_eagle.py | 3 ++- .../layers/attention/cross_attention.py | 11 ++++++---- vllm/v1/attention/backends/mla/indexer.py | 5 +++-- vllm/v1/attention/backends/utils.py | 2 +- vllm/v1/spec_decode/eagle.py | 22 +++++++++---------- vllm/v1/worker/gpu_model_runner.py | 5 +++-- 6 files changed, 26 insertions(+), 22 deletions(-) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 8b180168dffc..f3658c620559 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -166,6 +166,7 @@ def test_prepare_next_token_ids(): block_size=16, device=device, ) + seq_lens_cpu = common_attn_metadata.seq_lens.cpu() expected_valid_sampled_tokens_count = torch.tensor( [2, 5, 0, 0], dtype=torch.int32, device=device @@ -173,7 +174,7 @@ def test_prepare_next_token_ids(): next_token_ids_from_padded, valid_sampled_tokens_count = ( proposer.prepare_next_token_ids_padded( - common_attn_metadata, + seq_lens_cpu, sampled_token_ids_tensor, mock_requests, mock_input_batch, diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index 9333b35e65b5..11ca0de5d70a 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -85,15 +85,18 @@ def build( new_metadata.causal = False max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max()) new_metadata.max_seq_len = max_encoder_len + # Derive num_computed_tokens from seq_lens and query_lens + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + seq_lens_cpu = common_attn_metadata.seq_lens.cpu() + num_computed_tokens_cpu = seq_lens_cpu - query_lens_cpu # Any computed tokens indicated decode step>1 (no chunked prefill) - num_cache_decodes = ( - (common_attn_metadata.num_computed_tokens_cpu > 0).sum().item() - ) + num_cache_decodes = (num_computed_tokens_cpu > 0).sum().item() if num_cache_decodes > 0: # CrossAttn KV cache has already been populated on first decoder step, # skip slot_mapping calculation for requests that do not need # reshape_and_cache. - num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy() + num_tokens = num_computed_tokens_cpu.numpy() new_metadata.encoder_seq_lens_cpu = np.where( num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu ) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 368b217f0ba6..3b6db79f46da 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -278,6 +278,7 @@ def build( num_tokens = common_attn_metadata.num_actual_tokens query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + seq_lens_cpu = common_attn_metadata.seq_lens.cpu() num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold @@ -290,7 +291,7 @@ def build( prefill_metadata = None if num_prefills > 0: chunk_seq_ids = split_prefill_chunks( - common_attn_metadata.seq_lens_cpu[num_decodes:], + seq_lens_cpu[num_decodes:], self.max_prefill_buffer_size, request_offset=num_decodes, ) @@ -299,7 +300,7 @@ def build( reqs_start, reqs_end, query_start_loc_cpu, - common_attn_metadata.seq_lens_cpu, + seq_lens_cpu, common_attn_metadata.block_table_tensor, ) for reqs_start, reqs_end in chunk_seq_ids diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index e0aa2c988a21..3217725a985e 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -443,7 +443,7 @@ def split_decodes_prefills_and_extends( num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu - seq_lens = common_attn_metadata.seq_lens_cpu + seq_lens = common_attn_metadata.seq_lens.cpu() if max_query_len <= decode_threshold: return num_reqs, 0, 0, num_tokens, 0, 0 diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index b5532d652618..52970d78cb0a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -878,7 +878,7 @@ def prepare_next_token_ids_cpu( def prepare_next_token_ids_padded( self, - common_attn_metadata: CommonAttentionMetadata, + seq_lens_cpu: torch.Tensor, sampled_token_ids: torch.Tensor, requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, @@ -896,7 +896,7 @@ def prepare_next_token_ids_padded( self.backup_next_token_ids.np[:num_reqs] = np.array( [ requests[gpu_input_batch.req_ids[i]].get_token_id( - common_attn_metadata.seq_lens_cpu[i].item() + seq_lens_cpu[i].item() ) for i in range(num_reqs) ], @@ -977,12 +977,10 @@ def prepare_inputs_padded( query_start_loc=common_attn_metadata.query_start_loc, seq_lens=common_attn_metadata.seq_lens, query_start_loc_cpu=query_start_loc_cpu, - _seq_lens_cpu=common_attn_metadata._seq_lens_cpu, - _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), - max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), + max_seq_len=common_attn_metadata.max_seq_len, block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens], causal=True, @@ -1196,15 +1194,15 @@ def prepare_inputs( # q1, q1 + 1, ..., q1 + q2 - n2 - 1, # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] - num_rejected_tokens = [ + num_rejected_tokens_list = [ n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] - num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32) + num_rejected_tokens = torch.tensor(num_rejected_tokens_list, dtype=torch.int32) device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens + new_seq_lens = common_attn_metadata.seq_lens - num_rejected_tokens.to(device) # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] @@ -1254,14 +1252,14 @@ def prepare_inputs( spec_common_attn_metadata = CommonAttentionMetadata( query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), - seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), + seq_lens=new_seq_lens, query_start_loc_cpu=new_query_start_loc_cpu, - _seq_lens_cpu=new_seq_lens_cpu, - _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), - max_seq_len=new_seq_lens_cpu.max().item(), + # max_seq_len is just an upper bound; so use a rough estimate that doesn't + # involve a D<>H transfer + max_seq_len=common_attn_metadata.max_seq_len + max(num_draft_tokens), block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0e2e381f282d..dc10777366c9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3717,9 +3717,10 @@ def propose_draft_token_ids(sampled_token_ids): propose_draft_token_ids(sampled_token_ids) elif self.valid_sampled_token_count_event is not None: assert spec_decode_common_attn_metadata is not None + num_reqs = self.input_batch.num_reqs next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - spec_decode_common_attn_metadata, + self.seq_lens.cpu[:num_reqs], sampled_token_ids, self.requests, self.input_batch, @@ -4018,7 +4019,7 @@ def propose_draft_token_ids( ) next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - common_attn_metadata, + self.seq_lens.cpu[: self.input_batch.num_reqs], sampled_token_ids, self.requests, self.input_batch, From 52bf37615098b3bbd03a397353e866cec06f9cf5 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 14 Jan 2026 05:34:12 +0000 Subject: [PATCH 02/13] cleanup Signed-off-by: Lucas Wilkinson --- vllm/model_executor/layers/attention/cross_attention.py | 8 +++----- vllm/v1/attention/backend.py | 6 ++++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index 11ca0de5d70a..b8ab80754cf3 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -85,11 +85,9 @@ def build( new_metadata.causal = False max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max()) new_metadata.max_seq_len = max_encoder_len - # Derive num_computed_tokens from seq_lens and query_lens - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - seq_lens_cpu = common_attn_metadata.seq_lens.cpu() - num_computed_tokens_cpu = seq_lens_cpu - query_lens_cpu + num_computed_tokens_cpu = ( + common_attn_metadata.compute_num_computed_tokens_cpu() + ) # Any computed tokens indicated decode step>1 (no chunked prefill) num_cache_decodes = (num_computed_tokens_cpu > 0).sum().item() if num_cache_decodes > 0: diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 9c004d7724dd..4734c6940790 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -379,6 +379,12 @@ def compute_num_computed_tokens(self) -> torch.Tensor: self._num_computed_tokens_cache = self.seq_lens - query_lens return self._num_computed_tokens_cache + def compute_num_computed_tokens_cpu(self) -> torch.Tensor: + """Compute num_computed_tokens on CPU (seq_lens - query_lens).""" + query_lens_cpu = self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1] + seq_lens_cpu = self.seq_lens.cpu() + return seq_lens_cpu - query_lens_cpu + # TODO(lucas): remove once we have FULL-CG spec-decode support def unpadded( self, num_actual_tokens: int, num_actual_reqs: int From 213569acada827c50b647347097571e288d34dc4 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 21 Jan 2026 23:10:16 +0000 Subject: [PATCH 03/13] cleanup Signed-off-by: Lucas Wilkinson --- .../layers/attention/cross_attention.py | 13 +++++----- vllm/v1/attention/backend.py | 26 +++---------------- vllm/v1/attention/backends/mla/indexer.py | 3 ++- vllm/v1/attention/backends/utils.py | 7 +---- vllm/v1/worker/gpu_model_runner.py | 4 --- vllm/v1/worker/ubatch_utils.py | 10 +------ 6 files changed, 14 insertions(+), 49 deletions(-) diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index b8ab80754cf3..ef4f3c75ac64 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -85,18 +85,19 @@ def build( new_metadata.causal = False max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max()) new_metadata.max_seq_len = max_encoder_len - num_computed_tokens_cpu = ( - common_attn_metadata.compute_num_computed_tokens_cpu() + seq_lens_cpu = common_attn_metadata.seq_lens.cpu() + query_lens_cpu = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] ) # Any computed tokens indicated decode step>1 (no chunked prefill) - num_cache_decodes = (num_computed_tokens_cpu > 0).sum().item() - if num_cache_decodes > 0: + is_decode = seq_lens_cpu >= query_lens_cpu + if torch.any(is_decode): # CrossAttn KV cache has already been populated on first decoder step, # skip slot_mapping calculation for requests that do not need # reshape_and_cache. - num_tokens = num_computed_tokens_cpu.numpy() new_metadata.encoder_seq_lens_cpu = np.where( - num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu + is_decode, 0, new_metadata.encoder_seq_lens_cpu ) # seq_lens is provided by model runner: initial encoder input length is diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 4734c6940790..1f00b85eca51 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -8,7 +8,6 @@ import numpy as np import torch -from typing_extensions import deprecated if TYPE_CHECKING: from vllm.config import VllmConfig @@ -326,12 +325,6 @@ class CommonAttentionMetadata: dcp_local_seq_lens_cpu: torch.Tensor | None = None """Sequence lengths of the local rank in decode context parallelism world""" - # WARNING: Deprecated fields. Will be removed in a future release (v0.15.0) - _seq_lens_cpu: torch.Tensor | None = None - _num_computed_tokens_cpu: torch.Tensor | None = None - - _num_computed_tokens_cache: torch.Tensor | None = None - def batch_size(self) -> int: return self.seq_lens.shape[0] @@ -374,16 +367,9 @@ def num_computed_tokens_cpu(self) -> torch.Tensor: def compute_num_computed_tokens(self) -> torch.Tensor: """Compute num_computed_tokens on device (seq_lens - query_lens).""" - if self._num_computed_tokens_cache is None: - query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1] - self._num_computed_tokens_cache = self.seq_lens - query_lens - return self._num_computed_tokens_cache - - def compute_num_computed_tokens_cpu(self) -> torch.Tensor: - """Compute num_computed_tokens on CPU (seq_lens - query_lens).""" - query_lens_cpu = self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1] - seq_lens_cpu = self.seq_lens.cpu() - return seq_lens_cpu - query_lens_cpu + query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1] + num_computed_tokens = self.seq_lens - query_lens + return num_computed_tokens # TODO(lucas): remove once we have FULL-CG spec-decode support def unpadded( @@ -394,12 +380,6 @@ def unpadded( query_start_loc=self.query_start_loc[: num_actual_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1], seq_lens=self.seq_lens[:num_actual_reqs], - _seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs] - if self._seq_lens_cpu is not None - else None, - _num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs] - if self._num_computed_tokens_cpu is not None - else None, num_reqs=num_actual_reqs, num_actual_tokens=num_actual_tokens, max_query_len=self.max_query_len, diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 3b6db79f46da..4b48db4dad3f 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -278,7 +278,6 @@ def build( num_tokens = common_attn_metadata.num_actual_tokens query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - seq_lens_cpu = common_attn_metadata.seq_lens.cpu() num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold @@ -290,6 +289,8 @@ def build( prefill_metadata = None if num_prefills > 0: + seq_lens_cpu = common_attn_metadata.seq_lens.cpu() + chunk_seq_ids = split_prefill_chunks( seq_lens_cpu[num_decodes:], self.max_prefill_buffer_size, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 3217725a985e..42b21d732523 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -222,7 +222,7 @@ def make_local_attention_virtual_batches( block_size: int = 0, ) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]: query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy() - seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy() + seq_lens_np = common_attn_metadata.seq_lens.cpu().numpy() block_table = common_attn_metadata.block_table_tensor device = common_attn_metadata.query_start_loc.device @@ -285,7 +285,6 @@ def make_local_attention_virtual_batches( # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32) seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block - num_computed_tokens_local = seqlens_k_local - seqlens_q_local k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - ( rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks) @@ -354,8 +353,6 @@ def make_local_attention_virtual_batches( block_table_tensor=block_table_local, slot_mapping=common_attn_metadata.slot_mapping, causal=True, - _seq_lens_cpu=seq_lens_cpu, - _num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), ), make_block_table @@ -412,8 +409,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping, causal=True, - _seq_lens_cpu=common_attn_metadata._seq_lens_cpu, - _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, ) return common_attn_metadata diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index dc10777366c9..b3bfb1c14a01 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1745,10 +1745,6 @@ def _get_block_table(kv_cache_gid: int): query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], seq_lens=self.seq_lens.gpu[:num_reqs_padded], - _seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], - _num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ - :num_reqs_padded - ], num_reqs=num_reqs_padded, num_actual_tokens=num_tokens_padded, max_query_len=max_query_len, diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 7c41726472d5..147bc522e813 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -177,7 +177,6 @@ def _make_metadata_with_slice( query_start_loc[1:] -= tokens_skipped query_start_loc_cpu[1:] -= tokens_skipped seq_lens = attn_metadata.seq_lens[request_slice] - seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] if splits_last_request: # NOTE: We use start_locs (the original query_start_loc_cpu) to calculate @@ -190,12 +189,7 @@ def _make_metadata_with_slice( # Make sure we don't modify the seq_lens tensors # (not cudagraph compatible) seq_lens = seq_lens.clone() - seq_lens_cpu = seq_lens_cpu.clone() seq_lens[-1] -= tokens_skipped - seq_lens_cpu[-1] -= tokens_skipped - - max_seq_len = int(seq_lens_cpu.max()) - num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice] num_requests = request_slice.stop - request_slice.start num_actual_tokens = token_slice.stop - token_slice.start @@ -218,11 +212,9 @@ def _make_metadata_with_slice( num_reqs=num_requests, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, - max_seq_len=max_seq_len, + max_seq_len=attn_metadata.max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, - _seq_lens_cpu=seq_lens_cpu, - _num_computed_tokens_cpu=num_computed_tokens_cpu, ) From 2ce1130de09981c4d6d89623c639afbc1c9f3b34 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 23 Jan 2026 01:14:49 -0500 Subject: [PATCH 04/13] fix tests Signed-off-by: Lucas Wilkinson --- tests/v1/attention/utils.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 3cff52929146..d58e10d30dc7 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -67,15 +67,7 @@ def create_common_attn_metadata( # Create sequence lengths seq_lens = torch.tensor(batch_spec.seq_lens, dtype=torch.int32, device=device) - seq_lens_cpu = seq_lens.cpu() - max_seq_len = int(seq_lens_cpu.max()) - - # Create computed tokens (context length for each sequence) - context_lens = [ - batch_spec.seq_lens[i] - batch_spec.query_lens[i] - for i in range(batch_spec.batch_size) - ] - num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32) + max_seq_len = int(seq_lens.max().item()) # Create block table and slot mapping max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size @@ -106,8 +98,6 @@ def create_common_attn_metadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, seq_lens=seq_lens, - _seq_lens_cpu=seq_lens_cpu, - _num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=batch_spec.batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, From d3c7bdf957f4c3f8b12943564de5ab940b537c71 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 23 Jan 2026 01:16:19 -0500 Subject: [PATCH 05/13] clean Signed-off-by: Lucas Wilkinson --- vllm/v1/spec_decode/eagle.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 52970d78cb0a..524447645bf1 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -567,9 +567,6 @@ def propose( # (i.e., not the first proposal). if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None: common_attn_metadata.seq_lens -= num_rejected_tokens_gpu - # Invalidate the CPU-side shadows to avoid H<>D sync. - common_attn_metadata._seq_lens_cpu = None - common_attn_metadata._num_computed_tokens_cpu = None for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. @@ -613,13 +610,6 @@ def propose( common_attn_metadata.max_seq_len + 1, self.max_model_len ) - # Also update the CPU-side shadow; NOTE: this is hacky and should be - # removed in when common_attn_metadata.seq_lens_cpu is deprecated. - if common_attn_metadata._seq_lens_cpu is not None: - common_attn_metadata._seq_lens_cpu += 1 - if common_attn_metadata._num_computed_tokens_cpu is not None: - common_attn_metadata._num_computed_tokens_cpu += 1 - # Compute the slot mapping. block_size = attn_metadata_builder.kv_cache_spec.block_size if self.uses_mrope: From 7632fe85a90a2d76eec795312c382e1524d9e0d6 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 23 Jan 2026 01:18:46 -0500 Subject: [PATCH 06/13] fix Signed-off-by: Lucas Wilkinson --- vllm/model_executor/layers/attention/cross_attention.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index ef4f3c75ac64..d9a5f5c9673c 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -104,9 +104,6 @@ def build( # needed here to know how many tokens to attend to from the cached # cross-attention KV cache. new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens - new_metadata._seq_lens_cpu = torch.from_numpy( - common_attn_metadata.encoder_seq_lens_cpu - ) # NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here slot_mapping = _get_cross_slot_mapping( From 65a7c2695137e2438912ea0553172fe9270ba6d3 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 23 Jan 2026 17:17:43 +0000 Subject: [PATCH 07/13] fix Signed-off-by: Lucas Wilkinson --- vllm/model_executor/models/whisper_causal.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/whisper_causal.py b/vllm/model_executor/models/whisper_causal.py index c43c00840192..d4b7a536c804 100644 --- a/vllm/model_executor/models/whisper_causal.py +++ b/vllm/model_executor/models/whisper_causal.py @@ -133,8 +133,6 @@ def build( new_common_attn_metadata.query_start_loc *= block_pool_size new_common_attn_metadata.query_start_loc_cpu *= block_pool_size new_common_attn_metadata.seq_lens *= block_pool_size - new_common_attn_metadata._seq_lens_cpu *= block_pool_size - new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size new_common_attn_metadata.num_actual_tokens *= block_pool_size new_common_attn_metadata.max_query_len *= block_pool_size new_common_attn_metadata.max_seq_len *= block_pool_size From 9c56eb451471c9a2ecec905852e4b7a222d38b5e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 23 Jan 2026 18:00:16 +0000 Subject: [PATCH 08/13] fixes Signed-off-by: Lucas Wilkinson --- tests/v1/attention/test_chunked_local_attention.py | 2 +- tests/v1/spec_decode/test_tree_attention.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/v1/attention/test_chunked_local_attention.py b/tests/v1/attention/test_chunked_local_attention.py index 4529c2cfc29b..ab66c8ac6365 100644 --- a/tests/v1/attention/test_chunked_local_attention.py +++ b/tests/v1/attention/test_chunked_local_attention.py @@ -178,7 +178,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): # Convert to numpy for easier comparison actual_q_seqlens = np.diff(result.query_start_loc_cpu.numpy()) - actual_k_seqlens = result.seq_lens_cpu.numpy() + actual_k_seqlens = result.seq_lens.cpu().numpy() # Check that all query lengths are less than or equal to attn_chunk_size assert all(q_len <= attn_chunk_size for q_len in actual_q_seqlens) diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index bd7005540618..9b66d4c93a5f 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -54,14 +54,12 @@ def forward_attention( query_start_loc = q_len * torch.arange( batch_size + 1, device=q.device, dtype=torch.int32 ) - query_lens = torch.diff(query_start_loc) seq_lens = torch.full( (batch_size,), seqlen_k, device=q.device, dtype=torch.int32, ) - context_lens = seq_lens - query_lens max_seq_len = int(seq_lens.max()) max_query_len = q_len num_actual_tokens = query_start_loc[-1] @@ -89,8 +87,6 @@ def forward_attention( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc.cpu(), seq_lens=seq_lens, - _seq_lens_cpu=seq_lens.cpu(), - _num_computed_tokens_cpu=context_lens.cpu(), num_reqs=batch_size, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, From 9d685d253a4fca8058d73e4ae6565cdb085121e0 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 26 Jan 2026 12:59:29 +0000 Subject: [PATCH 09/13] fix Signed-off-by: Lucas Wilkinson --- tests/v1/e2e/test_async_spec_decode.py | 31 ++++++++++++-------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/tests/v1/e2e/test_async_spec_decode.py b/tests/v1/e2e/test_async_spec_decode.py index 4bf76da452f3..ffeeb90bc564 100644 --- a/tests/v1/e2e/test_async_spec_decode.py +++ b/tests/v1/e2e/test_async_spec_decode.py @@ -16,36 +16,33 @@ @pytest.fixture def sync_tracker(): """ - Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect - lazy init syncs. Prints stack traces immediately when syncs occur. + Fixture that patches torch.Tensor.cpu to detect GPU-CPU syncs + during speculative decoding generation. Prints stack traces + immediately when syncs occur. """ - from vllm.v1.attention.backend import CommonAttentionMetadata - # Shared counter for cross-process communication (inherited by fork) sync_count = multiprocessing.Value("i", 0) - # Save original property - original_prop = CommonAttentionMetadata.seq_lens_cpu - original_fget = original_prop.fget + original_cpu = torch.Tensor.cpu - # Create tracking wrapper - def tracking_seq_lens_cpu(self): - if self._seq_lens_cpu is None: - # Increment counter + def tracking_cpu(self, *args, **kwargs): + if self.is_cuda: with sync_count.get_lock(): sync_count.value += 1 count = sync_count.value # Print stack trace immediately (shows in subprocess output) print(f"\n{'=' * 60}", file=sys.stderr) - print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr) + print(f"SYNC #{count}: tensor.cpu() called on CUDA tensor!", + file=sys.stderr) + print(f"Shape: {self.shape}, dtype: {self.dtype}", file=sys.stderr) print(f"{'=' * 60}", file=sys.stderr) traceback.print_stack(file=sys.stderr) print(f"{'=' * 60}\n", file=sys.stderr) sys.stderr.flush() - return original_fget(self) + return original_cpu(self, *args, **kwargs) # Apply patch - CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu) + torch.Tensor.cpu = tracking_cpu class SyncTracker: @property @@ -55,14 +52,14 @@ def count(self) -> int: def assert_no_sync(self, msg: str = ""): count = sync_count.value assert count == 0, ( - f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered " + f"Unexpected GPU-CPU sync: tensor.cpu() called " f"{count} times. See stack traces above. {msg}" ) yield SyncTracker() - # Restore original property - CommonAttentionMetadata.seq_lens_cpu = original_prop + # Restore original method + torch.Tensor.cpu = original_cpu torch._dynamo.reset() From 589d5535aa1e28de1b70c86664668bb58630dbbf Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 26 Jan 2026 14:23:19 +0000 Subject: [PATCH 10/13] fix pre commit Signed-off-by: Lucas Wilkinson --- tests/v1/e2e/test_async_spec_decode.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/v1/e2e/test_async_spec_decode.py b/tests/v1/e2e/test_async_spec_decode.py index ffeeb90bc564..92a856beab04 100644 --- a/tests/v1/e2e/test_async_spec_decode.py +++ b/tests/v1/e2e/test_async_spec_decode.py @@ -32,8 +32,9 @@ def tracking_cpu(self, *args, **kwargs): count = sync_count.value # Print stack trace immediately (shows in subprocess output) print(f"\n{'=' * 60}", file=sys.stderr) - print(f"SYNC #{count}: tensor.cpu() called on CUDA tensor!", - file=sys.stderr) + print( + f"SYNC #{count}: tensor.cpu() called on CUDA tensor!", file=sys.stderr + ) print(f"Shape: {self.shape}, dtype: {self.dtype}", file=sys.stderr) print(f"{'=' * 60}", file=sys.stderr) traceback.print_stack(file=sys.stderr) From f5f10faee37a004f50fe02eac3999cabfa228e35 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 27 Jan 2026 22:20:02 +0000 Subject: [PATCH 11/13] test fix Signed-off-by: Lucas Wilkinson --- tests/v1/e2e/test_async_spec_decode.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/v1/e2e/test_async_spec_decode.py b/tests/v1/e2e/test_async_spec_decode.py index 92a856beab04..e6fe206cdee1 100644 --- a/tests/v1/e2e/test_async_spec_decode.py +++ b/tests/v1/e2e/test_async_spec_decode.py @@ -50,6 +50,11 @@ class SyncTracker: def count(self) -> int: return sync_count.value + def start_tracking(self): + """Start tracking syncs from this point. Call after model loading.""" + with sync_count.get_lock(): + sync_count.value = 0 + def assert_no_sync(self, msg: str = ""): count = sync_count.value assert count == 0, ( @@ -114,6 +119,9 @@ def test_no_sync_with_spec_decode( async_scheduling=True, ) + # Start tracking after model loading - we only care about syncs during generation + sync_tracker.start_tracking() + outputs = llm.generate( ["Hello, my name is"], SamplingParams(temperature=0, max_tokens=10), From e8546582d76be28f243a5abb9b8b120d2a4e8dbf Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 3 Feb 2026 21:49:08 -0500 Subject: [PATCH 12/13] fix tests Signed-off-by: Matthew Bonanni --- tests/v1/e2e/test_async_spec_decode.py | 77 +++++++++++++------ .../layers/attention/cross_attention.py | 2 +- 2 files changed, 54 insertions(+), 25 deletions(-) diff --git a/tests/v1/e2e/test_async_spec_decode.py b/tests/v1/e2e/test_async_spec_decode.py index e6fe206cdee1..63b836c6a30a 100644 --- a/tests/v1/e2e/test_async_spec_decode.py +++ b/tests/v1/e2e/test_async_spec_decode.py @@ -16,33 +16,59 @@ @pytest.fixture def sync_tracker(): """ - Fixture that patches torch.Tensor.cpu to detect GPU-CPU syncs - during speculative decoding generation. Prints stack traces - immediately when syncs occur. + Fixture that patches CommonAttentionMetadata.seq_lens to detect .cpu() calls. + This tracks when code accesses seq_lens and converts it to CPU, which causes + a GPU-CPU sync that breaks async scheduling. """ + from vllm.v1.attention.backend import CommonAttentionMetadata + # Shared counter for cross-process communication (inherited by fork) sync_count = multiprocessing.Value("i", 0) original_cpu = torch.Tensor.cpu - def tracking_cpu(self, *args, **kwargs): - if self.is_cuda: - with sync_count.get_lock(): - sync_count.value += 1 - count = sync_count.value - # Print stack trace immediately (shows in subprocess output) - print(f"\n{'=' * 60}", file=sys.stderr) - print( - f"SYNC #{count}: tensor.cpu() called on CUDA tensor!", file=sys.stderr - ) - print(f"Shape: {self.shape}, dtype: {self.dtype}", file=sys.stderr) - print(f"{'=' * 60}", file=sys.stderr) - traceback.print_stack(file=sys.stderr) - print(f"{'=' * 60}\n", file=sys.stderr) - sys.stderr.flush() - return original_cpu(self, *args, **kwargs) - - # Apply patch + # Create a wrapper that tracks .cpu() calls on seq_lens tensors + tracked_tensors: set = set() + + original_getattribute = CommonAttentionMetadata.__getattribute__ + + def tracking_getattribute(self, name): + value = original_getattribute(self, name) + if name == "seq_lens" and isinstance(value, torch.Tensor): + # Mark this tensor as one we want to track + tracked_tensors.add(id(value)) + return value + + # Backends that intentionally call .cpu() for their operations + ALLOWED_BACKENDS = ["flashinfer.py", "mla/indexer.py", "mla/flashmla_sparse.py"] + + def tracking_cpu(tensor_self, *args, **kwargs): + if tensor_self.is_cuda and id(tensor_self) in tracked_tensors: + # Check if this is from an allowed backend + stack = traceback.format_stack() + stack_str = "".join(stack) + is_allowed = any(backend in stack_str for backend in ALLOWED_BACKENDS) + if not is_allowed: + with sync_count.get_lock(): + sync_count.value += 1 + count = sync_count.value + print(f"\n{'=' * 60}", file=sys.stderr) + print( + f"SYNC #{count}: .cpu() called on CommonAttentionMetadata.seq_lens", + file=sys.stderr, + ) + print( + f"Shape: {tensor_self.shape}, dtype: {tensor_self.dtype}", + file=sys.stderr, + ) + print(f"{'=' * 60}", file=sys.stderr) + traceback.print_stack(file=sys.stderr) + print(f"{'=' * 60}\n", file=sys.stderr) + sys.stderr.flush() + return original_cpu(tensor_self, *args, **kwargs) + + # Apply patches + CommonAttentionMetadata.__getattribute__ = tracking_getattribute torch.Tensor.cpu = tracking_cpu class SyncTracker: @@ -54,17 +80,20 @@ def start_tracking(self): """Start tracking syncs from this point. Call after model loading.""" with sync_count.get_lock(): sync_count.value = 0 + tracked_tensors.clear() def assert_no_sync(self, msg: str = ""): count = sync_count.value assert count == 0, ( - f"Unexpected GPU-CPU sync: tensor.cpu() called " - f"{count} times. See stack traces above. {msg}" + f"Unexpected GPU-CPU sync: .cpu() called on " + f"CommonAttentionMetadata.seq_lens {count} times. " + f"See stack traces above. {msg}" ) yield SyncTracker() - # Restore original method + # Restore original methods + CommonAttentionMetadata.__getattribute__ = original_getattribute torch.Tensor.cpu = original_cpu torch._dynamo.reset() diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index d9a5f5c9673c..90075525184c 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -91,7 +91,7 @@ def build( - common_attn_metadata.query_start_loc_cpu[:-1] ) # Any computed tokens indicated decode step>1 (no chunked prefill) - is_decode = seq_lens_cpu >= query_lens_cpu + is_decode = seq_lens_cpu > query_lens_cpu if torch.any(is_decode): # CrossAttn KV cache has already been populated on first decoder step, # skip slot_mapping calculation for requests that do not need From 2f3f420e74481fcf1d13e51e2bbe35fd9301ca08 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 12 Feb 2026 05:24:53 +0000 Subject: [PATCH 13/13] cleanup, keep seqlens for now Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backend.py | 24 +++++++----------------- vllm/v1/spec_decode/eagle.py | 12 ++++++++++++ vllm/v1/worker/gpu_model_runner.py | 1 + 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 1f00b85eca51..1cad1a97de8c 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -8,6 +8,7 @@ import numpy as np import torch +from typing_extensions import deprecated if TYPE_CHECKING: from vllm.config import VllmConfig @@ -335,6 +336,12 @@ def naive_query_lens(self) -> torch.Tensor: def replace(self, **kwargs) -> "CommonAttentionMetadata": return replace(self, **kwargs) + # WARNING: Deprecated fields. Will be removed in a future release + # Keep seq_lens_cpu for now to avoid performance regressions with FlashInfer on + # sm120 machines, will remove once FA4 is performant enough on sm120. + # see: https://github.com/vllm-project/vllm/pull/33771 + _seq_lens_cpu: torch.Tensor | None = None + @property @deprecated( """ @@ -348,23 +355,6 @@ def seq_lens_cpu(self) -> torch.Tensor: self._seq_lens_cpu = self.seq_lens.to("cpu") return self._seq_lens_cpu - @property - @deprecated( - """ - Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full - async scheduling. If a CPU copy is needed, it can be derived from - query_start_loc_cpu and seq_lens. - Will be removed in a future release, please migrate as soon as possible. - """ - ) - def num_computed_tokens_cpu(self) -> torch.Tensor: - if self._num_computed_tokens_cpu is None: - query_seq_lens = ( - self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1] - ) - self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens - return self._num_computed_tokens_cpu - def compute_num_computed_tokens(self) -> torch.Tensor: """Compute num_computed_tokens on device (seq_lens - query_lens).""" query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1] diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 524447645bf1..1ce602e14d55 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -567,6 +567,8 @@ def propose( # (i.e., not the first proposal). if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None: common_attn_metadata.seq_lens -= num_rejected_tokens_gpu + # Invalidate the CPU-side shadows to avoid H<>D sync. + common_attn_metadata._seq_lens_cpu = None for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. @@ -610,6 +612,11 @@ def propose( common_attn_metadata.max_seq_len + 1, self.max_model_len ) + # Also update the CPU-side shadow; NOTE: this is hacky and should be + # removed in when common_attn_metadata.seq_lens_cpu is deprecated. + if common_attn_metadata._seq_lens_cpu is not None: + common_attn_metadata._seq_lens_cpu += 1 + # Compute the slot mapping. block_size = attn_metadata_builder.kv_cache_spec.block_size if self.uses_mrope: @@ -966,6 +973,7 @@ def prepare_inputs_padded( spec_common_attn_metadata = CommonAttentionMetadata( query_start_loc=common_attn_metadata.query_start_loc, seq_lens=common_attn_metadata.seq_lens, + _seq_lens_cpu=common_attn_metadata._seq_lens_cpu, query_start_loc_cpu=query_start_loc_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, @@ -1193,6 +1201,9 @@ def prepare_inputs( device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu new_seq_lens = common_attn_metadata.seq_lens - num_rejected_tokens.to(device) + new_seq_lens_cpu: torch.Tensor | None = None + if common_attn_metadata._seq_lens_cpu is not None: + new_seq_lens_cpu = common_attn_metadata._seq_lens_cpu - num_rejected_tokens # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] @@ -1243,6 +1254,7 @@ def prepare_inputs( spec_common_attn_metadata = CommonAttentionMetadata( query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), seq_lens=new_seq_lens, + _seq_lens_cpu=new_seq_lens_cpu, query_start_loc_cpu=new_query_start_loc_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b3bfb1c14a01..aa19b6cb605b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1745,6 +1745,7 @@ def _get_block_table(kv_cache_gid: int): query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], seq_lens=self.seq_lens.gpu[:num_reqs_padded], + _seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], num_reqs=num_reqs_padded, num_actual_tokens=num_tokens_padded, max_query_len=max_query_len,