Skip to content
Merged

fix aot #1279

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions aiter/aot/pa_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from collections import namedtuple
import os
import concurrent.futures
from csrc.cpp_itfs.pa.pa_v1 import compile

PAConfig = namedtuple(
"PAConfig",
[
"gqa_ratio",
"head_size",
"npar_loops",
"dtype",
"kv_dtype",
"fp8_kv_dtype",
"out_dtype",
"block_size",
"alibi_enabled",
"logits_soft_cap_enabled",
],
)


def process_config(config):
return compile(
config.gqa_ratio,
config.head_size,
config.npar_loops,
config.dtype,
config.kv_dtype,
config.fp8_kv_dtype,
config.out_dtype,
config.block_size,
config.alibi_enabled,
config.logits_soft_cap_enabled,
)


def main():
configs = []
for gqa_ratio in range(1, 17):
for alibi_enabled in [False, True]:
for logits_soft_cap_enabled in [False, True]:
for block_size in [1, 16, 32]:
for npar_loops in range(1, 9):
for head_size in [64, 128]:
configs.append(
PAConfig(
gqa_ratio=gqa_ratio,
head_size=head_size,
npar_loops=npar_loops,
dtype="_Float16",
kv_dtype="_Float16",
fp8_kv_dtype="auto",
out_dtype="_Float16",
block_size=block_size,
alibi_enabled=alibi_enabled,
logits_soft_cap_enabled=logits_soft_cap_enabled,
)
)
configs.append(
PAConfig(
gqa_ratio=gqa_ratio,
head_size=head_size,
npar_loops=npar_loops,
dtype="__hip_bfloat16",
kv_dtype="__hip_bfloat16",
fp8_kv_dtype="auto",
out_dtype="__hip_bfloat16",
block_size=block_size,
alibi_enabled=alibi_enabled,
logits_soft_cap_enabled=logits_soft_cap_enabled,
)
)
configs.append(
PAConfig(
gqa_ratio=gqa_ratio,
head_size=head_size,
npar_loops=npar_loops,
dtype="_Float16",
kv_dtype="uint8_t",
fp8_kv_dtype="fp8",
out_dtype="_Float16",
block_size=block_size,
alibi_enabled=alibi_enabled,
logits_soft_cap_enabled=logits_soft_cap_enabled,
)
)
configs.append(
PAConfig(
gqa_ratio=gqa_ratio,
head_size=head_size,
npar_loops=npar_loops,
dtype="__hip_bfloat16",
kv_dtype="uint8_t",
fp8_kv_dtype="fp8",
out_dtype="__hip_bfloat16",
block_size=block_size,
alibi_enabled=alibi_enabled,
logits_soft_cap_enabled=logits_soft_cap_enabled,
)
)

with concurrent.futures.ProcessPoolExecutor(
os.environ.get("MAX_JOBS", os.cpu_count())
) as executor:
executor.map(process_config, configs)


if __name__ == "__main__":
main()
176 changes: 45 additions & 131 deletions aiter/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,139 +248,53 @@ def forward_decode(
# Whether to use rocm custom paged attention or not
num_seqs, num_heads, head_size = query.shape
block_size = key_cache.size(3)
gqa_ratio = num_heads // num_kv_heads
use_custom = _use_rocm_custom_paged_attention(
query.dtype, head_size, block_size, gqa_ratio, max_seq_len
)
output = torch.empty_like(query, dtype=output_dtype)
if use_custom:
max_num_partitions = (
max_seq_len + _PARTITION_SIZE_ROCM - 1
) // _PARTITION_SIZE_ROCM
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=dtypes.fp32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
cpa_fp8_out = False
if fp8_out_scale is not None:
output = torch.empty_like(output, dtype=dtypes.fp8)
cpa_fp8_out = True
torch.ops.aiter.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
fp8_out_scale if cpa_fp8_out else None,
_PARTITION_SIZE_ROCM,
q_scale=q_scale,
mtp=mtp,
)
if cpa_fp8_out:
return output.view(num_seqs, num_heads * head_size)
else:
max_num_partitions = (max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
assert (
blocksparse_block_size > 0
and blocksparse_block_size % block_size == 0
), (
f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables."
)

# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = max_seq_len <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512
)

if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=dtypes.fp32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
max_num_partitions = (
max_seq_len + _PARTITION_SIZE_ROCM - 1
) // _PARTITION_SIZE_ROCM
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=dtypes.fp32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
cpa_fp8_out = False
if fp8_out_scale is not None:
output = torch.empty_like(output, dtype=dtypes.fp8)
cpa_fp8_out = True
if scale is None:
scale = float(1.0 / (head_size**0.5))
torch.ops.aiter.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
fp8_out_scale if cpa_fp8_out else None,
_PARTITION_SIZE_ROCM,
q_scale=q_scale,
mtp=mtp,
)
if cpa_fp8_out:
return output.view(num_seqs, num_heads * head_size)
return output

# @staticmethod
Expand Down
18 changes: 12 additions & 6 deletions csrc/cpp_itfs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def compile_lib(src_file, folder, includes=None, sources=None, cxxflags=None):
start_ts = time.perf_counter()

def main_func(includes=None, sources=None, cxxflags=None):
logger.info(f"start build {sub_build_dir}")
if AITER_LOG_MORE >= 2:
logger.info(f"start build {sub_build_dir}")
if includes is None:
includes = []
if sources is None:
Expand Down Expand Up @@ -216,13 +217,17 @@ def main_func(includes=None, sources=None, cxxflags=None):
with open(f"{sub_build_dir}/Makefile", "w") as f:
f.write(makefile_file)
subprocess.run(
f"cd {sub_build_dir} && make build -j{len(sources)}", shell=True, check=True
f"cd {sub_build_dir} && make build -j{len(sources)}",
shell=True,
capture_output=AITER_LOG_MORE < 2,
check=True,
)

def final_func():
logger.info(
f"finish build {sub_build_dir}, cost {time.perf_counter()-start_ts:.8f}s"
)
if AITER_LOG_MORE >= 2:
logger.info(
f"finish build {sub_build_dir}, cost {time.perf_counter()-start_ts:.8f}s"
)

main_func = partial(
main_func, includes=includes, sources=sources, cxxflags=cxxflags
Expand Down Expand Up @@ -276,8 +281,9 @@ def compile_template_op(
sources = []
if cxxflags is None:
cxxflags = []
if AITER_LOG_MORE >= 2:
logger.info(f"compile_template_op {func_name = } with {locals()}...")
src_file = src_template.render(func_name=func_name, **kwargs)
logger.info(f"compile_template_op {func_name = } with {locals()}...")
compile_lib(src_file, folder, includes, sources, cxxflags)
return run_lib(func_name, folder)

Expand Down