diff --git a/aiter/mla.py b/aiter/mla.py index 6b6e1081c8..1beafe4997 100644 --- a/aiter/mla.py +++ b/aiter/mla.py @@ -270,12 +270,14 @@ def mla_decode_fwd( assert False, f"{nhead=} and {max_seqlen_q=} not supported" logits = torch.empty( - (total_s, num_kv_splits, nhead, v_head_dim), + (reduce_partial_map.size(0) * max_seqlen_q, 1, nhead, v_head_dim), dtype=dtypes.fp32, device=device, ) attn_lse = torch.empty( - (total_s, num_kv_splits, nhead, 1), dtype=dtypes.fp32, device=device + (reduce_partial_map.size(0) * max_seqlen_q, 1, nhead, 1), + dtype=dtypes.fp32, + device=device, ) final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device) @@ -310,7 +312,10 @@ def mla_decode_fwd( ) if io_transformed: - logits = logits.view(ori_total_s, num_kv_splits, ori_nhead, v_head_dim) + if persistent_mode: + logits = logits.view(-1, 1, ori_nhead, v_head_dim) + else: + logits = logits.view(ori_total_s, num_kv_splits, ori_nhead, v_head_dim) q = q.view(ori_total_s, ori_nhead, -1) o = o.view(ori_total_s, ori_nhead, -1) diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index bb91d8be80..03ea084642 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -13,7 +13,7 @@ paged_attention_ragged as paged_attention_ragged_core, ) from csrc.cpp_itfs.torch_utils import direct_register_custom_op -from aiter.ops.triton.utils.types import get_fp8_e4m3_dtype +from aiter import dtypes MD_NAME = "module_attention" @@ -359,33 +359,30 @@ def get_mla_metadata_info_v1( device_properties = torch.cuda.get_device_properties(gpu) cu_num = device_properties.multi_processor_count - reduce_batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size max_qo_tiles_per_batch = ( - int(math.ceil(max_seqlen_qo * num_head_qo / 64)) - if num_head_qo == 16 - or (num_head_qo == 128 and kv_dtype == get_fp8_e4m3_dtype()) + int(math.ceil(max_seqlen_qo * num_head_qo / 128)) + if num_head_qo == 16 or (num_head_qo == 128 and kv_dtype == dtypes.fp8) else int(math.ceil(max_seqlen_qo * num_head_qo / 16)) ) + batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size + tile_cnt = batch_size * max_qo_tiles_per_batch + + if fast_mode: + max_work = tile_cnt + cu_num - 1 + max_split_tiles = ( + min(batch_size + cu_num - 1, (cu_num - 1) * 2) * max_qo_tiles_per_batch + ) + else: + max_work = tile_cnt * cu_num + max_split_tiles = tile_cnt * cu_num return ( ((2), torch.uint64), # work_metadata_ptrs ((cu_num + 1), torch.int32), # work_indptr - ( - (batch_size * max_qo_tiles_per_batch * cu_num, 8), - torch.int32, - ), # work_info_set - ( - (reduce_batch_size * max_qo_tiles_per_batch + 1), - torch.int32, - ), # reduce_indptr - ( - (reduce_batch_size * max_qo_tiles_per_batch, 2), - torch.int32, - ), # reduce_final_map - ( - (reduce_batch_size * max_qo_tiles_per_batch * cu_num), - torch.int32, - ), # reduce_partial_map + ((max_work, 8), torch.int32), # work_info_set + ((tile_cnt + 1), torch.int32), # reduce_indptr + ((tile_cnt, 2), torch.int32), # reduce_final_map + (max_split_tiles, torch.int32), # reduce_partial_map ) diff --git a/op_tests/test_mla_persistent.py b/op_tests/test_mla_persistent.py index c38611a768..548ef811da 100644 --- a/op_tests/test_mla_persistent.py +++ b/op_tests/test_mla_persistent.py @@ -249,7 +249,7 @@ def test_mla( work_info_set_size, dtype=work_info_set_type, device="cuda", - ).fill_(-1) + ) reduce_indptr = torch.empty( reduce_indptr_size, dtype=reduce_indptr_type, device="cuda" ) @@ -384,7 +384,9 @@ def test_absorb_decode_fp8(): err = None us_asm_decode = 1e12 - if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and nhead == 16: + if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and ( + nhead == 16 or (nhead in range(32, 128, 16) and mtp == 1) + ): err, us_asm_decode = test_absorb_decode_bf16() elif kvtype == dtypes.fp8 and nhead in [16, 128]: err, us_asm_decode = test_absorb_decode_fp8() @@ -413,7 +415,7 @@ def test_absorb_decode_fp8(): block_size = 1 list_dtype = ["bf16", "fp8"] l_kv_dtype = ["bf16", "fp8"] -list_nhead = [(16, 1), (16, 2), (16, 4), (128, 2)] +list_nhead = [(16, 1), (16, 2), (16, 4), (48, 1), (128, 2)] parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, diff --git a/op_tests/test_mla_sparse.py b/op_tests/test_mla_sparse.py index 6d343d7a4b..88328c394b 100644 --- a/op_tests/test_mla_sparse.py +++ b/op_tests/test_mla_sparse.py @@ -424,7 +424,7 @@ def test_mla( work_info_set_size, dtype=work_info_set_type, device="cuda", - ).fill_(-1) + ) reduce_indptr = torch.empty( reduce_indptr_size, dtype=reduce_indptr_type, device="cuda" )