Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/v1/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def create_common_attn_metadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu_upper_bound=seq_lens_cpu,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_spec.batch_size,
Expand Down
4 changes: 3 additions & 1 deletion tests/v1/spec_decode/test_tree_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,13 @@ def forward_attention(
)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
seq_lens_cpu = seq_lens.cpu()
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc.cpu(),
seq_lens=seq_lens,
_seq_lens_cpu=seq_lens.cpu(),
seq_lens_cpu_upper_bound=seq_lens_cpu,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=context_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_actual_tokens,
Expand Down
16 changes: 12 additions & 4 deletions vllm/model_executor/layers/attention/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,23 @@ def build(
assert new_metadata.encoder_seq_lens_cpu is not None
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
new_metadata.max_seq_len = max_encoder_len
# Any computed tokens indicated decode step>1 (no chunked prefill)
num_cache_decodes = (
(common_attn_metadata.num_computed_tokens_cpu > 0).sum().item()
# Any computed tokens indicates decode step>1 (no chunked prefill).
# The upper bound is exact for this `> 0` test - prefill rows have
# num_computed == 0 and decode rows have num_computed > 0.
query_lens_cpu = (
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
assert common_attn_metadata.seq_lens_cpu_upper_bound is not None
num_computed_tokens_cpu = (
common_attn_metadata.seq_lens_cpu_upper_bound - query_lens_cpu
)
num_cache_decodes = (num_computed_tokens_cpu > 0).sum().item()
if num_cache_decodes > 0:
# CrossAttn KV cache has already been populated on first decoder step,
# skip slot_mapping calculation for requests that do not need
# reshape_and_cache.
num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy()
num_tokens = num_computed_tokens_cpu.numpy()
new_metadata.encoder_seq_lens_cpu = np.where(
num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu
)
Expand Down
15 changes: 10 additions & 5 deletions vllm/model_executor/layers/attention/mla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1822,13 +1822,18 @@ def build(

prefill_metadata = None
if num_prefills > 0:
num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)

reqs_start = num_decodes # prefill_start

context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
# Upper bound is exact for prefill rows (no D2H sync).
seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound
assert seq_lens_cpu is not None
prefill_query_lens_cpu = (
query_start_loc_cpu[reqs_start + 1 : num_reqs + 1]
- query_start_loc_cpu[reqs_start:num_reqs]
)
context_lens_cpu = (
seq_lens_cpu[reqs_start:num_reqs] - prefill_query_lens_cpu
)
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
prefill_query_start_loc = (
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/attention/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,12 @@ class CommonAttentionMetadata:
(num_computed_tokens < num_prompt_tokens). Used by some backends to
distinguish actual decodes from short extends."""

seq_lens_cpu_upper_bound: torch.Tensor | None = None
"""(batch_size,) CPU upper bound on seq_lens. Precise for prefill rows
and for all rows outside async spec decode; optimistic for async-spec
decode rows (assumes every draft was accepted). Not safe for kernels
that need exact per-row context lengths on decode rows."""

# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
_seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None
Expand Down
9 changes: 5 additions & 4 deletions vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,10 +782,11 @@ def __init__(
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> FlexAttentionMetadata:
# Use actual max_seq_len instead of max_model_len to avoid
# torch.compile recompilation during CUDA graph capture.
common_attn_metadata.max_seq_len = (
common_attn_metadata.seq_lens_cpu.max().item()
# Use actual max_seq_len (not max_model_len) to avoid torch.compile
# recompilation during CUDA graph capture.
assert common_attn_metadata.seq_lens_cpu_upper_bound is not None
common_attn_metadata.max_seq_len = int(
common_attn_metadata.seq_lens_cpu_upper_bound.max().item()
)
return self.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/mla/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,10 @@ def _build_fp8_separate_prefill_decode(
# For pure decode batches, prefill_request_id will be None
# For mixed batches, it will have -1 for decode and request_id for prefill
if num_prefills > 0:
seq_lens_cpu = common_attn_metadata.seq_lens.cpu()
# Upper bound is exact for prefill rows (the `[num_decodes:]`
# slice below), so no D2H sync is needed.
seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound
assert seq_lens_cpu is not None
seq_lens = common_attn_metadata.seq_lens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

Expand Down
8 changes: 6 additions & 2 deletions vllm/v1/attention/backends/mla/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,12 @@ def build(
query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1]
)
max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
# Upper bound is exact for prefill rows (the `[num_decodes:]`
# slice below).
assert common_attn_metadata.seq_lens_cpu_upper_bound is not None
seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound
chunk_specs = split_indexer_prefill_chunks(
common_attn_metadata.seq_lens_cpu[num_decodes:],
seq_lens_cpu[num_decodes:],
prefill_query_lens_cpu,
self.max_prefill_buffer_size,
max_logits_bytes,
Expand All @@ -566,7 +570,7 @@ def build(
req_slice,
query_slice,
query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
seq_lens_cpu,
common_attn_metadata.block_table_tensor,
skip_kv_gather=query_slice.start > 0,
)
Expand Down
8 changes: 7 additions & 1 deletion vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def make_local_attention_virtual_batches(
block_table_tensor=block_table_local,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
seq_lens_cpu_upper_bound=seq_lens_cpu,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
), make_block_table
Expand Down Expand Up @@ -414,6 +415,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
seq_lens_cpu_upper_bound=common_attn_metadata.seq_lens_cpu_upper_bound,
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
)
Expand Down Expand Up @@ -445,7 +447,11 @@ def split_decodes_prefills_and_extends(
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens_cpu
# Upper bound is exact for prefill rows; decode rows still satisfy
# seq_len > query_len under the optimistic bound, so `seq_lens ==
# query_lens` identifies prefills correctly either way.
assert common_attn_metadata.seq_lens_cpu_upper_bound is not None
seq_lens = common_attn_metadata.seq_lens_cpu_upper_bound

if max_query_len <= decode_threshold:
return num_reqs, 0, 0, num_tokens, 0, 0
Expand Down
7 changes: 7 additions & 0 deletions vllm/v1/spec_decode/dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ def set_inputs_first_pass(
if has_num_rejected:
effective_seq_lens = effective_seq_lens - num_rejected_tokens_gpu

# Skip num_rejected_tokens (GPU-only); overestimating is fine here.
new_seq_lens_cpu_upper_bound = (
cad.seq_lens_cpu_upper_bound + num_query_per_req
if cad.seq_lens_cpu_upper_bound is not None
else None
)
new_cad = CommonAttentionMetadata(
query_start_loc=new_query_start_loc,
seq_lens=effective_seq_lens + num_query_per_req,
Expand All @@ -160,6 +166,7 @@ def set_inputs_first_pass(
),
_seq_lens_cpu=None,
_num_computed_tokens_cpu=None,
seq_lens_cpu_upper_bound=new_seq_lens_cpu_upper_bound,
num_reqs=cad.num_reqs,
num_actual_tokens=num_query_total,
max_query_len=num_query_per_req,
Expand Down
10 changes: 9 additions & 1 deletion vllm/v1/spec_decode/llm_base_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,8 @@ def propose(
common_attn_metadata._seq_lens_cpu += 1
if common_attn_metadata._num_computed_tokens_cpu is not None:
common_attn_metadata._num_computed_tokens_cpu += 1
if common_attn_metadata.seq_lens_cpu_upper_bound is not None:
common_attn_metadata.seq_lens_cpu_upper_bound += 1

# Rebuild attention metadata
_, per_layer_attn_metadata = self.build_per_group_and_layer_attn_metadata(
Expand Down Expand Up @@ -959,6 +961,7 @@ def prepare_inputs_padded(
query_start_loc_cpu=query_start_loc_cpu,
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
seq_lens_cpu_upper_bound=common_attn_metadata.seq_lens_cpu_upper_bound,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
Expand Down Expand Up @@ -1183,7 +1186,11 @@ def prepare_inputs(

device = common_attn_metadata.query_start_loc.device
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
# upper_bound - rejected = actual post-rejection seq_lens (no D2H sync).
assert common_attn_metadata.seq_lens_cpu_upper_bound is not None
new_seq_lens_cpu = (
common_attn_metadata.seq_lens_cpu_upper_bound - num_rejected_tokens
)

# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
Expand Down Expand Up @@ -1237,6 +1244,7 @@ def prepare_inputs(
query_start_loc_cpu=new_query_start_loc_cpu,
_seq_lens_cpu=new_seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
seq_lens_cpu_upper_bound=new_seq_lens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/worker/gpu/attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,15 @@ def build_attn_metadata(
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
seq_lens_cpu_upper_bound: torch.Tensor | None = None,
dcp_local_seq_lens: torch.Tensor | None = None,
encoder_seq_lens: dict[int, tuple[torch.Tensor, np.ndarray]] | None = None,
) -> dict[str, Any]:
seq_lens = seq_lens[:num_reqs]
if dcp_local_seq_lens is not None:
dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs]
if seq_lens_cpu_upper_bound is not None:
seq_lens_cpu_upper_bound = seq_lens_cpu_upper_bound[:num_reqs]

attn_metadata: dict[str, Any] = {}
num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
Expand All @@ -244,6 +247,7 @@ def build_attn_metadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
max_seq_len=max_seq_len,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/worker/gpu/input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class InputBatch:
query_start_loc_np: np.ndarray
# [num_reqs]
seq_lens: torch.Tensor
# [num_reqs] CPU upper bound on seq_lens (see CommonAttentionMetadata).
seq_lens_cpu_upper_bound: torch.Tensor
# [num_reqs]
dcp_local_seq_lens: torch.Tensor | None

Expand Down Expand Up @@ -121,6 +123,8 @@ def make_dummy(
logits_indices = query_start_loc[1:] - 1
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
# Dummy: seq_len == query_len (fresh-prefill shape).
seq_lens_cpu_upper_bound = torch.from_numpy(num_scheduled_tokens.copy())
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
Expand All @@ -136,6 +140,7 @@ def make_dummy(
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
dcp_local_seq_lens=None,
input_ids=input_ids,
positions=positions,
Expand Down
18 changes: 18 additions & 0 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,15 @@ def prepare_inputs(
total_num_logits,
)

# CPU upper bound on seq_lens; padded entries left at zero.
seq_lens_cpu_upper_bound_np = np.zeros(num_reqs_padded, dtype=np.int32)
np.add(
self.req_states.num_computed_tokens_np[idx_mapping_np],
num_scheduled_tokens,
out=seq_lens_cpu_upper_bound_np[:num_reqs],
)
seq_lens_cpu_upper_bound = torch.from_numpy(seq_lens_cpu_upper_bound_np)

return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
Expand All @@ -814,6 +823,7 @@ def prepare_inputs(
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
dcp_local_seq_lens=dcp_local_seq_lens,
input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
positions=self.input_buffers.positions[:num_tokens_after_padding],
Expand Down Expand Up @@ -927,6 +937,10 @@ def postprocess(
np.minimum(
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
)
# Advance the CPU mirror optimistically (assume all scheduled accepted).
self.req_states.num_computed_tokens_np[idx_mapping_np] += (
input_batch.num_scheduled_tokens
)

@torch.inference_mode()
def execute_model(
Expand Down Expand Up @@ -1297,6 +1311,10 @@ def postprocess_pool(self, input_batch: InputBatch) -> None:
np.minimum(
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
)
# Advance the CPU mirror optimistically (assume all scheduled accepted).
self.req_states.num_computed_tokens_np[idx_mapping_np] += (
input_batch.num_scheduled_tokens
)

########### EPLB methods start ###########
@property
Expand Down
9 changes: 8 additions & 1 deletion vllm/v1/worker/gpu/model_states/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ def prepare_attn(
num_tokens = input_batch.num_tokens
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
seq_lens_cpu_upper_bound = input_batch.seq_lens_cpu_upper_bound
if for_capture:
# Capture with worst-case max_seq_len so the graph is valid at any replay.
max_seq_len = self.max_model_len
else:
max_seq_len = int(seq_lens_cpu_upper_bound[:num_reqs].max().item())
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
Expand All @@ -181,10 +187,11 @@ def prepare_attn(
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
max_seq_len=max_seq_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
)
return attn_metadata
8 changes: 7 additions & 1 deletion vllm/v1/worker/gpu/model_states/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ def prepare_attn(

query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
seq_lens_cpu_upper_bound = input_batch.seq_lens_cpu_upper_bound
if for_capture:
max_seq_len = self.max_model_len
else:
max_seq_len = int(seq_lens_cpu_upper_bound[:num_reqs].max().item())
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
Expand All @@ -125,10 +130,11 @@ def prepare_attn(
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
max_seq_len=max_seq_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
encoder_seq_lens=encoder_seq_lens,
)
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/worker/gpu/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __init__(
self.num_computed_tokens = StagedWriteTensor(
self.max_num_reqs, dtype=torch.int32, device=device
)
# Optimistic CPU mirror of num_computed_tokens (upper bound on GPU value).
self.num_computed_tokens_np = np.zeros(self.max_num_reqs, dtype=np.int32)

# Last sampled tokens.
self.last_sampled_tokens = torch.zeros(
Expand Down Expand Up @@ -100,6 +102,7 @@ def add_request(
self.total_len.stage_write_elem(req_idx, prefill_len)
self.all_token_ids.stage_write(req_idx, 0, all_token_ids)
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
self.num_computed_tokens_np[req_idx] = num_computed_tokens
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)

if num_computed_tokens > 0 and num_computed_tokens <= prefill_len:
Expand Down
Loading
Loading