Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,14 @@ def mla_decode_fwd(
assert False, f"{nhead=} and {max_seqlen_q=} not supported"

logits = torch.empty(
(total_s, num_kv_splits, nhead, v_head_dim),
(reduce_partial_map.size(0) * max_seqlen_q, 1, nhead, v_head_dim),
dtype=dtypes.fp32,
device=device,
)
attn_lse = torch.empty(
(total_s, num_kv_splits, nhead, 1), dtype=dtypes.fp32, device=device
(reduce_partial_map.size(0) * max_seqlen_q, 1, nhead, 1),
dtype=dtypes.fp32,
device=device,
)
final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device)

Expand Down Expand Up @@ -310,7 +312,10 @@ def mla_decode_fwd(
)

if io_transformed:
logits = logits.view(ori_total_s, num_kv_splits, ori_nhead, v_head_dim)
if persistent_mode:
logits = logits.view(-1, 1, ori_nhead, v_head_dim)
else:
logits = logits.view(ori_total_s, num_kv_splits, ori_nhead, v_head_dim)
q = q.view(ori_total_s, ori_nhead, -1)
o = o.view(ori_total_s, ori_nhead, -1)

Expand Down
39 changes: 18 additions & 21 deletions aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
paged_attention_ragged as paged_attention_ragged_core,
)
from csrc.cpp_itfs.torch_utils import direct_register_custom_op
from aiter.ops.triton.utils.types import get_fp8_e4m3_dtype
from aiter import dtypes

MD_NAME = "module_attention"

Expand Down Expand Up @@ -359,33 +359,30 @@ def get_mla_metadata_info_v1(
device_properties = torch.cuda.get_device_properties(gpu)
cu_num = device_properties.multi_processor_count

reduce_batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size
max_qo_tiles_per_batch = (
int(math.ceil(max_seqlen_qo * num_head_qo / 64))
if num_head_qo == 16
or (num_head_qo == 128 and kv_dtype == get_fp8_e4m3_dtype())
int(math.ceil(max_seqlen_qo * num_head_qo / 128))
if num_head_qo == 16 or (num_head_qo == 128 and kv_dtype == dtypes.fp8)
else int(math.ceil(max_seqlen_qo * num_head_qo / 16))
)
batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size
tile_cnt = batch_size * max_qo_tiles_per_batch

if fast_mode:
max_work = tile_cnt + cu_num - 1
max_split_tiles = (
min(batch_size + cu_num - 1, (cu_num - 1) * 2) * max_qo_tiles_per_batch
)
else:
max_work = tile_cnt * cu_num
max_split_tiles = tile_cnt * cu_num

return (
((2), torch.uint64), # work_metadata_ptrs
((cu_num + 1), torch.int32), # work_indptr
(
(batch_size * max_qo_tiles_per_batch * cu_num, 8),
torch.int32,
), # work_info_set
(
(reduce_batch_size * max_qo_tiles_per_batch + 1),
torch.int32,
), # reduce_indptr
(
(reduce_batch_size * max_qo_tiles_per_batch, 2),
torch.int32,
), # reduce_final_map
(
(reduce_batch_size * max_qo_tiles_per_batch * cu_num),
torch.int32,
), # reduce_partial_map
((max_work, 8), torch.int32), # work_info_set
((tile_cnt + 1), torch.int32), # reduce_indptr
((tile_cnt, 2), torch.int32), # reduce_final_map
(max_split_tiles, torch.int32), # reduce_partial_map
)


Expand Down
8 changes: 5 additions & 3 deletions op_tests/test_mla_persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_mla(
work_info_set_size,
dtype=work_info_set_type,
device="cuda",
).fill_(-1)
)
reduce_indptr = torch.empty(
reduce_indptr_size, dtype=reduce_indptr_type, device="cuda"
)
Expand Down Expand Up @@ -384,7 +384,9 @@ def test_absorb_decode_fp8():

err = None
us_asm_decode = 1e12
if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and nhead == 16:
if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and (
nhead == 16 or (nhead in range(32, 128, 16) and mtp == 1)
):
err, us_asm_decode = test_absorb_decode_bf16()
elif kvtype == dtypes.fp8 and nhead in [16, 128]:
err, us_asm_decode = test_absorb_decode_fp8()
Expand Down Expand Up @@ -413,7 +415,7 @@ def test_absorb_decode_fp8():
block_size = 1
list_dtype = ["bf16", "fp8"]
l_kv_dtype = ["bf16", "fp8"]
list_nhead = [(16, 1), (16, 2), (16, 4), (128, 2)]
list_nhead = [(16, 1), (16, 2), (16, 4), (48, 1), (128, 2)]

parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter,
Expand Down
2 changes: 1 addition & 1 deletion op_tests/test_mla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def test_mla(
work_info_set_size,
dtype=work_info_set_type,
device="cuda",
).fill_(-1)
)
reduce_indptr = torch.empty(
reduce_indptr_size, dtype=reduce_indptr_type, device="cuda"
)
Expand Down