Skip to content

Commit 2fd221d

Browse files
committed
Refactor common attention metadata
Signed-off-by: luka <[email protected]>
1 parent a63e6d6 commit 2fd221d

File tree

8 files changed

+67
-37
lines changed

8 files changed

+67
-37
lines changed

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,12 @@ def reorder_batch(self, input_batch: InputBatch,
119119

120120
return True
121121

122-
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
123-
common_prefix_len: int,
122+
def build(self, common_prefix_len: int,
124123
common_attn_metadata: CommonAttentionMetadata):
124+
num_reqs = common_attn_metadata.num_reqs
125+
num_actual_tokens = common_attn_metadata.num_actual_tokens
126+
max_query_len = common_attn_metadata.max_query_len
127+
125128
runner = self.runner
126129
block_table = self.block_table
127130
seq_lens_np = runner.seq_lens_np[:num_reqs]

vllm/v1/attention/backends/flash_attn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,13 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
341341
self.aot_sliding_window: Optional[tuple[int, int]] = None
342342

343343
def build(
344-
self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
345-
common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata
344+
self, common_prefix_len: int,
345+
common_attn_metadata: CommonAttentionMetadata
346346
) -> FlashAttentionMetadata:
347+
num_reqs = common_attn_metadata.num_reqs
348+
num_actual_tokens = common_attn_metadata.num_actual_tokens
349+
max_query_len = common_attn_metadata.max_query_len
350+
347351
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
348352
query_start_loc = common_attn_metadata.query_start_loc
349353
seq_lens = common_attn_metadata.seq_lens

vllm/v1/attention/backends/flashinfer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,9 +400,11 @@ def _plan(self, attn_metadata: FlashInferMetadata):
400400
kv_data_type=attn_metadata.data_type,
401401
)
402402

403-
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
404-
common_prefix_len: int,
403+
def build(self, common_prefix_len: int,
405404
common_attn_metadata: CommonAttentionMetadata):
405+
num_reqs = common_attn_metadata.num_reqs
406+
num_actual_tokens = common_attn_metadata.num_actual_tokens
407+
406408
assert self._num_decodes + self._num_prefills == num_reqs
407409
assert (self._num_decode_tokens +
408410
self._num_prefill_tokens == num_actual_tokens)

vllm/v1/attention/backends/flex_attention.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,12 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
272272
self.kv_cache_spec = kv_cache_spec
273273
self.block_table = block_table
274274

275-
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
276-
common_prefix_len: int,
275+
def build(self, common_prefix_len: int,
277276
common_attn_metadata: CommonAttentionMetadata):
277+
num_reqs = common_attn_metadata.num_reqs
278+
num_actual_tokens = common_attn_metadata.num_actual_tokens
279+
max_query_len = common_attn_metadata.max_query_len
280+
278281
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
279282
query_start_loc = common_attn_metadata.query_start_loc
280283
seq_lens = common_attn_metadata.seq_lens

vllm/v1/attention/backends/mla/common.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -459,26 +459,31 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
459459
)
460460

461461
def build_for_cudagraph_capture(
462-
self, num_reqs: int, num_tokens: int,
463-
common_attn_metadata: CommonAttentionMetadata) -> M:
462+
self, common_attn_metadata: CommonAttentionMetadata) -> M:
464463
"""
465464
This method builds the metadata for full cudagraph capture.
466465
Currently, only decode is supported for full cudagraphs with MLA.
467466
"""
468-
assert num_reqs == num_tokens, \
467+
m = common_attn_metadata
468+
assert m.num_reqs == m.num_actual_tokens, \
469469
"MLA only supports decode-only full CUDAGraph capture. " \
470470
"Make sure all cudagraph capture sizes <= max_num_seq."
471471

472+
m.max_query_len = 1 # decode-only
473+
472474
# Update state usually set in reorder_batch.
473-
self._num_decodes = num_tokens
474-
self._num_decode_tokens = num_tokens
475+
self._num_decodes = m.num_reqs
476+
self._num_decode_tokens = m.num_actual_tokens
475477
self._num_prefills = 0
476478
self._num_prefill_tokens = 0
477-
return self.build(num_tokens, num_tokens, 1, 0, common_attn_metadata)
479+
return self.build(0, m)
478480

479-
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
480-
common_prefix_len: int,
481+
def build(self, common_prefix_len: int,
481482
common_attn_metadata: CommonAttentionMetadata) -> M:
483+
num_reqs = common_attn_metadata.num_reqs
484+
num_actual_tokens = common_attn_metadata.num_actual_tokens
485+
max_query_len = common_attn_metadata.max_query_len
486+
482487
assert self._num_decodes + self._num_prefills == num_reqs
483488

484489
# Note(simon): be careful about the CPU <> GPU memory movement in this

vllm/v1/attention/backends/utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,21 @@ class CommonAttentionMetadata:
2626
"""(batch_size,), the length of each request including both computed tokens
2727
and newly scheduled tokens"""
2828

