Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,105 @@ def pa_fwd_asm(
) -> torch.Tensor: ...


def _should_use_asm_kernel(
num_seqs: int,
num_heads: int,
kv_cache_tensor_dtype: torch.dtype,
) -> bool:
# TODO: HIP kernel yet isn't supporting fp8 scales in asm layout.
if (
kv_cache_tensor_dtype == torch.int8
or kv_cache_tensor_dtype == torch.float8_e4m3fnuz
):
return True

# Get GPU compute units (CUs)
gpu = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(gpu)
cu_num = device_properties.multi_processor_count
# ASM kernel becomes relevant, once the total_heads is sufficiently large compared to CUs
total_heads = num_seqs * num_heads
return total_heads > 2 * cu_num


def paged_attention_common(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
workspace_buffer: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_tables_stride0: int,
logits_soft_cap: float,
scale: float,
max_qlen: int = 1,
max_seq_len: int = 1,
cu_query_lens: Optional[torch.Tensor] = None,
K_QScale: Optional[torch.Tensor] = None,
V_QScale: Optional[torch.Tensor] = None,
out_: Optional[torch.Tensor] = None,
qo_indptr: Optional[torch.Tensor] = None,
high_precision: Optional[
int
] = 1, # [0, 1, 2] 2 is the highest precision, this is only for fp8 kvcache
kernelName: Optional[str] = None,
kv_cache_dtype: str = "auto",
kv_cache_tensor_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""
Paged attention forward pass with automatic kernel selection.
ASM is favored for int8 kv caches, for short ctx_len, or when the workload exceeds
the heuristic thresholds for larger ctx_len values.
"""
kv_cache_tensor_dtype = (
kv_cache_tensor_dtype if kv_cache_tensor_dtype is not None else K.dtype
)
num_seqs, num_heads, head_size = Q.shape

# Route to ASM kernel based on the heuristic above.
use_asm_kernel = _should_use_asm_kernel(num_seqs, num_heads, kv_cache_tensor_dtype)

if use_asm_kernel:
output = pa_fwd_asm(
Q,
K,
V,
block_tables,
context_lens,
block_tables_stride0,
max_qlen,
K_QScale,
V_QScale,
out_,
qo_indptr,
high_precision,
kernelName,
)
return output

# Use HIP kernel for smaller workloads (5D V cache)
output = out_ if out_ is not None else torch.empty_like(Q)
paged_attention_v1(
output,
workspace_buffer,
Q,
K,
V,
scale,
block_tables,
cu_query_lens,
context_lens,
max_seq_len,
None, # alibi_slopes
kv_cache_dtype,
"HND",
logits_soft_cap,
K_QScale,
V_QScale,
)
return output


def gen_pa_ps_fwd_asm(
Q: torch.Tensor,
K: torch.Tensor,
Expand Down
158 changes: 101 additions & 57 deletions csrc/cpp_itfs/pa/pa_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ template <typename scalar_t,
int GQA_RATIO,
int MTP,
typename AttentionVariant,
bool SLIDING_WINDOW_ENABLED>
bool SLIDING_WINDOW_ENABLED,
bool USE_5D_VCACHE = false>
__inline__ __device__ void
_paged_attention_kernel(const int* block_table_seq,
const int64_t query_loc,
Expand Down Expand Up @@ -221,10 +222,20 @@ _paged_attention_kernel(const int* block_table_seq,
for(int head_loop = 0; head_loop < HEAD_LOOP; head_loop++)
{
const int head_elem =
row_head_elem + qkhe_depth * QKHE_PER_FETCH + head_loop * HEAD_SIZE_PER_LOOP;
const int offset1 = head_elem / KX;
const int offset2 = head_elem % KX;
const cache_t* k_fetch_ptr = k_ptr3 + offset1 * KX + offset2;
row_head_elem + qkhe_depth * QKHE_PER_FETCH + head_loop * HEAD_SIZE_PER_LOOP;
const int offset1 = head_elem / KX;
const int offset2 = head_elem % KX;
const cache_t* k_fetch_ptr = [&] {
if constexpr(USE_5D_VCACHE)
{
const int head_stride = BLOCK_SIZE * kv_seq_stride;
return k_ptr3 + offset1 * head_stride + offset2;
}
else
{
return k_ptr3 + offset1 * KX + offset2;
}
}();
const _B16x8* k_fetch_ptr_16B = reinterpret_cast<const _B16x8*>(k_fetch_ptr);
if constexpr(NT_KV_LOAD)
{
Expand Down Expand Up @@ -291,27 +302,62 @@ _paged_attention_kernel(const int* block_table_seq,
static_assert(VBLOCKS_PER_LANE == VTLANELOOP,
"make sure we can keep un-shuffled data in Vlocal as well");

const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride +
((threadIdx.x / n_thread_per_block) % BLOCK_SIZE) * kv_seq_stride;
constexpr int V_X = CONTIGUOUS_KV_ELEMS_16B_LOAD;

// v fetches are 16head elems across lanes x 16 tokens per lane
for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++)
if constexpr(USE_5D_VCACHE)
{
for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++)
constexpr int V5D_TOKEN_GRP_STRIDE = HEAD_SIZE * V_X;
constexpr int V5D_HEAD_STRIDE = V_X;

for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++)
{
for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++)
for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++)
{
const int vlds_col_idx = laneid % n_thread_per_block;
const int vhead_elem =
vhe_depth * NWARPS * 16 + vlds_col_idx * CONTIGUOUS_KV_ELEMS_16B_LOAD;
const cache_t* v_ptr2 = v_ptr + vhead_elem;

const int64_t vblock_number =
static_cast<int64_t>(vphysical_block_number[vtoken_depth][vblock_depth]);
const cache_t* v_fetch_ptr = v_ptr2 + (vblock_number * kv_block_stride);
for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++)
{
const int vlocal_token_idx = rowid * VTOKENS_PER_LANE + vfetch_depth * V_X;
const int global_token_idx = partition_start_token_idx +
vtoken_depth * TOKENS_PER_WARP +
vlocal_token_idx;
const int block_idx = global_token_idx / BLOCK_SIZE;
const int token_in_block = global_token_idx % BLOCK_SIZE;
const int token_grp = token_in_block / V_X;
const int safe_block_idx =
(global_token_idx < context_len) ? block_idx : last_ctx_block;
const int physical_block = block_table_seq[safe_block_idx];
const int head_elem = (warpid * 16 + lane16id) + vhe_depth * NWARPS * 16;
const int64_t v_offset = static_cast<int64_t>(physical_block) * kv_block_stride +
wg_start_kv_head_idx * kv_head_stride +
token_grp * V5D_TOKEN_GRP_STRIDE +
head_elem * V5D_HEAD_STRIDE;
const cache_t* v_fetch_ptr = v_cache + v_offset;
Vlocal[vtoken_depth][vhe_depth][vfetch_depth] =
*reinterpret_cast<const _B16x8*>(v_fetch_ptr);
}
}
}
}
else
{
const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride +
((threadIdx.x / n_thread_per_block) % BLOCK_SIZE) * kv_seq_stride;

Vlocal[vtoken_depth][vhe_depth][vblock_depth] =
*reinterpret_cast<const _B16x8*>(v_fetch_ptr);
for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++)
{
for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++)
{
for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++)
{
const int vlds_col_idx = laneid % n_thread_per_block;
const int vhead_elem =
vhe_depth * NWARPS * 16 + vlds_col_idx * V_X;
const cache_t* v_ptr2 = v_ptr + vhead_elem;
const int64_t vblock_number =
static_cast<int64_t>(vphysical_block_number[vtoken_depth][vblock_depth]);
const cache_t* v_fetch_ptr = v_ptr2 + (vblock_number * kv_block_stride);
Vlocal[vtoken_depth][vhe_depth][vblock_depth] =
*reinterpret_cast<const _B16x8*>(v_fetch_ptr);
}
}
}
}
Expand Down Expand Up @@ -675,54 +721,52 @@ _paged_attention_kernel(const int* block_table_seq,
constexpr int ELEMS8_ELEMS4_RATIO = 8 / 4;
constexpr int ELEMS16_ELEMS8_RATIO = 16 / 8;

for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++)
if constexpr(!USE_5D_VCACHE)
{
for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++)
for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++)
{
// 1. store data into LDS
for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++)
for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++)
{
const int vlds_col_idx = laneid % n_thread_per_block;
const int vlocal_token_idx =
vblock_depth * k_thread_per_block + threadIdx.x / n_thread_per_block;
*reinterpret_cast<_B16x8*>(vlds_ptr +
(/*row=*/vlocal_token_idx * n_thread_per_block +
/*col=*/vlds_col_idx) *
16) = Vlocal[vtoken_depth][vhe_depth][vblock_depth];
}
__syncthreads();
for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++)
{
const int vlds_col_idx = laneid % n_thread_per_block;
const int vlocal_token_idx =
vblock_depth * k_thread_per_block + threadIdx.x / n_thread_per_block;
*reinterpret_cast<_B16x8*>(vlds_ptr +
(/*row=*/vlocal_token_idx * n_thread_per_block +
/*col=*/vlds_col_idx) *
16) = Vlocal[vtoken_depth][vhe_depth][vblock_depth];
}
__syncthreads();

// 2. load data from LDS (transposed), then do multification
for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++)
{
const int vlocal_head_elem = warpid * 16 + lane16id;
for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++)
{
const int vlocal_head_elem = warpid * 16 + lane16id;

const int vlds_col_idx = vlocal_head_elem / CONTIGUOUS_KV_ELEMS_16B_LOAD;
const int vlds_elem_idx = vlocal_head_elem % CONTIGUOUS_KV_ELEMS_16B_LOAD;
const int vlds_col_idx = vlocal_head_elem / CONTIGUOUS_KV_ELEMS_16B_LOAD;
const int vlds_elem_idx = vlocal_head_elem % CONTIGUOUS_KV_ELEMS_16B_LOAD;

const int vlocal_token_idx =
rowid * VTOKENS_PER_LANE + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD;
const int vlocal_token_idx =
rowid * VTOKENS_PER_LANE + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD;

// read data points individually and save them into array
cache_t elems[CONTIGUOUS_KV_ELEMS_16B_LOAD];
for(int d2 = 0; d2 < CONTIGUOUS_KV_ELEMS_16B_LOAD; ++d2)
{
const cache_t* fetched_elems = reinterpret_cast<const cache_t*>(
vlds_ptr + (/*row=*/(vlocal_token_idx + d2) * n_thread_per_block +
/*col=*/vlds_col_idx) *
16);
cache_t elems[CONTIGUOUS_KV_ELEMS_16B_LOAD];
for(int d2 = 0; d2 < CONTIGUOUS_KV_ELEMS_16B_LOAD; ++d2)
{
const cache_t* fetched_elems = reinterpret_cast<const cache_t*>(
vlds_ptr + (/*row=*/(vlocal_token_idx + d2) * n_thread_per_block +
/*col=*/vlds_col_idx) *
16);
elems[d2] = fetched_elems[vlds_elem_idx];
}

elems[d2] = fetched_elems[vlds_elem_idx];
Vlocal[vtoken_depth][vhe_depth][vfetch_depth] =
*reinterpret_cast<const _B16x8*>(elems);
}

// copy all the read data points together
Vlocal[vtoken_depth][vhe_depth][vfetch_depth] =
*reinterpret_cast<const _B16x8*>(elems);
__syncthreads();
}
__syncthreads();
}
}

// For 5D, Vlocal is already in the correct format from the load phase
_B16x4 outelems[GQA_RATIO_LOOP][MTP_PER_THREAD][VHELOOP];

// Softmax V mfma
Expand Down
3 changes: 2 additions & 1 deletion csrc/cpp_itfs/pa/pa_v1.cpp.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ void {{func_name}}(void* out_ptr,
gqa_ratio,
{{mtp}},
decltype(variant),
{{"true" if sliding_window_enabled else "false"}}>
{{"true" if sliding_window_enabled else "false"}},
{{"true" if use_5d_vcache else "false"}}>
<<<grid, block, 0, reinterpret_cast<hipStream_t>(stream)>>>(reinterpret_cast<{{dtype}}*>(query_ptr),
reinterpret_cast<{{kv_dtype}}*>(key_cache_ptr),
reinterpret_cast<{{kv_dtype}}*>(value_cache_ptr),
Expand Down
19 changes: 10 additions & 9 deletions csrc/cpp_itfs/pa/pa_v1.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ template <typename scalar_t,
int GQA_RATIO,
int MTP,
typename AttentionVariant,
bool SLIDING_WINDOW_ENABLED>
bool SLIDING_WINDOW_ENABLED,
bool USE_5D_VCACHE = false>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, block_size, num_kv_heads,
// head_size]
const cache_t* __restrict__ v_cache, // [num_blocks, block_size, num_kv_heads,
// head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
Comment thread
JohnNikolay84 marked this conversation as resolved.
Outdated
const cache_t* __restrict__ v_cache, // 4D: [num_blocks, num_kv_heads, head_size, block_size]
// 5D: [num_blocks, num_kv_heads, block_size/x, head_size, x]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ cu_query_lens, // [num_seqs+1]
Expand Down Expand Up @@ -84,7 +84,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
return;
}
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
_paged_attention_kernel<scalar_t, cache_t, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, NUM_THREADS, ALIBI_ENABLED, GQA_RATIO, MTP, AttentionVariant, SLIDING_WINDOW_ENABLED>(block_table_seq, static_cast<int64_t>(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, sliding_window);
_paged_attention_kernel<scalar_t, cache_t, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, NUM_THREADS, ALIBI_ENABLED, GQA_RATIO, MTP, AttentionVariant, SLIDING_WINDOW_ENABLED, USE_5D_VCACHE>(block_table_seq, static_cast<int64_t>(query_loc), context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, sliding_window);
}

// Grid: (num_heads, num_seqs).
Expand Down Expand Up @@ -136,13 +136,14 @@ template <typename scalar_t,
int GQA_RATIO,
int MTP,
typename AttentionVariant,
bool SLIDING_WINDOW_ENABLED>
bool SLIDING_WINDOW_ENABLED,
bool USE_5D_VCACHE = false>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const cache_t* __restrict__ v_cache, // 4D: [num_blocks, num_kv_heads, head_size, block_size]
// 5D: [num_blocks, num_kv_heads, block_size/x, head_size, x]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ cu_query_lens, // [num_seqs+1]
Expand Down
Loading