From 48f1e3b80132750e925aedf33ed6bae55e4936e2 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 28 Oct 2025 10:04:14 +0000 Subject: [PATCH 1/4] fix aot --- aiter/aot/pa_v1.py | 110 +++++++++++++++++++++++++++++++++++++++++ csrc/cpp_itfs/utils.py | 13 +++-- 2 files changed, 118 insertions(+), 5 deletions(-) create mode 100644 aiter/aot/pa_v1.py diff --git a/aiter/aot/pa_v1.py b/aiter/aot/pa_v1.py new file mode 100644 index 0000000000..8f384d720f --- /dev/null +++ b/aiter/aot/pa_v1.py @@ -0,0 +1,110 @@ +from collections import namedtuple +import os +import concurrent.futures +from csrc.cpp_itfs.pa.pa_v1 import compile + +PAConfig = namedtuple( + "PAConfig", + [ + "gqa_ratio", + "head_size", + "npar_loops", + "dtype", + "kv_dtype", + "fp8_kv_dtype", + "out_dtype", + "block_size", + "alibi_enabled", + "logits_soft_cap_enabled", + ], +) + + +def process_config(config): + return compile( + config.gqa_ratio, + config.head_size, + config.npar_loops, + config.dtype, + config.kv_dtype, + config.fp8_kv_dtype, + config.out_dtype, + config.block_size, + config.alibi_enabled, + config.logits_soft_cap_enabled, + ) + + +def main(): + configs = [] + for gqa_ratio in range(1, 17): + for alibi_enabled in [False, True]: + for logits_soft_cap_enabled in [False, True]: + for block_size in [1, 16, 32]: + for npar_loops in range(1, 9): + for head_size in [64, 128]: + configs.append( + PAConfig( + gqa_ratio=gqa_ratio, + head_size=head_size, + npar_loops=npar_loops, + dtype="_Float16", + kv_dtype="_Float16", + fp8_kv_dtype="auto", + out_dtype="_Float16", + block_size=block_size, + alibi_enabled=alibi_enabled, + logits_soft_cap_enabled=logits_soft_cap_enabled, + ) + ) + configs.append( + PAConfig( + gqa_ratio=gqa_ratio, + head_size=head_size, + npar_loops=npar_loops, + dtype="__hip_bfloat16", + kv_dtype="__hip_bfloat16", + fp8_kv_dtype="auto", + out_dtype="__hip_bfloat16", + block_size=block_size, + alibi_enabled=alibi_enabled, + logits_soft_cap_enabled=logits_soft_cap_enabled, + ) + ) + configs.append( + PAConfig( + gqa_ratio=gqa_ratio, + head_size=head_size, + npar_loops=npar_loops, + dtype="_Float16", + kv_dtype="uint8_t", + fp8_kv_dtype="fp8", + out_dtype="_Float16", + block_size=block_size, + alibi_enabled=alibi_enabled, + logits_soft_cap_enabled=logits_soft_cap_enabled, + ) + ) + configs.append( + PAConfig( + gqa_ratio=gqa_ratio, + head_size=head_size, + npar_loops=npar_loops, + dtype="__hip_bfloat16", + kv_dtype="uint8_t", + fp8_kv_dtype="fp8", + out_dtype="__hip_bfloat16", + block_size=block_size, + alibi_enabled=alibi_enabled, + logits_soft_cap_enabled=logits_soft_cap_enabled, + ) + ) + + with concurrent.futures.ProcessPoolExecutor( + os.environ.get("MAX_JOBS", os.cpu_count()) + ) as executor: + executor.map(process_config, configs) + + +if __name__ == "__main__": + main() diff --git a/csrc/cpp_itfs/utils.py b/csrc/cpp_itfs/utils.py index 3df6657647..513e82acc2 100644 --- a/csrc/cpp_itfs/utils.py +++ b/csrc/cpp_itfs/utils.py @@ -142,7 +142,8 @@ def compile_lib(src_file, folder, includes=None, sources=None, cxxflags=None): start_ts = time.perf_counter() def main_func(includes=None, sources=None, cxxflags=None): - logger.info(f"start build {sub_build_dir}") + if AITER_LOG_MORE >= 2: + logger.info(f"start build {sub_build_dir}") if includes is None: includes = [] if sources is None: @@ -220,9 +221,10 @@ def main_func(includes=None, sources=None, cxxflags=None): ) def final_func(): - logger.info( - f"finish build {sub_build_dir}, cost {time.perf_counter()-start_ts:.8f}s" - ) + if AITER_LOG_MORE >= 2: + logger.info( + f"finish build {sub_build_dir}, cost {time.perf_counter()-start_ts:.8f}s" + ) main_func = partial( main_func, includes=includes, sources=sources, cxxflags=cxxflags @@ -276,8 +278,9 @@ def compile_template_op( sources = [] if cxxflags is None: cxxflags = [] + if AITER_LOG_MORE >= 2: + logger.info(f"compile_template_op {func_name = } with {locals()}...") src_file = src_template.render(func_name=func_name, **kwargs) - logger.info(f"compile_template_op {func_name = } with {locals()}...") compile_lib(src_file, folder, includes, sources, cxxflags) return run_lib(func_name, folder) From 3f8dbbdda0415101541a9051ef4f6d0ef0318eea Mon Sep 17 00:00:00 2001 From: root Date: Wed, 29 Oct 2025 03:05:39 +0000 Subject: [PATCH 2/4] remove other kernels path --- aiter/paged_attn.py | 176 +++++++++++--------------------------------- 1 file changed, 45 insertions(+), 131 deletions(-) diff --git a/aiter/paged_attn.py b/aiter/paged_attn.py index 6e07794cab..b471d4d5df 100644 --- a/aiter/paged_attn.py +++ b/aiter/paged_attn.py @@ -248,139 +248,53 @@ def forward_decode( # Whether to use rocm custom paged attention or not num_seqs, num_heads, head_size = query.shape block_size = key_cache.size(3) - gqa_ratio = num_heads // num_kv_heads - use_custom = _use_rocm_custom_paged_attention( - query.dtype, head_size, block_size, gqa_ratio, max_seq_len - ) output = torch.empty_like(query, dtype=output_dtype) - if use_custom: - max_num_partitions = ( - max_seq_len + _PARTITION_SIZE_ROCM - 1 - ) // _PARTITION_SIZE_ROCM - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=output.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=dtypes.fp32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - cpa_fp8_out = False - if fp8_out_scale is not None: - output = torch.empty_like(output, dtype=dtypes.fp8) - cpa_fp8_out = True - torch.ops.aiter.paged_attention_rocm( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - fp8_out_scale if cpa_fp8_out else None, - _PARTITION_SIZE_ROCM, - q_scale=q_scale, - mtp=mtp, - ) - if cpa_fp8_out: - return output.view(num_seqs, num_heads * head_size) - else: - max_num_partitions = (max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE - if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: - # use blocksparse paged attention - assert ( - blocksparse_block_size > 0 - and blocksparse_block_size % block_size == 0 - ), ( - f"{blocksparse_block_size=} needs to be a multiple of" - f"{block_size=} used in block_tables." - ) - - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - # TODO(woosuk): Tune this heuristic. - # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = max_seq_len <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512 - ) - if use_v1: - # Run PagedAttention V1. - ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - tp_rank, - blocksparse_local_blocks, - blocksparse_vert_stride, - blocksparse_block_size, - blocksparse_head_sliding_step, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=output.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=dtypes.fp32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - tp_rank, - blocksparse_local_blocks, - blocksparse_vert_stride, - blocksparse_block_size, - blocksparse_head_sliding_step, - ) + max_num_partitions = ( + max_seq_len + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=dtypes.fp32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + cpa_fp8_out = False + if fp8_out_scale is not None: + output = torch.empty_like(output, dtype=dtypes.fp8) + cpa_fp8_out = True + if scale is None: + scale = float(1.0 / (head_size**0.5)) + torch.ops.aiter.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + fp8_out_scale if cpa_fp8_out else None, + _PARTITION_SIZE_ROCM, + q_scale=q_scale, + mtp=mtp, + ) + if cpa_fp8_out: + return output.view(num_seqs, num_heads * head_size) return output # @staticmethod From 55512297747aff3cbb6f8565cfdb1a1540dd0830 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 30 Oct 2025 02:55:29 +0000 Subject: [PATCH 3/4] fix aot --- csrc/cpp_itfs/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpp_itfs/utils.py b/csrc/cpp_itfs/utils.py index 513e82acc2..226b7bafbc 100644 --- a/csrc/cpp_itfs/utils.py +++ b/csrc/cpp_itfs/utils.py @@ -217,7 +217,7 @@ def main_func(includes=None, sources=None, cxxflags=None): with open(f"{sub_build_dir}/Makefile", "w") as f: f.write(makefile_file) subprocess.run( - f"cd {sub_build_dir} && make build -j{len(sources)}", shell=True, check=True + f"cd {sub_build_dir} && make build -j{len(sources)}", shell=True, capture_output=AITER_LOG_MORE<2, check=True ) def final_func(): From 245b278690c21489f86359502744881ada2d162d Mon Sep 17 00:00:00 2001 From: root Date: Thu, 30 Oct 2025 03:35:07 +0000 Subject: [PATCH 4/4] format code --- csrc/cpp_itfs/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/csrc/cpp_itfs/utils.py b/csrc/cpp_itfs/utils.py index 226b7bafbc..d6c9acef59 100644 --- a/csrc/cpp_itfs/utils.py +++ b/csrc/cpp_itfs/utils.py @@ -217,7 +217,10 @@ def main_func(includes=None, sources=None, cxxflags=None): with open(f"{sub_build_dir}/Makefile", "w") as f: f.write(makefile_file) subprocess.run( - f"cd {sub_build_dir} && make build -j{len(sources)}", shell=True, capture_output=AITER_LOG_MORE<2, check=True + f"cd {sub_build_dir} && make build -j{len(sources)}", + shell=True, + capture_output=AITER_LOG_MORE < 2, + check=True, ) def final_func():