From a6e8a590b4c28396451459d788ff076e6351c19a Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Mon, 3 Nov 2025 11:11:29 +0000 Subject: [PATCH 1/7] add sliding window to paged_attention_v1 --- aiter/ops/attention.py | 2 ++ csrc/cpp_itfs/pa/pa_kernels.cuh | 25 ++++++++++++++++++++++++- csrc/cpp_itfs/pa/pa_v1.cpp.jinja | 8 ++++++-- csrc/cpp_itfs/pa/pa_v1.cuh | 10 +++++++--- csrc/cpp_itfs/pa/pa_v1.py | 6 ++++++ op_tests/test_pa_v1.py | 14 +++++++++++++- 6 files changed, 58 insertions(+), 7 deletions(-) diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index 03ea084642..ee9f0c79d6 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -187,6 +187,7 @@ def paged_attention_v1( logits_soft_cap: float, k_scale: torch.Tensor, v_scale: torch.Tensor, + sliding_window: int = 0, fp8_out_scale: Optional[torch.Tensor] = None, partition_size: int = 256, mtp: int = 1, @@ -208,6 +209,7 @@ def paged_attention_v1( logits_soft_cap, k_scale, v_scale, + sliding_window, fp8_out_scale, partition_size, mtp, diff --git a/csrc/cpp_itfs/pa/pa_kernels.cuh b/csrc/cpp_itfs/pa/pa_kernels.cuh index 31e2d3bd7d..4f87853eaa 100644 --- a/csrc/cpp_itfs/pa/pa_kernels.cuh +++ b/csrc/cpp_itfs/pa/pa_kernels.cuh @@ -11,6 +11,7 @@ template __inline__ __device__ void _paged_attention_kernel(const int* block_table_seq, @@ -36,7 +37,8 @@ _paged_attention_kernel(const int* block_table_seq, const float* q_scale_ptr, const float* k_scale_ptr, const float* v_scale_ptr, - const AttentionVariant* variant) + const AttentionVariant* variant, + const int sliding_window = 0) { const int seq_idx = blockIdx.x; const int partition_idx = blockIdx.y; @@ -464,6 +466,27 @@ _paged_attention_kernel(const int* block_table_seq, float qk_max[GQA_RATIO_LOOP][MTP_PER_THREAD] = {{-FLT_MAX}}; float exp_sum[GQA_RATIO_LOOP][MTP_PER_THREAD] = {{0.0f}}; + if constexpr(HAS_SLIDING_WINDOW) + { + for(int mtp = 0; mtp < mtp_loop; mtp++) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { + for(int token_depth = 0; token_depth < TLOOP; token_depth++) + { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 4; i++) + { + float tmp = d_out[gqa_ratio_loop][mtp][token_depth][i]; + if (local_token_idx + i <= context_len - sliding_window) + tmp = -FLT_MAX; + d_out[gqa_ratio_loop][mtp][token_depth][i] = tmp; + } + } + } + } + } + for(int mtp = 0; mtp < mtp_loop; mtp++) { for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) diff --git a/csrc/cpp_itfs/pa/pa_v1.cpp.jinja b/csrc/cpp_itfs/pa/pa_v1.cpp.jinja index 3f12e96aef..61bbe26bee 100644 --- a/csrc/cpp_itfs/pa/pa_v1.cpp.jinja +++ b/csrc/cpp_itfs/pa/pa_v1.cpp.jinja @@ -26,6 +26,7 @@ void {{func_name}}(void* out_ptr, const int kv_block_stride, const int kv_head_stride, const int kv_seq_stride, + const int sliding_window, void* stream); } @@ -53,6 +54,7 @@ void {{func_name}}(void* out_ptr, const int kv_block_stride, const int kv_head_stride, const int kv_seq_stride, + const int sliding_window, void* stream) { constexpr int head_size = {{head_size}}; @@ -86,7 +88,8 @@ void {{func_name}}(void* out_ptr, NTHR, {{"true" if alibi_enabled else "false"}}, gqa_ratio, - {{mtp}}> + {{mtp}}, + {{"true" if sliding_window_enabled else "false"}}> <<(stream)>>>(reinterpret_cast<{{dtype}}*>(query_ptr), reinterpret_cast<{{kv_dtype}}*>(key_cache_ptr), reinterpret_cast<{{kv_dtype}}*>(value_cache_ptr), @@ -108,7 +111,8 @@ void {{func_name}}(void* out_ptr, q_scale_ptr, k_scale_ptr, v_scale_ptr, - &variant); + &variant, + sliding_window); dim3 reduce_grid(num_heads, num_seqs, {{mtp}}); dim3 reduce_block(head_size); diff --git a/csrc/cpp_itfs/pa/pa_v1.cuh b/csrc/cpp_itfs/pa/pa_v1.cuh index 02f3dfc6de..db5f904cc8 100644 --- a/csrc/cpp_itfs/pa/pa_v1.cuh +++ b/csrc/cpp_itfs/pa/pa_v1.cuh @@ -34,6 +34,7 @@ template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -61,7 +62,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ const float* q_scale_ptr, const float* k_scale_ptr, const float* v_scale_ptr, - const AttentionVariant* variant) + const AttentionVariant* variant, + const int sliding_window) { const int seq_idx = blockIdx.x; int query_loc = seq_idx * MTP; @@ -82,7 +84,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ return; } const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; - _paged_attention_kernel(block_table_seq, static_cast(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant); + _paged_attention_kernel(block_table_seq, static_cast(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, sliding_window); } // Grid: (num_heads, num_seqs). @@ -133,6 +135,7 @@ template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -160,7 +163,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ const float* q_scale_ptr, const float* k_scale_ptr, const float* v_scale_ptr, - const AttentionVariant* variant) + const AttentionVariant* variant, + const int sliding_window) { UNREACHABLE_CODE } diff --git a/csrc/cpp_itfs/pa/pa_v1.py b/csrc/cpp_itfs/pa/pa_v1.py index 746c5cc9ea..a660a5d320 100644 --- a/csrc/cpp_itfs/pa/pa_v1.py +++ b/csrc/cpp_itfs/pa/pa_v1.py @@ -22,6 +22,7 @@ def compile( logits_soft_cap_enabled: bool, partition_size: int = 256, mtp: int = 1, + sliding_window_enabled: bool = False, folder: str = None, ): return compile_template_op( @@ -47,6 +48,7 @@ def compile( logits_soft_cap_enabled=logits_soft_cap_enabled, partition_size=partition_size, mtp=mtp, + sliding_window_enabled=sliding_window_enabled, folder=folder, ) @@ -68,6 +70,7 @@ def paged_attention_v1( logits_soft_cap: float, k_scale, v_scale, + sliding_window: int, fp8_out_scale=None, partition_size: int = 256, mtp: int = 1, @@ -124,6 +127,7 @@ def paged_attention_v1( npar_loops = int(math.ceil(max_num_partitions / warpSize)) logits_soft_cap_enabled = logits_soft_cap > 0 alibi_enabled = alibi_slopes is not None + sliding_window_enabled = sliding_window > 0 func = compile( gqa_ratio, head_size, @@ -137,6 +141,7 @@ def paged_attention_v1( logits_soft_cap_enabled, partition_size, mtp, + sliding_window_enabled=sliding_window_enabled ) alibi_slopes_ptr = ( @@ -230,6 +235,7 @@ def paged_attention_v1( kv_block_stride, kv_head_stride, kv_seq_stride, + sliding_window, stream, ) return out diff --git a/op_tests/test_pa_v1.py b/op_tests/test_pa_v1.py index 487db10a1a..495dc11796 100644 --- a/op_tests/test_pa_v1.py +++ b/op_tests/test_pa_v1.py @@ -128,10 +128,13 @@ def ref_masked_attention( scale: float, attn_mask: Optional[torch.Tensor] = None, logits_soft_cap: float = 0.0, + sliding_window: int = 0, ) -> torch.Tensor: attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() if attn_mask is not None: attn_weights = attn_weights + attn_mask.float() + if sliding_window: + attn_weights[:,:,:-sliding_window] = -1e38 if 0 < logits_soft_cap: attn_weights = logits_soft_cap * torch.tanh(attn_weights / logits_soft_cap) attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) @@ -214,6 +217,7 @@ def run_torch( k_scale, v_scale, num_queries_per_kv, + sliding_window ): output = torch.zeros_like(query) num_query_heads = query.shape[1] @@ -255,7 +259,7 @@ def run_torch( alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1) - out = ref_masked_attention(q, keys, values, scale, alibi_bias, logits_soft_cap) + out = ref_masked_attention(q, keys, values, scale, alibi_bias, logits_soft_cap, sliding_window=sliding_window) out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) return output, 1 @@ -277,6 +281,7 @@ def run_aiter( logits_soft_cap, k_scale, v_scale, + sliding_window, mtp=1, ): # copied from ops.PagedAttention.forward_decode() @@ -326,6 +331,7 @@ def run_aiter( logits_soft_cap, k_scale, v_scale, + sliding_window, fp8_out_scale if cpa_fp8_out else None, _PARTITION_SIZE_ROCM, ) @@ -422,6 +428,7 @@ class PAVariant(Enum): @pytest.mark.parametrize("quant_cache_dtype", [None, dtypes.fp8, dtypes.i8]) @pytest.mark.parametrize("seed", [0]) @pytest.mark.parametrize("device", ["cuda:0"]) +@pytest.mark.parametrize("sliding_window", [0, 10]) def test_paged_attention( ctx_lens: int, num_seqs: int, @@ -437,6 +444,7 @@ def test_paged_attention( quant_cache_dtype: torch.dtype, seed: int, device: str, + sliding_window: int, ) -> None: if pa_variant == PAVariant.Shomy: if quant_cache_dtype is not None: @@ -448,6 +456,7 @@ def test_paged_attention( or block_size != 16 or dtype is not dtypes.bf16 or quant_cache_dtype not in [None, dtypes.i8] + or sliding_window != 0 ): pytest.skip() elif pa_variant == PAVariant.Naive: @@ -523,6 +532,7 @@ def test_paged_attention( k_scale, v_scale, num_queries_per_kv, + sliding_window, ) cu_query_lens = torch.arange(0, num_seqs + 1, dtype=torch.int) @@ -546,6 +556,7 @@ def test_paged_attention( logits_soft_cap, k_scale, v_scale, + sliding_window ) assert ( checkAllclose(out_golden, out_aiter, msg=f"golden vs aiter:{time_aiter}") @@ -643,4 +654,5 @@ def test_paged_attention( quant_cache_dtype, 0, "cuda:0", + 0 ) From 2a964ac2fb9166eaee36de08f55cf66fcf19b443 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Wed, 5 Nov 2025 14:29:51 +0800 Subject: [PATCH 2/7] Avoid use fmha_v3_varlen_fwd on unsupported architecture gfx90a --- aiter/ops/mha.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index ebf7f6f28e..65e058d5ee 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -1864,6 +1864,7 @@ def can_impl_fmha_v3_fwd(): ret = ret and (dropout_p == 0.0) ret = ret and (hdim_v == 128) ret = ret and (hdim_q == 128 or (get_gfx() == "gfx950" and hdim_q == 192)) + ret = ret and (get_gfx() != "gfx90a") ret = ret and (nhead_q % nhead_k == 0) ret = ret and (not swa) ret = ret and (q.dtype == dtypes.bf16) From 74da0a0068353f93b5f3938e43dceff95e85e808 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Fri, 7 Nov 2025 03:28:27 +0000 Subject: [PATCH 3/7] make `sliding_window` default to 0 for better compatibility --- aiter/ops/attention.py | 4 ++-- csrc/cpp_itfs/pa/pa_kernels.cuh | 6 +++--- csrc/cpp_itfs/pa/pa_v1.cpp.jinja | 1 + csrc/cpp_itfs/pa/pa_v1.cuh | 10 +++++----- csrc/cpp_itfs/pa/pa_v1.py | 2 +- op_tests/test_pa_v1.py | 6 +++--- 6 files changed, 15 insertions(+), 14 deletions(-) diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index ee9f0c79d6..410f3600d6 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -187,10 +187,10 @@ def paged_attention_v1( logits_soft_cap: float, k_scale: torch.Tensor, v_scale: torch.Tensor, - sliding_window: int = 0, fp8_out_scale: Optional[torch.Tensor] = None, partition_size: int = 256, mtp: int = 1, + sliding_window: int = 0, ) -> torch.Tensor: paged_attention_v1_core( out, @@ -209,10 +209,10 @@ def paged_attention_v1( logits_soft_cap, k_scale, v_scale, - sliding_window, fp8_out_scale, partition_size, mtp, + sliding_window=sliding_window, ) return out diff --git a/csrc/cpp_itfs/pa/pa_kernels.cuh b/csrc/cpp_itfs/pa/pa_kernels.cuh index 4f87853eaa..8ca8658b8a 100644 --- a/csrc/cpp_itfs/pa/pa_kernels.cuh +++ b/csrc/cpp_itfs/pa/pa_kernels.cuh @@ -11,8 +11,8 @@ template + typename AttentionVariant, + bool SLIDING_WINDOW_ENABLED> __inline__ __device__ void _paged_attention_kernel(const int* block_table_seq, const int64_t query_loc, @@ -466,7 +466,7 @@ _paged_attention_kernel(const int* block_table_seq, float qk_max[GQA_RATIO_LOOP][MTP_PER_THREAD] = {{-FLT_MAX}}; float exp_sum[GQA_RATIO_LOOP][MTP_PER_THREAD] = {{0.0f}}; - if constexpr(HAS_SLIDING_WINDOW) + if constexpr(SLIDING_WINDOW_ENABLED) { for(int mtp = 0; mtp < mtp_loop; mtp++) { diff --git a/csrc/cpp_itfs/pa/pa_v1.cpp.jinja b/csrc/cpp_itfs/pa/pa_v1.cpp.jinja index 61bbe26bee..96d53c61d9 100644 --- a/csrc/cpp_itfs/pa/pa_v1.cpp.jinja +++ b/csrc/cpp_itfs/pa/pa_v1.cpp.jinja @@ -89,6 +89,7 @@ void {{func_name}}(void* out_ptr, {{"true" if alibi_enabled else "false"}}, gqa_ratio, {{mtp}}, + decltype(variant), {{"true" if sliding_window_enabled else "false"}}> <<(stream)>>>(reinterpret_cast<{{dtype}}*>(query_ptr), reinterpret_cast<{{kv_dtype}}*>(key_cache_ptr), diff --git a/csrc/cpp_itfs/pa/pa_v1.cuh b/csrc/cpp_itfs/pa/pa_v1.cuh index db5f904cc8..d00308ee16 100644 --- a/csrc/cpp_itfs/pa/pa_v1.cuh +++ b/csrc/cpp_itfs/pa/pa_v1.cuh @@ -34,8 +34,8 @@ template + typename AttentionVariant, + bool SLIDING_WINDOW_ENABLED> __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, block_size, num_kv_heads, @@ -84,7 +84,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ return; } const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; - _paged_attention_kernel(block_table_seq, static_cast(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, sliding_window); + _paged_attention_kernel(block_table_seq, static_cast(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, sliding_window); } // Grid: (num_heads, num_seqs). @@ -135,8 +135,8 @@ template + typename AttentionVariant, + bool SLIDING_WINDOW_ENABLED> __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, diff --git a/csrc/cpp_itfs/pa/pa_v1.py b/csrc/cpp_itfs/pa/pa_v1.py index a660a5d320..2a225bbf0b 100644 --- a/csrc/cpp_itfs/pa/pa_v1.py +++ b/csrc/cpp_itfs/pa/pa_v1.py @@ -70,11 +70,11 @@ def paged_attention_v1( logits_soft_cap: float, k_scale, v_scale, - sliding_window: int, fp8_out_scale=None, partition_size: int = 256, mtp: int = 1, q_scale=None, + sliding_window: int = 0, ): import torch from csrc.cpp_itfs.torch_utils import torch_to_c_types diff --git a/op_tests/test_pa_v1.py b/op_tests/test_pa_v1.py index 495dc11796..8324fc5132 100644 --- a/op_tests/test_pa_v1.py +++ b/op_tests/test_pa_v1.py @@ -281,8 +281,8 @@ def run_aiter( logits_soft_cap, k_scale, v_scale, - sliding_window, mtp=1, + sliding_window=0, ): # copied from ops.PagedAttention.forward_decode() _PARTITION_SIZE_ROCM = 256 @@ -331,9 +331,9 @@ def run_aiter( logits_soft_cap, k_scale, v_scale, - sliding_window, fp8_out_scale if cpa_fp8_out else None, _PARTITION_SIZE_ROCM, + sliding_window=sliding_window, ) if cpa_fp8_out: return output.view(num_seqs, num_heads * head_size) @@ -556,7 +556,7 @@ def test_paged_attention( logits_soft_cap, k_scale, v_scale, - sliding_window + sliding_window=sliding_window ) assert ( checkAllclose(out_golden, out_aiter, msg=f"golden vs aiter:{time_aiter}") From 28bde45ba794b83159e50b82096202e3f291b7d7 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Fri, 7 Nov 2025 06:36:43 +0000 Subject: [PATCH 4/7] fix possible used compilation problem --- csrc/cpp_itfs/pa/pa_kernels.cuh | 42 ++++++++++++++++----------------- csrc/cpp_itfs/pa/pa_ragged.cuh | 2 +- op_tests/test_pa_v1.py | 2 +- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/csrc/cpp_itfs/pa/pa_kernels.cuh b/csrc/cpp_itfs/pa/pa_kernels.cuh index 8ca8658b8a..6c2cd5df1f 100644 --- a/csrc/cpp_itfs/pa/pa_kernels.cuh +++ b/csrc/cpp_itfs/pa/pa_kernels.cuh @@ -441,6 +441,27 @@ _paged_attention_kernel(const int* block_table_seq, } } } + // apply sliding window + if constexpr(SLIDING_WINDOW_ENABLED) + { + for(int token_depth = 0; token_depth < TLOOP; token_depth++) + { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for(int mtp = 0; mtp < mtp_loop; mtp++) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { + for(int i = 0; i < 4; i++) + { + float tmp = d_out[gqa_ratio_loop][mtp][token_depth][i]; + if (local_token_idx + i < context_len - sliding_window) + tmp = -FLT_MAX; + d_out[gqa_ratio_loop][mtp][token_depth][i] = tmp; + } + } + } + } + } // apply soft-capping to logits for(int token_depth = 0; token_depth < TLOOP; token_depth++) { @@ -466,27 +487,6 @@ _paged_attention_kernel(const int* block_table_seq, float qk_max[GQA_RATIO_LOOP][MTP_PER_THREAD] = {{-FLT_MAX}}; float exp_sum[GQA_RATIO_LOOP][MTP_PER_THREAD] = {{0.0f}}; - if constexpr(SLIDING_WINDOW_ENABLED) - { - for(int mtp = 0; mtp < mtp_loop; mtp++) - { - for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) - { - for(int token_depth = 0; token_depth < TLOOP; token_depth++) - { - const int local_token_idx = qkout_token_idx + token_depth * 16; - for (int i = 0; i < 4; i++) - { - float tmp = d_out[gqa_ratio_loop][mtp][token_depth][i]; - if (local_token_idx + i <= context_len - sliding_window) - tmp = -FLT_MAX; - d_out[gqa_ratio_loop][mtp][token_depth][i] = tmp; - } - } - } - } - } - for(int mtp = 0; mtp < mtp_loop; mtp++) { for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) diff --git a/csrc/cpp_itfs/pa/pa_ragged.cuh b/csrc/cpp_itfs/pa/pa_ragged.cuh index 3f24cf2fa1..0f23d1cfcd 100644 --- a/csrc/cpp_itfs/pa/pa_ragged.cuh +++ b/csrc/cpp_itfs/pa/pa_ragged.cuh @@ -89,7 +89,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ } const int64_t query_loc = static_cast(seq_idx * MTP); const int* block_table_seq = kv_page_indices + kv_indptr[seq_idx]; - _paged_attention_kernel(block_table_seq, query_loc, context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant); + _paged_attention_kernel(block_table_seq, query_loc, context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant); } // Grid: (num_heads, num_seqs, mtp). diff --git a/op_tests/test_pa_v1.py b/op_tests/test_pa_v1.py index 8324fc5132..982e009c93 100644 --- a/op_tests/test_pa_v1.py +++ b/op_tests/test_pa_v1.py @@ -654,5 +654,5 @@ def test_paged_attention( quant_cache_dtype, 0, "cuda:0", - 0 + 10 ) From bf8f6dcb3148d27269d39b07331035504af3e61d Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Sun, 9 Nov 2025 07:07:37 +0000 Subject: [PATCH 5/7] fix ci failure --- csrc/cpp_itfs/pa/pa_v1.py | 2 +- op_tests/test_pa_v1.py | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/csrc/cpp_itfs/pa/pa_v1.py b/csrc/cpp_itfs/pa/pa_v1.py index 2a225bbf0b..cad347d8e9 100644 --- a/csrc/cpp_itfs/pa/pa_v1.py +++ b/csrc/cpp_itfs/pa/pa_v1.py @@ -141,7 +141,7 @@ def paged_attention_v1( logits_soft_cap_enabled, partition_size, mtp, - sliding_window_enabled=sliding_window_enabled + sliding_window_enabled=sliding_window_enabled, ) alibi_slopes_ptr = ( diff --git a/op_tests/test_pa_v1.py b/op_tests/test_pa_v1.py index 982e009c93..1608e78890 100644 --- a/op_tests/test_pa_v1.py +++ b/op_tests/test_pa_v1.py @@ -134,7 +134,7 @@ def ref_masked_attention( if attn_mask is not None: attn_weights = attn_weights + attn_mask.float() if sliding_window: - attn_weights[:,:,:-sliding_window] = -1e38 + attn_weights[:, :, :-sliding_window] = -1e38 if 0 < logits_soft_cap: attn_weights = logits_soft_cap * torch.tanh(attn_weights / logits_soft_cap) attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) @@ -217,7 +217,7 @@ def run_torch( k_scale, v_scale, num_queries_per_kv, - sliding_window + sliding_window, ): output = torch.zeros_like(query) num_query_heads = query.shape[1] @@ -259,7 +259,14 @@ def run_torch( alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1) - out = ref_masked_attention(q, keys, values, scale, alibi_bias, logits_soft_cap, sliding_window=sliding_window) + out = ref_masked_attention(q, + keys, + values, + scale, + alibi_bias, + logits_soft_cap, + sliding_window=sliding_window, + ) out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) return output, 1 @@ -556,7 +563,7 @@ def test_paged_attention( logits_soft_cap, k_scale, v_scale, - sliding_window=sliding_window + sliding_window=sliding_window, ) assert ( checkAllclose(out_golden, out_aiter, msg=f"golden vs aiter:{time_aiter}") @@ -654,5 +661,5 @@ def test_paged_attention( quant_cache_dtype, 0, "cuda:0", - 10 + 10, ) From 058c28c841da4e289ddd31fd445b93d21c265961 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Sun, 9 Nov 2025 13:16:31 +0000 Subject: [PATCH 6/7] fix ci failure --- op_tests/test_pa_v1.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/op_tests/test_pa_v1.py b/op_tests/test_pa_v1.py index 1608e78890..0ab23848eb 100644 --- a/op_tests/test_pa_v1.py +++ b/op_tests/test_pa_v1.py @@ -259,14 +259,15 @@ def run_torch( alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1) - out = ref_masked_attention(q, - keys, - values, - scale, - alibi_bias, - logits_soft_cap, - sliding_window=sliding_window, - ) + out = ref_masked_attention( + q, + keys, + values, + scale, + alibi_bias, + logits_soft_cap, + sliding_window=sliding_window, + ) out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) return output, 1 From c4ae933b4a0bed5d48bb6e1be29eb63c88ecc0f9 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Wed, 12 Nov 2025 02:50:25 +0000 Subject: [PATCH 7/7] add a single test to avoid increasing test time a lot --- aiter/ops/mha.py | 1 - csrc/cpp_itfs/pa/pa_ragged.cuh | 2 +- op_tests/test_pa_v1.py | 34 ++++++++++++++++++++++++++++++++-- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 49a18318b9..a05a940c9c 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -1907,7 +1907,6 @@ def can_impl_fmha_v3_fwd(): ret = ret and (dropout_p == 0.0) ret = ret and (hdim_v == 128) ret = ret and (hdim_q == 128 or (get_gfx() == "gfx950" and hdim_q == 192)) - ret = ret and (get_gfx() != "gfx90a") ret = ret and (nhead_q % nhead_k == 0) ret = ret and (not swa) ret = ret and (q.dtype == dtypes.bf16) diff --git a/csrc/cpp_itfs/pa/pa_ragged.cuh b/csrc/cpp_itfs/pa/pa_ragged.cuh index 0f23d1cfcd..20ec3b4ed5 100644 --- a/csrc/cpp_itfs/pa/pa_ragged.cuh +++ b/csrc/cpp_itfs/pa/pa_ragged.cuh @@ -89,7 +89,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ } const int64_t query_loc = static_cast(seq_idx * MTP); const int* block_table_seq = kv_page_indices + kv_indptr[seq_idx]; - _paged_attention_kernel(block_table_seq, query_loc, context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant); + _paged_attention_kernel(block_table_seq, query_loc, context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, 0); } // Grid: (num_heads, num_seqs, mtp). diff --git a/op_tests/test_pa_v1.py b/op_tests/test_pa_v1.py index 0ab23848eb..306ad5b6f4 100644 --- a/op_tests/test_pa_v1.py +++ b/op_tests/test_pa_v1.py @@ -436,7 +436,6 @@ class PAVariant(Enum): @pytest.mark.parametrize("quant_cache_dtype", [None, dtypes.fp8, dtypes.i8]) @pytest.mark.parametrize("seed", [0]) @pytest.mark.parametrize("device", ["cuda:0"]) -@pytest.mark.parametrize("sliding_window", [0, 10]) def test_paged_attention( ctx_lens: int, num_seqs: int, @@ -452,7 +451,7 @@ def test_paged_attention( quant_cache_dtype: torch.dtype, seed: int, device: str, - sliding_window: int, + sliding_window: int = 0, ) -> None: if pa_variant == PAVariant.Shomy: if quant_cache_dtype is not None: @@ -594,6 +593,37 @@ def test_paged_attention( # f"[test] dim: {str((ctx_lens, num_seqs, num_heads, head_size)):<20}, dtype: {dtype}, finished)\n") +@pytest.mark.parametrize("ctx_lens", [1, 26, 128, 4097]) +@pytest.mark.parametrize("num_seqs", [1, 3, 31, 128]) +@pytest.mark.parametrize("num_heads", [(8, 1), (32, 4)]) +@pytest.mark.parametrize("use_alibi", [False, True]) +@pytest.mark.parametrize("sliding_window", [0, 10]) +def test_paged_attention_sliding_window( + ctx_lens: int, + num_seqs: int, + num_heads: Tuple[int, int], + use_alibi: bool, + sliding_window: int, +) -> None: + test_paged_attention( + ctx_lens, + num_seqs, + num_heads, + 128, + use_alibi, + block_size=16, + dtype=dtypes.fp16, + kv_cache_dtype="auto", + kv_cache_layout="NHD", + logits_soft_cap=0.0, + pa_variant=PAVariant.Shomy, + quant_cache_dtype=None, + seed=0, + device="cuda:0", + sliding_window=sliding_window, + ) + + if __name__ == "__main__": parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter,