@@ -53,6 +53,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
5353 # The number of entries in the last page of each request in
5454 # the paged kv cache, shape: [batch_size]
5555 paged_kv_last_page_len : Optional [torch .Tensor ] = None
56+ # The query indptr, shape : [num_decode + 1]
57+ qo_indptr : Optional [torch .Tensor ] = None
5658
5759
5860class AiterMLAMetadata (MLACommonMetadata [AiterMLADecodeMetadata ]):
@@ -75,27 +77,33 @@ def _get_paged_kv_tensors(
7577 seq_lens : torch .Tensor ) -> tuple [torch .Tensor , ...]:
7678 page_size = self .runner .block_size
7779 block_table_bounds = (seq_lens + page_size - 1 ) // page_size
80+ device = self .runner .device
7881
7982 mask = (torch .arange (block_table .size (1 ),
8083 dtype = block_table .dtype ,
81- device = block_table . device ).unsqueeze (0 )
84+ device = device ).unsqueeze (0 )
8285 < block_table_bounds .unsqueeze (1 ))
8386 paged_kv_indices = block_table [mask ]
8487
8588 paged_kv_indptr = torch .cat ([
86- torch .zeros (1 ,
87- dtype = block_table_bounds .dtype ,
88- device = block_table_bounds .device ),
89+ torch .zeros (1 , dtype = block_table_bounds .dtype , device = device ),
8990 block_table_bounds .cumsum (dim = 0 , dtype = torch .int32 )
9091 ])
9192
9293 paged_kv_last_page_len = seq_lens % page_size
9394 paged_kv_last_page_len = torch .where (paged_kv_last_page_len == 0 ,
9495 page_size , paged_kv_last_page_len )
96+ qo_indptr = torch .arange (0 ,
97+ self ._num_decodes + 1 ,
98+ step = 1 ,
99+ dtype = torch .int32 ,
100+ device = device )
101+
95102 return (
96103 paged_kv_indices ,
97104 paged_kv_indptr ,
98105 paged_kv_last_page_len ,
106+ qo_indptr ,
99107 )
100108
101109 def _build_decode (self , block_table_tensor : torch .Tensor ,
@@ -105,14 +113,16 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
105113 paged_kv_indices ,
106114 paged_kv_indptr ,
107115 paged_last_page_len ,
116+ qo_indptr ,
108117 ) = self ._get_paged_kv_tensors (block_table_tensor , seq_lens )
109118
110119 attn_metadata = AiterMLADecodeMetadata (
111120 block_table = block_table_tensor ,
112121 seq_lens = seq_lens ,
113122 paged_kv_indptr = paged_kv_indptr ,
114123 paged_kv_indices = paged_kv_indices ,
115- paged_kv_last_page_len = paged_last_page_len )
124+ paged_kv_last_page_len = paged_last_page_len ,
125+ qo_indptr = qo_indptr )
116126
117127 return attn_metadata
118128
@@ -137,7 +147,10 @@ def __init__(
137147 alibi_slopes , sliding_window , kv_cache_dtype ,
138148 blocksparse_params , logits_soft_cap , attn_type ,
139149 ** mla_args )
140-
150+ assert (num_heads == 16 or num_heads == 128 ), (
151+ f"Aiter MLA only supports 16 or 128 number of heads.\n "
152+ f"Provided { num_heads } number of heads.\n "
153+ "Try adjusting tensor_parallel_size value." )
141154 unsupported_features = [
142155 alibi_slopes , sliding_window , blocksparse_params , logits_soft_cap
143156 ]
@@ -189,7 +202,18 @@ def _forward_decode(
189202
190203 kv_buffer = kv_c_and_k_pe_cache .unsqueeze (2 )
191204
205+ if self .num_heads == 16 :
206+ # AITER MLA decode kernel only supports
207+ # max_seqlen_q=1 when using 16 heads.
208+ max_seqlen_qo = 1
209+ else :
210+ # AITER MLA decode Kernel handles arbitrary
211+ # max_seqlen_q values when using 128 heads.
212+ assert attn_metadata .prefill is not None
213+ max_seqlen_qo = attn_metadata .prefill .max_query_len
214+
192215 aiter_mla_decode_fwd (q , kv_buffer , o , self .scale ,
216+ attn_metadata .decode .qo_indptr , max_seqlen_qo ,
193217 attn_metadata .decode .paged_kv_indptr ,
194218 attn_metadata .decode .paged_kv_indices ,
195219 attn_metadata .decode .paged_kv_last_page_len )
0 commit comments