@@ -53,7 +53,7 @@ def get_state_cls() -> Type["AiterMLAState"]:
5353
5454@dataclass
5555class AiterMLAMetadata (MLACommonMetadata ):
56- # The following 5 tensors are for current version of AITER MLA
56+ # The following 4 tensors are for current version of AITER MLA
5757 block_table_bound : Optional [torch .Tensor ] = None
5858 # The indptr of the paged kv cache, shape: [batch_size + 1]
5959 paged_kv_indptr : Optional [torch .Tensor ] = None
@@ -63,10 +63,6 @@ class AiterMLAMetadata(MLACommonMetadata):
6363 # the paged kv cache, shape: [batch_size]
6464 paged_kv_last_page_lens : Optional [torch .Tensor ] = None
6565
66- # This is just to make new AITER MLA API work
67- # -- MTP support is not added yet.
68- qo_indptr : Optional [torch .Tensor ] = None
69-
7066 @property
7167 def prefill_metadata (self ):
7268 prefill_metadata = super ().prefill_metadata
@@ -78,7 +74,6 @@ def prefill_metadata(self):
7874 prefill_metadata \
7975 .paged_kv_last_page_lens = self .paged_kv_last_page_lens
8076 prefill_metadata .block_table_bound = self .block_table_bound
81- prefill_metadata .qo_indptr = self .qo_indptr
8277
8378 # update the cache
8479 self ._cached_prefill_metadata = self .__class__ (
@@ -98,7 +93,6 @@ def decode_metadata(self):
9893 decode_metadata \
9994 .paged_kv_last_page_lens = self .paged_kv_last_page_lens
10095 decode_metadata .block_table_bound = self .block_table_bound
101- decode_metadata .qo_indptr = self .qo_indptr
10296
10397 # update the cache
10498 self ._cached_decode_metadata = self .__class__ (
@@ -142,7 +136,6 @@ def prepare(self):
142136 self .paged_kv_indptr : list [int ] = [0 ]
143137 self .paged_kv_last_page_lens : list [int ] = []
144138 self .total_blocks = 0
145- self .qo_indptr : list [int ] = [0 ]
146139
147140 def _add_seq_group (self , inter_data , chunked_prefill_enabled : bool ,
148141 prefix_cache_hit : bool ):
@@ -215,7 +208,6 @@ def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
215208 self .paged_kv_indices .extend (block_table [:block_table_bound ])
216209 self .paged_kv_indptr .append (self .paged_kv_indptr [- 1 ] +
217210 block_table_bound )
218- self .qo_indptr .append (self .qo_indptr [- 1 ] + 1 )
219211
220212 last_page_len = seq_len % self .block_size
221213 if last_page_len == 0 :
@@ -234,8 +226,6 @@ def build(self, seq_lens: list[int], query_lens: list[int],
234226 self .paged_kv_indptr .extend ([last_paged_kv_indptr ] *
235227 cuda_graph_pad_size )
236228 self .paged_kv_last_page_lens .extend ([0 ] * cuda_graph_pad_size )
237- last_qo_indptr = self .qo_indptr [- 1 ]
238- self .qo_indptr .extend ([last_qo_indptr ] * cuda_graph_pad_size )
239229
240230 # For current version of AITER MLA
241231 if len (self .paged_kv_indptr ) > 0 :
@@ -255,22 +245,16 @@ def build(self, seq_lens: list[int], query_lens: list[int],
255245 1 ,
256246 device = device ,
257247 dtype = torch .int )
258-
259- qo_indptr = torch .tensor (self .qo_indptr ,
260- device = device ,
261- dtype = torch .int )
262248 else :
263249 paged_kv_indices_tensor = None
264250 paged_kv_indptr_tensor = None
265251 paged_kv_last_page_lens_tensor = None
266252 block_table_bound_tensor = None
267- qo_indptr = None
268253
269254 metadata .paged_kv_indptr = paged_kv_indptr_tensor
270255 metadata .paged_kv_indices = paged_kv_indices_tensor
271256 metadata .paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
272257 metadata .block_table_bound = block_table_bound_tensor
273- metadata .qo_indptr = qo_indptr
274258
275259 return metadata
276260
@@ -279,25 +263,21 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
279263
280264 @contextmanager
281265 def graph_capture (self , max_batch_size : int ):
282- kv_indices , kv_indptr , last_page_lens , qo_indptr = \
283- get_aiter_mla_metadata (
284- max_batch_size = max_batch_size ,
285- block_size = self .runner .block_size ,
286- max_block_per_batch = \
287- self .runner .get_max_block_per_batch (),
288- device = self .runner .device )
266+ kv_indices , kv_indptr , last_page_lens = get_aiter_mla_metadata (
267+ max_batch_size = max_batch_size ,
268+ block_size = self .runner .block_size ,
269+ max_block_per_batch = self .runner .get_max_block_per_batch (),
270+ device = self .runner .device )
289271 self ._paged_kv_indices_tensor = kv_indices
290272 self ._paged_kv_indptr_tensor = kv_indptr
291273 self ._paged_kv_last_page_lens_tensor = last_page_lens
292- self ._qo_indptr_tensor = qo_indptr
293274
294275 with super ().graph_capture (max_batch_size ):
295276 yield
296277
297278 del self ._paged_kv_indices_tensor
298279 del self ._paged_kv_indptr_tensor
299280 del self ._paged_kv_last_page_lens_tensor
300- del self ._qo_indptr_tensor
301281
302282 def graph_capture_get_metadata_for_batch (
303283 self ,
@@ -311,12 +291,10 @@ def graph_capture_get_metadata_for_batch(
311291 paged_kv_indices = self ._paged_kv_indices_tensor
312292 paged_kv_last_page_lens = self ._paged_kv_last_page_lens_tensor [:
313293 batch_size ]
314- qo_indptr = self ._qo_indptr_tensor [:batch_size + 1 ]
315294
316295 metadata .paged_kv_indptr = paged_kv_indptr
317296 metadata .paged_kv_indices = paged_kv_indices
318297 metadata .paged_kv_last_page_lens = paged_kv_last_page_lens
319- metadata .qo_indptr = qo_indptr
320298
321299 return metadata
322300
@@ -333,7 +311,6 @@ def get_graph_input_buffers(self,
333311 input_buffers [
334312 "paged_kv_last_page_lens" ] = attn_metadata .\
335313 decode_metadata .paged_kv_last_page_lens
336- input_buffers ['qo_indptr' ] = attn_metadata .qo_indptr
337314
338315 return input_buffers
339316
@@ -353,8 +330,6 @@ def prepare_graph_input_buffers(self,
353330 input_buffers ["paged_kv_last_page_lens" ].copy_ (
354331 attn_metadata .decode_metadata .paged_kv_last_page_lens ,
355332 non_blocking = True )
356- input_buffers ["qo_indptr" ].copy_ (
357- attn_metadata .decode_metadata .qo_indptr , non_blocking = True )
358333
359334
360335class AiterMLAImpl (MLACommonImpl [AiterMLAMetadata ]):
@@ -395,9 +370,11 @@ def _flash_attn_varlen_diff_headdims(
395370 softmax_scale : float , return_softmax_lse : bool ,
396371 ** kwargs ) -> Union [tuple [torch .Tensor , ...], torch .Tensor ]:
397372 output = self .flash_attn_varlen_func (
398- q ,
399- k ,
400- v ,
373+ q = q ,
374+ k = k ,
375+ v = v ,
376+ softmax_scale = softmax_scale ,
377+ return_lse = return_softmax_lse ,
401378 ** kwargs ,
402379 )
403380
@@ -417,7 +394,7 @@ def _forward_decode(
417394 B = q_nope .shape [0 ]
418395
419396 q = torch .cat ([q_nope , q_pe ], dim = - 1 )
420- o = torch .empty (B ,
397+ o = torch .zeros (B ,
421398 self .num_heads ,
422399 self .kv_lora_rank ,
423400 dtype = q .dtype ,
@@ -426,8 +403,6 @@ def _forward_decode(
426403 kv_buffer = kv_c_and_k_pe_cache .unsqueeze (2 )
427404
428405 aiter_mla_decode_fwd (q , kv_buffer , o , self .scale ,
429- attn_metadata .qo_indptr ,
430- attn_metadata .max_query_len ,
431406 attn_metadata .paged_kv_indptr ,
432407 attn_metadata .paged_kv_indices ,
433408 attn_metadata .paged_kv_last_page_lens )
0 commit comments