Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,116 @@ def pa_fwd_asm(
) -> torch.Tensor: ...


def _should_use_asm_kernel(
num_seqs: int,
num_heads: int,
kv_cache_tensor_dtype: torch.dtype,
) -> bool:

if kv_cache_tensor_dtype == torch.int8:
return True

# Get GPU compute units (CUs)
gpu = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(gpu)
cu_num = device_properties.multi_processor_count
# ASM kernel becomes relevant, once the total_heads is sufficiently large compared to CUs
total_heads = num_seqs * num_heads
return total_heads > 2 * cu_num


def paged_attention_common(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
exp_sums: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_tables_stride0: int,
scale: float,
max_qlen: int = 1,
max_seq_len: int = 1,
K_QScale_hip: Optional[torch.Tensor] = None, # [num_seqs, num_heads]
V_QScale_hip: Optional[torch.Tensor] = None,
K_QScale_asm: Optional[
torch.Tensor
] = None, # [num_blocks, num_kv_heads, block_size]
V_QScale_asm: Optional[torch.Tensor] = None,
out_: Optional[torch.Tensor] = None,
qo_indptr: Optional[torch.Tensor] = None,
high_precision: Optional[
int
] = 1, # [0, 1, 2] 2 is the highest precision, this is only for fp8 kvcache
kernelName: Optional[str] = None,
kv_cache_dtype: str = "auto",
kv_cache_tensor_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""
Paged attention forward pass with automatic kernel selection.
ASM is favored for int8 kv caches, for short ctx_len, or when the workload exceeds
the heuristic thresholds for larger ctx_len values.
PA is normally using per tensor quant and this is what has been tested, however,
per head quant can be supported as well in principle, but not tested.
"""
kv_cache_tensor_dtype = (
kv_cache_tensor_dtype if kv_cache_tensor_dtype is not None else K.dtype
)
num_seqs, num_heads, head_size = Q.shape

use_asm_kernel = (
_should_use_asm_kernel(num_seqs, num_heads, kv_cache_tensor_dtype)
or high_precision == 2
)

if use_asm_kernel:
output = pa_fwd_asm(
Q,
K,
V,
block_tables,
context_lens,
block_tables_stride0,
max_qlen,
K_QScale_asm,
V_QScale_asm,
out_,
qo_indptr,
high_precision,
kernelName,
)
return output

# Use ROCm paged attention kernel for smaller workloads / common path.
output = out_ if out_ is not None else torch.empty_like(Q)

paged_attention_rocm(
out=output,
exp_sums=exp_sums,
max_logits=max_logits,
tmp_out=tmp_out,
query=Q,
key_cache=K,
value_cache=V,
num_kv_heads=int(K.size(1)),
scale=scale,
block_tables=block_tables,
context_lens=context_lens,
block_size=int(K.size(3)),
max_context_len=max_seq_len,
alibi_slopes=None,
kv_cache_dtype=kv_cache_dtype,
k_scale=K_QScale_hip,
v_scale=V_QScale_hip,
fp8_out_scale=None,
partition_size=256,
mtp=1,
q_scale=None,
)
return output


def gen_pa_ps_fwd_asm(
Q: torch.Tensor,
K: torch.Tensor,
Expand Down
163 changes: 162 additions & 1 deletion op_tests/test_pa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
benchmark,
)
from aiter import pertoken_quant
from aiter.ops import attention

import argparse
import pandas as pd

Expand Down Expand Up @@ -404,6 +406,110 @@ def run_aiter_asm(
)


@perftest()
def run_aiter_common(
query,
k_cache,
v_cache,
block_tables,
seq_lens,
max_seq_len,
kv_cache_dtype,
num_kv_heads,
scale,
alibi_slopes,
block_tables_stride0,
# ROCm/HIP (scalar) scales
k_scale_hip=None,
v_scale_hip=None,
# ASM (expanded) scales
k_scale_asm=None,
v_scale_asm=None,
high_precision=0,
kv_cache_tensor_dtype=None,
):
"""
Test paged_attention_common which automatically switches between ASM and HIP kernels.
"""

num_seqs, num_heads, head_size = query.shape
# Client-side allocations required by ROCm paged attention path.
_PARTITION_SIZE_ROCM = 256
max_num_partitions = (
max_seq_len + _PARTITION_SIZE_ROCM - 1
) // _PARTITION_SIZE_ROCM
tmp_out = torch.empty(
(num_seqs, num_heads, max_num_partitions, head_size),
dtype=query.dtype,
device=query.device,
)
exp_sums = torch.empty(
(num_seqs, num_heads, max_num_partitions),
dtype=dtypes.fp32,
device=query.device,
)
max_logits = torch.empty_like(exp_sums)

def _normalize_scale(s):
if s is None:
return None
if isinstance(s, torch.Tensor):
return s.to(device=query.device, dtype=dtypes.fp32)
# python scalar
return torch.tensor(float(s), device=query.device, dtype=dtypes.fp32)

k_scale_hip_tensor = _normalize_scale(k_scale_hip)
v_scale_hip_tensor = _normalize_scale(v_scale_hip)
# ASM scales are already tensors in the expected layout; just ensure fp32 on device.
k_scale_asm_tensor = _normalize_scale(k_scale_asm)
v_scale_asm_tensor = _normalize_scale(v_scale_asm)

# Determine kv_cache_dtype string.
def _is_fp8_storage(dt: torch.dtype) -> bool:
if dt == torch.int8 or dt == torch.uint8:
return True
# torch float8 dtypes (guard for older torch builds)
for name in (
"float8_e4m3fnuz",
"float8_e4m3fn",
"float8_e5m2fnuz",
"float8_e5m2",
):
if hasattr(torch, name) and dt == getattr(torch, name):
return True
return False

cache_dt = (
kv_cache_tensor_dtype if kv_cache_tensor_dtype is not None else k_cache.dtype
)
kv_cache_dtype_str = "fp8" if _is_fp8_storage(cache_dt) else "auto"

return attention.paged_attention_common(
Q=query.contiguous(),
K=k_cache,
V=v_cache,
exp_sums=exp_sums,
max_logits=max_logits,
tmp_out=tmp_out,
block_tables=block_tables,
context_lens=seq_lens,
block_tables_stride0=block_tables_stride0,
scale=scale,
max_qlen=1,
max_seq_len=max_seq_len,
K_QScale_hip=k_scale_hip_tensor,
V_QScale_hip=v_scale_hip_tensor,
K_QScale_asm=k_scale_asm_tensor,
V_QScale_asm=v_scale_asm_tensor,
out_=None,
qo_indptr=None,
high_precision=high_precision,
kernelName=None,
kv_cache_dtype=kv_cache_dtype_str,
kv_cache_tensor_dtype=kv_cache_tensor_dtype,
)


def dump_input(
path,
query: torch.Tensor,
Expand Down Expand Up @@ -587,6 +693,32 @@ def test_paged_attention(
)
# tensor_dump(out_aiter, 'out_aiter')

# Test paged_attention_common which automatically switches between ASM and HIP
# The routing is internal, so we just test the common API regardless of which path it takes
time_aiter_common = None
if dtype == dtypes.bf16:
try:
out_aiter_common, time_aiter_common = run_aiter_common(
query.contiguous(),
k_cache,
asm_V_shuffle(v_cache), # Shuffle V cache, same as run_aiter_asm
block_tables,
seq_lens,
max_seq_len,
kv_cache_dtype,
num_kv_heads,
scale,
alibi_slopes,
block_tables.stride(0),
)
checkAllclose(
out_golden,
out_aiter_common,
msg=f"golden vs aiter_common:{time_aiter_common:>8.2f} us......",
)
except Exception as e:
print(f"Warning: Could not test aiter_common: {e}")

for quant_algo_, cache_type_ in [
(0, k_cache.dtype),
(2, dtypes.fp8),
Expand Down Expand Up @@ -705,6 +837,31 @@ def test_paged_attention(
msg=f"golden vs aiter_asm:{time_aiter_asm:>8.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})",
)

if quant_algo_ == 4:
# Test paged_attention_common with quantized cache
out_aiter_common, time_aiter_common = run_aiter_common(
query.contiguous(),
k_quant_,
asm_V_shuffle(v_quant_),
block_tables,
seq_lens,
max_seq_len,
kv_cache_dtype,
num_kv_heads,
scale,
alibi_slopes,
block_tables.stride(0),
k_scale_hip=k_scale_,
v_scale_hip=v_scale_,
k_scale_asm=k_scale_asm,
v_scale_asm=v_scale_asm,
)
checkAllclose(
out_golden,
out_aiter_common,
msg=f"golden vs aiter_common:{time_aiter_common:>8.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})",
)

if (
dtype in [dtypes.bf16, dtypes.fp16]
and quant_algo_ == 2
Expand Down Expand Up @@ -809,7 +966,11 @@ def test_paged_attention(
print(
f"finish~ {ctx_lens=}, {num_seqs=}, {num_heads=}, {head_size=}, {use_alibi=}, {block_size=}, {dtype=}, {kv_cache_dtype=}\n"
)
return {"aiter_shomy": time_aiter, "aiter_asm": time_aiter_asm}
return {
"aiter_shomy": time_aiter,
"aiter_asm": time_aiter_asm,
"aiter_common": time_aiter_common,
}


df = []
Expand Down