29+
num_reqs: int
30+
"""Number of requests"""
31+
num_actual_tokens: int
32+
"""Total number of tokens in batch"""
33+
max_query_len: int
34+
"""Longest query in batch"""
35+
2936

3037
M = TypeVar("M")
3138

3239

3340
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
3441

3542
@abstractmethod
36-
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
37-
common_prefix_len: int,
43+
def build(self, common_prefix_len: int,
3844
common_attn_metadata: CommonAttentionMetadata) -> M:
3945
"""
4046
Central method that builds attention metadata.
@@ -43,14 +49,14 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
4349
raise NotImplementedError
4450

4551
def build_for_cudagraph_capture(
46-
self, num_reqs: int, num_tokens: int,
47-
common_attn_metadata: CommonAttentionMetadata) -> M:
52+
self, common_attn_metadata: CommonAttentionMetadata) -> M:
4853
"""
4954
Build attention metadata for CUDA graph capture. Uses build by default.
50-
Subclasses that override this method should call self.build.
55+
Subclasses that override this method should call self.build or
56+
super().build_for_cudagraph_capture.
5157
"""
52-
return self.build(num_reqs, num_tokens, num_tokens, 0,
53-
common_attn_metadata)
58+
return self.build(common_prefix_len=0,
59+
common_attn_metadata=common_attn_metadata)
5460

5561
def use_cascade_attention(
5662
self,

vllm/v1/spec_decode/eagle.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,15 +138,17 @@ def propose(
138138
max_query_len = query_lens.max().item()
139139

140140
common_attn_metadata = CommonAttentionMetadata(
141-
query_start_loc=cu_num_tokens, seq_lens=seq_lens)
141+
query_start_loc=cu_num_tokens,
142+
seq_lens=seq_lens,
143+
num_reqs=batch_size,
144+
num_actual_tokens=num_tokens,
145+
max_query_len=max_query_len,
146+
)
142147

143148
assert self.runner is not None
144149

145150
# FIXME: need to consider multiple kv_cache_groups
146151
attn_metadata = self.runner.attn_metadata_builder.build(
147-
num_reqs=batch_size,
148-
num_actual_tokens=num_tokens,
149-
max_query_len=max_query_len,
150152
common_prefix_len=0,
151153
common_attn_metadata=common_attn_metadata,
152154
)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,12 @@ def _prepare_inputs(
669669
seq_lens = self.seq_lens[:num_reqs]
670670

671671
common_attn_metadata = CommonAttentionMetadata(
672-
query_start_loc=query_start_loc, seq_lens=seq_lens)
672+
query_start_loc=query_start_loc,
673+
seq_lens=seq_lens,
674+
num_reqs=num_reqs,
675+
num_actual_tokens=total_num_scheduled_tokens,
676+
max_query_len=max_num_scheduled_tokens,
677+
)
673678

674679
attn_metadata: dict[str, Any] = {}
675680
# Prepare the attention metadata for each KV cache group and make layers
@@ -690,11 +695,9 @@ def _prepare_inputs(
690695

691696
attn_metadata_i = (
692697
self.attn_metadata_builders[kv_cache_group_id].build(
693-
num_reqs=num_reqs,
694-
num_actual_tokens=total_num_scheduled_tokens,
695-
max_query_len=max_num_scheduled_tokens,
696698
common_prefix_len=common_prefix_len,
697-
common_attn_metadata=common_attn_metadata))
699+
common_attn_metadata=common_attn_metadata,
700+
))
698701
for layer_name in kv_cache_group_spec.layer_names:
699702
attn_metadata[layer_name] = attn_metadata_i
700703

@@ -1809,18 +1812,20 @@ def _dummy_run(
18091812
seq_lens = self.seq_lens[:num_reqs]
18101813

18111814
common_attn_metadata = CommonAttentionMetadata(
1812-
query_start_loc=query_start_loc, seq_lens=seq_lens)
1815+
query_start_loc=query_start_loc,
1816+
seq_lens=seq_lens,
1817+
num_reqs=num_reqs,
1818+
num_actual_tokens=num_tokens,
1819+
max_query_len=num_tokens,
1820+
)
18131821

18141822
attn_metadata = {}
18151823
for kv_cache_group_id, kv_cache_group_spec in enumerate(
18161824
self.kv_cache_config.kv_cache_groups):
18171825

18181826
attn_metadata_i = self.attn_metadata_builders[
18191827
kv_cache_group_id].build_for_cudagraph_capture(
1820-
num_reqs=num_reqs,
1821-
num_tokens=num_tokens,
1822-
common_attn_metadata=common_attn_metadata,
1823-
)
1828+
common_attn_metadata)
18241829
for layer_name in kv_cache_group_spec.layer_names:
18251830
attn_metadata[layer_name] = attn_metadata_i
18261831

0 commit comments

Comments
 (0)