diff --git a/tests/v1/attention/test_chunked_local_attention.py b/tests/v1/attention/test_chunked_local_attention.py index 4529c2cfc29b..ab66c8ac6365 100644 --- a/tests/v1/attention/test_chunked_local_attention.py +++ b/tests/v1/attention/test_chunked_local_attention.py @@ -178,7 +178,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): # Convert to numpy for easier comparison actual_q_seqlens = np.diff(result.query_start_loc_cpu.numpy()) - actual_k_seqlens = result.seq_lens_cpu.numpy() + actual_k_seqlens = result.seq_lens.cpu().numpy() # Check that all query lengths are less than or equal to attn_chunk_size assert all(q_len <= attn_chunk_size for q_len in actual_q_seqlens) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 3cff52929146..d58e10d30dc7 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -67,15 +67,7 @@ def create_common_attn_metadata( # Create sequence lengths seq_lens = torch.tensor(batch_spec.seq_lens, dtype=torch.int32, device=device) - seq_lens_cpu = seq_lens.cpu() - max_seq_len = int(seq_lens_cpu.max()) - - # Create computed tokens (context length for each sequence) - context_lens = [ - batch_spec.seq_lens[i] - batch_spec.query_lens[i] - for i in range(batch_spec.batch_size) - ] - num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32) + max_seq_len = int(seq_lens.max().item()) # Create block table and slot mapping max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size @@ -106,8 +98,6 @@ def create_common_attn_metadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, seq_lens=seq_lens, - _seq_lens_cpu=seq_lens_cpu, - _num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=batch_spec.batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, diff --git a/tests/v1/e2e/test_async_spec_decode.py b/tests/v1/e2e/test_async_spec_decode.py index 4bf76da452f3..63b836c6a30a 100644 --- a/tests/v1/e2e/test_async_spec_decode.py +++ b/tests/v1/e2e/test_async_spec_decode.py @@ -16,53 +16,85 @@ @pytest.fixture def sync_tracker(): """ - Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect - lazy init syncs. Prints stack traces immediately when syncs occur. + Fixture that patches CommonAttentionMetadata.seq_lens to detect .cpu() calls. + This tracks when code accesses seq_lens and converts it to CPU, which causes + a GPU-CPU sync that breaks async scheduling. """ from vllm.v1.attention.backend import CommonAttentionMetadata # Shared counter for cross-process communication (inherited by fork) sync_count = multiprocessing.Value("i", 0) - # Save original property - original_prop = CommonAttentionMetadata.seq_lens_cpu - original_fget = original_prop.fget - - # Create tracking wrapper - def tracking_seq_lens_cpu(self): - if self._seq_lens_cpu is None: - # Increment counter - with sync_count.get_lock(): - sync_count.value += 1 - count = sync_count.value - # Print stack trace immediately (shows in subprocess output) - print(f"\n{'=' * 60}", file=sys.stderr) - print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr) - print(f"{'=' * 60}", file=sys.stderr) - traceback.print_stack(file=sys.stderr) - print(f"{'=' * 60}\n", file=sys.stderr) - sys.stderr.flush() - return original_fget(self) - - # Apply patch - CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu) + original_cpu = torch.Tensor.cpu + + # Create a wrapper that tracks .cpu() calls on seq_lens tensors + tracked_tensors: set = set() + + original_getattribute = CommonAttentionMetadata.__getattribute__ + + def tracking_getattribute(self, name): + value = original_getattribute(self, name) + if name == "seq_lens" and isinstance(value, torch.Tensor): + # Mark this tensor as one we want to track + tracked_tensors.add(id(value)) + return value + + # Backends that intentionally call .cpu() for their operations + ALLOWED_BACKENDS = ["flashinfer.py", "mla/indexer.py", "mla/flashmla_sparse.py"] + + def tracking_cpu(tensor_self, *args, **kwargs): + if tensor_self.is_cuda and id(tensor_self) in tracked_tensors: + # Check if this is from an allowed backend + stack = traceback.format_stack() + stack_str = "".join(stack) + is_allowed = any(backend in stack_str for backend in ALLOWED_BACKENDS) + if not is_allowed: + with sync_count.get_lock(): + sync_count.value += 1 + count = sync_count.value + print(f"\n{'=' * 60}", file=sys.stderr) + print( + f"SYNC #{count}: .cpu() called on CommonAttentionMetadata.seq_lens", + file=sys.stderr, + ) + print( + f"Shape: {tensor_self.shape}, dtype: {tensor_self.dtype}", + file=sys.stderr, + ) + print(f"{'=' * 60}", file=sys.stderr) + traceback.print_stack(file=sys.stderr) + print(f"{'=' * 60}\n", file=sys.stderr) + sys.stderr.flush() + return original_cpu(tensor_self, *args, **kwargs) + + # Apply patches + CommonAttentionMetadata.__getattribute__ = tracking_getattribute + torch.Tensor.cpu = tracking_cpu class SyncTracker: @property def count(self) -> int: return sync_count.value + def start_tracking(self): + """Start tracking syncs from this point. Call after model loading.""" + with sync_count.get_lock(): + sync_count.value = 0 + tracked_tensors.clear() + def assert_no_sync(self, msg: str = ""): count = sync_count.value assert count == 0, ( - f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered " - f"{count} times. See stack traces above. {msg}" + f"Unexpected GPU-CPU sync: .cpu() called on " + f"CommonAttentionMetadata.seq_lens {count} times. " + f"See stack traces above. {msg}" ) yield SyncTracker() - # Restore original property - CommonAttentionMetadata.seq_lens_cpu = original_prop + # Restore original methods + CommonAttentionMetadata.__getattribute__ = original_getattribute + torch.Tensor.cpu = original_cpu torch._dynamo.reset() @@ -116,6 +148,9 @@ def test_no_sync_with_spec_decode( async_scheduling=True, ) + # Start tracking after model loading - we only care about syncs during generation + sync_tracker.start_tracking() + outputs = llm.generate( ["Hello, my name is"], SamplingParams(temperature=0, max_tokens=10), diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 8b180168dffc..f3658c620559 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -166,6 +166,7 @@ def test_prepare_next_token_ids(): block_size=16, device=device, ) + seq_lens_cpu = common_attn_metadata.seq_lens.cpu() expected_valid_sampled_tokens_count = torch.tensor( [2, 5, 0, 0], dtype=torch.int32, device=device @@ -173,7 +174,7 @@ def test_prepare_next_token_ids(): next_token_ids_from_padded, valid_sampled_tokens_count = ( proposer.prepare_next_token_ids_padded( - common_attn_metadata, + seq_lens_cpu, sampled_token_ids_tensor, mock_requests, mock_input_batch, diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index bd7005540618..9b66d4c93a5f 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -54,14 +54,12 @@ def forward_attention( query_start_loc = q_len * torch.arange( batch_size + 1, device=q.device, dtype=torch.int32 ) - query_lens = torch.diff(query_start_loc) seq_lens = torch.full( (batch_size,), seqlen_k, device=q.device, dtype=torch.int32, ) - context_lens = seq_lens - query_lens max_seq_len = int(seq_lens.max()) max_query_len = q_len num_actual_tokens = query_start_loc[-1] @@ -89,8 +87,6 @@ def forward_attention( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc.cpu(), seq_lens=seq_lens, - _seq_lens_cpu=seq_lens.cpu(), - _num_computed_tokens_cpu=context_lens.cpu(), num_reqs=batch_size, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index 9333b35e65b5..90075525184c 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -85,26 +85,25 @@ def build( new_metadata.causal = False max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max()) new_metadata.max_seq_len = max_encoder_len - # Any computed tokens indicated decode step>1 (no chunked prefill) - num_cache_decodes = ( - (common_attn_metadata.num_computed_tokens_cpu > 0).sum().item() + seq_lens_cpu = common_attn_metadata.seq_lens.cpu() + query_lens_cpu = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] ) - if num_cache_decodes > 0: + # Any computed tokens indicated decode step>1 (no chunked prefill) + is_decode = seq_lens_cpu > query_lens_cpu + if torch.any(is_decode): # CrossAttn KV cache has already been populated on first decoder step, # skip slot_mapping calculation for requests that do not need # reshape_and_cache. - num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy() new_metadata.encoder_seq_lens_cpu = np.where( - num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu + is_decode, 0, new_metadata.encoder_seq_lens_cpu ) # seq_lens is provided by model runner: initial encoder input length is # needed here to know how many tokens to attend to from the cached # cross-attention KV cache. new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens - new_metadata._seq_lens_cpu = torch.from_numpy( - common_attn_metadata.encoder_seq_lens_cpu - ) # NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here slot_mapping = _get_cross_slot_mapping( diff --git a/vllm/model_executor/models/whisper_causal.py b/vllm/model_executor/models/whisper_causal.py index c43c00840192..d4b7a536c804 100644 --- a/vllm/model_executor/models/whisper_causal.py +++ b/vllm/model_executor/models/whisper_causal.py @@ -133,8 +133,6 @@ def build( new_common_attn_metadata.query_start_loc *= block_pool_size new_common_attn_metadata.query_start_loc_cpu *= block_pool_size new_common_attn_metadata.seq_lens *= block_pool_size - new_common_attn_metadata._seq_lens_cpu *= block_pool_size - new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size new_common_attn_metadata.num_actual_tokens *= block_pool_size new_common_attn_metadata.max_query_len *= block_pool_size new_common_attn_metadata.max_seq_len *= block_pool_size diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 9c004d7724dd..1cad1a97de8c 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -326,12 +326,6 @@ class CommonAttentionMetadata: dcp_local_seq_lens_cpu: torch.Tensor | None = None """Sequence lengths of the local rank in decode context parallelism world""" - # WARNING: Deprecated fields. Will be removed in a future release (v0.15.0) - _seq_lens_cpu: torch.Tensor | None = None - _num_computed_tokens_cpu: torch.Tensor | None = None - - _num_computed_tokens_cache: torch.Tensor | None = None - def batch_size(self) -> int: return self.seq_lens.shape[0] @@ -342,6 +336,12 @@ def naive_query_lens(self) -> torch.Tensor: def replace(self, **kwargs) -> "CommonAttentionMetadata": return replace(self, **kwargs) + # WARNING: Deprecated fields. Will be removed in a future release + # Keep seq_lens_cpu for now to avoid performance regressions with FlashInfer on + # sm120 machines, will remove once FA4 is performant enough on sm120. + # see: https://github.com/vllm-project/vllm/pull/33771 + _seq_lens_cpu: torch.Tensor | None = None + @property @deprecated( """ @@ -355,29 +355,11 @@ def seq_lens_cpu(self) -> torch.Tensor: self._seq_lens_cpu = self.seq_lens.to("cpu") return self._seq_lens_cpu - @property - @deprecated( - """ - Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full - async scheduling. If a CPU copy is needed, it can be derived from - query_start_loc_cpu and seq_lens. - Will be removed in a future release, please migrate as soon as possible. - """ - ) - def num_computed_tokens_cpu(self) -> torch.Tensor: - if self._num_computed_tokens_cpu is None: - query_seq_lens = ( - self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1] - ) - self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens - return self._num_computed_tokens_cpu - def compute_num_computed_tokens(self) -> torch.Tensor: """Compute num_computed_tokens on device (seq_lens - query_lens).""" - if self._num_computed_tokens_cache is None: - query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1] - self._num_computed_tokens_cache = self.seq_lens - query_lens - return self._num_computed_tokens_cache + query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1] + num_computed_tokens = self.seq_lens - query_lens + return num_computed_tokens # TODO(lucas): remove once we have FULL-CG spec-decode support def unpadded( @@ -388,12 +370,6 @@ def unpadded( query_start_loc=self.query_start_loc[: num_actual_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1], seq_lens=self.seq_lens[:num_actual_reqs], - _seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs] - if self._seq_lens_cpu is not None - else None, - _num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs] - if self._num_computed_tokens_cpu is not None - else None, num_reqs=num_actual_reqs, num_actual_tokens=num_actual_tokens, max_query_len=self.max_query_len, diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 368b217f0ba6..4b48db4dad3f 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -289,8 +289,10 @@ def build( prefill_metadata = None if num_prefills > 0: + seq_lens_cpu = common_attn_metadata.seq_lens.cpu() + chunk_seq_ids = split_prefill_chunks( - common_attn_metadata.seq_lens_cpu[num_decodes:], + seq_lens_cpu[num_decodes:], self.max_prefill_buffer_size, request_offset=num_decodes, ) @@ -299,7 +301,7 @@ def build( reqs_start, reqs_end, query_start_loc_cpu, - common_attn_metadata.seq_lens_cpu, + seq_lens_cpu, common_attn_metadata.block_table_tensor, ) for reqs_start, reqs_end in chunk_seq_ids diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index e0aa2c988a21..42b21d732523 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -222,7 +222,7 @@ def make_local_attention_virtual_batches( block_size: int = 0, ) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]: query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy() - seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy() + seq_lens_np = common_attn_metadata.seq_lens.cpu().numpy() block_table = common_attn_metadata.block_table_tensor device = common_attn_metadata.query_start_loc.device @@ -285,7 +285,6 @@ def make_local_attention_virtual_batches( # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32) seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block - num_computed_tokens_local = seqlens_k_local - seqlens_q_local k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - ( rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks) @@ -354,8 +353,6 @@ def make_local_attention_virtual_batches( block_table_tensor=block_table_local, slot_mapping=common_attn_metadata.slot_mapping, causal=True, - _seq_lens_cpu=seq_lens_cpu, - _num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), ), make_block_table @@ -412,8 +409,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping, causal=True, - _seq_lens_cpu=common_attn_metadata._seq_lens_cpu, - _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, ) return common_attn_metadata @@ -443,7 +438,7 @@ def split_decodes_prefills_and_extends( num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu - seq_lens = common_attn_metadata.seq_lens_cpu + seq_lens = common_attn_metadata.seq_lens.cpu() if max_query_len <= decode_threshold: return num_reqs, 0, 0, num_tokens, 0, 0 diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index b5532d652618..1ce602e14d55 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -569,7 +569,6 @@ def propose( common_attn_metadata.seq_lens -= num_rejected_tokens_gpu # Invalidate the CPU-side shadows to avoid H<>D sync. common_attn_metadata._seq_lens_cpu = None - common_attn_metadata._num_computed_tokens_cpu = None for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. @@ -617,8 +616,6 @@ def propose( # removed in when common_attn_metadata.seq_lens_cpu is deprecated. if common_attn_metadata._seq_lens_cpu is not None: common_attn_metadata._seq_lens_cpu += 1 - if common_attn_metadata._num_computed_tokens_cpu is not None: - common_attn_metadata._num_computed_tokens_cpu += 1 # Compute the slot mapping. block_size = attn_metadata_builder.kv_cache_spec.block_size @@ -878,7 +875,7 @@ def prepare_next_token_ids_cpu( def prepare_next_token_ids_padded( self, - common_attn_metadata: CommonAttentionMetadata, + seq_lens_cpu: torch.Tensor, sampled_token_ids: torch.Tensor, requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, @@ -896,7 +893,7 @@ def prepare_next_token_ids_padded( self.backup_next_token_ids.np[:num_reqs] = np.array( [ requests[gpu_input_batch.req_ids[i]].get_token_id( - common_attn_metadata.seq_lens_cpu[i].item() + seq_lens_cpu[i].item() ) for i in range(num_reqs) ], @@ -976,13 +973,12 @@ def prepare_inputs_padded( spec_common_attn_metadata = CommonAttentionMetadata( query_start_loc=common_attn_metadata.query_start_loc, seq_lens=common_attn_metadata.seq_lens, - query_start_loc_cpu=query_start_loc_cpu, _seq_lens_cpu=common_attn_metadata._seq_lens_cpu, - _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, + query_start_loc_cpu=query_start_loc_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), - max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), + max_seq_len=common_attn_metadata.max_seq_len, block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens], causal=True, @@ -1196,15 +1192,18 @@ def prepare_inputs( # q1, q1 + 1, ..., q1 + q2 - n2 - 1, # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] - num_rejected_tokens = [ + num_rejected_tokens_list = [ n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] - num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32) + num_rejected_tokens = torch.tensor(num_rejected_tokens_list, dtype=torch.int32) device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens + new_seq_lens = common_attn_metadata.seq_lens - num_rejected_tokens.to(device) + new_seq_lens_cpu: torch.Tensor | None = None + if common_attn_metadata._seq_lens_cpu is not None: + new_seq_lens_cpu = common_attn_metadata._seq_lens_cpu - num_rejected_tokens # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] @@ -1254,14 +1253,15 @@ def prepare_inputs( spec_common_attn_metadata = CommonAttentionMetadata( query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), - seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), - query_start_loc_cpu=new_query_start_loc_cpu, + seq_lens=new_seq_lens, _seq_lens_cpu=new_seq_lens_cpu, - _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, + query_start_loc_cpu=new_query_start_loc_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), - max_seq_len=new_seq_lens_cpu.max().item(), + # max_seq_len is just an upper bound; so use a rough estimate that doesn't + # involve a D<>H transfer + max_seq_len=common_attn_metadata.max_seq_len + max(num_draft_tokens), block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0e2e381f282d..aa19b6cb605b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1746,9 +1746,6 @@ def _get_block_table(kv_cache_gid: int): query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], seq_lens=self.seq_lens.gpu[:num_reqs_padded], _seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], - _num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ - :num_reqs_padded - ], num_reqs=num_reqs_padded, num_actual_tokens=num_tokens_padded, max_query_len=max_query_len, @@ -3717,9 +3714,10 @@ def propose_draft_token_ids(sampled_token_ids): propose_draft_token_ids(sampled_token_ids) elif self.valid_sampled_token_count_event is not None: assert spec_decode_common_attn_metadata is not None + num_reqs = self.input_batch.num_reqs next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - spec_decode_common_attn_metadata, + self.seq_lens.cpu[:num_reqs], sampled_token_ids, self.requests, self.input_batch, @@ -4018,7 +4016,7 @@ def propose_draft_token_ids( ) next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - common_attn_metadata, + self.seq_lens.cpu[: self.input_batch.num_reqs], sampled_token_ids, self.requests, self.input_batch, diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 7c41726472d5..147bc522e813 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -177,7 +177,6 @@ def _make_metadata_with_slice( query_start_loc[1:] -= tokens_skipped query_start_loc_cpu[1:] -= tokens_skipped seq_lens = attn_metadata.seq_lens[request_slice] - seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] if splits_last_request: # NOTE: We use start_locs (the original query_start_loc_cpu) to calculate @@ -190,12 +189,7 @@ def _make_metadata_with_slice( # Make sure we don't modify the seq_lens tensors # (not cudagraph compatible) seq_lens = seq_lens.clone() - seq_lens_cpu = seq_lens_cpu.clone() seq_lens[-1] -= tokens_skipped - seq_lens_cpu[-1] -= tokens_skipped - - max_seq_len = int(seq_lens_cpu.max()) - num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice] num_requests = request_slice.stop - request_slice.start num_actual_tokens = token_slice.stop - token_slice.start @@ -218,11 +212,9 @@ def _make_metadata_with_slice( num_reqs=num_requests, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, - max_seq_len=max_seq_len, + max_seq_len=attn_metadata.max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, - _seq_lens_cpu=seq_lens_cpu, - _num_computed_tokens_cpu=num_computed_tokens_cpu, )