Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def paged_attention_v1(
fp8_out_scale: Optional[torch.Tensor] = None,
partition_size: int = 256,
mtp: int = 1,
sliding_window: int = 0,
) -> torch.Tensor:
paged_attention_v1_core(
out,
Expand All @@ -211,6 +212,7 @@ def paged_attention_v1(
fp8_out_scale,
partition_size,
mtp,
sliding_window=sliding_window,
)
return out

Expand Down
27 changes: 25 additions & 2 deletions csrc/cpp_itfs/pa/pa_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ template <typename scalar_t,
bool ALIBI_ENABLED,
int GQA_RATIO,
int MTP,
typename AttentionVariant>
typename AttentionVariant,
bool SLIDING_WINDOW_ENABLED>
__inline__ __device__ void
_paged_attention_kernel(const int* block_table_seq,
const int64_t query_loc,
Expand All @@ -36,7 +37,8 @@ _paged_attention_kernel(const int* block_table_seq,
const float* q_scale_ptr,
const float* k_scale_ptr,
const float* v_scale_ptr,
const AttentionVariant* variant)
const AttentionVariant* variant,
const int sliding_window = 0)
{
const int seq_idx = blockIdx.x;
const int partition_idx = blockIdx.y;
Expand Down Expand Up @@ -439,6 +441,27 @@ _paged_attention_kernel(const int* block_table_seq,
}
}
}
// apply sliding window
if constexpr(SLIDING_WINDOW_ENABLED)
{
for(int token_depth = 0; token_depth < TLOOP; token_depth++)
{
const int local_token_idx = qkout_token_idx + token_depth * 16;
for(int mtp = 0; mtp < mtp_loop; mtp++)
{
for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++)
{
for(int i = 0; i < 4; i++)
{
float tmp = d_out[gqa_ratio_loop][mtp][token_depth][i];
if (local_token_idx + i < context_len - sliding_window)
tmp = -FLT_MAX;
d_out[gqa_ratio_loop][mtp][token_depth][i] = tmp;
}
}
}
}
}
// apply soft-capping to logits
for(int token_depth = 0; token_depth < TLOOP; token_depth++)
{
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpp_itfs/pa/pa_ragged.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
}
const int64_t query_loc = static_cast<int64_t>(seq_idx * MTP);
const int* block_table_seq = kv_page_indices + kv_indptr[seq_idx];
_paged_attention_kernel<scalar_t, cache_t, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, NUM_THREADS, ALIBI_ENABLED, GQA_RATIO, MTP, AttentionVariant>(block_table_seq, query_loc, context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant);
_paged_attention_kernel<scalar_t, cache_t, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, NUM_THREADS, ALIBI_ENABLED, GQA_RATIO, MTP, AttentionVariant, false>(block_table_seq, query_loc, context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, 0);
}

// Grid: (num_heads, num_seqs, mtp).
Expand Down
9 changes: 7 additions & 2 deletions csrc/cpp_itfs/pa/pa_v1.cpp.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ void {{func_name}}(void* out_ptr,
const int kv_block_stride,
const int kv_head_stride,
const int kv_seq_stride,
const int sliding_window,
void* stream);
}

