77import torch
88
99from vllm .attention .backends .abstract import AttentionBackend
10+ from vllm .attention .backends .utils import PAD_SLOT_ID
1011from vllm .config import VllmConfig
11- from vllm .v1 .attention .backends .utils import (AttentionMetadataBuilder ,
12+ from vllm .v1 .attention .backends .utils import (AttentionCGSupport ,
13+ AttentionMetadataBuilder ,
1214 CommonAttentionMetadata ,
1315 split_decodes_and_prefills )
1416from vllm .v1 .kv_cache_interface import AttentionSpec , MambaSpec
@@ -82,6 +84,8 @@ class Mamba2AttentionMetadata:
8284
8385class Mamba2AttentionMetadataBuilder (
8486 AttentionMetadataBuilder [Mamba2AttentionMetadata ]):
87+ attn_cudagraph_support : ClassVar [AttentionCGSupport ] = \
88+ AttentionCGSupport .PURE_DECODE_ONLY
8589
8690 reorder_batch_threshold : ClassVar [int ] = 1
8791
@@ -90,8 +94,18 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
9094 assert isinstance (kv_cache_spec , MambaSpec )
9195 self .kv_cache_spec = kv_cache_spec
9296 self .chunk_size = vllm_config .model_config .get_mamba_chunk_size ()
97+ self .vllm_config = vllm_config
98+ self .compilation_config = vllm_config .compilation_config
9399 assert self .chunk_size is not None , (
94100 "chunk_size needs to be set in the model config for Mamba2 models" )
101+ self .decode_cudagraph_max_bs = min (
102+ self .vllm_config .scheduler_config .max_num_seqs ,
103+ self .compilation_config .max_capture_size )
104+ self .state_indices_tensor = torch .empty (
105+ (self .decode_cudagraph_max_bs , ),
106+ dtype = torch .int32 ,
107+ device = device ,
108+ )
95109
96110 def build (self ,
97111 common_prefix_len : int ,
@@ -144,6 +158,14 @@ def build(self,
144158 query_start_loc_p , self .chunk_size ,
145159 num_prefill_tokens ))
146160
161+ elif num_decodes <= self .decode_cudagraph_max_bs :
162+ # Pad state tensor for CUDA graph
163+ num_input_tokens = self .vllm_config .pad_for_cudagraph (num_decodes )
164+ self .state_indices_tensor [:num_decodes ].copy_ (state_indices_tensor ,
165+ non_blocking = True )
166+ state_indices_tensor = self .state_indices_tensor [:num_input_tokens ]
167+ state_indices_tensor [num_decodes :] = PAD_SLOT_ID
168+
147169 attn_metadata = Mamba2AttentionMetadata (
148170 num_prefills = num_prefills ,
149171 num_prefill_tokens = num_prefill_tokens ,
@@ -160,3 +182,23 @@ def build(self,
160182 state_indices_tensor = state_indices_tensor ,
161183 )
162184 return attn_metadata
185+
186+ def build_for_cudagraph_capture (
187+ self , common_attn_metadata : CommonAttentionMetadata ):
188+ """
189+ This method builds the metadata for full cudagraph capture.
190+ Currently, only decode is supported for full cudagraphs with Mamba.
191+ """
192+ m = common_attn_metadata
193+
194+ assert m .num_reqs == m .num_actual_tokens , \
195+ "Mamba only supports decode-only full CUDAGraph capture. " \
196+ "Make sure all cudagraph capture sizes <= max_num_seq."
197+
198+ m .max_query_len = 1 # decode-only
199+
200+ return self .build (0 , m )
201+
202+ def can_run_in_cudagraph (
203+ self , common_attn_metadata : CommonAttentionMetadata ) -> bool :
204+ return common_attn_metadata .max_query_len == 1
0 commit comments