@@ -53,7 +53,7 @@ def get_state_cls() -> Type["AiterMLAState"]:
5353
5454@dataclass
5555class AiterMLAMetadata (MLACommonMetadata ):
56- # The following 4 tensors are for current version of AITER MLA
56+ # The following 5 tensors are for current version of AITER MLA
5757 block_table_bound : Optional [torch .Tensor ] = None
5858 # The indptr of the paged kv cache, shape: [batch_size + 1]
5959 paged_kv_indptr : Optional [torch .Tensor ] = None
@@ -63,6 +63,10 @@ class AiterMLAMetadata(MLACommonMetadata):
6363 # the paged kv cache, shape: [batch_size]
6464 paged_kv_last_page_lens : Optional [torch .Tensor ] = None
6565
66+ # This is just to make new AITER MLA API work
67+ # -- MTP support is not added yet.
68+ qo_indptr : Optional [torch .Tensor ] = None
69+
6670 @property
6771 def prefill_metadata (self ):
6872 prefill_metadata = super ().prefill_metadata
@@ -74,6 +78,7 @@ def prefill_metadata(self):
7478 prefill_metadata \
7579 .paged_kv_last_page_lens = self .paged_kv_last_page_lens
7680 prefill_metadata .block_table_bound = self .block_table_bound
81+ prefill_metadata .qo_indptr = self .qo_indptr
7782
7883 # update the cache
7984 self ._cached_prefill_metadata = self .__class__ (
@@ -93,6 +98,7 @@ def decode_metadata(self):
9398 decode_metadata \
9499 .paged_kv_last_page_lens = self .paged_kv_last_page_lens
95100 decode_metadata .block_table_bound = self .block_table_bound
101+ decode_metadata .qo_indptr = self .qo_indptr
96102
97103 # update the cache
98104 self ._cached_decode_metadata = self .__class__ (
@@ -136,6 +142,7 @@ def prepare(self):
136142 self .paged_kv_indptr : list [int ] = [0 ]
137143 self .paged_kv_last_page_lens : list [int ] = []
138144 self .total_blocks = 0
145+ self .qo_indptr : list [int ] = [0 ]
139146
140147 def _add_seq_group (self , inter_data , chunked_prefill_enabled : bool ,
141148 prefix_cache_hit : bool ):
@@ -208,6 +215,7 @@ def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
208215 self .paged_kv_indices .extend (block_table [:block_table_bound ])
209216 self .paged_kv_indptr .append (self .paged_kv_indptr [- 1 ] +
210217 block_table_bound )
218+ self .qo_indptr .append (self .qo_indptr [- 1 ] + 1 )
211219
212220 last_page_len = seq_len % self .block_size
213221 if last_page_len == 0 :
@@ -226,6 +234,8 @@ def build(self, seq_lens: list[int], query_lens: list[int],
226234 self .paged_kv_indptr .extend ([last_paged_kv_indptr ] *
227235 cuda_graph_pad_size )
228236 self .paged_kv_last_page_lens .extend ([0 ] * cuda_graph_pad_size )
237+ last_qo_indptr = self .qo_indptr [- 1 ]
238+ self .qo_indptr .extend ([last_qo_indptr ] * cuda_graph_pad_size )
229239
230240 # For current version of AITER MLA
231241 if len (self .paged_kv_indptr ) > 0 :
@@ -245,16 +255,22 @@ def build(self, seq_lens: list[int], query_lens: list[int],
245255 1 ,
246256 device = device ,
247257 dtype = torch .int )
258+
259+ qo_indptr = torch .tensor (self .qo_indptr ,
260+ device = device ,
261+ dtype = torch .int )
248262 else :
249263 paged_kv_indices_tensor = None
250264 paged_kv_indptr_tensor = None
251265 paged_kv_last_page_lens_tensor = None
252266 block_table_bound_tensor = None
267+ qo_indptr = None
253268
254269 metadata .paged_kv_indptr = paged_kv_indptr_tensor
255270 metadata .paged_kv_indices = paged_kv_indices_tensor
256271 metadata .paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
257272 metadata .block_table_bound = block_table_bound_tensor
273+ metadata .qo_indptr = qo_indptr
258274
259275 return metadata
260276
@@ -263,21 +279,25 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
263279
264280 @contextmanager
265281 def graph_capture (self , max_batch_size : int ):
266- kv_indices , kv_indptr , last_page_lens = get_aiter_mla_metadata (
267- max_batch_size = max_batch_size ,
268- block_size = self .runner .block_size ,
269- max_block_per_batch = self .runner .get_max_block_per_batch (),
270- device = self .runner .device )
282+ kv_indices , kv_indptr , last_page_lens , qo_indptr = \
283+ get_aiter_mla_metadata (
284+ max_batch_size = max_batch_size ,
285+ block_size = self .runner .block_size ,
286+ max_block_per_batch = \
287+ self .runner .get_max_block_per_batch (),
288+ device = self .runner .device )
271289 self ._paged_kv_indices_tensor = kv_indices
272290 self ._paged_kv_indptr_tensor = kv_indptr
273291 self ._paged_kv_last_page_lens_tensor = last_page_lens
292+ self ._qo_indptr_tensor = qo_indptr
274293
275294 with super ().graph_capture (max_batch_size ):
276295 yield
277296
278297 del self ._paged_kv_indices_tensor
279298 del self ._paged_kv_indptr_tensor
280299 del self ._paged_kv_last_page_lens_tensor
300+ del self ._qo_indptr_tensor
281301
282302 def graph_capture_get_metadata_for_batch (
283303 self ,
@@ -291,10 +311,12 @@ def graph_capture_get_metadata_for_batch(
291311 paged_kv_indices = self ._paged_kv_indices_tensor
292312 paged_kv_last_page_lens = self ._paged_kv_last_page_lens_tensor [:
293313 batch_size ]
314+ qo_indptr = self ._qo_indptr_tensor [:batch_size + 1 ]
294315
295316 metadata .paged_kv_indptr = paged_kv_indptr
296317 metadata .paged_kv_indices = paged_kv_indices
297318 metadata .paged_kv_last_page_lens = paged_kv_last_page_lens
319+ metadata .qo_indptr = qo_indptr
298320
299321 return metadata
300322
@@ -311,6 +333,7 @@ def get_graph_input_buffers(self,
311333 input_buffers [
312334 "paged_kv_last_page_lens" ] = attn_metadata .\
313335 decode_metadata .paged_kv_last_page_lens
336+ input_buffers ['qo_indptr' ] = attn_metadata .qo_indptr
314337
315338 return input_buffers
316339
@@ -330,6 +353,8 @@ def prepare_graph_input_buffers(self,
330353 input_buffers ["paged_kv_last_page_lens" ].copy_ (
331354 attn_metadata .decode_metadata .paged_kv_last_page_lens ,
332355 non_blocking = True )
356+ input_buffers ["qo_indptr" ].copy_ (
357+ attn_metadata .decode_metadata .qo_indptr , non_blocking = True )
333358
334359
335360class AiterMLAImpl (MLACommonImpl [AiterMLAMetadata ]):
@@ -370,11 +395,9 @@ def _flash_attn_varlen_diff_headdims(
370395 softmax_scale : float , return_softmax_lse : bool ,
371396 ** kwargs ) -> Union [tuple [torch .Tensor , ...], torch .Tensor ]:
372397 output = self .flash_attn_varlen_func (
373- q = q ,
374- k = k ,
375- v = v ,
376- softmax_scale = softmax_scale ,
377- return_lse = return_softmax_lse ,
398+ q ,
399+ k ,
400+ v ,
378401 ** kwargs ,
379402 )
380403
@@ -394,7 +417,7 @@ def _forward_decode(
394417 B = q_nope .shape [0 ]
395418
396419 q = torch .cat ([q_nope , q_pe ], dim = - 1 )
397- o = torch .zeros (B ,
420+ o = torch .empty (B ,
398421 self .num_heads ,
399422 self .kv_lora_rank ,
400423 dtype = q .dtype ,
@@ -403,6 +426,8 @@ def _forward_decode(
403426 kv_buffer = kv_c_and_k_pe_cache .unsqueeze (2 )
404427
405428 aiter_mla_decode_fwd (q , kv_buffer , o , self .scale ,
429+ attn_metadata .qo_indptr ,
430+ attn_metadata .max_query_len ,
406431 attn_metadata .paged_kv_indptr ,
407432 attn_metadata .paged_kv_indices ,
408433 attn_metadata .paged_kv_last_page_lens )
0 commit comments