Skip to content

Commit a63e6d6

Browse files
committed
Superclass for attention metadata builder.
Signed-off-by: luka <[email protected]>
1 parent 682c6b6 commit a63e6d6

File tree

7 files changed

+76
-38
lines changed

7 files changed

+76
-38
lines changed

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
TorchSDPAMetadata)
88
from vllm.attention.backends.utils import CommonAttentionState
99
from vllm.attention.ops.ipex_attn import PagedAttention
10-
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
10+
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
11+
CommonAttentionMetadata)
1112
from vllm.v1.core.sched.output import SchedulerOutput
1213
from vllm.v1.kv_cache_interface import AttentionSpec
1314
from vllm.v1.worker.block_table import BlockTable
@@ -53,7 +54,7 @@ def use_cascade_attention(*args, **kwargs) -> bool:
5354
return False
5455

5556

56-
class TorchSDPAMetadataBuilderV1:
57+
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
5758

5859
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
5960
block_table: BlockTable) -> None:

vllm/v1/attention/backends/flash_attn.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@
2121
from vllm.logger import init_logger
2222
from vllm.platforms import current_platform
2323
from vllm.utils import cdiv
24-
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
24+
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
25+
CommonAttentionMetadata)
2526
from vllm.v1.kv_cache_interface import AttentionSpec
2627
from vllm.v1.worker.block_table import BlockTable
2728

2829
if TYPE_CHECKING:
29-
from vllm.v1.core.sched.output import SchedulerOutput
30-
from vllm.v1.worker.gpu_input_batch import InputBatch
3130
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
3231

3332
if current_platform.is_cuda():
@@ -310,7 +309,8 @@ def _get_sliding_window_configs(
310309
return sliding_window_configs
311310

312311

313-
class FlashAttentionMetadataBuilder:
312+
class FlashAttentionMetadataBuilder(
313+
AttentionMetadataBuilder[FlashAttentionMetadata]):
314314

315315
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
316316
block_table: BlockTable):
@@ -340,17 +340,6 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
340340
# populated on first build() call.
341341
self.aot_sliding_window: Optional[tuple[int, int]] = None
342342

343-
def reorder_batch(self, input_batch: "InputBatch",
344-
scheduler_output: "SchedulerOutput") -> bool:
345-
return False
346-
347-
def build_for_cudagraph_capture(
348-
self, num_reqs: int, num_tokens: int,
349-
common_attn_metadata: CommonAttentionMetadata
350-
) -> FlashAttentionMetadata:
351-
return self.build(num_reqs, num_tokens, num_tokens, 0,
352-
common_attn_metadata)
353-
354343
def build(
355344
self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
356345
common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata

vllm/v1/attention/backends/flashinfer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from vllm.config import VllmConfig, get_layers_from_vllm_config
1919
from vllm.logger import init_logger
2020
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
21-
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
21+
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
22+
CommonAttentionMetadata)
2223
from vllm.v1.kv_cache_interface import AttentionSpec
2324
from vllm.v1.worker.block_table import BlockTable
2425

@@ -202,7 +203,7 @@ def __post_init__(self):
202203
f" received {self.head_dim}.")
203204

204205

205-
class FlashInferMetadataBuilder:
206+
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
206207

207208
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
208209
block_table: BlockTable):

vllm/v1/attention/backends/flex_attention.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
is_quantized_kv_cache)
1616
from vllm.logger import init_logger
1717
from vllm.platforms import current_platform
18-
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
18+
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
19+
CommonAttentionMetadata)
1920
from vllm.v1.kv_cache_interface import AttentionSpec
2021
from vllm.v1.worker.block_table import BlockTable
2122

@@ -25,8 +26,6 @@
2526
logger = init_logger(__name__)
2627

2728
if TYPE_CHECKING:
28-
from vllm.v1.core.sched.output import SchedulerOutput
29-
from vllm.v1.worker.gpu_input_batch import InputBatch
3029
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
3130

3231
create_block_mask_compiled = torch.compile(create_block_mask,
@@ -256,7 +255,8 @@ def __post_init__(self):
256255
self.block_mask = self.build_block_mask()
257256

258257

259-
class FlexAttentionMetadataBuilder:
258+
class FlexAttentionMetadataBuilder(
259+
AttentionMetadataBuilder[FlexAttentionMetadata]):
260260

261261
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
262262
block_table: BlockTable):
@@ -272,10 +272,6 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
272272
self.kv_cache_spec = kv_cache_spec
273273
self.block_table = block_table
274274

275-
def reorder_batch(self, input_batch: "InputBatch",
276-
scheduler_output: "SchedulerOutput") -> bool:
277-
return False
278-
279275
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
280276
common_prefix_len: int,
281277
common_attn_metadata: CommonAttentionMetadata):
@@ -332,9 +328,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
332328
)
333329
return out
334330

