diff --git a/aiter/mla.py b/aiter/mla.py index 32fe8aed36..2be2a228aa 100644 --- a/aiter/mla.py +++ b/aiter/mla.py @@ -233,6 +233,12 @@ def mla_decode_fwd( and ( q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4) + or ( + q.dtype == dtypes.bf16 + and kv_buffer.dtype == dtypes.bf16 + and nhead == 32 + and max_seqlen_q == 1 + ) ) ) else torch.empty( @@ -275,7 +281,14 @@ def mla_decode_fwd( ) if num_kv_splits == 1 and ( - q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4) + q.dtype == dtypes.fp8 + or (q.dtype == dtypes.bf16 and max_seqlen_q == 4) + or ( + q.dtype == dtypes.bf16 + and kv_buffer.dtype == dtypes.bf16 + and nhead == 32 + and max_seqlen_q == 1 + ) ): lse = final_lse if return_lse else attn_lse return logits.view(total_s, nhead, v_head_dim), lse diff --git a/csrc/py_itfs_cu/asm_mla.cu b/csrc/py_itfs_cu/asm_mla.cu index 622bdbc65e..56d31357f2 100644 --- a/csrc/py_itfs_cu/asm_mla.cu +++ b/csrc/py_itfs_cu/asm_mla.cu @@ -185,8 +185,13 @@ void mla_decode_stage1_asm_fwd( } else { - args.out_16_nosplit = 0; - args.ptr_RP = nullptr; + // The legacy QH16 m32x1_n16x1 kernel (gqa_ratio=32, decode qseqlen=1) + // writes directly to output via ptr_RP when kv_split==1. Passing + // nullptr causes GPU memory faults on gfx950. Other non-persistent + // kernels (v3, stage1) use split-reduce and expect ptr_RP = nullptr. + bool legacy_qh16 = (gqa_ratio == 32 && max_seqlen_q == 1); + args.out_16_nosplit = legacy_qh16 ? kv_split : 0; + args.ptr_RP = legacy_qh16 ? output->data_ptr() : nullptr; args.ptr_STP = num_kv_splits_indptr->data_ptr(); }