Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,12 @@ def mla_decode_fwd(
and (
q.dtype == dtypes.fp8
or (q.dtype == dtypes.bf16 and max_seqlen_q == 4)
or (
q.dtype == dtypes.bf16
and kv_buffer.dtype == dtypes.bf16
and nhead == 32
and max_seqlen_q == 1
)
)
)
else torch.empty(
Expand Down Expand Up @@ -275,7 +281,14 @@ def mla_decode_fwd(
)

if num_kv_splits == 1 and (
q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4)
q.dtype == dtypes.fp8
or (q.dtype == dtypes.bf16 and max_seqlen_q == 4)
or (
q.dtype == dtypes.bf16
and kv_buffer.dtype == dtypes.bf16
and nhead == 32
and max_seqlen_q == 1
)
):
lse = final_lse if return_lse else attn_lse
return logits.view(total_s, nhead, v_head_dim), lse
Expand Down
9 changes: 7 additions & 2 deletions csrc/py_itfs_cu/asm_mla.cu
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,13 @@ void mla_decode_stage1_asm_fwd(
}
else
{
args.out_16_nosplit = 0;
args.ptr_RP = nullptr;
// The legacy QH16 m32x1_n16x1 kernel (gqa_ratio=32, decode qseqlen=1)
// writes directly to output via ptr_RP when kv_split==1. Passing
// nullptr causes GPU memory faults on gfx950. Other non-persistent
// kernels (v3, stage1) use split-reduce and expect ptr_RP = nullptr.
bool legacy_qh16 = (gqa_ratio == 32 && max_seqlen_q == 1);
args.out_16_nosplit = legacy_qh16 ? kv_split : 0;
args.ptr_RP = legacy_qh16 ? output->data_ptr() : nullptr;
args.ptr_STP = num_kv_splits_indptr->data_ptr();
}

Expand Down
Loading