diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index aac4a46be3b3..1d5eba74693a 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -107,6 +107,7 @@ def create_common_attn_metadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, seq_lens=seq_lens, + seq_lens_cpu_upper_bound=seq_lens_cpu, _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=batch_spec.batch_size, diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index 1b6fa4f6f484..3c126c49f8cc 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -241,11 +241,13 @@ def forward_attention( ) kv_cache_spec = create_standard_kv_cache_spec(vllm_config) builder = builder_cls(kv_cache_spec, [], vllm_config, q.device) + seq_lens_cpu = seq_lens.cpu() common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc.cpu(), seq_lens=seq_lens, - _seq_lens_cpu=seq_lens.cpu(), + seq_lens_cpu_upper_bound=seq_lens_cpu, + _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=context_lens.cpu(), num_reqs=batch_size, num_actual_tokens=num_actual_tokens, diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index 312f906abacc..091f0a1856d4 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -90,15 +90,23 @@ def build( assert new_metadata.encoder_seq_lens_cpu is not None max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max()) new_metadata.max_seq_len = max_encoder_len - # Any computed tokens indicated decode step>1 (no chunked prefill) - num_cache_decodes = ( - (common_attn_metadata.num_computed_tokens_cpu > 0).sum().item() + # Any computed tokens indicates decode step>1 (no chunked prefill). + # The upper bound is exact for this `> 0` test - prefill rows have + # num_computed == 0 and decode rows have num_computed > 0. + query_lens_cpu = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] ) + assert common_attn_metadata.seq_lens_cpu_upper_bound is not None + num_computed_tokens_cpu = ( + common_attn_metadata.seq_lens_cpu_upper_bound - query_lens_cpu + ) + num_cache_decodes = (num_computed_tokens_cpu > 0).sum().item() if num_cache_decodes > 0: # CrossAttn KV cache has already been populated on first decoder step, # skip slot_mapping calculation for requests that do not need # reshape_and_cache. - num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy() + num_tokens = num_computed_tokens_cpu.numpy() new_metadata.encoder_seq_lens_cpu = np.where( num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu ) diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 5c7dc60fe15c..e649d790e82a 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1822,13 +1822,18 @@ def build( prefill_metadata = None if num_prefills > 0: - num_computed_tokens_cpu = ( - common_attn_metadata.compute_num_computed_tokens().cpu() - ) - reqs_start = num_decodes # prefill_start - context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + # Upper bound is exact for prefill rows (no D2H sync). + seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound + assert seq_lens_cpu is not None + prefill_query_lens_cpu = ( + query_start_loc_cpu[reqs_start + 1 : num_reqs + 1] + - query_start_loc_cpu[reqs_start:num_reqs] + ) + context_lens_cpu = ( + seq_lens_cpu[reqs_start:num_reqs] - prefill_query_lens_cpu + ) max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() prefill_query_start_loc = ( diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 7d6bba4189de..16535ee3c6c1 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -397,6 +397,12 @@ class CommonAttentionMetadata: (num_computed_tokens < num_prompt_tokens). Used by some backends to distinguish actual decodes from short extends.""" + seq_lens_cpu_upper_bound: torch.Tensor | None = None + """(batch_size,) CPU upper bound on seq_lens. Precise for prefill rows + and for all rows outside async spec decode; optimistic for async-spec + decode rows (assumes every draft was accepted). Not safe for kernels + that need exact per-row context lengths on decode rows.""" + # WARNING: Deprecated fields. Will be removed in a future release (v0.15.0) _seq_lens_cpu: torch.Tensor | None = None _num_computed_tokens_cpu: torch.Tensor | None = None diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index a027fe52441f..a917235ed8cb 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -782,10 +782,11 @@ def __init__( def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> FlexAttentionMetadata: - # Use actual max_seq_len instead of max_model_len to avoid - # torch.compile recompilation during CUDA graph capture. - common_attn_metadata.max_seq_len = ( - common_attn_metadata.seq_lens_cpu.max().item() + # Use actual max_seq_len (not max_model_len) to avoid torch.compile + # recompilation during CUDA graph capture. + assert common_attn_metadata.seq_lens_cpu_upper_bound is not None + common_attn_metadata.max_seq_len = int( + common_attn_metadata.seq_lens_cpu_upper_bound.max().item() ) return self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 1d981717cbff..e67282aab8cc 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -364,7 +364,10 @@ def _build_fp8_separate_prefill_decode( # For pure decode batches, prefill_request_id will be None # For mixed batches, it will have -1 for decode and request_id for prefill if num_prefills > 0: - seq_lens_cpu = common_attn_metadata.seq_lens.cpu() + # Upper bound is exact for prefill rows (the `[num_decodes:]` + # slice below), so no D2H sync is needed. + seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound + assert seq_lens_cpu is not None seq_lens = common_attn_metadata.seq_lens query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 3b719d10ff89..237ccfeb4729 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -554,8 +554,12 @@ def build( query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1] ) max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + # Upper bound is exact for prefill rows (the `[num_decodes:]` + # slice below). + assert common_attn_metadata.seq_lens_cpu_upper_bound is not None + seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound chunk_specs = split_indexer_prefill_chunks( - common_attn_metadata.seq_lens_cpu[num_decodes:], + seq_lens_cpu[num_decodes:], prefill_query_lens_cpu, self.max_prefill_buffer_size, max_logits_bytes, @@ -566,7 +570,7 @@ def build( req_slice, query_slice, query_start_loc_cpu, - common_attn_metadata.seq_lens_cpu, + seq_lens_cpu, common_attn_metadata.block_table_tensor, skip_kv_gather=query_slice.start > 0, ) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 0a36e6fd490a..b4bdce876d81 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -356,6 +356,7 @@ def make_local_attention_virtual_batches( block_table_tensor=block_table_local, slot_mapping=common_attn_metadata.slot_mapping, causal=True, + seq_lens_cpu_upper_bound=seq_lens_cpu, _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), ), make_block_table @@ -414,6 +415,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping, causal=True, + seq_lens_cpu_upper_bound=common_attn_metadata.seq_lens_cpu_upper_bound, _seq_lens_cpu=common_attn_metadata._seq_lens_cpu, _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, ) @@ -445,7 +447,11 @@ def split_decodes_prefills_and_extends( num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu - seq_lens = common_attn_metadata.seq_lens_cpu + # Upper bound is exact for prefill rows; decode rows still satisfy + # seq_len > query_len under the optimistic bound, so `seq_lens == + # query_lens` identifies prefills correctly either way. + assert common_attn_metadata.seq_lens_cpu_upper_bound is not None + seq_lens = common_attn_metadata.seq_lens_cpu_upper_bound if max_query_len <= decode_threshold: return num_reqs, 0, 0, num_tokens, 0, 0 diff --git a/vllm/v1/spec_decode/dflash.py b/vllm/v1/spec_decode/dflash.py index cb31a97a1312..0d9d6809680b 100644 --- a/vllm/v1/spec_decode/dflash.py +++ b/vllm/v1/spec_decode/dflash.py @@ -151,6 +151,12 @@ def set_inputs_first_pass( if has_num_rejected: effective_seq_lens = effective_seq_lens - num_rejected_tokens_gpu + # Skip num_rejected_tokens (GPU-only); overestimating is fine here. + new_seq_lens_cpu_upper_bound = ( + cad.seq_lens_cpu_upper_bound + num_query_per_req + if cad.seq_lens_cpu_upper_bound is not None + else None + ) new_cad = CommonAttentionMetadata( query_start_loc=new_query_start_loc, seq_lens=effective_seq_lens + num_query_per_req, @@ -160,6 +166,7 @@ def set_inputs_first_pass( ), _seq_lens_cpu=None, _num_computed_tokens_cpu=None, + seq_lens_cpu_upper_bound=new_seq_lens_cpu_upper_bound, num_reqs=cad.num_reqs, num_actual_tokens=num_query_total, max_query_len=num_query_per_req, diff --git a/vllm/v1/spec_decode/llm_base_proposer.py b/vllm/v1/spec_decode/llm_base_proposer.py index 1764ae8db4d0..44156b60c0da 100644 --- a/vllm/v1/spec_decode/llm_base_proposer.py +++ b/vllm/v1/spec_decode/llm_base_proposer.py @@ -593,6 +593,8 @@ def propose( common_attn_metadata._seq_lens_cpu += 1 if common_attn_metadata._num_computed_tokens_cpu is not None: common_attn_metadata._num_computed_tokens_cpu += 1 + if common_attn_metadata.seq_lens_cpu_upper_bound is not None: + common_attn_metadata.seq_lens_cpu_upper_bound += 1 # Rebuild attention metadata _, per_layer_attn_metadata = self.build_per_group_and_layer_attn_metadata( @@ -959,6 +961,7 @@ def prepare_inputs_padded( query_start_loc_cpu=query_start_loc_cpu, _seq_lens_cpu=common_attn_metadata._seq_lens_cpu, _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, + seq_lens_cpu_upper_bound=common_attn_metadata.seq_lens_cpu_upper_bound, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), @@ -1183,7 +1186,11 @@ def prepare_inputs( device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens + # upper_bound - rejected = actual post-rejection seq_lens (no D2H sync). + assert common_attn_metadata.seq_lens_cpu_upper_bound is not None + new_seq_lens_cpu = ( + common_attn_metadata.seq_lens_cpu_upper_bound - num_rejected_tokens + ) # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] @@ -1237,6 +1244,7 @@ def prepare_inputs( query_start_loc_cpu=new_query_start_loc_cpu, _seq_lens_cpu=new_seq_lens_cpu, _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, + seq_lens_cpu_upper_bound=new_seq_lens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index ee6244c42a04..354be3cd2a40 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -227,12 +227,15 @@ def build_attn_metadata( block_tables: Sequence[torch.Tensor], slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig, + seq_lens_cpu_upper_bound: torch.Tensor | None = None, dcp_local_seq_lens: torch.Tensor | None = None, encoder_seq_lens: dict[int, tuple[torch.Tensor, np.ndarray]] | None = None, ) -> dict[str, Any]: seq_lens = seq_lens[:num_reqs] if dcp_local_seq_lens is not None: dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs] + if seq_lens_cpu_upper_bound is not None: + seq_lens_cpu_upper_bound = seq_lens_cpu_upper_bound[:num_reqs] attn_metadata: dict[str, Any] = {} num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) @@ -244,6 +247,7 @@ def build_attn_metadata( query_start_loc=query_start_loc_gpu, query_start_loc_cpu=query_start_loc_cpu, seq_lens=seq_lens, + seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound, max_seq_len=max_seq_len, num_reqs=num_reqs, num_actual_tokens=num_tokens, diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 24df137cb31e..be14de272a42 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -60,6 +60,8 @@ class InputBatch: query_start_loc_np: np.ndarray # [num_reqs] seq_lens: torch.Tensor + # [num_reqs] CPU upper bound on seq_lens (see CommonAttentionMetadata). + seq_lens_cpu_upper_bound: torch.Tensor # [num_reqs] dcp_local_seq_lens: torch.Tensor | None @@ -121,6 +123,8 @@ def make_dummy( logits_indices = query_start_loc[1:] - 1 cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32) cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32) + # Dummy: seq_len == query_len (fresh-prefill shape). + seq_lens_cpu_upper_bound = torch.from_numpy(num_scheduled_tokens.copy()) return cls( req_ids=req_ids, num_reqs=num_reqs, @@ -136,6 +140,7 @@ def make_dummy( query_start_loc=query_start_loc, query_start_loc_np=query_start_loc_np, seq_lens=seq_lens, + seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound, dcp_local_seq_lens=None, input_ids=input_ids, positions=positions, diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index a0025d8c795f..820704ecff3b 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -799,6 +799,15 @@ def prepare_inputs( total_num_logits, ) + # CPU upper bound on seq_lens; padded entries left at zero. + seq_lens_cpu_upper_bound_np = np.zeros(num_reqs_padded, dtype=np.int32) + np.add( + self.req_states.num_computed_tokens_np[idx_mapping_np], + num_scheduled_tokens, + out=seq_lens_cpu_upper_bound_np[:num_reqs], + ) + seq_lens_cpu_upper_bound = torch.from_numpy(seq_lens_cpu_upper_bound_np) + return InputBatch( req_ids=req_ids, num_reqs=num_reqs, @@ -814,6 +823,7 @@ def prepare_inputs( query_start_loc=query_start_loc, query_start_loc_np=query_start_loc_np, seq_lens=seq_lens, + seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound, dcp_local_seq_lens=dcp_local_seq_lens, input_ids=self.input_buffers.input_ids[:num_tokens_after_padding], positions=self.input_buffers.positions[:num_tokens_after_padding], @@ -927,6 +937,10 @@ def postprocess( np.minimum( computed_prefill, self.req_states.prefill_len.np, out=computed_prefill ) + # Advance the CPU mirror optimistically (assume all scheduled accepted). + self.req_states.num_computed_tokens_np[idx_mapping_np] += ( + input_batch.num_scheduled_tokens + ) @torch.inference_mode() def execute_model( @@ -1297,6 +1311,10 @@ def postprocess_pool(self, input_batch: InputBatch) -> None: np.minimum( computed_prefill, self.req_states.prefill_len.np, out=computed_prefill ) + # Advance the CPU mirror optimistically (assume all scheduled accepted). + self.req_states.num_computed_tokens_np[idx_mapping_np] += ( + input_batch.num_scheduled_tokens + ) ########### EPLB methods start ########### @property diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index 8e73867deb2e..5d36b12f9c27 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -173,6 +173,12 @@ def prepare_attn( num_tokens = input_batch.num_tokens query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np) max_query_len = input_batch.num_scheduled_tokens.max().item() + seq_lens_cpu_upper_bound = input_batch.seq_lens_cpu_upper_bound + if for_capture: + # Capture with worst-case max_seq_len so the graph is valid at any replay. + max_seq_len = self.max_model_len + else: + max_seq_len = int(seq_lens_cpu_upper_bound[:num_reqs].max().item()) attn_metadata = build_attn_metadata( attn_groups=attn_groups, num_reqs=num_reqs, @@ -181,10 +187,11 @@ def prepare_attn( query_start_loc_cpu=query_start_loc_cpu, max_query_len=max_query_len, seq_lens=input_batch.seq_lens, - max_seq_len=self.max_model_len, + max_seq_len=max_seq_len, block_tables=block_tables, slot_mappings=slot_mappings, kv_cache_config=kv_cache_config, + seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound, dcp_local_seq_lens=input_batch.dcp_local_seq_lens, ) return attn_metadata diff --git a/vllm/v1/worker/gpu/model_states/whisper.py b/vllm/v1/worker/gpu/model_states/whisper.py index 1268fee88210..a6faea482c25 100644 --- a/vllm/v1/worker/gpu/model_states/whisper.py +++ b/vllm/v1/worker/gpu/model_states/whisper.py @@ -117,6 +117,11 @@ def prepare_attn( query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np) max_query_len = input_batch.num_scheduled_tokens.max().item() + seq_lens_cpu_upper_bound = input_batch.seq_lens_cpu_upper_bound + if for_capture: + max_seq_len = self.max_model_len + else: + max_seq_len = int(seq_lens_cpu_upper_bound[:num_reqs].max().item()) attn_metadata = build_attn_metadata( attn_groups=attn_groups, num_reqs=num_reqs, @@ -125,10 +130,11 @@ def prepare_attn( query_start_loc_cpu=query_start_loc_cpu, max_query_len=max_query_len, seq_lens=input_batch.seq_lens, - max_seq_len=self.max_model_len, + max_seq_len=max_seq_len, block_tables=block_tables, slot_mappings=slot_mappings, kv_cache_config=kv_cache_config, + seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound, dcp_local_seq_lens=input_batch.dcp_local_seq_lens, encoder_seq_lens=encoder_seq_lens, ) diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index cc371d32a913..b2683966b315 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -57,6 +57,8 @@ def __init__( self.num_computed_tokens = StagedWriteTensor( self.max_num_reqs, dtype=torch.int32, device=device ) + # Optimistic CPU mirror of num_computed_tokens (upper bound on GPU value). + self.num_computed_tokens_np = np.zeros(self.max_num_reqs, dtype=np.int32) # Last sampled tokens. self.last_sampled_tokens = torch.zeros( @@ -100,6 +102,7 @@ def add_request( self.total_len.stage_write_elem(req_idx, prefill_len) self.all_token_ids.stage_write(req_idx, 0, all_token_ids) self.num_computed_prefill_tokens[req_idx] = num_computed_tokens + self.num_computed_tokens_np[req_idx] = num_computed_tokens self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens) if num_computed_tokens > 0 and num_computed_tokens <= prefill_len: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 386db4fecd4b..8aca4594137a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2155,6 +2155,7 @@ def _get_block_table(kv_cache_gid: int): :num_reqs_padded ] seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs_padded] + seq_lens_cpu_upper_bound = seq_lens_cpu # is_prefilling: True if request is still in prefill phase. # Used by mamba backends to distinguish actual decodes from @@ -2172,6 +2173,7 @@ def _get_block_table(kv_cache_gid: int): seq_lens=self.seq_lens[:num_reqs_padded], _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=num_computed_tokens_cpu, + seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound, num_reqs=num_reqs_padded, num_actual_tokens=num_tokens_padded, max_query_len=max_query_len, diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 7c41726472d5..1338b46996fc 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -177,7 +177,22 @@ def _make_metadata_with_slice( query_start_loc[1:] -= tokens_skipped query_start_loc_cpu[1:] -= tokens_skipped seq_lens = attn_metadata.seq_lens[request_slice] - seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] + # Read raw fields to avoid triggering the deprecated D2H-syncing properties. + seq_lens_cpu = ( + attn_metadata._seq_lens_cpu[request_slice] + if attn_metadata._seq_lens_cpu is not None + else None + ) + seq_lens_cpu_upper_bound = ( + attn_metadata.seq_lens_cpu_upper_bound[request_slice] + if attn_metadata.seq_lens_cpu_upper_bound is not None + else None + ) + num_computed_tokens_cpu = ( + attn_metadata._num_computed_tokens_cpu[request_slice] + if attn_metadata._num_computed_tokens_cpu is not None + else None + ) if splits_last_request: # NOTE: We use start_locs (the original query_start_loc_cpu) to calculate @@ -190,12 +205,16 @@ def _make_metadata_with_slice( # Make sure we don't modify the seq_lens tensors # (not cudagraph compatible) seq_lens = seq_lens.clone() - seq_lens_cpu = seq_lens_cpu.clone() seq_lens[-1] -= tokens_skipped - seq_lens_cpu[-1] -= tokens_skipped + if seq_lens_cpu is not None: + seq_lens_cpu = seq_lens_cpu.clone() + seq_lens_cpu[-1] -= tokens_skipped + if seq_lens_cpu_upper_bound is not None: + seq_lens_cpu_upper_bound = seq_lens_cpu_upper_bound.clone() + seq_lens_cpu_upper_bound[-1] -= tokens_skipped - max_seq_len = int(seq_lens_cpu.max()) - num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice] + assert seq_lens_cpu_upper_bound is not None + max_seq_len = int(seq_lens_cpu_upper_bound.max()) num_requests = request_slice.stop - request_slice.start num_actual_tokens = token_slice.stop - token_slice.start @@ -221,6 +240,7 @@ def _make_metadata_with_slice( max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, + seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound, _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=num_computed_tokens_cpu, )