11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ import abc
4+ from abc import abstractmethod
35from dataclasses import dataclass
6+ from typing import TYPE_CHECKING , Generic , TypeVar
47
8+ import numpy as np
59import torch
610
11+ if TYPE_CHECKING :
12+ from vllm .v1 .core .sched .output import SchedulerOutput
13+ from vllm .v1 .worker .gpu_input_batch import InputBatch
14+
715
816@dataclass
917class CommonAttentionMetadata :
@@ -19,6 +27,52 @@ class CommonAttentionMetadata:
1927 and newly scheduled tokens"""
2028
2129
30+ M = TypeVar ("M" )
31+
32+
33+ class AttentionMetadataBuilder (abc .ABC , Generic [M ]):
34+
35+ @abstractmethod
36+ def build (self , num_reqs : int , num_actual_tokens : int , max_query_len : int ,
37+ common_prefix_len : int ,
38+ common_attn_metadata : CommonAttentionMetadata ) -> M :
39+ """
40+ Central method that builds attention metadata.
41+ Some builders (MLA) require reorder_batch to be called prior to build.
42+ """
43+ raise NotImplementedError
44+
45+ def build_for_cudagraph_capture (
46+ self , num_reqs : int , num_tokens : int ,
47+ common_attn_metadata : CommonAttentionMetadata ) -> M :
48+ """
49+ Build attention metadata for CUDA graph capture. Uses build by default.
50+ Subclasses that override this method should call self.build.
51+ """
52+ return self .build (num_reqs , num_tokens , num_tokens , 0 ,
53+ common_attn_metadata )
54+
55+ def use_cascade_attention (
56+ self ,
57+ common_prefix_len : int ,
58+ query_lens : np .ndarray ,
59+ num_query_heads : int ,
60+ num_kv_heads : int ,
61+ use_alibi : bool ,
62+ use_sliding_window : bool ,
63+ num_sms : int ,
64+ ) -> bool :
65+ return False
66+
67+ def reorder_batch (self , input_batch : "InputBatch" ,
68+ scheduler_output : "SchedulerOutput" ) -> bool :
69+ """
70+ This method can reorder the batch if desired by the backend.
71+ :return: Has the batch been reordered (default False).
72+ """
73+ return False
74+
75+
2276def validate_kv_sharing_target (current_layer_name , target_layer_name ,
2377 static_forward_context ):
2478 error_msg = (f"Specified KV sharing target layer for { current_layer_name } "
0 commit comments