Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype):
num_kv_splits = sorted(tmp, key=lambda x: x[0], reverse=True)[0][1]

get_block_n_fp8 = {
4: 128,
8: 128,
16: 128,
32: 128,
48: 64,
Expand Down Expand Up @@ -188,6 +190,18 @@ def mla_decode_fwd(
bs = qo_indptr.shape[0] - 1
total_kv = kv_indices.shape[0]

_head_pad_factor = 1
_o_unpadded = None
if nhead < 16 and nhead > 0 and 16 % nhead == 0:
_head_pad_factor = 16 // nhead
q = q.repeat_interleave(_head_pad_factor, dim=1)
_o_unpadded = o
nhead = 16
ori_nhead = 16
o = torch.empty(
total_s, nhead, v_head_dim, dtype=_o_unpadded.dtype, device=device
)

persistent_mode = work_meta_data is not None

io_transformed = False
Expand Down Expand Up @@ -266,6 +280,12 @@ def mla_decode_fwd(
and nhead in [32, 64]
)
):
if _o_unpadded is not None:
_o_unpadded.copy_(o[:, ::_head_pad_factor, :])
return (
logits.view(total_s, nhead, v_head_dim)[:, ::_head_pad_factor, :],
attn_lse[:, :, ::_head_pad_factor, :],
)
return logits.view(total_s, nhead, v_head_dim), attn_lse

Lv = v_head_dim
Expand Down Expand Up @@ -495,6 +515,12 @@ def mla_decode_fwd(
.contiguous()
)

if _o_unpadded is not None:
_o_unpadded.copy_(o[:, ::_head_pad_factor, :])
if final_lse is not None:
final_lse = final_lse[:, ::_head_pad_factor]
logits = logits[:, :, ::_head_pad_factor, :]

return logits, final_lse


Expand Down
14 changes: 9 additions & 5 deletions aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,11 @@ def get_mla_metadata_info_v1(
6. Shape of reduce_partial_map followed by its scalar type.
"""

assert num_head_qo % 8 == 0
effective_num_head = num_head_qo
if num_head_qo < 16 and num_head_qo > 0 and 16 % num_head_qo == 0:
effective_num_head = 16
assert effective_num_head % 8 == 0

gpu = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(gpu)
cu_num = device_properties.multi_processor_count
Expand All @@ -934,11 +938,11 @@ def get_mla_metadata_info_v1(
)

max_qo_tiles_per_batch = (
int(math.ceil(max_seqlen_qo * num_head_qo / 128))
if num_head_qo == 16
int(math.ceil(max_seqlen_qo * effective_num_head / 128))
if effective_num_head == 16
or (
get_gfx() == "gfx942"
and num_head_qo == 128
and effective_num_head == 128
and kv_dtype == dtypes.fp8
and q_dtype == dtypes.fp8
)
Expand All @@ -950,7 +954,7 @@ def get_mla_metadata_info_v1(
and max_seqlen_qo == 1
)
or use_qseqlen_fold
else int(math.ceil(max_seqlen_qo * num_head_qo / 16))
else int(math.ceil(max_seqlen_qo * effective_num_head / 16))
)
batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size
tile_cnt = batch_size * max_qo_tiles_per_batch
Expand Down
17 changes: 6 additions & 11 deletions csrc/kernels/mla/metadata/v1_2_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -471,19 +471,14 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
!natively_supported && (arch_id == "gfx950") && q_is_fp8 && kv_is_fp8 && (num_heads > 16) &&
((uni_seqlen_qo * (num_heads / 16) == 4) || ((num_heads == 64) && (uni_seqlen_qo == 2)));

if(use_qseqlen_fold && (num_heads == 64) && (uni_seqlen_qo == 2))
{
qk_seqlen_ratio = num_heads / 32;
num_heads = 32;
uni_seqlen_qo *= qk_seqlen_ratio;
}
else if(use_qseqlen_fold && (uni_seqlen_qo * (num_heads / 16) == 4))
const bool pad_to_qh16 = (!natively_supported) && (num_heads < 16) &&
(num_heads > 0) && (16 % num_heads == 0);

if(pad_to_qh16)
{
qk_seqlen_ratio = num_heads / 16;
num_heads = 16;
uni_seqlen_qo *= qk_seqlen_ratio;
num_heads = 16;
}
else if(!natively_supported && (num_heads % 16 == 0))
else if((natively_supported == false) && (num_heads % 16 == 0))
{
qk_batch_ratio = num_heads / 16;
num_heads = 16;
Expand Down
8 changes: 5 additions & 3 deletions op_tests/test_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,15 @@ def test_absorb_decode_fp8():
err = None
us_asm_decode = 1e12
if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and nhead in [
4,
8,
16,
32,
64,
128,
]:
err, us_asm_decode = test_absorb_decode_bf16()
elif kvtype == dtypes.fp8 and nhead in [16, 128]:
elif kvtype == dtypes.fp8 and nhead in [4, 8, 16, 128]:
err, us_asm_decode = test_absorb_decode_fp8()

ret["decode:err"] = err
Expand Down Expand Up @@ -573,10 +575,10 @@ def test_absorb_decode_fp8():
"-n",
"--nhead",
type=dtypes.str2tuple,
choices=[(16, 1), (16, 2), (16, 4), (128, 1), (128, 2), (128, 4)],
choices=[(4, 1), (16, 1), (16, 2), (16, 4), (128, 1), (128, 2), (128, 4)],
nargs="*",
const=None,
default=[(16, 1), (16, 2), (16, 4), (128, 1), (128, 2)],
default=[(4, 1), (16, 1), (16, 2), (16, 4), (128, 1), (128, 2)],
help="""Number of nhead and decode_qlen.
e.g.: -n 16,1""",
)
Expand Down
2 changes: 1 addition & 1 deletion op_tests/test_mla_persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,7 +1702,7 @@ def test_absorb_decode_3buffer():
type=dtypes.str2tuple,
nargs="*",
const=None,
default=[(16, 1), (16, 2), (16, 4), (48, 1), (128, 2)],
default=[(4, 1), (16, 1), (16, 2), (16, 4), (48, 1), (128, 2)],
help="""Number of heads.
e.g.: -n 16,1""",
)
Expand Down
6 changes: 3 additions & 3 deletions op_tests/test_mla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
torch.set_printoptions(sci_mode=False)

# current supported case in ps decode MLA: mtp == 0, 1, 2, 3 (decode_qlen = 1, 2, 3, 4)
# qdtype bf16, kdtype bf16: nhead16
# qdtype fp8, kdtype fp8: nhead16, nhead128
# qdtype bf16, kdtype bf16: nhead4, nhead8, nhead16
# qdtype fp8, kdtype fp8: nhead4, nhead8, nhead16, nhead128
# qdtype fp8, kdtype bf16: nhead16


Expand Down Expand Up @@ -722,7 +722,7 @@ def test_sparse_mla_fp8():
"--nhead",
type=dtypes.str2tuple,
nargs="*",
default=[(16, 2), (48, 1), (128, 2)],
default=[(4, 1), (8, 1), (16, 2), (48, 1), (128, 2)],
help="""Number of heads.
e.g.: -n 16,1""",
)
Expand Down
Loading