From 763f4d7e13853e96719d8cad04a43f02f3f0215d Mon Sep 17 00:00:00 2001 From: Sergey Solo Date: Tue, 16 Dec 2025 14:13:15 +0000 Subject: [PATCH 01/11] Implement a new api that will be switching between asm and hip pa Inference engines should be calling paged_attention_common now with shuffled kv cache layout and aiter internally will decide between asm or hip kernel. HIP is more performant for lower concurrencies ( < 128). Also a unit test has been updated to include the new interface. Note that support for the shuffled scales in HIP is not supported and is always redirected to asm now when KV cache is in int8 or fp8 formats. --- aiter/ops/attention.py | 74 ++++++ csrc/cpp_itfs/pa/pa_kernels.cuh | 158 ++++++++----- csrc/cpp_itfs/pa/pa_v1.cpp.jinja | 3 +- csrc/cpp_itfs/pa/pa_v1.cuh | 19 +- csrc/cpp_itfs/pa/pa_v1.py | 56 ++++- op_tests/README_pa_merged_tests.md | 85 +++++++ op_tests/test_pa.py | 139 +++++++++++- op_tests/test_pa_merged.py | 350 +++++++++++++++++++++++++++++ 8 files changed, 805 insertions(+), 79 deletions(-) create mode 100644 op_tests/README_pa_merged_tests.md create mode 100644 op_tests/test_pa_merged.py diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index 67ba14f5a4..fb287b313e 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -122,6 +122,80 @@ def pa_fwd_asm( ) -> torch.Tensor: ... +def _should_use_asm_kernel( + num_seqs: int, + num_heads: int, + kv_cache_tensor_dtype: torch.dtype, +) -> bool: + #TODO: HIP kernel yet isn't supporting fp8 scales in asm layout. + if kv_cache_tensor_dtype == torch.int8 or kv_cache_tensor_dtype == torch.float8_e4m3fnuz: + return True + + # Get GPU compute units (CUs) + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + # ASM kernel becomes relevant, once the total_heads is sufficiently large compared to CUs + total_heads = num_seqs * num_heads + return total_heads > 2 * cu_num + + +def paged_attention_common( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + workspace_buffer: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_tables_stride0: int, + logits_soft_cap: float, + scale: float, + max_qlen: int = 1, + max_seq_len: int = 1, + cu_query_lens: Optional[torch.Tensor] = None, + K_QScale: Optional[torch.Tensor] = None, + V_QScale: Optional[torch.Tensor] = None, + out_: Optional[torch.Tensor] = None, + qo_indptr: Optional[torch.Tensor] = None, + high_precision: Optional[ + int + ] = 1, # [0, 1, 2] 2 is the highest precision, this is only for fp8 kvcache + kernelName: Optional[str] = None, + kv_cache_dtype: str = "auto", + kv_cache_tensor_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + Paged attention forward pass with automatic kernel selection. + ASM is favored for int8 kv caches, for short ctx_len, or when the workload exceeds + the heuristic thresholds for larger ctx_len values. + """ + kv_cache_tensor_dtype = kv_cache_tensor_dtype if kv_cache_tensor_dtype is not None else K.dtype + num_seqs, num_heads, head_size = Q.shape + + # Route to ASM kernel based on the heuristic above. + use_asm_kernel = _should_use_asm_kernel( + num_seqs, num_heads, kv_cache_tensor_dtype + ) + + if use_asm_kernel: + output = pa_fwd_asm( + Q, K, V, block_tables, context_lens, block_tables_stride0, + max_qlen, K_QScale, V_QScale, out_, qo_indptr, high_precision, kernelName + ) + return output + + # Use HIP kernel for smaller workloads (5D V cache) + output = out_ if out_ is not None else torch.empty_like(Q) + paged_attention_v1( + output, workspace_buffer, Q, K, V, scale, + block_tables, cu_query_lens, context_lens, max_seq_len, + None, # alibi_slopes + kv_cache_dtype, "HND", logits_soft_cap, + K_QScale, V_QScale, + ) + return output + + def gen_pa_ps_fwd_asm( Q: torch.Tensor, K: torch.Tensor, diff --git a/csrc/cpp_itfs/pa/pa_kernels.cuh b/csrc/cpp_itfs/pa/pa_kernels.cuh index 44959d0d05..420884c6a8 100644 --- a/csrc/cpp_itfs/pa/pa_kernels.cuh +++ b/csrc/cpp_itfs/pa/pa_kernels.cuh @@ -12,7 +12,8 @@ template + bool SLIDING_WINDOW_ENABLED, + bool USE_5D_VCACHE = false> __inline__ __device__ void _paged_attention_kernel(const int* block_table_seq, const int64_t query_loc, @@ -221,10 +222,20 @@ _paged_attention_kernel(const int* block_table_seq, for(int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) { const int head_elem = - row_head_elem + qkhe_depth * QKHE_PER_FETCH + head_loop * HEAD_SIZE_PER_LOOP; - const int offset1 = head_elem / KX; - const int offset2 = head_elem % KX; - const cache_t* k_fetch_ptr = k_ptr3 + offset1 * KX + offset2; + row_head_elem + qkhe_depth * QKHE_PER_FETCH + head_loop * HEAD_SIZE_PER_LOOP; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = [&] { + if constexpr(USE_5D_VCACHE) + { + const int head_stride = BLOCK_SIZE * kv_seq_stride; + return k_ptr3 + offset1 * head_stride + offset2; + } + else + { + return k_ptr3 + offset1 * KX + offset2; + } + }(); const _B16x8* k_fetch_ptr_16B = reinterpret_cast(k_fetch_ptr); if constexpr(NT_KV_LOAD) { @@ -291,27 +302,62 @@ _paged_attention_kernel(const int* block_table_seq, static_assert(VBLOCKS_PER_LANE == VTLANELOOP, "make sure we can keep un-shuffled data in Vlocal as well"); - const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + - ((threadIdx.x / n_thread_per_block) % BLOCK_SIZE) * kv_seq_stride; + constexpr int V_X = CONTIGUOUS_KV_ELEMS_16B_LOAD; - // v fetches are 16head elems across lanes x 16 tokens per lane - for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) + if constexpr(USE_5D_VCACHE) { - for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) + constexpr int V5D_TOKEN_GRP_STRIDE = HEAD_SIZE * V_X; + constexpr int V5D_HEAD_STRIDE = V_X; + + for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) + for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - const int vlds_col_idx = laneid % n_thread_per_block; - const int vhead_elem = - vhe_depth * NWARPS * 16 + vlds_col_idx * CONTIGUOUS_KV_ELEMS_16B_LOAD; - const cache_t* v_ptr2 = v_ptr + vhead_elem; - - const int64_t vblock_number = - static_cast(vphysical_block_number[vtoken_depth][vblock_depth]); - const cache_t* v_fetch_ptr = v_ptr2 + (vblock_number * kv_block_stride); + for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) + { + const int vlocal_token_idx = rowid * VTOKENS_PER_LANE + vfetch_depth * V_X; + const int global_token_idx = partition_start_token_idx + + vtoken_depth * TOKENS_PER_WARP + + vlocal_token_idx; + const int block_idx = global_token_idx / BLOCK_SIZE; + const int token_in_block = global_token_idx % BLOCK_SIZE; + const int token_grp = token_in_block / V_X; + const int safe_block_idx = + (global_token_idx < context_len) ? block_idx : last_ctx_block; + const int physical_block = block_table_seq[safe_block_idx]; + const int head_elem = (warpid * 16 + lane16id) + vhe_depth * NWARPS * 16; + const int64_t v_offset = static_cast(physical_block) * kv_block_stride + + wg_start_kv_head_idx * kv_head_stride + + token_grp * V5D_TOKEN_GRP_STRIDE + + head_elem * V5D_HEAD_STRIDE; + const cache_t* v_fetch_ptr = v_cache + v_offset; + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = + *reinterpret_cast(v_fetch_ptr); + } + } + } + } + else + { + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + + ((threadIdx.x / n_thread_per_block) % BLOCK_SIZE) * kv_seq_stride; - Vlocal[vtoken_depth][vhe_depth][vblock_depth] = - *reinterpret_cast(v_fetch_ptr); + for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) + { + for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) + { + for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) + { + const int vlds_col_idx = laneid % n_thread_per_block; + const int vhead_elem = + vhe_depth * NWARPS * 16 + vlds_col_idx * V_X; + const cache_t* v_ptr2 = v_ptr + vhead_elem; + const int64_t vblock_number = + static_cast(vphysical_block_number[vtoken_depth][vblock_depth]); + const cache_t* v_fetch_ptr = v_ptr2 + (vblock_number * kv_block_stride); + Vlocal[vtoken_depth][vhe_depth][vblock_depth] = + *reinterpret_cast(v_fetch_ptr); + } } } } @@ -675,54 +721,52 @@ _paged_attention_kernel(const int* block_table_seq, constexpr int ELEMS8_ELEMS4_RATIO = 8 / 4; constexpr int ELEMS16_ELEMS8_RATIO = 16 / 8; - for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) + if constexpr(!USE_5D_VCACHE) { - for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) + for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - // 1. store data into LDS - for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) + for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - const int vlds_col_idx = laneid % n_thread_per_block; - const int vlocal_token_idx = - vblock_depth * k_thread_per_block + threadIdx.x / n_thread_per_block; - *reinterpret_cast<_B16x8*>(vlds_ptr + - (/*row=*/vlocal_token_idx * n_thread_per_block + - /*col=*/vlds_col_idx) * - 16) = Vlocal[vtoken_depth][vhe_depth][vblock_depth]; - } - __syncthreads(); + for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) + { + const int vlds_col_idx = laneid % n_thread_per_block; + const int vlocal_token_idx = + vblock_depth * k_thread_per_block + threadIdx.x / n_thread_per_block; + *reinterpret_cast<_B16x8*>(vlds_ptr + + (/*row=*/vlocal_token_idx * n_thread_per_block + + /*col=*/vlds_col_idx) * + 16) = Vlocal[vtoken_depth][vhe_depth][vblock_depth]; + } + __syncthreads(); - // 2. load data from LDS (transposed), then do multification - for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) - { - const int vlocal_head_elem = warpid * 16 + lane16id; + for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) + { + const int vlocal_head_elem = warpid * 16 + lane16id; - const int vlds_col_idx = vlocal_head_elem / CONTIGUOUS_KV_ELEMS_16B_LOAD; - const int vlds_elem_idx = vlocal_head_elem % CONTIGUOUS_KV_ELEMS_16B_LOAD; + const int vlds_col_idx = vlocal_head_elem / CONTIGUOUS_KV_ELEMS_16B_LOAD; + const int vlds_elem_idx = vlocal_head_elem % CONTIGUOUS_KV_ELEMS_16B_LOAD; - const int vlocal_token_idx = - rowid * VTOKENS_PER_LANE + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const int vlocal_token_idx = + rowid * VTOKENS_PER_LANE + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; - // read data points individually and save them into array - cache_t elems[CONTIGUOUS_KV_ELEMS_16B_LOAD]; - for(int d2 = 0; d2 < CONTIGUOUS_KV_ELEMS_16B_LOAD; ++d2) - { - const cache_t* fetched_elems = reinterpret_cast( - vlds_ptr + (/*row=*/(vlocal_token_idx + d2) * n_thread_per_block + - /*col=*/vlds_col_idx) * - 16); + cache_t elems[CONTIGUOUS_KV_ELEMS_16B_LOAD]; + for(int d2 = 0; d2 < CONTIGUOUS_KV_ELEMS_16B_LOAD; ++d2) + { + const cache_t* fetched_elems = reinterpret_cast( + vlds_ptr + (/*row=*/(vlocal_token_idx + d2) * n_thread_per_block + + /*col=*/vlds_col_idx) * + 16); + elems[d2] = fetched_elems[vlds_elem_idx]; + } - elems[d2] = fetched_elems[vlds_elem_idx]; + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = + *reinterpret_cast(elems); } - - // copy all the read data points together - Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = - *reinterpret_cast(elems); + __syncthreads(); } - __syncthreads(); } } - + // For 5D, Vlocal is already in the correct format from the load phase _B16x4 outelems[GQA_RATIO_LOOP][MTP_PER_THREAD][VHELOOP]; // Softmax V mfma diff --git a/csrc/cpp_itfs/pa/pa_v1.cpp.jinja b/csrc/cpp_itfs/pa/pa_v1.cpp.jinja index 96d53c61d9..1803c13116 100644 --- a/csrc/cpp_itfs/pa/pa_v1.cpp.jinja +++ b/csrc/cpp_itfs/pa/pa_v1.cpp.jinja @@ -90,7 +90,8 @@ void {{func_name}}(void* out_ptr, gqa_ratio, {{mtp}}, decltype(variant), - {{"true" if sliding_window_enabled else "false"}}> + {{"true" if sliding_window_enabled else "false"}}, + {{"true" if use_5d_vcache else "false"}}> <<(stream)>>>(reinterpret_cast<{{dtype}}*>(query_ptr), reinterpret_cast<{{kv_dtype}}*>(key_cache_ptr), reinterpret_cast<{{kv_dtype}}*>(value_cache_ptr), diff --git a/csrc/cpp_itfs/pa/pa_v1.cuh b/csrc/cpp_itfs/pa/pa_v1.cuh index d00308ee16..af6c988781 100644 --- a/csrc/cpp_itfs/pa/pa_v1.cuh +++ b/csrc/cpp_itfs/pa/pa_v1.cuh @@ -35,13 +35,13 @@ template + bool SLIDING_WINDOW_ENABLED, + bool USE_5D_VCACHE = false> __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, block_size, num_kv_heads, - // head_size] - const cache_t* __restrict__ v_cache, // [num_blocks, block_size, num_kv_heads, - // head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // 4D: [num_blocks, num_kv_heads, head_size, block_size] + // 5D: [num_blocks, num_kv_heads, block_size/x, head_size, x] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ cu_query_lens, // [num_seqs+1] @@ -84,7 +84,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ return; } const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; - _paged_attention_kernel(block_table_seq, static_cast(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, sliding_window); + _paged_attention_kernel(block_table_seq, static_cast(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, sliding_window); } // Grid: (num_heads, num_seqs). @@ -136,13 +136,14 @@ template + bool SLIDING_WINDOW_ENABLED, + bool USE_5D_VCACHE = false> __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + const cache_t* __restrict__ v_cache, // 4D: [num_blocks, num_kv_heads, head_size, block_size] + // 5D: [num_blocks, num_kv_heads, block_size/x, head_size, x] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ cu_query_lens, // [num_seqs+1] diff --git a/csrc/cpp_itfs/pa/pa_v1.py b/csrc/cpp_itfs/pa/pa_v1.py index cad347d8e9..3c56dc2a7a 100644 --- a/csrc/cpp_itfs/pa/pa_v1.py +++ b/csrc/cpp_itfs/pa/pa_v1.py @@ -23,6 +23,7 @@ def compile( partition_size: int = 256, mtp: int = 1, sliding_window_enabled: bool = False, + use_5d_vcache: bool = False, folder: str = None, ): return compile_template_op( @@ -49,6 +50,7 @@ def compile( partition_size=partition_size, mtp=mtp, sliding_window_enabled=sliding_window_enabled, + use_5d_vcache=use_5d_vcache, folder=folder, ) @@ -108,20 +110,38 @@ def paged_attention_v1( else: raise ValueError(f"Unsupported data type: {out.dtype}") - num_kv_heads = key_cache.size(1) if kv_cache_layout == "HND" else key_cache.size(2) + # Handle both 4D and 5D layouts (key and value cache always have same layout) + is_5d_cache = value_cache.dim() == 5 + + if is_5d_cache: + # 5D V cache layout: [num_blocks, num_kv_heads, block_size/x, head_size, x] + # K cache is kept in ASM layout: [num_blocks, num_kv_heads, head_size/x, block_size, x] + # Pass raw strides so the HIP kernel can consume ASM layout directly. + num_kv_heads = value_cache.size(1) + x = value_cache.size(4) # e.g., 8 for bf16, 16 for fp8 + block_size = value_cache.size(2) * x # block_size/x * x = block_size + + # K uses ASM layout strides; V uses its own 5D strides. + kv_block_stride = value_cache.stride(0) # stride over blocks for V cache + kv_head_stride = value_cache.stride(1) # stride over heads for V cache + kv_seq_stride = key_cache.stride(3) # stride of block_size dimension in ASM K layout + else: + # 4D layout: [num_blocks, num_heads, block_size, head_size] or [num_blocks, block_size, num_heads, head_size] + num_kv_heads = key_cache.size(1) if kv_cache_layout == "HND" else key_cache.size(2) + block_size = key_cache.size(2) if kv_cache_layout == "HND" else key_cache.size(1) + kv_block_stride = key_cache.stride(0) + kv_head_stride = ( + key_cache.stride(1) if kv_cache_layout == "HND" else key_cache.stride(2) + ) + kv_seq_stride = ( + key_cache.stride(2) if kv_cache_layout == "HND" else key_cache.stride(1) + ) + num_seqs = block_tables.size(0) num_heads = query.size(1) head_size = query.size(2) q_stride = query.stride(0) - block_size = key_cache.size(2) if kv_cache_layout == "HND" else key_cache.size(1) max_num_blocks_per_seq = block_tables.size(1) - kv_block_stride = key_cache.stride(0) - kv_head_stride = ( - key_cache.stride(1) if kv_cache_layout == "HND" else key_cache.stride(2) - ) - kv_seq_stride = ( - key_cache.stride(2) if kv_cache_layout == "HND" else key_cache.stride(1) - ) gqa_ratio = int(num_heads / num_kv_heads) max_num_partitions = int(math.ceil(max_context_len / partition_size)) npar_loops = int(math.ceil(max_num_partitions / warpSize)) @@ -142,6 +162,7 @@ def paged_attention_v1( partition_size, mtp, sliding_window_enabled=sliding_window_enabled, + use_5d_vcache=is_5d_cache, ) alibi_slopes_ptr = ( @@ -210,6 +231,19 @@ def paged_attention_v1( if q_scale is not None else ctypes.POINTER(ctypes.c_float)() ) + + k_scale_ptr = ( + ctypes.cast(k_scale.data_ptr(), ctypes.POINTER(ctypes.c_float)) + if k_scale is not None + else ctypes.POINTER(ctypes.c_float)() + ) + v_scale_ptr = ( + ctypes.cast(v_scale.data_ptr(), ctypes.POINTER(ctypes.c_float)) + if v_scale is not None + else ctypes.POINTER(ctypes.c_float)() + ) + + func( out_ptr, workspace_buffer_ptr, @@ -221,8 +255,8 @@ def paged_attention_v1( context_lens_ptr, alibi_slopes_ptr, q_scale_ptr, - ctypes.cast(k_scale.data_ptr(), ctypes.POINTER(ctypes.c_float)), - ctypes.cast(v_scale.data_ptr(), ctypes.POINTER(ctypes.c_float)), + k_scale_ptr, + v_scale_ptr, fp8_out_scale_ptr, scale, max_num_blocks_per_seq, diff --git a/op_tests/README_pa_merged_tests.md b/op_tests/README_pa_merged_tests.md new file mode 100644 index 0000000000..d244021854 --- /dev/null +++ b/op_tests/README_pa_merged_tests.md @@ -0,0 +1,85 @@ +# Paged Attention HIP vs ASM comparison tests + +This repo contains multiple PA (Paged Attention) implementations. The file +`op_tests/test_pa_merged.py` is a **focused** test module whose only goal is to: + +- **Compare correctness**: HIP (`paged_attention_v1_core`) vs ASM (`pa_fwd_asm`) +- **Use ASM-compatible KV-cache layouts for BOTH paths** + - HIP is exercised through the **5D-cache** codepath that can consume these layouts + - ASM is exercised directly via `pa_fwd_asm` on the same layouts +- Optionally **measure performance** using AITER’s standard `@perftest()` harness + +The tests are **opt-in** (they skip by default) to avoid running long GPU workloads in CI. + +## What’s inside `op_tests/test_pa_merged.py` + +- **Repro test (`-k repro`)** + - Uses a fixed set of shapes (from an EngineCore log) + - Compares `paged_attention_v1_core` vs `pa_fwd_asm` + - Optional perf: runs `@perftest(num_iters=200)` for HIP and ASM and prints a winner line + +- **Stress test (`-k stress`)** + - Randomly generates ASM-compatible KV-cache inputs + - Compares `paged_attention_v1_core` vs `pa_fwd_asm` + +## Environment variables (“defines”) you may set + +### Required (choose one, otherwise tests skip) + +- **`AITER_RUN_REPRO_SHAPES=1`** + - Enables the repro test (fixed shapes) +- **`AITER_RUN_STRESS=1`** + - Enables the stress test (random trials) + +### Optional knobs + +- **`AITER_REPRO_PERF=1`** (default in the test is `"1"`) + - If enabled, the repro test runs perf for HIP and ASM and prints: + - `HIP(paged_attention_v1_core) avg_us/iter=...` + - `ASM(pa_fwd_asm) avg_us/iter=...` + - `winner=...` + - Set **`AITER_REPRO_PERF=0`** to run correctness only. + +- **`AITER_STRESS_TRIALS=N`** (default: `25`) + - Number of randomized stress trials. + +- **`AITER_LOG_MORE=1`** + - Makes AITER logging more verbose and enables profiler table printing inside `@perftest()`. + - Under pytest you often want `-s` (or `-o log_cli=true`) to see logs. + +## How to run + +Run from repo root. + +### Repro (fixed shapes) — correctness only + +```bash +AITER_RUN_REPRO_SHAPES=1 AITER_REPRO_PERF=0 pytest -q op_tests/test_pa_merged.py -k repro +``` + +### Repro (fixed shapes) — correctness + perf (prints who is faster) + +```bash +AITER_RUN_REPRO_SHAPES=1 AITER_REPRO_PERF=1 pytest -q op_tests/test_pa_merged.py -k repro -s +``` + +### Stress (random ASM-layout inputs) — 1 quick trial + +```bash +AITER_RUN_STRESS=1 AITER_STRESS_TRIALS=1 pytest -q op_tests/test_pa_merged.py -k stress +``` + +### Stress (random ASM-layout inputs) — more trials + +```bash +AITER_RUN_STRESS=1 AITER_STRESS_TRIALS=25 pytest -q op_tests/test_pa_merged.py -k stress +``` + +## Notes + +- If you see `ss` in pytest output, that means **both tests were skipped** because you didn’t set + `AITER_RUN_REPRO_SHAPES=1` and/or `AITER_RUN_STRESS=1`. +- `@perftest()` uses `torch.profiler` to attribute time to kernels and reports an average at the end. +- The ASM path (`pa_fwd_asm`) has tighter supported-config constraints (e.g., typically `bf16`, `head_size=128`, + `block_size=16`); this module intentionally sticks to those constraints. + diff --git a/op_tests/test_pa.py b/op_tests/test_pa.py index c9ae6890db..95809bb81a 100644 --- a/op_tests/test_pa.py +++ b/op_tests/test_pa.py @@ -15,6 +15,8 @@ benchmark, ) from aiter import pertoken_quant +from aiter.ops import attention + import argparse import pandas as pd @@ -404,6 +406,89 @@ def run_aiter_asm( ) +@perftest() +def run_aiter_common( + query, + k_cache, + v_cache, + block_tables, + seq_lens, + max_seq_len, + kv_cache_dtype, + num_kv_heads, + scale, + alibi_slopes, + block_tables_stride0, + k_scale=None, + v_scale=None, + high_precision=0, + kv_cache_tensor_dtype=None, +): + """ + Test paged_attention_common which automatically switches between ASM and HIP kernels. + """ + + num_seqs, num_heads, head_size = query.shape + # Create workspace buffer for HIP kernel path + # Workspace buffer size calculation from test_pa_v1.py + _PARTITION_SIZE_ROCM = 256 + max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM - 1) // _PARTITION_SIZE_ROCM + nbyes_per_qo_elem = torch.finfo(query.dtype).bits // 8 + workspace_buffer = torch.empty( + (num_seqs * num_heads * max_num_partitions * head_size) * nbyes_per_qo_elem + + 2 * (num_seqs * num_heads * max_num_partitions) * 4, + dtype=torch.uint8, + device=query.device, + ) + + def _normalize_scale(s): + if s is None: + return None + if isinstance(s, torch.Tensor): + return s.to(device=query.device, dtype=dtypes.fp32) + # python scalar + return torch.tensor(float(s), device=query.device, dtype=dtypes.fp32) + + k_scale_tensor = _normalize_scale(k_scale) + v_scale_tensor = _normalize_scale(v_scale) + + # Determine kv_cache_dtype string. + def _is_fp8_storage(dt: torch.dtype) -> bool: + if dt == torch.int8 or dt == torch.uint8: + return True + # torch float8 dtypes (guard for older torch builds) + for name in ("float8_e4m3fnuz", "float8_e4m3fn", "float8_e5m2fnuz", "float8_e5m2"): + if hasattr(torch, name) and dt == getattr(torch, name): + return True + return False + + cache_dt = kv_cache_tensor_dtype if kv_cache_tensor_dtype is not None else k_cache.dtype + kv_cache_dtype_str = "fp8" if _is_fp8_storage(cache_dt) else "auto" + + return attention.paged_attention_common( + Q=query.contiguous(), + K=k_cache, + V=v_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + context_lens=seq_lens, + block_tables_stride0=block_tables_stride0, + logits_soft_cap=0.0, + scale=scale, + max_qlen=1, + max_seq_len=max_seq_len, + cu_query_lens=None, + K_QScale=k_scale_tensor, + V_QScale=v_scale_tensor, + out_=None, + qo_indptr=None, + high_precision=high_precision, + kernelName=None, + kv_cache_dtype=kv_cache_dtype_str, + kv_cache_tensor_dtype=kv_cache_tensor_dtype, + ) + + def dump_input( path, query: torch.Tensor, @@ -587,6 +672,32 @@ def test_paged_attention( ) # tensor_dump(out_aiter, 'out_aiter') + # Test paged_attention_common which automatically switches between ASM and HIP + # The routing is internal, so we just test the common API regardless of which path it takes + time_aiter_common = None + if dtype == dtypes.bf16: + try: + out_aiter_common, time_aiter_common = run_aiter_common( + query.contiguous(), + k_cache, + asm_V_shuffle(v_cache), # Shuffle V cache, same as run_aiter_asm + block_tables, + seq_lens, + max_seq_len, + kv_cache_dtype, + num_kv_heads, + scale, + alibi_slopes, + block_tables.stride(0), + ) + checkAllclose( + out_golden, + out_aiter_common, + msg=f"golden vs aiter_common:{time_aiter_common:>8.2f} us......", + ) + except Exception as e: + print(f"Warning: Could not test aiter_common: {e}") + for quant_algo_, cache_type_ in [ (0, k_cache.dtype), (2, dtypes.fp8), @@ -705,6 +816,28 @@ def test_paged_attention( msg=f"golden vs aiter_asm:{time_aiter_asm:>8.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})", ) + # Test paged_attention_common with quantized cache + out_aiter_common, time_aiter_common = run_aiter_common( + query.contiguous(), + k_quant_, + asm_V_shuffle(v_quant_), + block_tables, + seq_lens, + max_seq_len, + kv_cache_dtype, + num_kv_heads, + scale, + alibi_slopes, + block_tables.stride(0), + k_scale_asm, + v_scale_asm, + ) + checkAllclose( + out_golden, + out_aiter_common, + msg=f"golden vs aiter_common:{time_aiter_common:>8.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})", + ) + if ( dtype in [dtypes.bf16, dtypes.fp16] and quant_algo_ == 2 @@ -809,7 +942,11 @@ def test_paged_attention( print( f"finish~ {ctx_lens=}, {num_seqs=}, {num_heads=}, {head_size=}, {use_alibi=}, {block_size=}, {dtype=}, {kv_cache_dtype=}\n" ) - return {"aiter_shomy": time_aiter, "aiter_asm": time_aiter_asm} + return { + "aiter_shomy": time_aiter, + "aiter_asm": time_aiter_asm, + "aiter_common": time_aiter_common, + } df = [] diff --git a/op_tests/test_pa_merged.py b/op_tests/test_pa_merged.py new file mode 100644 index 0000000000..53354d2e83 --- /dev/null +++ b/op_tests/test_pa_merged.py @@ -0,0 +1,350 @@ +# SPDX-License-Identifier: MIT +# Consolidated PA tests (repro + stress) added during this session. + +import os +import random +from typing import Tuple + +import pytest +import torch + +from aiter import dtypes +from aiter.ops.attention import pa_fwd_asm, paged_attention_v1_core +from aiter.test_common import checkAllclose, perftest + + +def _make_alias_view_5d(base_5d: torch.Tensor, num_blocks: int) -> torch.Tensor: + """ + Create a view with first dimension = num_blocks without allocating num_blocks storage. + We do this by setting stride(0)=0 so every "block" aliases block 0. + """ + assert base_5d.dim() == 5 and base_5d.size(0) == 1 + _, s1, s2, s3, s4 = base_5d.stride() + return base_5d.as_strided( + size=(num_blocks, base_5d.size(1), base_5d.size(2), base_5d.size(3), base_5d.size(4)), + stride=(0, s1, s2, s3, s4), + ) + + +def _asm_v_shuffle(value_cache_4d: torch.Tensor) -> torch.Tensor: + """Convert V from [B, H, D, BS] -> ASM 5D [B, H, BS/x, D, x].""" + x = 16 // value_cache_4d.element_size() + num_blocks, num_kv_heads, head_size, block_size = value_cache_4d.shape + assert block_size % x == 0 + v5 = value_cache_4d.view(num_blocks, num_kv_heads, head_size, block_size // x, x) + return v5.permute(0, 1, 3, 2, 4).contiguous() + + +def _kv_cache_factory_asm_kv( + *, + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + seed: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Create KV in ASM-compatible layouts: + - K: 5D ASM layout [B, H, D/x, BS, x] + - V: 5D ASM layout [B, H, BS/x, D, x] (via shuffle from 4D) + """ + torch.manual_seed(seed) + random.seed(seed) + + x = 16 // torch.tensor([], dtype=dtype).element_size() + assert head_size % x == 0 + + k = torch.empty( + (num_blocks, num_kv_heads, head_size // x, block_size, x), + device=device, + dtype=dtype, + ) + k.uniform_(-1, 1) + + v4 = torch.empty( + (num_blocks, num_kv_heads, head_size, block_size), + device=device, + dtype=dtype, + ) + v4.uniform_(-1, 1) + v5 = _asm_v_shuffle(v4) + return k, v5 + + +def _make_workspace_buffer( + *, + num_seqs: int, + num_heads: int, + head_size: int, + max_seq_len: int, + dtype: torch.dtype, + device: torch.device, + partition_size: int = 256, +) -> torch.Tensor: + # Mirrors workspace sizing used by `paged_attention_v1` wrapper. + max_num_partitions = (max_seq_len + partition_size - 1) // partition_size + nbytes_per_elem = torch.finfo(dtype).bits // 8 + nbytes = ( + (num_seqs * num_heads * max_num_partitions * head_size) * nbytes_per_elem + + 2 * (num_seqs * num_heads * max_num_partitions) * 4 + ) + return torch.empty((nbytes,), dtype=torch.uint8, device=device) + + +def _random_block_tables( + *, + num_seqs: int, + max_seq_len: int, + block_size: int, + num_blocks: int, + device: torch.device, +) -> torch.Tensor: + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + bt = torch.empty((num_seqs, max_num_blocks_per_seq), dtype=torch.int32, device="cpu") + for i in range(num_seqs): + bt[i].random_(0, num_blocks) + return bt.to(device=device) + + +@perftest(num_iters=200, num_warmup=2, num_rotate_args=1) +def _perf_v1( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache_asm5d: torch.Tensor, + workspace: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_seq_len: int, +) -> torch.Tensor: + out = torch.empty_like(query) + paged_attention_v1_core( + out, + workspace, + query, + key_cache, + value_cache_asm5d, + float(1.0 / (query.size(2) ** 0.5)), + block_tables, + None, # cu_query_lens + context_lens, + int(max_seq_len), + None, # alibi_slopes + "auto", + "HND", + 0.0, + None, + None, + ) + return out + + +@perftest(num_iters=200, num_warmup=2, num_rotate_args=1) +def _perf_asm( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache_asm5d: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, +) -> torch.Tensor: + return pa_fwd_asm( + query.contiguous(), + key_cache, + value_cache_asm5d, + block_tables, + context_lens, + int(block_tables.stride(0)), + max_qlen=1, + K_QScale=None, + V_QScale=None, + out_=None, + qo_indptr=None, + high_precision=1, + kernelName=None, + ) + + +@pytest.mark.repro +def test_pa_repro_enginecore_shapes_compare_v1_vs_asm() -> None: + """ + Repro using your logged shapes, then compare HIP v1 vs ASM outputs. + Opt-in: + AITER_RUN_REPRO_SHAPES=1 pytest -q op_tests/test_pa_merged.py -k repro -s + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm device not available") + if os.getenv("AITER_RUN_REPRO_SHAPES", "0") != "1": + pytest.skip("Set AITER_RUN_REPRO_SHAPES=1 to enable this repro test") + + device = torch.device("cuda:0") + + # Exact shapes from your log + Q = torch.empty((1, 32, 128), device=device, dtype=torch.bfloat16).uniform_(-1, 1) + K_base = torch.empty((1, 4, 16, 16, 8), device=device, dtype=torch.bfloat16).uniform_(-1, 1) + V_base = torch.empty((1, 4, 2, 128, 8), device=device, dtype=torch.bfloat16).uniform_(-1, 1) + K = _make_alias_view_5d(K_base, 134921) + V = _make_alias_view_5d(V_base, 134921) + + max_seq_len = 17 + block_tables = torch.zeros((1, 16384), device=device, dtype=torch.int32) + assert block_tables.stride(0) == 16384 + context_lens = torch.tensor([17], device=device, dtype=torch.int32) + + workspace = torch.empty((8448,), device=device, dtype=torch.uint8) + + # 1) correctness compare + out_v1 = torch.empty_like(Q) + paged_attention_v1_core( + out_v1, + workspace, + Q, + K, + V, + float(1.0 / (128**0.5)), + block_tables, + None, + context_lens, + int(max_seq_len), + None, + "auto", + "HND", + 0.0, + None, + None, + ) + out_asm = pa_fwd_asm( + Q.contiguous(), + K, + V, + block_tables, + context_lens, + 16384, + max_qlen=1, + K_QScale=None, + V_QScale=None, + out_=None, + qo_indptr=None, + high_precision=1, + kernelName=None, + ) + checkAllclose(out_v1, out_asm, msg="repro shapes: v1_core vs pa_fwd_asm") + + # 2) perf (200 iters each), env-gated so repro correctness can run fast if desired + if os.getenv("AITER_REPRO_PERF", "1") == "1": + _, v1_us = _perf_v1(Q, K, V, workspace, block_tables, context_lens, max_seq_len) + _, asm_us = _perf_asm(Q, K, V, block_tables, context_lens) + print( + f"[perf repro] HIP(paged_attention_v1_core) avg_us/iter={v1_us:.2f} | " + f"ASM(pa_fwd_asm) avg_us/iter={asm_us:.2f} | " + f"winner={'ASM' if asm_us < v1_us else 'HIP' if v1_us < asm_us else 'tie'}" + ) + assert v1_us > 0 and asm_us > 0 + + +@pytest.mark.stress +def test_pa_stress_asm_layout_compare_v1_vs_asm() -> None: + """ + Randomized stress: compare HIP v1 core vs ASM kernel on ASM-shaped KV-cache layouts. + Opt-in: + AITER_RUN_STRESS=1 pytest -q op_tests/test_pa_merged.py -k stress + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm device not available") + if os.getenv("AITER_RUN_STRESS", "0") != "1": + pytest.skip("Set AITER_RUN_STRESS=1 to enable stress tests") + + device = torch.device("cuda:0") + head_size = 128 + block_size = 16 + dtype = dtypes.bf16 if torch.cuda.is_bf16_supported() else torch.float16 + + num_heads = 32 + num_kv_heads = 4 + assert num_heads % num_kv_heads == 0 + + trials = int(os.getenv("AITER_STRESS_TRIALS", "25")) + max_seq_len_choices = [64, 256, 1024, 2048, 4096] + num_seqs_choices = [1, 2, 4, 8] + + for t in range(trials): + seed = 1337 + t + max_seq_len = random.choice(max_seq_len_choices) + num_seqs = random.choice(num_seqs_choices) + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + num_blocks = num_seqs * max_num_blocks_per_seq + if num_blocks > 4096: + num_blocks = 4096 + + query = torch.empty((num_seqs, num_heads, head_size), device=device, dtype=dtype).uniform_( + -1, 1 + ) + context_lens = torch.randint( + low=1, + high=max_seq_len + 1, + size=(num_seqs,), + device=device, + dtype=torch.int32, + ) + block_tables = _random_block_tables( + num_seqs=num_seqs, + max_seq_len=max_seq_len, + block_size=block_size, + num_blocks=num_blocks, + device=device, + ) + key_cache, value_cache_asm5d = _kv_cache_factory_asm_kv( + num_blocks=num_blocks, + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + device=device, + seed=seed, + ) + workspace = _make_workspace_buffer( + num_seqs=num_seqs, + num_heads=num_heads, + head_size=head_size, + max_seq_len=max_seq_len, + dtype=dtype, + device=device, + ) + + out_v1 = torch.empty_like(query) + paged_attention_v1_core( + out_v1, + workspace, + query, + key_cache, + value_cache_asm5d, + float(1.0 / (head_size**0.5)), + block_tables, + None, + context_lens, + int(max_seq_len), + None, + "auto", + "HND", + 0.0, + None, + None, + ) + out_asm = pa_fwd_asm( + query.contiguous(), + key_cache, + value_cache_asm5d, + block_tables, + context_lens, + int(block_tables.stride(0)), + max_qlen=1, + K_QScale=None, + V_QScale=None, + out_=None, + qo_indptr=None, + high_precision=1, + kernelName=None, + ) + checkAllclose(out_v1, out_asm, msg=f"trial {t}: v1_core vs pa_fwd_asm") + From 04d6687f645d48868974084dd33dadcf6d653528 Mon Sep 17 00:00:00 2001 From: Sergey Solovyev Date: Mon, 12 Jan 2026 20:09:12 +0000 Subject: [PATCH 02/11] Delete op_tests/README_pa_merged_tests.md --- op_tests/README_pa_merged_tests.md | 85 ------------------------------ 1 file changed, 85 deletions(-) delete mode 100644 op_tests/README_pa_merged_tests.md diff --git a/op_tests/README_pa_merged_tests.md b/op_tests/README_pa_merged_tests.md deleted file mode 100644 index d244021854..0000000000 --- a/op_tests/README_pa_merged_tests.md +++ /dev/null @@ -1,85 +0,0 @@ -# Paged Attention HIP vs ASM comparison tests - -This repo contains multiple PA (Paged Attention) implementations. The file -`op_tests/test_pa_merged.py` is a **focused** test module whose only goal is to: - -- **Compare correctness**: HIP (`paged_attention_v1_core`) vs ASM (`pa_fwd_asm`) -- **Use ASM-compatible KV-cache layouts for BOTH paths** - - HIP is exercised through the **5D-cache** codepath that can consume these layouts - - ASM is exercised directly via `pa_fwd_asm` on the same layouts -- Optionally **measure performance** using AITER’s standard `@perftest()` harness - -The tests are **opt-in** (they skip by default) to avoid running long GPU workloads in CI. - -## What’s inside `op_tests/test_pa_merged.py` - -- **Repro test (`-k repro`)** - - Uses a fixed set of shapes (from an EngineCore log) - - Compares `paged_attention_v1_core` vs `pa_fwd_asm` - - Optional perf: runs `@perftest(num_iters=200)` for HIP and ASM and prints a winner line - -- **Stress test (`-k stress`)** - - Randomly generates ASM-compatible KV-cache inputs - - Compares `paged_attention_v1_core` vs `pa_fwd_asm` - -## Environment variables (“defines”) you may set - -### Required (choose one, otherwise tests skip) - -- **`AITER_RUN_REPRO_SHAPES=1`** - - Enables the repro test (fixed shapes) -- **`AITER_RUN_STRESS=1`** - - Enables the stress test (random trials) - -### Optional knobs - -- **`AITER_REPRO_PERF=1`** (default in the test is `"1"`) - - If enabled, the repro test runs perf for HIP and ASM and prints: - - `HIP(paged_attention_v1_core) avg_us/iter=...` - - `ASM(pa_fwd_asm) avg_us/iter=...` - - `winner=...` - - Set **`AITER_REPRO_PERF=0`** to run correctness only. - -- **`AITER_STRESS_TRIALS=N`** (default: `25`) - - Number of randomized stress trials. - -- **`AITER_LOG_MORE=1`** - - Makes AITER logging more verbose and enables profiler table printing inside `@perftest()`. - - Under pytest you often want `-s` (or `-o log_cli=true`) to see logs. - -## How to run - -Run from repo root. - -### Repro (fixed shapes) — correctness only - -```bash -AITER_RUN_REPRO_SHAPES=1 AITER_REPRO_PERF=0 pytest -q op_tests/test_pa_merged.py -k repro -``` - -### Repro (fixed shapes) — correctness + perf (prints who is faster) - -```bash -AITER_RUN_REPRO_SHAPES=1 AITER_REPRO_PERF=1 pytest -q op_tests/test_pa_merged.py -k repro -s -``` - -### Stress (random ASM-layout inputs) — 1 quick trial - -```bash -AITER_RUN_STRESS=1 AITER_STRESS_TRIALS=1 pytest -q op_tests/test_pa_merged.py -k stress -``` - -### Stress (random ASM-layout inputs) — more trials - -```bash -AITER_RUN_STRESS=1 AITER_STRESS_TRIALS=25 pytest -q op_tests/test_pa_merged.py -k stress -``` - -## Notes - -- If you see `ss` in pytest output, that means **both tests were skipped** because you didn’t set - `AITER_RUN_REPRO_SHAPES=1` and/or `AITER_RUN_STRESS=1`. -- `@perftest()` uses `torch.profiler` to attribute time to kernels and reports an average at the end. -- The ASM path (`pa_fwd_asm`) has tighter supported-config constraints (e.g., typically `bf16`, `head_size=128`, - `block_size=16`); this module intentionally sticks to those constraints. - From a744e4a817d9cedc6e643ba9acadec81dad96d73 Mon Sep 17 00:00:00 2001 From: Sergey Solovyev Date: Mon, 12 Jan 2026 20:09:40 +0000 Subject: [PATCH 03/11] Delete op_tests/test_pa_merged.py --- op_tests/test_pa_merged.py | 350 ------------------------------------- 1 file changed, 350 deletions(-) delete mode 100644 op_tests/test_pa_merged.py diff --git a/op_tests/test_pa_merged.py b/op_tests/test_pa_merged.py deleted file mode 100644 index 53354d2e83..0000000000 --- a/op_tests/test_pa_merged.py +++ /dev/null @@ -1,350 +0,0 @@ -# SPDX-License-Identifier: MIT -# Consolidated PA tests (repro + stress) added during this session. - -import os -import random -from typing import Tuple - -import pytest -import torch - -from aiter import dtypes -from aiter.ops.attention import pa_fwd_asm, paged_attention_v1_core -from aiter.test_common import checkAllclose, perftest - - -def _make_alias_view_5d(base_5d: torch.Tensor, num_blocks: int) -> torch.Tensor: - """ - Create a view with first dimension = num_blocks without allocating num_blocks storage. - We do this by setting stride(0)=0 so every "block" aliases block 0. - """ - assert base_5d.dim() == 5 and base_5d.size(0) == 1 - _, s1, s2, s3, s4 = base_5d.stride() - return base_5d.as_strided( - size=(num_blocks, base_5d.size(1), base_5d.size(2), base_5d.size(3), base_5d.size(4)), - stride=(0, s1, s2, s3, s4), - ) - - -def _asm_v_shuffle(value_cache_4d: torch.Tensor) -> torch.Tensor: - """Convert V from [B, H, D, BS] -> ASM 5D [B, H, BS/x, D, x].""" - x = 16 // value_cache_4d.element_size() - num_blocks, num_kv_heads, head_size, block_size = value_cache_4d.shape - assert block_size % x == 0 - v5 = value_cache_4d.view(num_blocks, num_kv_heads, head_size, block_size // x, x) - return v5.permute(0, 1, 3, 2, 4).contiguous() - - -def _kv_cache_factory_asm_kv( - *, - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - seed: int, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Create KV in ASM-compatible layouts: - - K: 5D ASM layout [B, H, D/x, BS, x] - - V: 5D ASM layout [B, H, BS/x, D, x] (via shuffle from 4D) - """ - torch.manual_seed(seed) - random.seed(seed) - - x = 16 // torch.tensor([], dtype=dtype).element_size() - assert head_size % x == 0 - - k = torch.empty( - (num_blocks, num_kv_heads, head_size // x, block_size, x), - device=device, - dtype=dtype, - ) - k.uniform_(-1, 1) - - v4 = torch.empty( - (num_blocks, num_kv_heads, head_size, block_size), - device=device, - dtype=dtype, - ) - v4.uniform_(-1, 1) - v5 = _asm_v_shuffle(v4) - return k, v5 - - -def _make_workspace_buffer( - *, - num_seqs: int, - num_heads: int, - head_size: int, - max_seq_len: int, - dtype: torch.dtype, - device: torch.device, - partition_size: int = 256, -) -> torch.Tensor: - # Mirrors workspace sizing used by `paged_attention_v1` wrapper. - max_num_partitions = (max_seq_len + partition_size - 1) // partition_size - nbytes_per_elem = torch.finfo(dtype).bits // 8 - nbytes = ( - (num_seqs * num_heads * max_num_partitions * head_size) * nbytes_per_elem - + 2 * (num_seqs * num_heads * max_num_partitions) * 4 - ) - return torch.empty((nbytes,), dtype=torch.uint8, device=device) - - -def _random_block_tables( - *, - num_seqs: int, - max_seq_len: int, - block_size: int, - num_blocks: int, - device: torch.device, -) -> torch.Tensor: - max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size - bt = torch.empty((num_seqs, max_num_blocks_per_seq), dtype=torch.int32, device="cpu") - for i in range(num_seqs): - bt[i].random_(0, num_blocks) - return bt.to(device=device) - - -@perftest(num_iters=200, num_warmup=2, num_rotate_args=1) -def _perf_v1( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache_asm5d: torch.Tensor, - workspace: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - max_seq_len: int, -) -> torch.Tensor: - out = torch.empty_like(query) - paged_attention_v1_core( - out, - workspace, - query, - key_cache, - value_cache_asm5d, - float(1.0 / (query.size(2) ** 0.5)), - block_tables, - None, # cu_query_lens - context_lens, - int(max_seq_len), - None, # alibi_slopes - "auto", - "HND", - 0.0, - None, - None, - ) - return out - - -@perftest(num_iters=200, num_warmup=2, num_rotate_args=1) -def _perf_asm( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache_asm5d: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, -) -> torch.Tensor: - return pa_fwd_asm( - query.contiguous(), - key_cache, - value_cache_asm5d, - block_tables, - context_lens, - int(block_tables.stride(0)), - max_qlen=1, - K_QScale=None, - V_QScale=None, - out_=None, - qo_indptr=None, - high_precision=1, - kernelName=None, - ) - - -@pytest.mark.repro -def test_pa_repro_enginecore_shapes_compare_v1_vs_asm() -> None: - """ - Repro using your logged shapes, then compare HIP v1 vs ASM outputs. - Opt-in: - AITER_RUN_REPRO_SHAPES=1 pytest -q op_tests/test_pa_merged.py -k repro -s - """ - if not torch.cuda.is_available(): - pytest.skip("CUDA/ROCm device not available") - if os.getenv("AITER_RUN_REPRO_SHAPES", "0") != "1": - pytest.skip("Set AITER_RUN_REPRO_SHAPES=1 to enable this repro test") - - device = torch.device("cuda:0") - - # Exact shapes from your log - Q = torch.empty((1, 32, 128), device=device, dtype=torch.bfloat16).uniform_(-1, 1) - K_base = torch.empty((1, 4, 16, 16, 8), device=device, dtype=torch.bfloat16).uniform_(-1, 1) - V_base = torch.empty((1, 4, 2, 128, 8), device=device, dtype=torch.bfloat16).uniform_(-1, 1) - K = _make_alias_view_5d(K_base, 134921) - V = _make_alias_view_5d(V_base, 134921) - - max_seq_len = 17 - block_tables = torch.zeros((1, 16384), device=device, dtype=torch.int32) - assert block_tables.stride(0) == 16384 - context_lens = torch.tensor([17], device=device, dtype=torch.int32) - - workspace = torch.empty((8448,), device=device, dtype=torch.uint8) - - # 1) correctness compare - out_v1 = torch.empty_like(Q) - paged_attention_v1_core( - out_v1, - workspace, - Q, - K, - V, - float(1.0 / (128**0.5)), - block_tables, - None, - context_lens, - int(max_seq_len), - None, - "auto", - "HND", - 0.0, - None, - None, - ) - out_asm = pa_fwd_asm( - Q.contiguous(), - K, - V, - block_tables, - context_lens, - 16384, - max_qlen=1, - K_QScale=None, - V_QScale=None, - out_=None, - qo_indptr=None, - high_precision=1, - kernelName=None, - ) - checkAllclose(out_v1, out_asm, msg="repro shapes: v1_core vs pa_fwd_asm") - - # 2) perf (200 iters each), env-gated so repro correctness can run fast if desired - if os.getenv("AITER_REPRO_PERF", "1") == "1": - _, v1_us = _perf_v1(Q, K, V, workspace, block_tables, context_lens, max_seq_len) - _, asm_us = _perf_asm(Q, K, V, block_tables, context_lens) - print( - f"[perf repro] HIP(paged_attention_v1_core) avg_us/iter={v1_us:.2f} | " - f"ASM(pa_fwd_asm) avg_us/iter={asm_us:.2f} | " - f"winner={'ASM' if asm_us < v1_us else 'HIP' if v1_us < asm_us else 'tie'}" - ) - assert v1_us > 0 and asm_us > 0 - - -@pytest.mark.stress -def test_pa_stress_asm_layout_compare_v1_vs_asm() -> None: - """ - Randomized stress: compare HIP v1 core vs ASM kernel on ASM-shaped KV-cache layouts. - Opt-in: - AITER_RUN_STRESS=1 pytest -q op_tests/test_pa_merged.py -k stress - """ - if not torch.cuda.is_available(): - pytest.skip("CUDA/ROCm device not available") - if os.getenv("AITER_RUN_STRESS", "0") != "1": - pytest.skip("Set AITER_RUN_STRESS=1 to enable stress tests") - - device = torch.device("cuda:0") - head_size = 128 - block_size = 16 - dtype = dtypes.bf16 if torch.cuda.is_bf16_supported() else torch.float16 - - num_heads = 32 - num_kv_heads = 4 - assert num_heads % num_kv_heads == 0 - - trials = int(os.getenv("AITER_STRESS_TRIALS", "25")) - max_seq_len_choices = [64, 256, 1024, 2048, 4096] - num_seqs_choices = [1, 2, 4, 8] - - for t in range(trials): - seed = 1337 + t - max_seq_len = random.choice(max_seq_len_choices) - num_seqs = random.choice(num_seqs_choices) - - max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size - num_blocks = num_seqs * max_num_blocks_per_seq - if num_blocks > 4096: - num_blocks = 4096 - - query = torch.empty((num_seqs, num_heads, head_size), device=device, dtype=dtype).uniform_( - -1, 1 - ) - context_lens = torch.randint( - low=1, - high=max_seq_len + 1, - size=(num_seqs,), - device=device, - dtype=torch.int32, - ) - block_tables = _random_block_tables( - num_seqs=num_seqs, - max_seq_len=max_seq_len, - block_size=block_size, - num_blocks=num_blocks, - device=device, - ) - key_cache, value_cache_asm5d = _kv_cache_factory_asm_kv( - num_blocks=num_blocks, - block_size=block_size, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - device=device, - seed=seed, - ) - workspace = _make_workspace_buffer( - num_seqs=num_seqs, - num_heads=num_heads, - head_size=head_size, - max_seq_len=max_seq_len, - dtype=dtype, - device=device, - ) - - out_v1 = torch.empty_like(query) - paged_attention_v1_core( - out_v1, - workspace, - query, - key_cache, - value_cache_asm5d, - float(1.0 / (head_size**0.5)), - block_tables, - None, - context_lens, - int(max_seq_len), - None, - "auto", - "HND", - 0.0, - None, - None, - ) - out_asm = pa_fwd_asm( - query.contiguous(), - key_cache, - value_cache_asm5d, - block_tables, - context_lens, - int(block_tables.stride(0)), - max_qlen=1, - K_QScale=None, - V_QScale=None, - out_=None, - qo_indptr=None, - high_precision=1, - kernelName=None, - ) - checkAllclose(out_v1, out_asm, msg=f"trial {t}: v1_core vs pa_fwd_asm") - From 476286ec6456f0cb9bd7b4ffdda2462cdb520d8c Mon Sep 17 00:00:00 2001 From: Sergey Solovyev Date: Mon, 12 Jan 2026 20:26:27 +0000 Subject: [PATCH 04/11] Fix formatting according to Black requirements --- aiter/ops/attention.py | 51 +++++++++++++++++++++++++++++---------- csrc/cpp_itfs/pa/pa_v1.py | 23 +++++++++++------- op_tests/test_pa.py | 23 ++++++++++++------ 3 files changed, 68 insertions(+), 29 deletions(-) diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index fb287b313e..5739232be3 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -127,8 +127,11 @@ def _should_use_asm_kernel( num_heads: int, kv_cache_tensor_dtype: torch.dtype, ) -> bool: - #TODO: HIP kernel yet isn't supporting fp8 scales in asm layout. - if kv_cache_tensor_dtype == torch.int8 or kv_cache_tensor_dtype == torch.float8_e4m3fnuz: + # TODO: HIP kernel yet isn't supporting fp8 scales in asm layout. + if ( + kv_cache_tensor_dtype == torch.int8 + or kv_cache_tensor_dtype == torch.float8_e4m3fnuz + ): return True # Get GPU compute units (CUs) @@ -169,29 +172,51 @@ def paged_attention_common( ASM is favored for int8 kv caches, for short ctx_len, or when the workload exceeds the heuristic thresholds for larger ctx_len values. """ - kv_cache_tensor_dtype = kv_cache_tensor_dtype if kv_cache_tensor_dtype is not None else K.dtype + kv_cache_tensor_dtype = ( + kv_cache_tensor_dtype if kv_cache_tensor_dtype is not None else K.dtype + ) num_seqs, num_heads, head_size = Q.shape # Route to ASM kernel based on the heuristic above. - use_asm_kernel = _should_use_asm_kernel( - num_seqs, num_heads, kv_cache_tensor_dtype - ) + use_asm_kernel = _should_use_asm_kernel(num_seqs, num_heads, kv_cache_tensor_dtype) if use_asm_kernel: output = pa_fwd_asm( - Q, K, V, block_tables, context_lens, block_tables_stride0, - max_qlen, K_QScale, V_QScale, out_, qo_indptr, high_precision, kernelName + Q, + K, + V, + block_tables, + context_lens, + block_tables_stride0, + max_qlen, + K_QScale, + V_QScale, + out_, + qo_indptr, + high_precision, + kernelName, ) return output - + # Use HIP kernel for smaller workloads (5D V cache) output = out_ if out_ is not None else torch.empty_like(Q) paged_attention_v1( - output, workspace_buffer, Q, K, V, scale, - block_tables, cu_query_lens, context_lens, max_seq_len, + output, + workspace_buffer, + Q, + K, + V, + scale, + block_tables, + cu_query_lens, + context_lens, + max_seq_len, None, # alibi_slopes - kv_cache_dtype, "HND", logits_soft_cap, - K_QScale, V_QScale, + kv_cache_dtype, + "HND", + logits_soft_cap, + K_QScale, + V_QScale, ) return output diff --git a/csrc/cpp_itfs/pa/pa_v1.py b/csrc/cpp_itfs/pa/pa_v1.py index 3c56dc2a7a..af23315188 100644 --- a/csrc/cpp_itfs/pa/pa_v1.py +++ b/csrc/cpp_itfs/pa/pa_v1.py @@ -112,7 +112,7 @@ def paged_attention_v1( # Handle both 4D and 5D layouts (key and value cache always have same layout) is_5d_cache = value_cache.dim() == 5 - + if is_5d_cache: # 5D V cache layout: [num_blocks, num_kv_heads, block_size/x, head_size, x] # K cache is kept in ASM layout: [num_blocks, num_kv_heads, head_size/x, block_size, x] @@ -123,20 +123,26 @@ def paged_attention_v1( # K uses ASM layout strides; V uses its own 5D strides. kv_block_stride = value_cache.stride(0) # stride over blocks for V cache - kv_head_stride = value_cache.stride(1) # stride over heads for V cache - kv_seq_stride = key_cache.stride(3) # stride of block_size dimension in ASM K layout + kv_head_stride = value_cache.stride(1) # stride over heads for V cache + kv_seq_stride = key_cache.stride( + 3 + ) # stride of block_size dimension in ASM K layout else: # 4D layout: [num_blocks, num_heads, block_size, head_size] or [num_blocks, block_size, num_heads, head_size] - num_kv_heads = key_cache.size(1) if kv_cache_layout == "HND" else key_cache.size(2) - block_size = key_cache.size(2) if kv_cache_layout == "HND" else key_cache.size(1) + num_kv_heads = ( + key_cache.size(1) if kv_cache_layout == "HND" else key_cache.size(2) + ) + block_size = ( + key_cache.size(2) if kv_cache_layout == "HND" else key_cache.size(1) + ) kv_block_stride = key_cache.stride(0) kv_head_stride = ( key_cache.stride(1) if kv_cache_layout == "HND" else key_cache.stride(2) ) kv_seq_stride = ( - key_cache.stride(2) if kv_cache_layout == "HND" else key_cache.stride(1) - ) - + key_cache.stride(2) if kv_cache_layout == "HND" else key_cache.stride(1) + ) + num_seqs = block_tables.size(0) num_heads = query.size(1) head_size = query.size(2) @@ -243,7 +249,6 @@ def paged_attention_v1( else ctypes.POINTER(ctypes.c_float)() ) - func( out_ptr, workspace_buffer_ptr, diff --git a/op_tests/test_pa.py b/op_tests/test_pa.py index 95809bb81a..78dc29ccb6 100644 --- a/op_tests/test_pa.py +++ b/op_tests/test_pa.py @@ -427,12 +427,14 @@ def run_aiter_common( """ Test paged_attention_common which automatically switches between ASM and HIP kernels. """ - + num_seqs, num_heads, head_size = query.shape # Create workspace buffer for HIP kernel path # Workspace buffer size calculation from test_pa_v1.py _PARTITION_SIZE_ROCM = 256 - max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM - 1) // _PARTITION_SIZE_ROCM + max_num_partitions = ( + max_seq_len + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM nbyes_per_qo_elem = torch.finfo(query.dtype).bits // 8 workspace_buffer = torch.empty( (num_seqs * num_heads * max_num_partitions * head_size) * nbyes_per_qo_elem @@ -440,7 +442,7 @@ def run_aiter_common( dtype=torch.uint8, device=query.device, ) - + def _normalize_scale(s): if s is None: return None @@ -451,20 +453,27 @@ def _normalize_scale(s): k_scale_tensor = _normalize_scale(k_scale) v_scale_tensor = _normalize_scale(v_scale) - + # Determine kv_cache_dtype string. def _is_fp8_storage(dt: torch.dtype) -> bool: if dt == torch.int8 or dt == torch.uint8: return True # torch float8 dtypes (guard for older torch builds) - for name in ("float8_e4m3fnuz", "float8_e4m3fn", "float8_e5m2fnuz", "float8_e5m2"): + for name in ( + "float8_e4m3fnuz", + "float8_e4m3fn", + "float8_e5m2fnuz", + "float8_e5m2", + ): if hasattr(torch, name) and dt == getattr(torch, name): return True return False - cache_dt = kv_cache_tensor_dtype if kv_cache_tensor_dtype is not None else k_cache.dtype + cache_dt = ( + kv_cache_tensor_dtype if kv_cache_tensor_dtype is not None else k_cache.dtype + ) kv_cache_dtype_str = "fp8" if _is_fp8_storage(cache_dt) else "auto" - + return attention.paged_attention_common( Q=query.contiguous(), K=k_cache, From b95df2efd0be5ac542e0fdfcb9fb35e2b2dc9fc1 Mon Sep 17 00:00:00 2001 From: Sergey Solovyev Date: Mon, 12 Jan 2026 20:32:13 +0000 Subject: [PATCH 05/11] Fix one last place with broken formatting --- csrc/cpp_itfs/pa/pa_v1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/cpp_itfs/pa/pa_v1.py b/csrc/cpp_itfs/pa/pa_v1.py index af23315188..589ad12009 100644 --- a/csrc/cpp_itfs/pa/pa_v1.py +++ b/csrc/cpp_itfs/pa/pa_v1.py @@ -140,8 +140,8 @@ def paged_attention_v1( key_cache.stride(1) if kv_cache_layout == "HND" else key_cache.stride(2) ) kv_seq_stride = ( - key_cache.stride(2) if kv_cache_layout == "HND" else key_cache.stride(1) - ) + key_cache.stride(2) if kv_cache_layout == "HND" else key_cache.stride(1) + ) num_seqs = block_tables.size(0) num_heads = query.size(1) From 85d998496b097c64ece216cc789985ac12981670 Mon Sep 17 00:00:00 2001 From: Sergey Sol Date: Tue, 13 Jan 2026 10:49:55 +0000 Subject: [PATCH 06/11] Remove modification to pa_v1, we already have pa for 5D kv cache --- aiter/ops/attention.py | 46 +++++---- csrc/cpp_itfs/pa/pa_kernels.cuh | 158 +++++++++++-------------------- csrc/cpp_itfs/pa/pa_v1.cpp.jinja | 3 +- csrc/cpp_itfs/pa/pa_v1.cuh | 19 ++-- csrc/cpp_itfs/pa/pa_v1.py | 61 +++--------- op_tests/test_pa.py | 25 ++--- 6 files changed, 119 insertions(+), 193 deletions(-) diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index 5739232be3..3d82757247 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -147,7 +147,9 @@ def paged_attention_common( Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, - workspace_buffer: torch.Tensor, + exp_sums: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, block_tables: torch.Tensor, context_lens: torch.Tensor, block_tables_stride0: int, @@ -198,25 +200,31 @@ def paged_attention_common( ) return output - # Use HIP kernel for smaller workloads (5D V cache) + # Use ROCm paged attention kernel for smaller workloads / common path. output = out_ if out_ is not None else torch.empty_like(Q) - paged_attention_v1( - output, - workspace_buffer, - Q, - K, - V, - scale, - block_tables, - cu_query_lens, - context_lens, - max_seq_len, - None, # alibi_slopes - kv_cache_dtype, - "HND", - logits_soft_cap, - K_QScale, - V_QScale, + + paged_attention_rocm( + out=output, + exp_sums=exp_sums, + max_logits=max_logits, + tmp_out=tmp_out, + query=Q, + key_cache=K, + value_cache=V, + num_kv_heads=int(K.size(1)), + scale=scale, + block_tables=block_tables, + context_lens=context_lens, + block_size=int(K.size(3)), + max_context_len=max_seq_len, + alibi_slopes=None, + kv_cache_dtype=kv_cache_dtype, + k_scale=K_QScale, + v_scale=V_QScale, + fp8_out_scale=None, + partition_size=256, + mtp=1, + q_scale=None, ) return output diff --git a/csrc/cpp_itfs/pa/pa_kernels.cuh b/csrc/cpp_itfs/pa/pa_kernels.cuh index 420884c6a8..44959d0d05 100644 --- a/csrc/cpp_itfs/pa/pa_kernels.cuh +++ b/csrc/cpp_itfs/pa/pa_kernels.cuh @@ -12,8 +12,7 @@ template + bool SLIDING_WINDOW_ENABLED> __inline__ __device__ void _paged_attention_kernel(const int* block_table_seq, const int64_t query_loc, @@ -222,20 +221,10 @@ _paged_attention_kernel(const int* block_table_seq, for(int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) { const int head_elem = - row_head_elem + qkhe_depth * QKHE_PER_FETCH + head_loop * HEAD_SIZE_PER_LOOP; - const int offset1 = head_elem / KX; - const int offset2 = head_elem % KX; - const cache_t* k_fetch_ptr = [&] { - if constexpr(USE_5D_VCACHE) - { - const int head_stride = BLOCK_SIZE * kv_seq_stride; - return k_ptr3 + offset1 * head_stride + offset2; - } - else - { - return k_ptr3 + offset1 * KX + offset2; - } - }(); + row_head_elem + qkhe_depth * QKHE_PER_FETCH + head_loop * HEAD_SIZE_PER_LOOP; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = k_ptr3 + offset1 * KX + offset2; const _B16x8* k_fetch_ptr_16B = reinterpret_cast(k_fetch_ptr); if constexpr(NT_KV_LOAD) { @@ -302,62 +291,27 @@ _paged_attention_kernel(const int* block_table_seq, static_assert(VBLOCKS_PER_LANE == VTLANELOOP, "make sure we can keep un-shuffled data in Vlocal as well"); - constexpr int V_X = CONTIGUOUS_KV_ELEMS_16B_LOAD; + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + + ((threadIdx.x / n_thread_per_block) % BLOCK_SIZE) * kv_seq_stride; - if constexpr(USE_5D_VCACHE) + // v fetches are 16head elems across lanes x 16 tokens per lane + for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - constexpr int V5D_TOKEN_GRP_STRIDE = HEAD_SIZE * V_X; - constexpr int V5D_HEAD_STRIDE = V_X; - - for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) + for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) + for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) { - for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) - { - const int vlocal_token_idx = rowid * VTOKENS_PER_LANE + vfetch_depth * V_X; - const int global_token_idx = partition_start_token_idx + - vtoken_depth * TOKENS_PER_WARP + - vlocal_token_idx; - const int block_idx = global_token_idx / BLOCK_SIZE; - const int token_in_block = global_token_idx % BLOCK_SIZE; - const int token_grp = token_in_block / V_X; - const int safe_block_idx = - (global_token_idx < context_len) ? block_idx : last_ctx_block; - const int physical_block = block_table_seq[safe_block_idx]; - const int head_elem = (warpid * 16 + lane16id) + vhe_depth * NWARPS * 16; - const int64_t v_offset = static_cast(physical_block) * kv_block_stride + - wg_start_kv_head_idx * kv_head_stride + - token_grp * V5D_TOKEN_GRP_STRIDE + - head_elem * V5D_HEAD_STRIDE; - const cache_t* v_fetch_ptr = v_cache + v_offset; - Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = - *reinterpret_cast(v_fetch_ptr); - } - } - } - } - else - { - const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + - ((threadIdx.x / n_thread_per_block) % BLOCK_SIZE) * kv_seq_stride; + const int vlds_col_idx = laneid % n_thread_per_block; + const int vhead_elem = + vhe_depth * NWARPS * 16 + vlds_col_idx * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const cache_t* v_ptr2 = v_ptr + vhead_elem; - for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) - { - for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) - { - for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) - { - const int vlds_col_idx = laneid % n_thread_per_block; - const int vhead_elem = - vhe_depth * NWARPS * 16 + vlds_col_idx * V_X; - const cache_t* v_ptr2 = v_ptr + vhead_elem; - const int64_t vblock_number = - static_cast(vphysical_block_number[vtoken_depth][vblock_depth]); - const cache_t* v_fetch_ptr = v_ptr2 + (vblock_number * kv_block_stride); - Vlocal[vtoken_depth][vhe_depth][vblock_depth] = - *reinterpret_cast(v_fetch_ptr); - } + const int64_t vblock_number = + static_cast(vphysical_block_number[vtoken_depth][vblock_depth]); + const cache_t* v_fetch_ptr = v_ptr2 + (vblock_number * kv_block_stride); + + Vlocal[vtoken_depth][vhe_depth][vblock_depth] = + *reinterpret_cast(v_fetch_ptr); } } } @@ -721,52 +675,54 @@ _paged_attention_kernel(const int* block_table_seq, constexpr int ELEMS8_ELEMS4_RATIO = 8 / 4; constexpr int ELEMS16_ELEMS8_RATIO = 16 / 8; - if constexpr(!USE_5D_VCACHE) + for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) + for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) + // 1. store data into LDS + for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) { - for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) - { - const int vlds_col_idx = laneid % n_thread_per_block; - const int vlocal_token_idx = - vblock_depth * k_thread_per_block + threadIdx.x / n_thread_per_block; - *reinterpret_cast<_B16x8*>(vlds_ptr + - (/*row=*/vlocal_token_idx * n_thread_per_block + - /*col=*/vlds_col_idx) * - 16) = Vlocal[vtoken_depth][vhe_depth][vblock_depth]; - } - __syncthreads(); + const int vlds_col_idx = laneid % n_thread_per_block; + const int vlocal_token_idx = + vblock_depth * k_thread_per_block + threadIdx.x / n_thread_per_block; + *reinterpret_cast<_B16x8*>(vlds_ptr + + (/*row=*/vlocal_token_idx * n_thread_per_block + + /*col=*/vlds_col_idx) * + 16) = Vlocal[vtoken_depth][vhe_depth][vblock_depth]; + } + __syncthreads(); - for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) - { - const int vlocal_head_elem = warpid * 16 + lane16id; + // 2. load data from LDS (transposed), then do multification + for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) + { + const int vlocal_head_elem = warpid * 16 + lane16id; - const int vlds_col_idx = vlocal_head_elem / CONTIGUOUS_KV_ELEMS_16B_LOAD; - const int vlds_elem_idx = vlocal_head_elem % CONTIGUOUS_KV_ELEMS_16B_LOAD; + const int vlds_col_idx = vlocal_head_elem / CONTIGUOUS_KV_ELEMS_16B_LOAD; + const int vlds_elem_idx = vlocal_head_elem % CONTIGUOUS_KV_ELEMS_16B_LOAD; - const int vlocal_token_idx = - rowid * VTOKENS_PER_LANE + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const int vlocal_token_idx = + rowid * VTOKENS_PER_LANE + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; - cache_t elems[CONTIGUOUS_KV_ELEMS_16B_LOAD]; - for(int d2 = 0; d2 < CONTIGUOUS_KV_ELEMS_16B_LOAD; ++d2) - { - const cache_t* fetched_elems = reinterpret_cast( - vlds_ptr + (/*row=*/(vlocal_token_idx + d2) * n_thread_per_block + - /*col=*/vlds_col_idx) * - 16); - elems[d2] = fetched_elems[vlds_elem_idx]; - } + // read data points individually and save them into array + cache_t elems[CONTIGUOUS_KV_ELEMS_16B_LOAD]; + for(int d2 = 0; d2 < CONTIGUOUS_KV_ELEMS_16B_LOAD; ++d2) + { + const cache_t* fetched_elems = reinterpret_cast( + vlds_ptr + (/*row=*/(vlocal_token_idx + d2) * n_thread_per_block + + /*col=*/vlds_col_idx) * + 16); - Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = - *reinterpret_cast(elems); + elems[d2] = fetched_elems[vlds_elem_idx]; } - __syncthreads(); + + // copy all the read data points together + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = + *reinterpret_cast(elems); } + __syncthreads(); } } - // For 5D, Vlocal is already in the correct format from the load phase + _B16x4 outelems[GQA_RATIO_LOOP][MTP_PER_THREAD][VHELOOP]; // Softmax V mfma diff --git a/csrc/cpp_itfs/pa/pa_v1.cpp.jinja b/csrc/cpp_itfs/pa/pa_v1.cpp.jinja index 1803c13116..96d53c61d9 100644 --- a/csrc/cpp_itfs/pa/pa_v1.cpp.jinja +++ b/csrc/cpp_itfs/pa/pa_v1.cpp.jinja @@ -90,8 +90,7 @@ void {{func_name}}(void* out_ptr, gqa_ratio, {{mtp}}, decltype(variant), - {{"true" if sliding_window_enabled else "false"}}, - {{"true" if use_5d_vcache else "false"}}> + {{"true" if sliding_window_enabled else "false"}}> <<(stream)>>>(reinterpret_cast<{{dtype}}*>(query_ptr), reinterpret_cast<{{kv_dtype}}*>(key_cache_ptr), reinterpret_cast<{{kv_dtype}}*>(value_cache_ptr), diff --git a/csrc/cpp_itfs/pa/pa_v1.cuh b/csrc/cpp_itfs/pa/pa_v1.cuh index af6c988781..d00308ee16 100644 --- a/csrc/cpp_itfs/pa/pa_v1.cuh +++ b/csrc/cpp_itfs/pa/pa_v1.cuh @@ -35,13 +35,13 @@ template + bool SLIDING_WINDOW_ENABLED> __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // 4D: [num_blocks, num_kv_heads, head_size, block_size] - // 5D: [num_blocks, num_kv_heads, block_size/x, head_size, x] + const cache_t* __restrict__ k_cache, // [num_blocks, block_size, num_kv_heads, + // head_size] + const cache_t* __restrict__ v_cache, // [num_blocks, block_size, num_kv_heads, + // head_size] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ cu_query_lens, // [num_seqs+1] @@ -84,7 +84,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ return; } const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; - _paged_attention_kernel(block_table_seq, static_cast(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, sliding_window); + _paged_attention_kernel(block_table_seq, static_cast(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, sliding_window); } // Grid: (num_heads, num_seqs). @@ -136,14 +136,13 @@ template + bool SLIDING_WINDOW_ENABLED> __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // 4D: [num_blocks, num_kv_heads, head_size, block_size] - // 5D: [num_blocks, num_kv_heads, block_size/x, head_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ cu_query_lens, // [num_seqs+1] diff --git a/csrc/cpp_itfs/pa/pa_v1.py b/csrc/cpp_itfs/pa/pa_v1.py index 589ad12009..cad347d8e9 100644 --- a/csrc/cpp_itfs/pa/pa_v1.py +++ b/csrc/cpp_itfs/pa/pa_v1.py @@ -23,7 +23,6 @@ def compile( partition_size: int = 256, mtp: int = 1, sliding_window_enabled: bool = False, - use_5d_vcache: bool = False, folder: str = None, ): return compile_template_op( @@ -50,7 +49,6 @@ def compile( partition_size=partition_size, mtp=mtp, sliding_window_enabled=sliding_window_enabled, - use_5d_vcache=use_5d_vcache, folder=folder, ) @@ -110,44 +108,20 @@ def paged_attention_v1( else: raise ValueError(f"Unsupported data type: {out.dtype}") - # Handle both 4D and 5D layouts (key and value cache always have same layout) - is_5d_cache = value_cache.dim() == 5 - - if is_5d_cache: - # 5D V cache layout: [num_blocks, num_kv_heads, block_size/x, head_size, x] - # K cache is kept in ASM layout: [num_blocks, num_kv_heads, head_size/x, block_size, x] - # Pass raw strides so the HIP kernel can consume ASM layout directly. - num_kv_heads = value_cache.size(1) - x = value_cache.size(4) # e.g., 8 for bf16, 16 for fp8 - block_size = value_cache.size(2) * x # block_size/x * x = block_size - - # K uses ASM layout strides; V uses its own 5D strides. - kv_block_stride = value_cache.stride(0) # stride over blocks for V cache - kv_head_stride = value_cache.stride(1) # stride over heads for V cache - kv_seq_stride = key_cache.stride( - 3 - ) # stride of block_size dimension in ASM K layout - else: - # 4D layout: [num_blocks, num_heads, block_size, head_size] or [num_blocks, block_size, num_heads, head_size] - num_kv_heads = ( - key_cache.size(1) if kv_cache_layout == "HND" else key_cache.size(2) - ) - block_size = ( - key_cache.size(2) if kv_cache_layout == "HND" else key_cache.size(1) - ) - kv_block_stride = key_cache.stride(0) - kv_head_stride = ( - key_cache.stride(1) if kv_cache_layout == "HND" else key_cache.stride(2) - ) - kv_seq_stride = ( - key_cache.stride(2) if kv_cache_layout == "HND" else key_cache.stride(1) - ) - + num_kv_heads = key_cache.size(1) if kv_cache_layout == "HND" else key_cache.size(2) num_seqs = block_tables.size(0) num_heads = query.size(1) head_size = query.size(2) q_stride = query.stride(0) + block_size = key_cache.size(2) if kv_cache_layout == "HND" else key_cache.size(1) max_num_blocks_per_seq = block_tables.size(1) + kv_block_stride = key_cache.stride(0) + kv_head_stride = ( + key_cache.stride(1) if kv_cache_layout == "HND" else key_cache.stride(2) + ) + kv_seq_stride = ( + key_cache.stride(2) if kv_cache_layout == "HND" else key_cache.stride(1) + ) gqa_ratio = int(num_heads / num_kv_heads) max_num_partitions = int(math.ceil(max_context_len / partition_size)) npar_loops = int(math.ceil(max_num_partitions / warpSize)) @@ -168,7 +142,6 @@ def paged_attention_v1( partition_size, mtp, sliding_window_enabled=sliding_window_enabled, - use_5d_vcache=is_5d_cache, ) alibi_slopes_ptr = ( @@ -237,18 +210,6 @@ def paged_attention_v1( if q_scale is not None else ctypes.POINTER(ctypes.c_float)() ) - - k_scale_ptr = ( - ctypes.cast(k_scale.data_ptr(), ctypes.POINTER(ctypes.c_float)) - if k_scale is not None - else ctypes.POINTER(ctypes.c_float)() - ) - v_scale_ptr = ( - ctypes.cast(v_scale.data_ptr(), ctypes.POINTER(ctypes.c_float)) - if v_scale is not None - else ctypes.POINTER(ctypes.c_float)() - ) - func( out_ptr, workspace_buffer_ptr, @@ -260,8 +221,8 @@ def paged_attention_v1( context_lens_ptr, alibi_slopes_ptr, q_scale_ptr, - k_scale_ptr, - v_scale_ptr, + ctypes.cast(k_scale.data_ptr(), ctypes.POINTER(ctypes.c_float)), + ctypes.cast(v_scale.data_ptr(), ctypes.POINTER(ctypes.c_float)), fp8_out_scale_ptr, scale, max_num_blocks_per_seq, diff --git a/op_tests/test_pa.py b/op_tests/test_pa.py index 78dc29ccb6..e39204f374 100644 --- a/op_tests/test_pa.py +++ b/op_tests/test_pa.py @@ -429,19 +429,20 @@ def run_aiter_common( """ num_seqs, num_heads, head_size = query.shape - # Create workspace buffer for HIP kernel path - # Workspace buffer size calculation from test_pa_v1.py + # Client-side allocations required by ROCm paged attention path. _PARTITION_SIZE_ROCM = 256 - max_num_partitions = ( - max_seq_len + _PARTITION_SIZE_ROCM - 1 - ) // _PARTITION_SIZE_ROCM - nbyes_per_qo_elem = torch.finfo(query.dtype).bits // 8 - workspace_buffer = torch.empty( - (num_seqs * num_heads * max_num_partitions * head_size) * nbyes_per_qo_elem - + 2 * (num_seqs * num_heads * max_num_partitions) * 4, - dtype=torch.uint8, + max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM - 1) // _PARTITION_SIZE_ROCM + tmp_out = torch.empty( + (num_seqs, num_heads, max_num_partitions, head_size), + dtype=query.dtype, device=query.device, ) + exp_sums = torch.empty( + (num_seqs, num_heads, max_num_partitions), + dtype=dtypes.fp32, + device=query.device, + ) + max_logits = torch.empty_like(exp_sums) def _normalize_scale(s): if s is None: @@ -478,7 +479,9 @@ def _is_fp8_storage(dt: torch.dtype) -> bool: Q=query.contiguous(), K=k_cache, V=v_cache, - workspace_buffer=workspace_buffer, + exp_sums=exp_sums, + max_logits=max_logits, + tmp_out=tmp_out, block_tables=block_tables, context_lens=seq_lens, block_tables_stride0=block_tables_stride0, From d759438d55c69c9b7fa95d04619619584bdac066 Mon Sep 17 00:00:00 2001 From: Sergey Sol Date: Tue, 13 Jan 2026 10:55:56 +0000 Subject: [PATCH 07/11] Fix another formatting issue --- op_tests/test_pa.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/op_tests/test_pa.py b/op_tests/test_pa.py index e39204f374..5c16504070 100644 --- a/op_tests/test_pa.py +++ b/op_tests/test_pa.py @@ -431,7 +431,9 @@ def run_aiter_common( num_seqs, num_heads, head_size = query.shape # Client-side allocations required by ROCm paged attention path. _PARTITION_SIZE_ROCM = 256 - max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM - 1) // _PARTITION_SIZE_ROCM + max_num_partitions = ( + max_seq_len + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM tmp_out = torch.empty( (num_seqs, num_heads, max_num_partitions, head_size), dtype=query.dtype, From fa7634da6cbf0bd87c1a0b33b21f5bda3ab4d061 Mon Sep 17 00:00:00 2001 From: Sergey Sol Date: Tue, 13 Jan 2026 15:42:41 +0000 Subject: [PATCH 08/11] Add proper quant support for the common API --- aiter/ops/attention.py | 29 ++++++++++--------- op_tests/test_pa.py | 66 +++++++++++++++++++++++++----------------- 2 files changed, 55 insertions(+), 40 deletions(-) diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index 3d82757247..90492c8208 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -127,11 +127,8 @@ def _should_use_asm_kernel( num_heads: int, kv_cache_tensor_dtype: torch.dtype, ) -> bool: - # TODO: HIP kernel yet isn't supporting fp8 scales in asm layout. - if ( - kv_cache_tensor_dtype == torch.int8 - or kv_cache_tensor_dtype == torch.float8_e4m3fnuz - ): + + if kv_cache_tensor_dtype == torch.int8: return True # Get GPU compute units (CUs) @@ -158,8 +155,10 @@ def paged_attention_common( max_qlen: int = 1, max_seq_len: int = 1, cu_query_lens: Optional[torch.Tensor] = None, - K_QScale: Optional[torch.Tensor] = None, - V_QScale: Optional[torch.Tensor] = None, + K_QScale_hip: Optional[torch.Tensor] = None, # [num_seqs, num_heads] + V_QScale_hip: Optional[torch.Tensor] = None, + K_QScale_asm: Optional[torch.Tensor] = None, # [num_blocks, num_kv_heads, block_size] + V_QScale_asm: Optional[torch.Tensor] = None, out_: Optional[torch.Tensor] = None, qo_indptr: Optional[torch.Tensor] = None, high_precision: Optional[ @@ -173,14 +172,18 @@ def paged_attention_common( Paged attention forward pass with automatic kernel selection. ASM is favored for int8 kv caches, for short ctx_len, or when the workload exceeds the heuristic thresholds for larger ctx_len values. + PA is normally using per tensor quant and this is what has been tested, however, + per head quant can be supported as well in principle, but not tested. """ kv_cache_tensor_dtype = ( kv_cache_tensor_dtype if kv_cache_tensor_dtype is not None else K.dtype ) num_seqs, num_heads, head_size = Q.shape - # Route to ASM kernel based on the heuristic above. - use_asm_kernel = _should_use_asm_kernel(num_seqs, num_heads, kv_cache_tensor_dtype) + use_asm_kernel = ( + _should_use_asm_kernel(num_seqs, num_heads, kv_cache_tensor_dtype) + or high_precision == 2 + ) if use_asm_kernel: output = pa_fwd_asm( @@ -191,8 +194,8 @@ def paged_attention_common( context_lens, block_tables_stride0, max_qlen, - K_QScale, - V_QScale, + K_QScale_asm, + V_QScale_asm, out_, qo_indptr, high_precision, @@ -219,8 +222,8 @@ def paged_attention_common( max_context_len=max_seq_len, alibi_slopes=None, kv_cache_dtype=kv_cache_dtype, - k_scale=K_QScale, - v_scale=V_QScale, + k_scale=K_QScale_hip, + v_scale=V_QScale_hip, fp8_out_scale=None, partition_size=256, mtp=1, diff --git a/op_tests/test_pa.py b/op_tests/test_pa.py index 5c16504070..cc1048a1b0 100644 --- a/op_tests/test_pa.py +++ b/op_tests/test_pa.py @@ -419,8 +419,12 @@ def run_aiter_common( scale, alibi_slopes, block_tables_stride0, - k_scale=None, - v_scale=None, + # ROCm/HIP (scalar) scales + k_scale_hip=None, + v_scale_hip=None, + # ASM (expanded) scales + k_scale_asm=None, + v_scale_asm=None, high_precision=0, kv_cache_tensor_dtype=None, ): @@ -454,8 +458,11 @@ def _normalize_scale(s): # python scalar return torch.tensor(float(s), device=query.device, dtype=dtypes.fp32) - k_scale_tensor = _normalize_scale(k_scale) - v_scale_tensor = _normalize_scale(v_scale) + k_scale_hip_tensor = _normalize_scale(k_scale_hip) + v_scale_hip_tensor = _normalize_scale(v_scale_hip) + # ASM scales are already tensors in the expected layout; just ensure fp32 on device. + k_scale_asm_tensor = _normalize_scale(k_scale_asm) + v_scale_asm_tensor = _normalize_scale(v_scale_asm) # Determine kv_cache_dtype string. def _is_fp8_storage(dt: torch.dtype) -> bool: @@ -492,8 +499,10 @@ def _is_fp8_storage(dt: torch.dtype) -> bool: max_qlen=1, max_seq_len=max_seq_len, cu_query_lens=None, - K_QScale=k_scale_tensor, - V_QScale=v_scale_tensor, + K_QScale_hip=k_scale_hip_tensor, + V_QScale_hip=v_scale_hip_tensor, + K_QScale_asm=k_scale_asm_tensor, + V_QScale_asm=v_scale_asm_tensor, out_=None, qo_indptr=None, high_precision=high_precision, @@ -830,27 +839,30 @@ def test_paged_attention( msg=f"golden vs aiter_asm:{time_aiter_asm:>8.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})", ) - # Test paged_attention_common with quantized cache - out_aiter_common, time_aiter_common = run_aiter_common( - query.contiguous(), - k_quant_, - asm_V_shuffle(v_quant_), - block_tables, - seq_lens, - max_seq_len, - kv_cache_dtype, - num_kv_heads, - scale, - alibi_slopes, - block_tables.stride(0), - k_scale_asm, - v_scale_asm, - ) - checkAllclose( - out_golden, - out_aiter_common, - msg=f"golden vs aiter_common:{time_aiter_common:>8.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})", - ) + if quant_algo_ == 4: + # Test paged_attention_common with quantized cache + out_aiter_common, time_aiter_common = run_aiter_common( + query.contiguous(), + k_quant_, + asm_V_shuffle(v_quant_), + block_tables, + seq_lens, + max_seq_len, + kv_cache_dtype, + num_kv_heads, + scale, + alibi_slopes, + block_tables.stride(0), + k_scale_hip=k_scale_, + v_scale_hip=v_scale_, + k_scale_asm=k_scale_asm, + v_scale_asm=v_scale_asm, + ) + checkAllclose( + out_golden, + out_aiter_common, + msg=f"golden vs aiter_common:{time_aiter_common:>8.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})", + ) if ( dtype in [dtypes.bf16, dtypes.fp16] From 39e3d6685f13036bc1689d2e44c087d2762cf806 Mon Sep 17 00:00:00 2001 From: Sergey Sol Date: Tue, 13 Jan 2026 16:09:44 +0000 Subject: [PATCH 09/11] Apply formatting --- aiter/ops/attention.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index 90492c8208..ae54450f69 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -155,9 +155,11 @@ def paged_attention_common( max_qlen: int = 1, max_seq_len: int = 1, cu_query_lens: Optional[torch.Tensor] = None, - K_QScale_hip: Optional[torch.Tensor] = None, # [num_seqs, num_heads] + K_QScale_hip: Optional[torch.Tensor] = None, # [num_seqs, num_heads] V_QScale_hip: Optional[torch.Tensor] = None, - K_QScale_asm: Optional[torch.Tensor] = None, # [num_blocks, num_kv_heads, block_size] + K_QScale_asm: Optional[ + torch.Tensor + ] = None, # [num_blocks, num_kv_heads, block_size] V_QScale_asm: Optional[torch.Tensor] = None, out_: Optional[torch.Tensor] = None, qo_indptr: Optional[torch.Tensor] = None, @@ -172,7 +174,7 @@ def paged_attention_common( Paged attention forward pass with automatic kernel selection. ASM is favored for int8 kv caches, for short ctx_len, or when the workload exceeds the heuristic thresholds for larger ctx_len values. - PA is normally using per tensor quant and this is what has been tested, however, + PA is normally using per tensor quant and this is what has been tested, however, per head quant can be supported as well in principle, but not tested. """ kv_cache_tensor_dtype = ( From 3677df8c0e178c652a979c64ed12b71a65f40312 Mon Sep 17 00:00:00 2001 From: Mikko Tukiainen Date: Wed, 14 Jan 2026 15:55:46 +0200 Subject: [PATCH 10/11] Remove redundant parameters --- aiter/ops/attention.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index ae54450f69..e5f90b4e77 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -150,11 +150,9 @@ def paged_attention_common( block_tables: torch.Tensor, context_lens: torch.Tensor, block_tables_stride0: int, - logits_soft_cap: float, scale: float, max_qlen: int = 1, max_seq_len: int = 1, - cu_query_lens: Optional[torch.Tensor] = None, K_QScale_hip: Optional[torch.Tensor] = None, # [num_seqs, num_heads] V_QScale_hip: Optional[torch.Tensor] = None, K_QScale_asm: Optional[ From 94816c78840349195dff05407237b60f75021de6 Mon Sep 17 00:00:00 2001 From: Mikko Tukiainen Date: Wed, 14 Jan 2026 16:07:39 +0200 Subject: [PATCH 11/11] Remove redundant parameters --- op_tests/test_pa.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/op_tests/test_pa.py b/op_tests/test_pa.py index cc1048a1b0..7c218a29ae 100644 --- a/op_tests/test_pa.py +++ b/op_tests/test_pa.py @@ -494,11 +494,9 @@ def _is_fp8_storage(dt: torch.dtype) -> bool: block_tables=block_tables, context_lens=seq_lens, block_tables_stride0=block_tables_stride0, - logits_soft_cap=0.0, scale=scale, max_qlen=1, max_seq_len=max_seq_len, - cu_query_lens=None, K_QScale_hip=k_scale_hip_tensor, V_QScale_hip=v_scale_hip_tensor, K_QScale_asm=k_scale_asm_tensor,