Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 29 additions & 29 deletions aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import functools
from typing import Optional

import torch
import triton
import triton.language as tl
Expand Down Expand Up @@ -43,6 +44,7 @@ def _fwd_kernel_stage2_asm(
cur_split_end = tl.load(num_kv_splits_indptr + cur_batch + 1)
num_max_kv_splits = tl.load(num_kv_splits_indptr + BATCH_NUM)
cur_kv_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch)

offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lv

Expand Down Expand Up @@ -118,6 +120,8 @@ def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype):
num_kv_splits = sorted(tmp, key=lambda x: x[0], reverse=True)[0][1]

get_block_n_fp8 = {
4: 128,
8: 128,
16: 128,
32: 128,
48: 64,
Expand All @@ -133,11 +137,6 @@ def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype):
num_kv_splits = min(
num_kv_splits, int(total_kv / bs + min_block_n - 1) // min_block_n
)
if num_kv_splits > 1:
num_kv_splits = min(
num_kv_splits,
(abs(total_kv / bs - max_seqlen_q) // min_block_n + 1),
)

num_kv_splits_indptr = torch.arange(
0, (bs + 1) * num_kv_splits, num_kv_splits, dtype=torch.int, device="cuda"
Expand Down Expand Up @@ -188,10 +187,24 @@ def mla_decode_fwd(
bs = qo_indptr.shape[0] - 1
total_kv = kv_indices.shape[0]

# Handle nhead < 16 by padding to 16 via repeat_interleave so that models
# with small head counts (e.g. Kimi-Linear-48B-A3B with TP=8, nhead=4)
# can reuse the nhead=16 ASM kernel transparently.
_head_pad_factor = 1
_o_unpadded = None
if nhead < 16 and nhead > 0 and 16 % nhead == 0:
_head_pad_factor = 16 // nhead
q = q.repeat_interleave(_head_pad_factor, dim=1)
_o_unpadded = o
nhead = 16
ori_nhead = 16
o = torch.empty(
total_s, nhead, v_head_dim, dtype=_o_unpadded.dtype, device=device
)

persistent_mode = work_meta_data is not None

io_transformed = False
qseqlen_folded = False

if not persistent_mode:
if num_kv_splits is None or num_kv_splits_indptr is None:
Expand All @@ -200,13 +213,6 @@ def mla_decode_fwd(
)

mgc = 64 if max_seqlen_q == 1 and nhead == 16 else 16
mgc = (
32
if (
nhead == 128 and q.dtype == dtypes.fp8 and kv_buffer.dtype == dtypes.fp8
)
else mgc
)

MAYBE_FINAL_OUT = True

Expand Down Expand Up @@ -266,6 +272,8 @@ def mla_decode_fwd(
and nhead in [32, 64]
)
):
if _o_unpadded is not None:
_o_unpadded.copy_(o[:, ::_head_pad_factor, :])
return logits.view(total_s, nhead, v_head_dim), attn_lse

Lv = v_head_dim
Expand Down Expand Up @@ -318,22 +326,9 @@ def mla_decode_fwd(
elif nhead in range(32, 128 + 1, 16) and persistent_mode:
# we use nhead=16 to simulate such cases by customized metadata
# metadata also views qo's tensor as shape (total_s * (nhead // 16), 16, ...)
fold_factor = ori_nhead // 16
use_qseqlen_fold = (
get_gfx() == "gfx950"
and q.dtype == dtypes.fp8
and kv_buffer.dtype == dtypes.fp8
and max_seqlen_q * fold_factor == 4
)

total_s = ori_total_s * fold_factor
total_s = ori_total_s * (ori_nhead // 16)
nhead = 16

if use_qseqlen_fold:
max_seqlen_q = max_seqlen_q * fold_factor
q = q.view(total_s, nhead, -1)
qseqlen_folded = True
elif max_seqlen_q == 1:
if max_seqlen_q == 1:
q = q.view(total_s, nhead, -1)
else:
q = (
Expand Down Expand Up @@ -433,7 +428,7 @@ def mla_decode_fwd(
if return_logits:
logits = logits.view(-1, 1, ori_nhead, v_head_dim)

if max_seqlen_q == 1 or qseqlen_folded:
if max_seqlen_q == 1:
q = q.view(ori_total_s, ori_nhead, -1)
o = o.view(ori_total_s, ori_nhead, -1)
if final_lse is not None:
Expand Down Expand Up @@ -467,6 +462,11 @@ def mla_decode_fwd(
.contiguous()
)

if _o_unpadded is not None:
_o_unpadded.copy_(o[:, ::_head_pad_factor, :])
if final_lse is not None:
final_lse = final_lse[:, ::_head_pad_factor]

return logits, final_lse


Expand Down
155 changes: 21 additions & 134 deletions aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import math
from typing import Optional, Tuple

from aiter.ops.enum import QuantType, Enum
import torch
import triton
import triton.language as tl
Expand Down Expand Up @@ -106,26 +105,7 @@ def pa_fwd_naive(
) -> torch.Tensor: ...


@compile_ops(
"module_attention_asm", fc_name="pa_fwd", ffi_type="ctypes", gen_fake=gen_pa_fwd_asm
)
def _pa_fwd_asm(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_tables_stride0: int,
max_qlen: int = 1,
K_QScale: Optional[torch.Tensor] = None,
V_QScale: Optional[torch.Tensor] = None,
out_: Optional[torch.Tensor] = None,
qo_indptr: Optional[torch.Tensor] = None,
high_precision: Optional[int] = 1,
kernelName: Optional[str] = None,
) -> None: ...


@compile_ops("module_attention_asm", gen_fake=gen_pa_fwd_asm)
def pa_fwd_asm(
Q: torch.Tensor,
K: torch.Tensor,
Expand All @@ -142,24 +122,7 @@ def pa_fwd_asm(
int
] = 1, # [0, 1, 2] 2 is the highest precision, this is only for fp8 kvcache
kernelName: Optional[str] = None,
) -> torch.Tensor:
output = out_ if out_ is not None else torch.empty_like(Q)
_pa_fwd_asm(
Q,
K,
V,
block_tables,
context_lens,
block_tables_stride0,
max_qlen,
K_QScale,
V_QScale,
output,
qo_indptr,
high_precision,
kernelName,
)
return output
) -> torch.Tensor: ...


def _should_use_asm_kernel(
Expand Down Expand Up @@ -303,57 +266,28 @@ def gen_pa_ps_fwd_asm(
int
] = 1, # [0, 1, 2] 2 is the highest precision, this is only for fp8 kvcache
kernelName: Optional[str] = None,
quant_type: Optional[Enum] = QuantType.per_Token.value,
) -> torch.Tensor:
if out_ is not None:
return out_
else:
return torch.empty_like(Q)


@compile_ops(
"module_attention_asm",
fc_name="pa_ps_fwd",
ffi_type="ctypes",
gen_fake=gen_pa_ps_fwd_asm,
)
def _pa_ps_fwd_asm(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
kv_indptr: torch.Tensor,
kv_page_indices: torch.Tensor,
context_lens: torch.Tensor,
softmax_scale: float,
max_qlen: int = 1,
K_QScale: Optional[torch.Tensor] = None,
V_QScale: Optional[torch.Tensor] = None,
out_: Optional[torch.Tensor] = None,
qo_indptr: Optional[torch.Tensor] = None,
work_indptr: Optional[torch.Tensor] = None,
work_info: Optional[torch.Tensor] = None,
splitData: Optional[torch.Tensor] = None,
splitLse: Optional[torch.Tensor] = None,
mask: int = 0,
high_precision: Optional[int] = 1,
kernelName: Optional[str] = None,
quant_type: Optional[Enum] = QuantType.per_Token.value,
) -> None: ...


@compile_ops("module_attention_asm", gen_fake=gen_pa_fwd_asm)
def pa_ps_fwd_asm(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
kv_indptr: torch.Tensor,
kv_page_indices: torch.Tensor,
context_lens: torch.Tensor,
softmax_scale: float,
softmax_scale: float, # better have ?
max_qlen: int = 1,
K_QScale: Optional[torch.Tensor] = None,
V_QScale: Optional[torch.Tensor] = None,
out_: Optional[torch.Tensor] = None,
qo_indptr: Optional[torch.Tensor] = None,
# work_meta_data: Optional[torch.Tensor] = None,
work_indptr: Optional[torch.Tensor] = None,
work_info: Optional[torch.Tensor] = None,
splitData: Optional[torch.Tensor] = None,
Expand All @@ -363,32 +297,7 @@ def pa_ps_fwd_asm(
int
] = 1, # [0, 1, 2] 2 is the highest precision, this is only for fp8 kvcache
kernelName: Optional[str] = None,
quant_type: Optional[Enum] = QuantType.per_Token.value,
) -> torch.Tensor:
output = out_ if out_ is not None else torch.empty_like(Q)
_pa_ps_fwd_asm(
Q,
K,
V,
kv_indptr,
kv_page_indices,
context_lens,
softmax_scale,
max_qlen,
K_QScale,
V_QScale,
output,
qo_indptr,
work_indptr,
work_info,
splitData,
splitLse,
mask,
high_precision,
kernelName,
quant_type,
)
return output
) -> torch.Tensor: ...


def pa_reduce_v1(
Expand Down Expand Up @@ -433,7 +342,6 @@ def pa_persistent_fwd(
V_QScale: Optional[torch.Tensor] = None, # [num_blocks, kv_heads, block_size]
softmax_scale: Optional[float] = None,
mask: int = 0,
quant_type: QuantType = QuantType.per_Token,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = Q.device
total_s, nhead, v_head_dim = output.shape
Expand Down Expand Up @@ -469,7 +377,6 @@ def pa_persistent_fwd(
logits,
splitLse,
mask,
quant_type=quant_type,
)
pa_reduce_v1(
logits,
Expand Down Expand Up @@ -652,7 +559,7 @@ def paged_attention_ragged(
MD_NAME = "module_mla_asm"


@compile_ops(MD_NAME, ffi_type="ctypes")
@compile_ops(MD_NAME)
def mla_decode_stage1_asm_fwd(
# [num_seqs, num_heads, head_size]
Q: torch.Tensor,
Expand Down Expand Up @@ -688,7 +595,7 @@ def mla_decode_stage1_asm_fwd(
) -> None: ...


@compile_ops(MD_NAME, ffi_type="ctypes")
@compile_ops(MD_NAME)
def mla_prefill_asm_fwd(
# [num_seqs, num_heads, head_size]
Q: torch.Tensor,
Expand Down Expand Up @@ -873,7 +780,7 @@ def get_ps_metadata_v1(
) -> None: ...


@compile_ops(MD_NAME, ffi_type="ctypes")
@compile_ops(MD_NAME)
def mla_prefill_ps_asm_fwd(
Q: torch.Tensor,
K: torch.Tensor,
Expand Down Expand Up @@ -916,30 +823,26 @@ def get_mla_metadata_info_v1(
6. Shape of reduce_partial_map followed by its scalar type.
"""

assert num_head_qo % 16 == 0
# Pad nhead < 16 to 16 for metadata sizing (mirrors C++ metadata logic)
effective_num_head = num_head_qo
if num_head_qo < 16 and num_head_qo > 0 and 16 % num_head_qo == 0:
effective_num_head = 16
assert effective_num_head % 16 == 0

gpu = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(gpu)
cu_num = device_properties.multi_processor_count

use_qseqlen_fold = (
get_gfx() == "gfx950"
and q_dtype == dtypes.fp8
and kv_dtype == dtypes.fp8
and num_head_qo > 16
and max_seqlen_qo * (num_head_qo // 16) == 4
)

max_qo_tiles_per_batch = (
int(math.ceil(max_seqlen_qo * num_head_qo / 128))
if num_head_qo == 16
int(math.ceil(max_seqlen_qo * effective_num_head / 128))
if effective_num_head == 16
or (
get_gfx() == "gfx942"
and num_head_qo == 128
and effective_num_head == 128
and kv_dtype == dtypes.fp8
and q_dtype == dtypes.fp8
)
or use_qseqlen_fold
else int(math.ceil(max_seqlen_qo * num_head_qo / 16))
else int(math.ceil(max_seqlen_qo * effective_num_head / 16))
)
batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size
tile_cnt = batch_size * max_qo_tiles_per_batch
Expand Down Expand Up @@ -1220,24 +1123,8 @@ def decode_update_mla_metadata_v1(
assert kv_granularity >= 16
assert page_size == 1
# assert not (dtype_q == dtypes.bf16 and dtype_kv == dtypes.bf16 and num_heads_per_head_k == 128), "In this case, use get_mla_metadata_v1 instead"
q_is_fp8 = dtype_q == dtypes.fp8
kv_is_fp8 = dtype_kv == dtypes.fp8
arch_id = get_gfx()
natively_supported = (
(num_heads_per_head_k == 16)
or (
arch_id == "gfx950"
and num_heads_per_head_k == 32
and q_is_fp8
and kv_is_fp8
and max_seqlen_qo == 4
)
or (
arch_id == "gfx942"
and num_heads_per_head_k == 128
and q_is_fp8
and kv_is_fp8
)
natively_supported = (num_heads_per_head_k == 16) or (
num_heads_per_head_k == 128 and dtype_q == dtypes.fp8 and dtype_kv == dtypes.fp8
)
cu_num = work_indptr.shape[0] - 1
tile_reduce_cnt = reduce_indptr.shape[0] - 1
Expand Down
Loading
Loading