diff --git a/aiter/mla.py b/aiter/mla.py index cef8181a7a..0b020facc2 100644 --- a/aiter/mla.py +++ b/aiter/mla.py @@ -118,6 +118,8 @@ def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype): num_kv_splits = sorted(tmp, key=lambda x: x[0], reverse=True)[0][1] get_block_n_fp8 = { + 4: 128, + 8: 128, 16: 128, 32: 128, 48: 64, @@ -188,6 +190,18 @@ def mla_decode_fwd( bs = qo_indptr.shape[0] - 1 total_kv = kv_indices.shape[0] + _head_pad_factor = 1 + _o_unpadded = None + if nhead < 16 and nhead > 0 and 16 % nhead == 0: + _head_pad_factor = 16 // nhead + q = q.repeat_interleave(_head_pad_factor, dim=1) + _o_unpadded = o + nhead = 16 + ori_nhead = 16 + o = torch.empty( + total_s, nhead, v_head_dim, dtype=_o_unpadded.dtype, device=device + ) + persistent_mode = work_meta_data is not None io_transformed = False @@ -266,6 +280,12 @@ def mla_decode_fwd( and nhead in [32, 64] ) ): + if _o_unpadded is not None: + _o_unpadded.copy_(o[:, ::_head_pad_factor, :]) + return ( + logits.view(total_s, nhead, v_head_dim)[:, ::_head_pad_factor, :], + attn_lse[:, :, ::_head_pad_factor, :], + ) return logits.view(total_s, nhead, v_head_dim), attn_lse Lv = v_head_dim @@ -495,6 +515,12 @@ def mla_decode_fwd( .contiguous() ) + if _o_unpadded is not None: + _o_unpadded.copy_(o[:, ::_head_pad_factor, :]) + if final_lse is not None: + final_lse = final_lse[:, ::_head_pad_factor] + logits = logits[:, :, ::_head_pad_factor, :] + return logits, final_lse diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index 572c522593..f11e95cce0 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -916,7 +916,11 @@ def get_mla_metadata_info_v1( 6. Shape of reduce_partial_map followed by its scalar type. """ - assert num_head_qo % 8 == 0 + effective_num_head = num_head_qo + if num_head_qo < 16 and num_head_qo > 0 and 16 % num_head_qo == 0: + effective_num_head = 16 + assert effective_num_head % 8 == 0 + gpu = torch.cuda.current_device() device_properties = torch.cuda.get_device_properties(gpu) cu_num = device_properties.multi_processor_count @@ -934,11 +938,11 @@ def get_mla_metadata_info_v1( ) max_qo_tiles_per_batch = ( - int(math.ceil(max_seqlen_qo * num_head_qo / 128)) - if num_head_qo == 16 + int(math.ceil(max_seqlen_qo * effective_num_head / 128)) + if effective_num_head == 16 or ( get_gfx() == "gfx942" - and num_head_qo == 128 + and effective_num_head == 128 and kv_dtype == dtypes.fp8 and q_dtype == dtypes.fp8 ) @@ -950,7 +954,7 @@ def get_mla_metadata_info_v1( and max_seqlen_qo == 1 ) or use_qseqlen_fold - else int(math.ceil(max_seqlen_qo * num_head_qo / 16)) + else int(math.ceil(max_seqlen_qo * effective_num_head / 16)) ) batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size tile_cnt = batch_size * max_qo_tiles_per_batch diff --git a/csrc/kernels/mla/metadata/v1_2_device.cuh b/csrc/kernels/mla/metadata/v1_2_device.cuh index 81cca36aec..003a41dbc8 100644 --- a/csrc/kernels/mla/metadata/v1_2_device.cuh +++ b/csrc/kernels/mla/metadata/v1_2_device.cuh @@ -471,19 +471,14 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba !natively_supported && (arch_id == "gfx950") && q_is_fp8 && kv_is_fp8 && (num_heads > 16) && ((uni_seqlen_qo * (num_heads / 16) == 4) || ((num_heads == 64) && (uni_seqlen_qo == 2))); - if(use_qseqlen_fold && (num_heads == 64) && (uni_seqlen_qo == 2)) - { - qk_seqlen_ratio = num_heads / 32; - num_heads = 32; - uni_seqlen_qo *= qk_seqlen_ratio; - } - else if(use_qseqlen_fold && (uni_seqlen_qo * (num_heads / 16) == 4)) + const bool pad_to_qh16 = (!natively_supported) && (num_heads < 16) && + (num_heads > 0) && (16 % num_heads == 0); + + if(pad_to_qh16) { - qk_seqlen_ratio = num_heads / 16; - num_heads = 16; - uni_seqlen_qo *= qk_seqlen_ratio; + num_heads = 16; } - else if(!natively_supported && (num_heads % 16 == 0)) + else if((natively_supported == false) && (num_heads % 16 == 0)) { qk_batch_ratio = num_heads / 16; num_heads = 16; diff --git a/op_tests/test_mla.py b/op_tests/test_mla.py index e81121b36f..a3ee519888 100644 --- a/op_tests/test_mla.py +++ b/op_tests/test_mla.py @@ -458,13 +458,15 @@ def test_absorb_decode_fp8(): err = None us_asm_decode = 1e12 if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and nhead in [ + 4, + 8, 16, 32, 64, 128, ]: err, us_asm_decode = test_absorb_decode_bf16() - elif kvtype == dtypes.fp8 and nhead in [16, 128]: + elif kvtype == dtypes.fp8 and nhead in [4, 8, 16, 128]: err, us_asm_decode = test_absorb_decode_fp8() ret["decode:err"] = err @@ -573,10 +575,10 @@ def test_absorb_decode_fp8(): "-n", "--nhead", type=dtypes.str2tuple, - choices=[(16, 1), (16, 2), (16, 4), (128, 1), (128, 2), (128, 4)], + choices=[(4, 1), (16, 1), (16, 2), (16, 4), (128, 1), (128, 2), (128, 4)], nargs="*", const=None, - default=[(16, 1), (16, 2), (16, 4), (128, 1), (128, 2)], + default=[(4, 1), (16, 1), (16, 2), (16, 4), (128, 1), (128, 2)], help="""Number of nhead and decode_qlen. e.g.: -n 16,1""", ) diff --git a/op_tests/test_mla_persistent.py b/op_tests/test_mla_persistent.py index 6427616f04..95ef956ac3 100644 --- a/op_tests/test_mla_persistent.py +++ b/op_tests/test_mla_persistent.py @@ -1702,7 +1702,7 @@ def test_absorb_decode_3buffer(): type=dtypes.str2tuple, nargs="*", const=None, - default=[(16, 1), (16, 2), (16, 4), (48, 1), (128, 2)], + default=[(4, 1), (16, 1), (16, 2), (16, 4), (48, 1), (128, 2)], help="""Number of heads. e.g.: -n 16,1""", ) diff --git a/op_tests/test_mla_sparse.py b/op_tests/test_mla_sparse.py index a699ae0573..76a5c1c3ad 100644 --- a/op_tests/test_mla_sparse.py +++ b/op_tests/test_mla_sparse.py @@ -16,8 +16,8 @@ torch.set_printoptions(sci_mode=False) # current supported case in ps decode MLA: mtp == 0, 1, 2, 3 (decode_qlen = 1, 2, 3, 4) -# qdtype bf16, kdtype bf16: nhead16 -# qdtype fp8, kdtype fp8: nhead16, nhead128 +# qdtype bf16, kdtype bf16: nhead4, nhead8, nhead16 +# qdtype fp8, kdtype fp8: nhead4, nhead8, nhead16, nhead128 # qdtype fp8, kdtype bf16: nhead16 @@ -722,7 +722,7 @@ def test_sparse_mla_fp8(): "--nhead", type=dtypes.str2tuple, nargs="*", - default=[(16, 2), (48, 1), (128, 2)], + default=[(4, 1), (8, 1), (16, 2), (48, 1), (128, 2)], help="""Number of heads. e.g.: -n 16,1""", )