Expand Down Expand Up @@ -53,6 +54,7 @@ void {{func_name}}(void* out_ptr,
const int kv_block_stride,
const int kv_head_stride,
const int kv_seq_stride,
const int sliding_window,
void* stream)
{
constexpr int head_size = {{head_size}};
Expand Down Expand Up @@ -86,7 +88,9 @@ void {{func_name}}(void* out_ptr,
NTHR,
{{"true" if alibi_enabled else "false"}},
gqa_ratio,
{{mtp}}>
{{mtp}},
decltype(variant),
{{"true" if sliding_window_enabled else "false"}}>
<<<grid, block, 0, reinterpret_cast<hipStream_t>(stream)>>>(reinterpret_cast<{{dtype}}*>(query_ptr),
reinterpret_cast<{{kv_dtype}}*>(key_cache_ptr),
reinterpret_cast<{{kv_dtype}}*>(value_cache_ptr),
Expand All @@ -108,7 +112,8 @@ void {{func_name}}(void* out_ptr,
q_scale_ptr,
k_scale_ptr,
v_scale_ptr,
&variant);
&variant,
sliding_window);

dim3 reduce_grid(num_heads, num_seqs, {{mtp}});
dim3 reduce_block(head_size);
Expand Down
14 changes: 9 additions & 5 deletions csrc/cpp_itfs/pa/pa_v1.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ template <typename scalar_t,
bool ALIBI_ENABLED,
int GQA_RATIO,
int MTP,
typename AttentionVariant>
typename AttentionVariant,
bool SLIDING_WINDOW_ENABLED>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, block_size, num_kv_heads,
Expand All @@ -61,7 +62,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
const float* q_scale_ptr,
const float* k_scale_ptr,
const float* v_scale_ptr,
const AttentionVariant* variant)
const AttentionVariant* variant,
const int sliding_window)
{
const int seq_idx = blockIdx.x;
int query_loc = seq_idx * MTP;
Expand All @@ -82,7 +84,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
return;
}
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
_paged_attention_kernel<scalar_t, cache_t, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, NUM_THREADS, ALIBI_ENABLED, GQA_RATIO, MTP, AttentionVariant>(block_table_seq, static_cast<int64_t>(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant);
_paged_attention_kernel<scalar_t, cache_t, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, NUM_THREADS, ALIBI_ENABLED, GQA_RATIO, MTP, AttentionVariant, SLIDING_WINDOW_ENABLED>(block_table_seq, static_cast<int64_t>(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, sliding_window);
}

// Grid: (num_heads, num_seqs).
Expand Down Expand Up @@ -133,7 +135,8 @@ template <typename scalar_t,
bool ALIBI_ENABLED,
int GQA_RATIO,
int MTP,
typename AttentionVariant>
typename AttentionVariant,
bool SLIDING_WINDOW_ENABLED>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
Expand All @@ -160,7 +163,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
const float* q_scale_ptr,
const float* k_scale_ptr,
const float* v_scale_ptr,
const AttentionVariant* variant)
const AttentionVariant* variant,
const int sliding_window)
{
UNREACHABLE_CODE
}
Expand Down
6 changes: 6 additions & 0 deletions csrc/cpp_itfs/pa/pa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def compile(
logits_soft_cap_enabled: bool,
partition_size: int = 256,
mtp: int = 1,
sliding_window_enabled: bool = False,
folder: str = None,
):
return compile_template_op(
Expand All @@ -47,6 +48,7 @@ def compile(
logits_soft_cap_enabled=logits_soft_cap_enabled,
partition_size=partition_size,
mtp=mtp,
sliding_window_enabled=sliding_window_enabled,
folder=folder,
)

Expand All @@ -72,6 +74,7 @@ def paged_attention_v1(
partition_size: int = 256,
mtp: int = 1,
q_scale=None,
sliding_window: int = 0,
):
import torch
from csrc.cpp_itfs.torch_utils import torch_to_c_types
Expand Down Expand Up @@ -124,6 +127,7 @@ def paged_attention_v1(
npar_loops = int(math.ceil(max_num_partitions / warpSize))
logits_soft_cap_enabled = logits_soft_cap > 0
alibi_enabled = alibi_slopes is not None
sliding_window_enabled = sliding_window > 0
func = compile(
gqa_ratio,
head_size,
Expand All @@ -137,6 +141,7 @@ def paged_attention_v1(
logits_soft_cap_enabled,
partition_size,
mtp,
sliding_window_enabled=sliding_window_enabled,
)

alibi_slopes_ptr = (
Expand Down Expand Up @@ -230,6 +235,7 @@ def paged_attention_v1(
kv_block_stride,
kv_head_stride,
kv_seq_stride,
sliding_window,
stream,
)
return out
Expand Down
52 changes: 51 additions & 1 deletion op_tests/test_pa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,13 @@ def ref_masked_attention(
scale: float,
attn_mask: Optional[torch.Tensor] = None,
logits_soft_cap: float = 0.0,
sliding_window: int = 0,
) -> torch.Tensor:
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
if attn_mask is not None:
attn_weights = attn_weights + attn_mask.float()
if sliding_window:
attn_weights[:, :, :-sliding_window] = -1e38
if 0 < logits_soft_cap:
attn_weights = logits_soft_cap * torch.tanh(attn_weights / logits_soft_cap)
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
Expand Down Expand Up @@ -214,6 +217,7 @@ def run_torch(
k_scale,
v_scale,
num_queries_per_kv,
sliding_window,
):
output = torch.zeros_like(query)
num_query_heads = query.shape[1]
Expand Down Expand Up @@ -255,7 +259,15 @@ def run_torch(
alibi_bias = (position_ids - seq_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)

out = ref_masked_attention(q, keys, values, scale, alibi_bias, logits_soft_cap)
out = ref_masked_attention(
q,
keys,
values,
scale,
alibi_bias,
logits_soft_cap,
sliding_window=sliding_window,
)
out = out.view(num_query_heads, head_size)
output[i].copy_(out, non_blocking=True)
return output, 1
Expand All @@ -278,6 +290,7 @@ def run_aiter(
k_scale,
v_scale,
mtp=1,
sliding_window=0,
):
# copied from ops.PagedAttention.forward_decode()
_PARTITION_SIZE_ROCM = 256
Expand Down Expand Up @@ -328,6 +341,7 @@ def run_aiter(
v_scale,
fp8_out_scale if cpa_fp8_out else None,
_PARTITION_SIZE_ROCM,
sliding_window=sliding_window,
)
if cpa_fp8_out:
return output.view(num_seqs, num_heads * head_size)
Expand Down Expand Up @@ -437,6 +451,7 @@ def test_paged_attention(
quant_cache_dtype: torch.dtype,
seed: int,
device: str,
sliding_window: int = 0,
) -> None:
if pa_variant == PAVariant.Shomy:
if quant_cache_dtype is not None:
Expand All @@ -448,6 +463,7 @@ def test_paged_attention(
or block_size != 16
or dtype is not dtypes.bf16
or quant_cache_dtype not in [None, dtypes.i8]
or sliding_window != 0
):
pytest.skip()
elif pa_variant == PAVariant.Naive:
Expand Down Expand Up @@ -523,6 +539,7 @@ def test_paged_attention(
k_scale,
v_scale,
num_queries_per_kv,
sliding_window,
)
cu_query_lens = torch.arange(0, num_seqs + 1, dtype=torch.int)

Expand All @@ -546,6 +563,7 @@ def test_paged_attention(
logits_soft_cap,
k_scale,
v_scale,
sliding_window=sliding_window,
)
assert (
checkAllclose(out_golden, out_aiter, msg=f"golden vs aiter:{time_aiter}")
Expand Down Expand Up @@ -575,6 +593,37 @@ def test_paged_attention(
# f"[test] dim: {str((ctx_lens, num_seqs, num_heads, head_size)):<20}, dtype: {dtype}, finished)\n")


@pytest.mark.parametrize("ctx_lens", [1, 26, 128, 4097])
@pytest.mark.parametrize("num_seqs", [1, 3, 31, 128])
@pytest.mark.parametrize("num_heads", [(8, 1), (32, 4)])
@pytest.mark.parametrize("use_alibi", [False, True])
@pytest.mark.parametrize("sliding_window", [0, 10])
def test_paged_attention_sliding_window(
ctx_lens: int,
num_seqs: int,
num_heads: Tuple[int, int],
use_alibi: bool,
sliding_window: int,
) -> None:
test_paged_attention(
ctx_lens,
num_seqs,
num_heads,
128,
use_alibi,
block_size=16,
dtype=dtypes.fp16,
kv_cache_dtype="auto",
kv_cache_layout="NHD",
logits_soft_cap=0.0,
pa_variant=PAVariant.Shomy,
quant_cache_dtype=None,
seed=0,
device="cuda:0",
sliding_window=sliding_window,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter,
Expand Down Expand Up @@ -643,4 +692,5 @@ def test_paged_attention(
quant_cache_dtype,
0,
"cuda:0",
10,
)
Loading