Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
def _fwd_kernel_stage2_asm(
Mid_O,
Mid_lse,
O,
O, # noqa: E741
Final_lse,
qo_indptr,
kv_indptr,
num_kv_splits_indptr,
Expand All @@ -29,7 +30,9 @@ def _fwd_kernel_stage2_asm(
stride_mid_os: tl.int64,
stride_obs: tl.int64,
stride_oh: tl.int64,
stride_lse_bs: tl.int64,
MAYBE_FINAL_OUT: tl.constexpr,
HAS_FINAL_LSE: tl.constexpr,
BATCH_NUM: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
Expand Down Expand Up @@ -57,7 +60,6 @@ def _fwd_kernel_stage2_asm(
if FINAL_OUT:
input_ptr = Mid_O.to(tl.pointer_type(O.type.element_ty))
out = tl.load(
# input_ptr + offs_v + stride_mid_ob * Lv,
input_ptr
+ Lv * (cur_qo * stride_mid_os + cur_head * stride_mid_oh)
+ offs_d,
Expand Down Expand Up @@ -96,6 +98,11 @@ def _fwd_kernel_stage2_asm(
acc / e_sum,
mask=mask_d,
)
if HAS_FINAL_LSE:
tl.store(
Final_lse + cur_qo * stride_lse_bs + cur_head,
e_max + tl.log(e_sum),
)


@functools.lru_cache()
Expand Down Expand Up @@ -205,6 +212,12 @@ def mla_decode_fwd(
if (
nhead == 128 and q.dtype == dtypes.fp8 and kv_buffer.dtype == dtypes.fp8
)
or (
nhead == 64
and q.dtype == dtypes.bf16
and kv_buffer.dtype == dtypes.bf16
and max_seqlen_q == 1
)
else mgc
)

Expand Down Expand Up @@ -232,7 +245,11 @@ def mla_decode_fwd(
attn_lse = torch.empty(
(total_s, num_kv_splits, nhead, 1), dtype=dtypes.fp32, device=device
)
final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device)
final_lse = (
torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device)
if return_lse
else None
)

aiter.mla_decode_stage1_asm_fwd(
q,
Expand All @@ -252,31 +269,34 @@ def mla_decode_fwd(
logits,
attn_lse,
o,
None,
final_lse,
q_scale,
kv_scale,
)

if num_kv_splits == 1 and (
q.dtype == dtypes.fp8
or (q.dtype == dtypes.bf16 and max_seqlen_q == 4)
or (
q.dtype == dtypes.bf16
and kv_buffer.dtype == dtypes.bf16
and nhead in [32, 64]
)
q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4)
):
return logits.view(total_s, nhead, v_head_dim), attn_lse
lse = final_lse if return_lse else attn_lse
return logits.view(total_s, nhead, v_head_dim), lse

Lv = v_head_dim
BLOCK_DV = triton.next_power_of_2(Lv)
grid = (bs, nhead)
extra_kargs = {"waves_per_eu": 4}

has_final_lse = final_lse is not None
final_lse_buf = (
final_lse
if has_final_lse
else torch.empty((1,), dtype=dtypes.fp32, device=device)
)

_fwd_kernel_stage2_asm[grid](
logits,
attn_lse,
o,
final_lse_buf,
qo_indptr,
kv_indptr,
num_kv_splits_indptr,
Expand All @@ -285,7 +305,9 @@ def mla_decode_fwd(
attn_lse.stride(1),
o.stride(0),
o.stride(1),
final_lse_buf.stride(0) if has_final_lse else 0,
MAYBE_FINAL_OUT=MAYBE_FINAL_OUT,
HAS_FINAL_LSE=has_final_lse,
BATCH_NUM=bs,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
Expand Down Expand Up @@ -512,11 +534,10 @@ def mla_prefill_fwd(
num_kv_splits=None, # for experts only!!!
):
device = q.device
num_page, page_size, nhead_kv, qk_head_dim = kv_buffer.shape
assert logit_cap <= 0, f"{logit_cap=} is not support yet"
if sm_scale is None:
sm_scale = 1.0 / (qk_head_dim**0.5)

num_page, page_size, nhead_kv, qk_head_dim = kv_buffer.shape
bs, nhead, v_head_dim = o.shape

num_kv_splits = 1
Expand Down
25 changes: 16 additions & 9 deletions csrc/py_itfs_cu/asm_mla.cu
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void mla_decode_stage1_asm_fwd(
int stride_Page = KV->stride(0) * KV->element_size();
uint32_t log2_page = (uint32_t)log2f(page_size);

KernelArgs args;
KernelArgs args = {};
size_t arg_size = sizeof(args);
args.ptr_R = splitData->data_ptr();
args.ptr_LSE = splitLse->data_ptr();
Expand All @@ -149,10 +149,17 @@ void mla_decode_stage1_asm_fwd(
args.s_Q_Bs = stride_Q;
args.s_Bs = stride_Page;
args.s_log2_plen = log2_page;
args.out_16_nosplit = kv_split;
args.ptr_LSEP = nullptr;
if (lse != nullptr)
{
args.ptr_LSEP = lse->data_ptr();
}

if (persistent)
{
args.out_16_nosplit = kv_split;
args.ptr_RP = output->data_ptr();

if (work_meta_data != nullptr)
{
args.ptr_STP = work_meta_data->data_ptr();
Expand All @@ -178,14 +185,10 @@ void mla_decode_stage1_asm_fwd(
}
else
{
args.out_16_nosplit = 0;
args.ptr_RP = nullptr;
args.ptr_STP = num_kv_splits_indptr->data_ptr();
}
args.ptr_RP = output->data_ptr(); //final output
args.ptr_LSEP = nullptr;
if (lse != nullptr)
{
args.ptr_LSEP = lse->data_ptr(); //final lse
}

// std::cout << "mla args" << std::endl;
// std::cout << "ptr_R: " << args.ptr_R << std::endl;
Expand Down Expand Up @@ -325,7 +328,11 @@ void mla_decode_stage1_asm_fwd(
} else if (gqa_ratio == 64){
if (q_type == "bf16" && kv_type == "bf16"){
if(!persistent){
config_max_seqlen_q = 0;
if(max_seqlen_q == 1){
config_max_seqlen_q = 1;
} else {
config_max_seqlen_q = 0;
}
sub_Q = 64;
}
} else if (q_type == "fp8" && kv_type == "fp8"){
Expand Down
Binary file not shown.
Binary file not shown.
2 changes: 2 additions & 0 deletions hsa/gfx950/mla/mla_asm.csv
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ fp8,fp8,1,1,0,1,1,0,_ZN5aiter40mla_pfl_qh192_vh128_m32x8_n128x1_causal1E,mla_pfl
fp8,fp8,1,1,0,1,0,0,_ZN5aiter40mla_pfl_qh192_vh128_m32x8_n128x1_causal0E,mla_pfl_qh192_vh128_m32x8_n128x1_causal0.co
bf16,bf16,32,0,0,0,0,0,_ZN5aiter39mla_a16w16_qh16_m32x1_n16x1_coex0_mask1E,MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co
bf16,bf16,64,0,0,0,0,0,_ZN5aiter39mla_a16w16_qh16_m64x1_n16x1_coex0_mask1E,MLA_A16W16_1TG_4W_64mx1_16nx1_Coex0_Msk1_QH16.co
bf16,bf16,64,0,1,0,0,0,_ZN5aiter38mla_a16w16_qh64_qseqlen1_gqaratio64_v3E,mla_a16w16_qh64_qseqlen1_gqaratio64_v3.co
bf16,bf16,64,0,1,0,0,1,_ZN5aiter42mla_a16w16_qh64_qseqlen1_gqaratio64_lse_v3E,mla_a16w16_qh64_qseqlen1_gqaratio64_lse_v3.co
fp8,fp8,8,1,4,0,0,0,_ZN5aiter35mla_a8w8_qh32_qseqlen4_gqaratio8_psE,mla_a8w8_qh32_qseqlen4_gqaratio8_ps.co
fp8,fp8,8,1,4,0,0,1,_ZN5aiter39mla_a8w8_qh32_qseqlen4_gqaratio8_lse_psE,mla_a8w8_qh32_qseqlen4_gqaratio8_lse_ps.co
40 changes: 27 additions & 13 deletions op_tests/test_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ def ref_masked_attention(
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weights += attn_bias
lse = attn_weights.logsumexp(dim=-1)
attn_weights = torch.softmax(attn_weights, dim=-1)

out = torch.einsum("hqk,khd->qhd", attn_weights.float(), value.float())
return out.to(dtype)
return out.to(dtype), lse


def torch_mha_extend(
Expand All @@ -82,7 +83,7 @@ def torch_mha_extend(
q = qs[i]
k = ks[i]
v = vs[i]
o = ref_masked_attention(q, k, v, sm_scale, dtype)
o, _ = ref_masked_attention(q, k, v, sm_scale, dtype)
os.append(o)
o = torch.concat(os)
return o
Expand All @@ -106,15 +107,18 @@ def torch_mla_extend(
bs = qo_indptr.shape[0] - 1

os = []
lses = []
for i in range(bs):
kvc = kvs[i]
q = qs[i]
k = kvc
v, _ = torch.split(kvc, [kv_lora_rank, qk_rope_head_dim], dim=-1)
o = ref_masked_attention(q, k, v, sm_scale, dtype, is_causal=is_causal)
o, lse = ref_masked_attention(q, k, v, sm_scale, dtype, is_causal=is_causal)
os.append(o)
lses.append(lse)
o = torch.concat(os)
return o
lse = torch.concat(lses, dim=1).transpose(0, 1)
return o, lse


@benchmark()
Expand All @@ -132,6 +136,7 @@ def test_mla(
varlen,
decode_qlen,
split_per_batch=None,
return_lse=False,
):
ret = {}

Expand Down Expand Up @@ -236,7 +241,7 @@ def test_normal_prefill():
def test_absorb_prefill():
q = torch.randn((total_qo, nhead, qk_head_dim), dtype=torch.bfloat16)

out_ref = torch_mla_extend(
out_ref, _ = torch_mla_extend(
q,
kv_buffer,
qo_indptr,
Expand Down Expand Up @@ -326,7 +331,7 @@ def test_absorb_prefill():
q = torch.randn((total_q, nhead, qk_head_dim), dtype=torch.bfloat16)

# troch implementation
out_ref = torch_mla_extend(
out_ref, lse_ref = torch_mla_extend(
q,
kv_buffer,
qo_indptr,
Expand Down Expand Up @@ -390,19 +395,20 @@ def test_absorb_decode_bf16():
nhead_kv,
sm_scale,
num_kv_splits=split_per_batch,
return_lse=return_lse,
)

# print(f"{out_ref.view(total_q, -1)=}")
# print(f"{out_asm.view(total_q, -1)=}")
# checkAllclose(logits_ref, attn_logits,
# msg=f'attn_logits [golden vs aiter_asm]')
# checkAllclose(lse_ref, attn_lse,
# msg=f'attn_lse [golden vs aiter_asm]')
err = checkAllclose(
out_ref,
out_asm,
msg=f"mla_decode-absorb [golden vs aiter_asm]: {us_asm_decode:>8.2f} us......",
)
if return_lse and attn_lse is not None:
checkAllclose(
lse_ref,
attn_lse.reshape(total_q, nhead),
msg=f"mla_decode-absorb [lse_ref vs attn_lse]: {us_asm_decode:>8.2f} us......",
)
return err, us_asm_decode

def test_absorb_decode_fp8():
Expand Down Expand Up @@ -573,7 +579,7 @@ def test_absorb_decode_fp8():
"-n",
"--nhead",
type=dtypes.str2tuple,
choices=[(16, 1), (16, 2), (16, 4), (128, 1), (128, 2), (128, 4)],
choices=[(16, 1), (16, 2), (16, 4), (64, 1), (128, 1), (128, 2), (128, 4)],
nargs="*",
const=None,
default=[(16, 1), (16, 2), (16, 4), (128, 1), (128, 2)],
Expand All @@ -595,6 +601,13 @@ def test_absorb_decode_fp8():
help="""variable kv seqlens per batch. Default: False.
--varlen # True""",
)
parser.add_argument(
"-lse",
"--return_lse",
action="store_true",
help="""return lse. Default: False.
--lse # True""",
)


args = parser.parse_args()
Expand All @@ -619,6 +632,7 @@ def test_absorb_decode_fp8():
varlen=args.varlen,
decode_qlen=decode_qlen,
split_per_batch=split_per_batch,
return_lse=args.return_lse,
)
df.append(ret)
df = pd.DataFrame(df)
Expand Down
Loading