335-
def use_cascade_attention(self, *args, **kwargs) -> bool:
336-
return False
337-
338331

339332
class FlexAttentionImpl(AttentionImpl):
340333
sliding_window: Optional[tuple[int, int]]

vllm/v1/attention/backends/mla/common.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@
207207
UnquantizedLinearMethod)
208208
from vllm.platforms import current_platform
209209
from vllm.utils import cdiv, round_down
210-
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
210+
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
211+
CommonAttentionMetadata)
211212
from vllm.v1.kv_cache_interface import AttentionSpec
212213
from vllm.v1.worker.block_table import BlockTable
213214

@@ -336,7 +337,7 @@ def cuda_graph_supported(self):
336337
M = TypeVar("M", bound=MLACommonMetadata)
337338

338339

339-
class MLACommonMetadataBuilder(Generic[M]):
340+
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
340341
"""
341342
NOTE: Please read the comment at the top of the file before trying to
342343
understand this class
@@ -467,6 +468,8 @@ def build_for_cudagraph_capture(
467468
assert num_reqs == num_tokens, \
468469
"MLA only supports decode-only full CUDAGraph capture. " \
469470
"Make sure all cudagraph capture sizes <= max_num_seq."
471+
472+
# Update state usually set in reorder_batch.
470473
self._num_decodes = num_tokens
471474
self._num_decode_tokens = num_tokens
472475
self._num_prefills = 0
@@ -590,9 +593,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
590593
decode=decode_metadata,
591594
)
592595

593-
def use_cascade_attention(self, *args, **kwargs) -> bool:
594-
return False
595-
596596

597597
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
598598
"""

vllm/v1/attention/backends/utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import abc
4+
from abc import abstractmethod
35
from dataclasses import dataclass
6+
from typing import TYPE_CHECKING, Generic, TypeVar
47

8+
import numpy as np
59
import torch
610

11+
if TYPE_CHECKING:
12+
from vllm.v1.core.sched.output import SchedulerOutput
13+
from vllm.v1.worker.gpu_input_batch import InputBatch
14+
715

816
@dataclass
917
class CommonAttentionMetadata:
@@ -19,6 +27,52 @@ class CommonAttentionMetadata:
1927
and newly scheduled tokens"""
2028

2129

30+
M = TypeVar("M")
31+
32+
33+
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
34+
35+
@abstractmethod
36+
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
37+
common_prefix_len: int,
38+
common_attn_metadata: CommonAttentionMetadata) -> M:
39+
"""
40+
Central method that builds attention metadata.
41+
Some builders (MLA) require reorder_batch to be called prior to build.
42+
"""
43+
raise NotImplementedError
44+
45+
def build_for_cudagraph_capture(
46+
self, num_reqs: int, num_tokens: int,
47+
common_attn_metadata: CommonAttentionMetadata) -> M:
48+
"""
49+
Build attention metadata for CUDA graph capture. Uses build by default.
50+
Subclasses that override this method should call self.build.
51+
"""
52+
return self.build(num_reqs, num_tokens, num_tokens, 0,
53+
common_attn_metadata)
54+
55+
def use_cascade_attention(
56+
self,
57+
common_prefix_len: int,
58+
query_lens: np.ndarray,
59+
num_query_heads: int,
60+
num_kv_heads: int,
61+
use_alibi: bool,
62+
use_sliding_window: bool,
63+
num_sms: int,
64+
) -> bool:
65+
return False
66+
67+
def reorder_batch(self, input_batch: "InputBatch",
68+
scheduler_output: "SchedulerOutput") -> bool:
69+
"""
70+
This method can reorder the batch if desired by the backend.
71+
:return: Has the batch been reordered (default False).
72+
"""
73+
return False
74+
75+
2276
def validate_kv_sharing_target(current_layer_name, target_layer_name,
2377
static_forward_context):
2478
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "

vllm/v1/worker/gpu_model_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616

1717
import vllm.envs as envs
1818
from vllm.attention import AttentionType, get_attn_backend
19-
from vllm.attention.backends.abstract import (AttentionBackend,
20-
AttentionMetadataBuilder)
19+
from vllm.attention.backends.abstract import AttentionBackend
2120
from vllm.attention.layer import Attention
2221
from vllm.attention.utils.fa_utils import get_flash_attn_version
2322
from vllm.config import (CompilationLevel, VllmConfig,
@@ -41,7 +40,8 @@
4140
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
4241
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
4342
check_use_alibi, is_pin_memory_available)
44-
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
43+
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
44+
CommonAttentionMetadata)
4545
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
4646
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
4747
KVCacheConfig, KVCacheSpec,

0 commit comments

Comments
 (0)