2626from vllm .attention .ops .common import cp_lse_ag_out_rs
2727from vllm .attention .ops .merge_attn_states import merge_attn_states
2828from vllm .config import CUDAGraphMode , VllmConfig
29- from vllm .distributed .parallel_state import get_dcp_group
3029from vllm .config .cache import CacheDType
30+ from vllm .distributed .parallel_state import get_dcp_group
3131from vllm .logger import init_logger
3232from vllm .model_executor .layers .batch_invariant import (
3333 vllm_is_batch_invariant ,
5151 AttentionCGSupport ,
5252 AttentionMetadataBuilder ,
5353 CommonAttentionMetadata ,
54- get_dcp_local_seq_lens ,
5554 KVCacheLayoutType ,
55+ get_dcp_local_seq_lens ,
5656 get_kv_cache_layout ,
5757 get_per_layer_parameters ,
5858 infer_global_hyperparameters ,
@@ -163,9 +163,11 @@ def trtllm_prefill_attn_kvfp8_dequant(
163163 )
164164 return mock_kv_cache , mock_block_table
165165
166+
166167@dataclass
167168class BatchDCPPrefillPlanConfig :
168169 """Parameters for BatchDCPPrefillWrapper.plan() method."""
170+
169171 qo_indptr_cpu : torch .Tensor
170172 paged_kv_indptr_cpu : torch .Tensor
171173 paged_kv_indices : torch .Tensor
@@ -203,7 +205,7 @@ def plan(self, cfg: BatchDCPPrefillPlanConfig):
203205 cfg .qo_indptr_cpu ,
204206 cfg .paged_kv_indptr_cpu ,
205207 cfg .paged_kv_indices ,
206- cfg .paged_kv_last_page_len_cpu [cfg .prefill_start :],
208+ cfg .paged_kv_last_page_len_cpu [cfg .prefill_start :],
207209 cfg .num_qo_heads * cfg .dcp_world_size ,
208210 cfg .num_kv_heads ,
209211 cfg .head_dim ,
@@ -230,15 +232,15 @@ def plan(self, cfg: BatchDCPPrefillPlanConfig):
230232 logits_soft_cap = cfg .logits_soft_cap ,
231233 q_data_type = cfg .q_data_type ,
232234 )
233-
235+
234236 def run (
235- self ,
236- layer : torch .nn .Module ,
237- prefill_query : torch .Tensor ,
238- kv_cache_permute : torch .Tensor ,
239- key : torch .Tensor ,
240- value : torch .Tensor ,
241- ):
237+ self ,
238+ layer : torch .nn .Module ,
239+ prefill_query : torch .Tensor ,
240+ kv_cache_permute : torch .Tensor ,
241+ key : torch .Tensor ,
242+ value : torch .Tensor ,
243+ ):
242244 prefill_query_across_dcp = get_dcp_group ().all_gather (
243245 prefill_query .contiguous (), dim = 1
244246 )
@@ -272,6 +274,7 @@ def run(
272274 )
273275 return output
274276
277+
275278class FlashInferBackend (AttentionBackend ):
276279 accept_output_buffer : bool = True
277280 supported_dtypes : ClassVar [list [torch .dtype ]] = [torch .float16 , torch .bfloat16 ]
@@ -379,9 +382,7 @@ class FlashInferMetadata:
379382 use_cascade : bool
380383
381384 prefill_wrapper : (
382- BatchPrefillWithPagedKVCacheWrapper
383- | BatchDCPPrefillWrapper
384- | None
385+ BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
385386 ) = None
386387 decode_wrapper : BatchDecodeWithPagedKVCacheWrapper | None = None
387388 cascade_wrapper : MultiLevelCascadeAttentionWrapper | None = None
@@ -409,13 +410,7 @@ def __init__(
409410 self .model_config = vllm_config .model_config
410411 self ._workspace_buffer = None
411412 self ._prefill_wrapper : (
412- BatchPrefillWithPagedKVCacheWrapper
413- | dict [
414- str ,
415- BatchPrefillWithPagedKVCacheWrapper
416- | BatchPrefillWithRaggedKVCacheWrapper ,
417- ]
418- | None
413+ BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
419414 ) = None # Wrapper for prefill/append
420415 self ._decode_wrapper = None # Wrapper for decode (general shape)
421416
@@ -563,13 +558,7 @@ def _get_workspace_buffer(self):
563558
564559 def _get_prefill_wrapper (
565560 self ,
566- ) -> (
567- BatchPrefillWithPagedKVCacheWrapper
568- | dict [
569- str ,
570- BatchPrefillWithPagedKVCacheWrapper | BatchPrefillWithRaggedKVCacheWrapper ,
571- ]
572- ):
561+ ) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper :
573562 if self ._prefill_wrapper is None :
574563 if self .dcp_world_size > 1 :
575564 self ._prefill_wrapper = BatchDCPPrefillWrapper (
@@ -855,7 +844,9 @@ def build(
855844
856845 if not attn_metadata .prefill_use_trtllm :
857846 if self .dcp_world_size > 1 :
858- assert isinstance (attn_metadata .prefill_wrapper , BatchDCPPrefillWrapper )
847+ assert isinstance (
848+ attn_metadata .prefill_wrapper , BatchDCPPrefillWrapper
849+ )
859850 plan_cfgs = BatchDCPPrefillPlanConfig (
860851 qo_indptr_cpu = qo_indptr_cpu ,
861852 paged_kv_indptr_cpu = paged_kv_indptr_cpu ,
@@ -877,6 +868,10 @@ def build(
877868 )
878869 attn_metadata .prefill_wrapper .plan (plan_cfgs )
879870 else :
871+ assert isinstance (
872+ attn_metadata .prefill_wrapper ,
873+ BatchPrefillWithPagedKVCacheWrapper ,
874+ )
880875 attn_metadata .prefill_wrapper .plan (
881876 qo_indptr_cpu ,
882877 paged_kv_indptr_cpu ,
@@ -1196,11 +1191,15 @@ def forward(
11961191 if self .dcp_world_size > 1 :
11971192 assert isinstance (prefill_wrapper , BatchDCPPrefillWrapper )
11981193 assert prefill_wrapper ._context ._window_left == self .window_left
1199- assert prefill_wrapper ._context ._logits_soft_cap == (self .logits_soft_cap or 0.0 )
1194+ assert prefill_wrapper ._context ._logits_soft_cap == (
1195+ self .logits_soft_cap or 0.0
1196+ )
12001197 assert prefill_wrapper ._context ._sm_scale == self .scale
12011198 assert not prefill_wrapper ._context .causal
12021199 assert prefill_wrapper ._new_tokens ._window_left == self .window_left
1203- assert prefill_wrapper ._new_tokens ._logits_soft_cap == (self .logits_soft_cap or 0.0 )
1200+ assert prefill_wrapper ._new_tokens ._logits_soft_cap == (
1201+ self .logits_soft_cap or 0.0
1202+ )
12041203 assert prefill_wrapper ._new_tokens ._sm_scale == self .scale
12051204 assert prefill_wrapper ._new_tokens .causal
12061205
@@ -1212,8 +1211,13 @@ def forward(
12121211 value [num_decode_tokens :],
12131212 )
12141213 else :
1214+ assert isinstance (
1215+ prefill_wrapper , BatchPrefillWithPagedKVCacheWrapper
1216+ )
12151217 assert prefill_wrapper ._window_left == self .window_left
1216- assert prefill_wrapper ._logits_soft_cap == (self .logits_soft_cap or 0.0 )
1218+ assert prefill_wrapper ._logits_soft_cap == (
1219+ self .logits_soft_cap or 0.0
1220+ )
12171221 assert prefill_wrapper ._sm_scale == self .scale
12181222 assert prefill_wrapper ._causal
12191223 prefill_wrapper .run (
0 commit comments