diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index 03ea084642..410f3600d6 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -190,6 +190,7 @@ def paged_attention_v1( fp8_out_scale: Optional[torch.Tensor] = None, partition_size: int = 256, mtp: int = 1, + sliding_window: int = 0, ) -> torch.Tensor: paged_attention_v1_core( out, @@ -211,6 +212,7 @@ def paged_attention_v1( fp8_out_scale, partition_size, mtp, + sliding_window=sliding_window, ) return out diff --git a/csrc/cpp_itfs/pa/pa_kernels.cuh b/csrc/cpp_itfs/pa/pa_kernels.cuh index 31e2d3bd7d..6c2cd5df1f 100644 --- a/csrc/cpp_itfs/pa/pa_kernels.cuh +++ b/csrc/cpp_itfs/pa/pa_kernels.cuh @@ -11,7 +11,8 @@ template + typename AttentionVariant, + bool SLIDING_WINDOW_ENABLED> __inline__ __device__ void _paged_attention_kernel(const int* block_table_seq, const int64_t query_loc, @@ -36,7 +37,8 @@ _paged_attention_kernel(const int* block_table_seq, const float* q_scale_ptr, const float* k_scale_ptr, const float* v_scale_ptr, - const AttentionVariant* variant) + const AttentionVariant* variant, + const int sliding_window = 0) { const int seq_idx = blockIdx.x; const int partition_idx = blockIdx.y; @@ -439,6 +441,27 @@ _paged_attention_kernel(const int* block_table_seq, } } } + // apply sliding window + if constexpr(SLIDING_WINDOW_ENABLED) + { + for(int token_depth = 0; token_depth < TLOOP; token_depth++) + { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for(int mtp = 0; mtp < mtp_loop; mtp++) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { + for(int i = 0; i < 4; i++) + { + float tmp = d_out[gqa_ratio_loop][mtp][token_depth][i]; + if (local_token_idx + i < context_len - sliding_window) + tmp = -FLT_MAX; + d_out[gqa_ratio_loop][mtp][token_depth][i] = tmp; + } + } + } + } + } // apply soft-capping to logits for(int token_depth = 0; token_depth < TLOOP; token_depth++) { diff --git a/csrc/cpp_itfs/pa/pa_ragged.cuh b/csrc/cpp_itfs/pa/pa_ragged.cuh index 3f24cf2fa1..20ec3b4ed5 100644 --- a/csrc/cpp_itfs/pa/pa_ragged.cuh +++ b/csrc/cpp_itfs/pa/pa_ragged.cuh @@ -89,7 +89,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ } const int64_t query_loc = static_cast(seq_idx * MTP); const int* block_table_seq = kv_page_indices + kv_indptr[seq_idx]; - _paged_attention_kernel(block_table_seq, query_loc, context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant); + _paged_attention_kernel(block_table_seq, query_loc, context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, 0); } // Grid: (num_heads, num_seqs, mtp). diff --git a/csrc/cpp_itfs/pa/pa_v1.cpp.jinja b/csrc/cpp_itfs/pa/pa_v1.cpp.jinja index 3f12e96aef..96d53c61d9 100644 --- a/csrc/cpp_itfs/pa/pa_v1.cpp.jinja +++ b/csrc/cpp_itfs/pa/pa_v1.cpp.jinja @@ -26,6 +26,7 @@ void {{func_name}}(void* out_ptr, const int kv_block_stride, const int kv_head_stride, const int kv_seq_stride, + const int sliding_window, void* stream); } @@ -53,6 +54,7 @@ void {{func_name}}(void* out_ptr, const int kv_block_stride, const int kv_head_stride, const int kv_seq_stride, + const int sliding_window, void* stream) { constexpr int head_size = {{head_size}}; @@ -86,7 +88,9 @@ void {{func_name}}(void* out_ptr, NTHR, {{"true" if alibi_enabled else "false"}}, gqa_ratio, - {{mtp}}> + {{mtp}}, + decltype(variant), + {{"true" if sliding_window_enabled else "false"}}> <<(stream)>>>(reinterpret_cast<{{dtype}}*>(query_ptr), reinterpret_cast<{{kv_dtype}}*>(key_cache_ptr), reinterpret_cast<{{kv_dtype}}*>(value_cache_ptr), @@ -108,7 +112,8 @@ void {{func_name}}(void* out_ptr, q_scale_ptr, k_scale_ptr, v_scale_ptr, - &variant); + &variant, + sliding_window); dim3 reduce_grid(num_heads, num_seqs, {{mtp}}); dim3 reduce_block(head_size); diff --git a/csrc/cpp_itfs/pa/pa_v1.cuh b/csrc/cpp_itfs/pa/pa_v1.cuh index 02f3dfc6de..d00308ee16 100644 --- a/csrc/cpp_itfs/pa/pa_v1.cuh +++ b/csrc/cpp_itfs/pa/pa_v1.cuh @@ -34,7 +34,8 @@ template + typename AttentionVariant, + bool SLIDING_WINDOW_ENABLED> __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, block_size, num_kv_heads, @@ -61,7 +62,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ const float* q_scale_ptr, const float* k_scale_ptr, const float* v_scale_ptr, - const AttentionVariant* variant) + const AttentionVariant* variant, + const int sliding_window) { const int seq_idx = blockIdx.x; int query_loc = seq_idx * MTP; @@ -82,7 +84,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ return; } const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; - _paged_attention_kernel(block_table_seq, static_cast(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant); + _paged_attention_kernel(block_table_seq, static_cast(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, sliding_window); } // Grid: (num_heads, num_seqs). @@ -133,7 +135,8 @@ template + typename AttentionVariant, + bool SLIDING_WINDOW_ENABLED> __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, @@ -160,7 +163,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ const float* q_scale_ptr, const float* k_scale_ptr, const float* v_scale_ptr, - const AttentionVariant* variant) + const AttentionVariant* variant, + const int sliding_window) { UNREACHABLE_CODE } diff --git a/csrc/cpp_itfs/pa/pa_v1.py b/csrc/cpp_itfs/pa/pa_v1.py index 746c5cc9ea..cad347d8e9 100644 --- a/csrc/cpp_itfs/pa/pa_v1.py +++ b/csrc/cpp_itfs/pa/pa_v1.py @@ -22,6 +22,7 @@ def compile( logits_soft_cap_enabled: bool, partition_size: int = 256, mtp: int = 1, + sliding_window_enabled: bool = False, folder: str = None, ): return compile_template_op( @@ -47,6 +48,7 @@ def compile( logits_soft_cap_enabled=logits_soft_cap_enabled, partition_size=partition_size, mtp=mtp, + sliding_window_enabled=sliding_window_enabled, folder=folder, ) @@ -72,6 +74,7 @@ def paged_attention_v1( partition_size: int = 256, mtp: int = 1, q_scale=None, + sliding_window: int = 0, ): import torch from csrc.cpp_itfs.torch_utils import torch_to_c_types @@ -124,6 +127,7 @@ def paged_attention_v1( npar_loops = int(math.ceil(max_num_partitions / warpSize)) logits_soft_cap_enabled = logits_soft_cap > 0 alibi_enabled = alibi_slopes is not None + sliding_window_enabled = sliding_window > 0 func = compile( gqa_ratio, head_size, @@ -137,6 +141,7 @@ def paged_attention_v1( logits_soft_cap_enabled, partition_size, mtp, + sliding_window_enabled=sliding_window_enabled, ) alibi_slopes_ptr = ( @@ -230,6 +235,7 @@ def paged_attention_v1( kv_block_stride, kv_head_stride, kv_seq_stride, + sliding_window, stream, ) return out diff --git a/op_tests/test_pa_v1.py b/op_tests/test_pa_v1.py index 487db10a1a..306ad5b6f4 100644 --- a/op_tests/test_pa_v1.py +++ b/op_tests/test_pa_v1.py @@ -128,10 +128,13 @@ def ref_masked_attention( scale: float, attn_mask: Optional[torch.Tensor] = None, logits_soft_cap: float = 0.0, + sliding_window: int = 0, ) -> torch.Tensor: attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() if attn_mask is not None: attn_weights = attn_weights + attn_mask.float() + if sliding_window: + attn_weights[:, :, :-sliding_window] = -1e38 if 0 < logits_soft_cap: attn_weights = logits_soft_cap * torch.tanh(attn_weights / logits_soft_cap) attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) @@ -214,6 +217,7 @@ def run_torch( k_scale, v_scale, num_queries_per_kv, + sliding_window, ): output = torch.zeros_like(query) num_query_heads = query.shape[1] @@ -255,7 +259,15 @@ def run_torch( alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1) - out = ref_masked_attention(q, keys, values, scale, alibi_bias, logits_soft_cap) + out = ref_masked_attention( + q, + keys, + values, + scale, + alibi_bias, + logits_soft_cap, + sliding_window=sliding_window, + ) out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) return output, 1 @@ -278,6 +290,7 @@ def run_aiter( k_scale, v_scale, mtp=1, + sliding_window=0, ): # copied from ops.PagedAttention.forward_decode() _PARTITION_SIZE_ROCM = 256 @@ -328,6 +341,7 @@ def run_aiter( v_scale, fp8_out_scale if cpa_fp8_out else None, _PARTITION_SIZE_ROCM, + sliding_window=sliding_window, ) if cpa_fp8_out: return output.view(num_seqs, num_heads * head_size) @@ -437,6 +451,7 @@ def test_paged_attention( quant_cache_dtype: torch.dtype, seed: int, device: str, + sliding_window: int = 0, ) -> None: if pa_variant == PAVariant.Shomy: if quant_cache_dtype is not None: @@ -448,6 +463,7 @@ def test_paged_attention( or block_size != 16 or dtype is not dtypes.bf16 or quant_cache_dtype not in [None, dtypes.i8] + or sliding_window != 0 ): pytest.skip() elif pa_variant == PAVariant.Naive: @@ -523,6 +539,7 @@ def test_paged_attention( k_scale, v_scale, num_queries_per_kv, + sliding_window, ) cu_query_lens = torch.arange(0, num_seqs + 1, dtype=torch.int) @@ -546,6 +563,7 @@ def test_paged_attention( logits_soft_cap, k_scale, v_scale, + sliding_window=sliding_window, ) assert ( checkAllclose(out_golden, out_aiter, msg=f"golden vs aiter:{time_aiter}") @@ -575,6 +593,37 @@ def test_paged_attention( # f"[test] dim: {str((ctx_lens, num_seqs, num_heads, head_size)):<20}, dtype: {dtype}, finished)\n") +@pytest.mark.parametrize("ctx_lens", [1, 26, 128, 4097]) +@pytest.mark.parametrize("num_seqs", [1, 3, 31, 128]) +@pytest.mark.parametrize("num_heads", [(8, 1), (32, 4)]) +@pytest.mark.parametrize("use_alibi", [False, True]) +@pytest.mark.parametrize("sliding_window", [0, 10]) +def test_paged_attention_sliding_window( + ctx_lens: int, + num_seqs: int, + num_heads: Tuple[int, int], + use_alibi: bool, + sliding_window: int, +) -> None: + test_paged_attention( + ctx_lens, + num_seqs, + num_heads, + 128, + use_alibi, + block_size=16, + dtype=dtypes.fp16, + kv_cache_dtype="auto", + kv_cache_layout="NHD", + logits_soft_cap=0.0, + pa_variant=PAVariant.Shomy, + quant_cache_dtype=None, + seed=0, + device="cuda:0", + sliding_window=sliding_window, + ) + + if __name__ == "__main__": parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, @@ -643,4 +692,5 @@ def test_paged_attention( quant_cache_dtype, 0, "cuda:0", + 10, )