From 4f6bd68d0ba6fed9f42253c31b35bee708bae87d Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Fri, 21 Feb 2025 23:45:13 -0500 Subject: [PATCH 1/7] Enable custom paged attention kernel for Navi3x Signed-off-by: Hosang Yoon --- csrc/rocm/attention.cu | 1270 +++++++++++++++++--- csrc/rocm/ops.h | 2 +- csrc/rocm/torch_bindings.cpp | 3 +- tests/kernels/attention/test_attention.py | 13 +- vllm/_custom_ops.py | 3 +- vllm/attention/backends/rocm_flash_attn.py | 4 +- 6 files changed, 1115 insertions(+), 180 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 2c3cae95e7f5..7f447d7ec855 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -29,6 +29,10 @@ #define __HIP__MI300_MI250__ #endif +#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__)) + #define __HIP__NAVI3__ +#endif + #if defined(NDEBUG) #undef NDEBUG #include @@ -1479,195 +1483,939 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } } -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#elif defined(__HIP__NAVI3__) -// clang-format off -template -__global__ -__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int* __restrict__ query_start_loc_ptr, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, const float* k_scale, const float* v_scale) { - UNREACHABLE_CODE +using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float; +using float16x4 = + __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; +typedef float16x4 _Half4; +typedef struct _Half8 { + _Half4 xy[2]; +} _Half8; + +using bit16_t = uint16_t; +using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; +typedef bit16x4 _B16x4; + +using bit16x8 = __attribute__((__vector_size__(8 * sizeof(uint16_t)))) uint16_t; +union b16x8_u { + bit16x8 u16x8; + _B16x4 xy[2]; +}; +typedef b16x8_u _B16x8; + +using bit16x16 = + __attribute__((__vector_size__(16 * sizeof(uint16_t)))) uint16_t; +union b16x16_u { + bit16x16 u16x16; + _B16x8 xy[2]; +}; +typedef b16x16_u _B16x16; + +using _B8x8 = uint2; +using bit8_t = uint8_t; + +typedef struct _B8x16 { + _B8x8 xy[2]; +} _B8x16; + +template +__device__ __forceinline__ floatx8 gcn_wmma16x16x16_instr(const bit16x16& inpA, + const bit16x16& inpB, + const floatx8& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(inpA, inpB, inpC); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(inpA, inpB, inpC); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float(const T& inp) { + if constexpr (std::is_same::value) { + return (float)inp; + } else if constexpr (std::is_same::value) { + return __bfloat162float(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ T from_float(const float& inp) { + if constexpr (std::is_same::value) { + return (_Float16)inp; + } else if constexpr (std::is_same::value) { + return __float2bfloat16(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) { + if constexpr (std::is_same::value) { + union h2cvt { + __half2 h2[4]; + _B16x8 b16x8; + } u; + u.h2[0] = __float22half2_rn(make_float2(inp[0], inp[1])); + u.h2[1] = __float22half2_rn(make_float2(inp[2], inp[3])); + u.h2[2] = __float22half2_rn(make_float2(inp[4], inp[5])); + u.h2[3] = __float22half2_rn(make_float2(inp[6], inp[7])); + return u.b16x8; + } else if constexpr (std::is_same::value) { + union b2cvt { + __hip_bfloat162 b2[4]; + _B16x8 b16x8; + } u; + + u.b2[0] = __float22bfloat162_rn(make_float2(inp[0], inp[1])); + u.b2[1] = __float22bfloat162_rn(make_float2(inp[2], inp[3])); + u.b2[2] = __float22bfloat162_rn(make_float2(inp[4], inp[5])); + u.b2[3] = __float22bfloat162_rn(make_float2(inp[6], inp[7])); + + return u.b16x8; + } else { + static_assert(false, "unsupported 16b dtype"); + } } template + int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO> __global__ -__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int* __restrict__ query_start_loc_ptr, // [num_seqs] +__launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] int max_ctx_blocks, const float* k_scale, const float* v_scale) { - UNREACHABLE_CODE -} + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11 + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane2id = laneid % 2; + const int lane4id = laneid % 4; + const int lane16id = laneid % 16; + const int rowid = laneid / 16; -// Grid: (num_heads, num_seqs). -template -__global__ -__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( - OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] - const int* __restrict__ query_start_loc_ptr, // [num_seqs] - const int max_num_partitions) { - UNREACHABLE_CODE -} -// clang-format on + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; -#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + constexpr int T_PAR_SIZE = 256; // token partition size set to 256 -#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ - paged_attention_ll4mi_QKV_mfma16_kernel \ - <<>>( \ - query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ - max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ - kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ - max_ctx_blocks, k_scale_ptr, v_scale_ptr); + const int max_num_partitions = gridDim.y; -#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ - paged_attention_ll4mi_QKV_mfma4_kernel \ - <<>>( \ - query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ - max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ - kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ - max_ctx_blocks, k_scale_ptr, v_scale_ptr); + const int context_len = context_lens[seq_idx]; // length of a seq -#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ - paged_attention_ll4mi_reduce_kernel \ - <<>>( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ - context_lens_ptr, query_start_loc_ptr, max_num_partitions); + const int partition_start_token_idx = partition_idx * T_PAR_SIZE; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } -template -void paged_attention_custom_launcher( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, const int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& context_lens, - const std::optional& query_start_loc, int max_context_len, - const std::optional& alibi_slopes, torch::Tensor& k_scale, - torch::Tensor& v_scale) { - int num_seqs = block_tables.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); + constexpr int GQA_RATIO2 = DIVIDE_ROUND_UP(GQA_RATIO, 2); - // NOTE: query start location is optional for V0 decode should not be used. - // If batch contains mix of prefills and decode, prefills should be skipped. - const int* query_start_loc_ptr = - query_start_loc - ? reinterpret_cast(query_start_loc.value().data_ptr()) - : nullptr; + __shared__ float shared_qk_max[NWARPS][16 + 1]; + __shared__ float shared_exp_sum[NWARPS][16 + 1]; + // shared_logits is used for multiple purposes + __shared__ _B16x16 shared_logits[NWARPS][2][16][2]; - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + // for QK wmma16x16, layout is QHead/Tokenx16 across every 16 lanes, + // 32 Bytes HeadElements in each lane, 2x16B HeadElements across a row of warp + constexpr int ROWS_PER_WARP = + WARP_SIZE / 16 / 2; // rows refers to 16 lanes; refer dpp terminology + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = + 16 / sizeof(cache_t); // 8 for 16 bit cache type, 16 for 8 bit types + constexpr int QKHE_PER_FETCH = + CONTIGUOUS_KV_ELEMS_16B_LOAD * + ROWS_PER_WARP; // each fetch across a warp fetches these many elements + constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 2xQKHE_16B across + // warp - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); - const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); - const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); - OUTT* out_ptr = reinterpret_cast(out.data_ptr()); + _B16x16 Qlocal[QKHELOOP / 2]; // note that 16 contiguous elements of Q should + // be fetched per lane for 16 bit cache types - const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); - // partition size is fixed at 256 since both mfma4 and mfma16 kernels support - // it mfma4 kernel also supports partition size 512 - constexpr int PARTITION_SIZE = 256; - const int max_num_partitions = - DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); - const int gqa_ratio = num_heads / num_kv_heads; - assert(num_heads % num_kv_heads == 0); - assert(head_size == HEAD_SIZE); + constexpr int TOKENS_PER_WARP = + T_PAR_SIZE / + NWARPS; // sub partition of tokens per warp for qk calculation + constexpr int TLOOP = + TOKENS_PER_WARP / + 16; // each wmma16x16x16 instruction processes 16 tokens - constexpr int NTHR = 256; - dim3 grid(num_seqs, max_num_partitions, num_kv_heads); - dim3 block(NTHR); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + _B16x16 Klocal[TLOOP] + [QKHELOOP / 2]; // can be interpreted as B8x16 for 8 bit types - // mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 - switch (gqa_ratio) { - case 1: - LAUNCH_CUSTOM_ATTENTION_MFMA4(1); - break; - case 2: - LAUNCH_CUSTOM_ATTENTION_MFMA4(2); - break; - case 3: - LAUNCH_CUSTOM_ATTENTION_MFMA4(3); - break; - case 4: - LAUNCH_CUSTOM_ATTENTION_MFMA4(4); - break; - case 5: - LAUNCH_CUSTOM_ATTENTION_MFMA16(5); - break; - case 6: - LAUNCH_CUSTOM_ATTENTION_MFMA16(6); - break; - case 7: - LAUNCH_CUSTOM_ATTENTION_MFMA16(7); - break; - case 8: + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + const int total_num_heads = gridDim.z * GQA_RATIO; + + // for QK wmma, tokens in multiples of TOKENS_PER_WARP are spread across warps + // each wmma takes QH16xT16x16HE across warp + // repeat wmma across QKHELOOP dimension + // output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens + // across 2 rows x 8 tokens per lane + + if (GQA_RATIO == 1) { + const int local_qhead_idx = lane16id % GQA_RATIO; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const int64_t seq_idx64 = static_cast(seq_idx); + const scalar_t* q_ptr = + q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; + if (lane16id < GQA_RATIO) { + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) { + const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * QKHE_PER_FETCH * 2; + const _B16x16* q_fetch_ptr_32B = + reinterpret_cast(q_fetch_ptr); + Qlocal[qkhe_depth] = *q_fetch_ptr_32B; + } + } + } else { + // fetch Q in shared across warps and then write to registers + const int local_qhead_idx = 2 * warpid + rowid; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const int64_t seq_idx64 = static_cast(seq_idx); + const scalar_t* q_ptr = + q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; + + const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; + if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) { + const scalar_t* q_fetch_ptr = q_ptr + qhead_element; + const _B16x8* q_fetch_ptr_16B = + reinterpret_cast(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; + + const int offset1 = + lane16id / + 2; // 16 contiguous chunks of head elems are spread across 8x2lanes + shared_logits[offset1][lane2id][local_qhead_idx][0].xy[0] = tmp; + } + + __syncthreads(); + + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) { + Qlocal[qkhe_depth].xy[0] = + shared_logits[qkhe_depth][0][lane16id % GQA_RATIO][0].xy[0]; + Qlocal[qkhe_depth].xy[1] = + shared_logits[qkhe_depth][1][lane16id % GQA_RATIO][0].xy[0]; + } + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; + + int kphysical_block_number[TLOOP]; + + // fetch k physical block numbers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kblock_idx = (kglobal_token_idx < context_len) + ? kglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; + } + + constexpr int KX = 16 / sizeof(cache_t); + const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + + const int row_head_elem = 0; + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int64_t kblock_number = + static_cast(kphysical_block_number[token_depth]); + const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; + const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; + + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; + const _B16x8* k_fetch_ptr_16B = + reinterpret_cast(k_fetch_ptr); + Klocal[token_depth][qkhe_depth / 2].xy[qkhe_depth % 2] = *k_fetch_ptr_16B; + } + } + + constexpr int VTOKENS_PER_LANE = + TOKENS_PER_WARP / ROWS_PER_WARP; // 32/1 = 32 vtokens per lane + constexpr int VBLOCKS_PER_LANE = 2; // assumes block size >=16 + constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps + constexpr int VTLANELOOP = DIVIDE_ROUND_UP( + VTOKENS_PER_LANE, + CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes + // minimum block size is 16 + constexpr int VHELOOP = + HEAD_SIZE / 16 / NWARPS; // head_size distributed across warps; each wmma + // instr works on 16 head elements + + int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + + // fetch v physical block numbers + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; + vblock_depth++) { + const int vlocal_token_idx = + vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP + + vblock_depth * BLOCK_SIZE; + const int vglobal_token_idx = + partition_start_token_idx + vlocal_token_idx; + const int vblock_idx = (vglobal_token_idx < context_len) + ? vglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + vphysical_block_number[vtoken_depth][vblock_depth] = + block_table_seq[vblock_idx]; + } + } + + _B16x16 Vlocal[VTLOOP][VHELOOP] + [VTLANELOOP / 2]; // this can be interpreted as B8x16 too + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + // v fetches are 16head elems across lanes x (16x2) tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id; + const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int64_t vblock_number = static_cast( + vphysical_block_number[vtoken_depth] + [vfetch_depth / VBLOCKS_PER_LANE]); + const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); + + const cache_t* v_fetch_ptr = + v_ptr3 + + (vfetch_depth % VBLOCKS_PER_LANE) * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const _B16x8* v_fetch_ptr_16B = + reinterpret_cast(v_fetch_ptr); + Vlocal[vtoken_depth][vhe_depth][vfetch_depth / 2].xy[vfetch_depth % 2] = + *v_fetch_ptr_16B; + } + } + } + + floatx8 dout[TLOOP]; + // qk wmma + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] = {0}; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) { + dout[token_depth] = gcn_wmma16x16x16_instr( + Klocal[token_depth][qkhe_depth].u16x16, Qlocal[qkhe_depth].u16x16, + dout[token_depth]); + } + dout[token_depth] *= scale; + } + + // calculate qk_max and exp_sum per warp and write to shared memory + float qk_max = -FLT_MAX; + float exp_sum = 0.0f; + const int qkout_token_idx = + partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid; + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 8; i++) { + const float tmp = (local_token_idx + 2 * i < context_len) + ? dout[token_depth][i] + : -FLT_MAX; + qk_max = fmaxf(qk_max, tmp); + } + } + + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, 16)); + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 8; i++) { + const float tmp = (local_token_idx + 2 * i < context_len) + ? __expf(dout[token_depth][i] - qk_max) + : 0.0f; + dout[token_depth][i] = tmp; + exp_sum += tmp; + } + } + + exp_sum += __shfl_xor(exp_sum, 16); + + __syncthreads(); + + if (laneid < 16) { + shared_qk_max[warpid][lane16id] = qk_max; + shared_exp_sum[warpid][lane16id] = exp_sum; + } + + __syncthreads(); + + // calculate partition qk_max and exp_sum + float partition_qk_max = -FLT_MAX; + float warp_qk_max_exp[NWARPS]; + float partition_exp_sum = 0.0f; + + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = shared_qk_max[w][lane16id]; + partition_qk_max = fmaxf(partition_qk_max, warp_qk_max_exp[w]); + } + + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max); + partition_exp_sum += shared_exp_sum[w][lane16id] * warp_qk_max_exp[w]; + } + + const float inv_sum_scale = + __fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid]; + + __syncthreads(); + + // write logits to shared mem + #pragma unroll + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] *= inv_sum_scale; + shared_logits[warpid][token_depth][lane16id][0].xy[rowid] = + from_floatx8(dout[token_depth]); + } + __syncthreads(); + + _B16x8 swp_buf[TLOOP][2]; + #pragma unroll + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + swp_buf[token_depth][0] = + shared_logits[warpid][token_depth][lane16id][0].xy[0]; + swp_buf[token_depth][1] = + shared_logits[warpid][token_depth][lane16id][0].xy[1]; + } + + #pragma unroll + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + shared_logits[warpid][token_depth][lane16id][0].xy[rowid].u16x8[i] = + swp_buf[token_depth][i % 2].u16x8[4 * rowid + (i / 2)]; + } + } + + // write out partition max_logits and exp_sum + if (threadIdx.x < GQA_RATIO) { + const int qhead_idx = lane16id; + const int offset = seq_idx * total_num_heads * max_num_partitions + + (wg_start_head_idx + qhead_idx) * max_num_partitions + + partition_idx; + max_logits[offset] = partition_qk_max; + exp_sums[offset] = partition_exp_sum; + } + + __syncthreads(); + + _B16x8 outelems[VHELOOP]; + // Softmax V wmma + // v layout: 16he across lanes x (16x2) tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + floatx8 tmp_out = {0}; + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP / 2; + vfetch_depth++) { + const int offset = vfetch_depth; + // if output format is 16 qheads across 16 lanes, 16 head elems spread + // across rows + tmp_out = gcn_wmma16x16x16_instr( + Vlocal[vtoken_depth][vhe_depth][vfetch_depth].u16x16, + shared_logits[vtoken_depth][offset][lane16id][0].u16x16, tmp_out); + } + } + outelems[vhe_depth] = from_floatx8(tmp_out); + } + + __syncthreads(); + + #pragma unroll + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + shared_logits[warpid][vhe_depth][lane16id][0].xy[rowid] = + outelems[vhe_depth]; // lane16 id head dimension; rowid head element + // dimension + } + + __syncthreads(); + + #pragma unroll + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + swp_buf[vhe_depth][0] = shared_logits[warpid][vhe_depth][lane16id][0].xy[0]; + swp_buf[vhe_depth][1] = shared_logits[warpid][vhe_depth][lane16id][0].xy[1]; + } + + #pragma unroll + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + shared_logits[warpid][vhe_depth][lane16id][0].xy[rowid].u16x8[i] = + swp_buf[vhe_depth][i % 2].u16x8[4 * rowid + (i / 2)]; + } + } + + __syncthreads(); + + // write to tmp_out with coalesced writes after reading from shared mem + if (warpid == 0) { + _B16x8 vout[GQA_RATIO2]; + // each lane writes out 16Bytes of tmp_out along head elem dimension + const int head_elem_idx = lane16id * 8; + if (head_elem_idx < HEAD_SIZE) { + for (int h = 0; h < GQA_RATIO2; h++) { + const int local_head_idx = 2 * h + rowid; + const int offset1 = (head_elem_idx / 16) % NWARPS; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 8) % 2; // num_he % num_row + vout[h] = + shared_logits[offset1][offset2][local_head_idx][0].xy[offset3]; + } + + const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; + scalar_t* out_ptr = out + seq_idx * total_num_heads * hsz_maxp_mult + + partition_idx * HEAD_SIZE; + for (int h = 0; h < GQA_RATIO2; h++) { + const int local_head_idx = 2 * h + rowid; + if (local_head_idx < GQA_RATIO) { + const int out_head_idx = wg_start_head_idx + local_head_idx; + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); + *out_ptr_B16x8 = vout[h]; + } + } + } + } +} + +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + + __shared__ float shared_global_exp_sum; + // max num partitions supported is warp_size * NPAR_LOOPS + __shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE]; + + if (warpid == 0) { + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + // valid partition is the last valid partition in case threadid > num + // partitions + int valid_partition[NPAR_LOOPS]; + float reg_max_logit[NPAR_LOOPS]; + const int last_valid_partition = num_partitions - 1; + + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + valid_partition[i] = + (partition_no < num_partitions) ? partition_no : last_valid_partition; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + reg_max_logit[i] = max_logits_ptr[valid_partition[i]]; + } + float max_logit = reg_max_logit[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + max_logit = fmaxf(max_logit, reg_max_logit[i]); + } + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); + } + + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + float rescaled_exp_sum[NPAR_LOOPS]; + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + rescaled_exp_sum[i] *= (partition_no < num_partitions) + ? expf(reg_max_logit[i] - max_logit) + : 0.0f; + } + float global_exp_sum = rescaled_exp_sum[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + global_exp_sum += rescaled_exp_sum[i]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + shared_exp_sums[partition_no] = rescaled_exp_sum[i]; + } + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + global_exp_sum += __shfl_xor(global_exp_sum, mask); + } + if (threadIdx.x == 0) { + shared_global_exp_sum = global_exp_sum; + } + } // warpid == 0 + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; + constexpr int MAX_NPAR = 32; + scalar_t tmps[MAX_NPAR]; + const float dzero = 0.0f; + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + tmps[j] = from_float(dzero); + } + const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; + const int num_partition_offset = (num_partitions)*HEAD_SIZE; + int idx = 0; + + constexpr int JCHUNK = 16; + + #pragma unroll + for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + __syncthreads(); + + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + } + } // num_partitions > JCHUNK + + // Aggregate tmp_out to out. + float acc = 0.0f; + #pragma unroll + for (int j = 0; j < JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK; j < 2 * JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + } + } + + for (int p = 1; p < NPAR_LOOPS; p++) { + if (num_partitions > p * MAX_NPAR) { + idx = 0; + #pragma unroll + for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR]; + } + } + } + + const float inv_global_exp_sum = + __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + acc *= inv_global_exp_sum; + OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + out_ptr[threadIdx.x] = from_float(acc); +} + +#else + +// clang-format off +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + UNREACHABLE_CODE +} + +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_partitions) { + UNREACHABLE_CODE +} +// clang-format on + +#endif + +#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma16_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ + max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ + kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ + max_ctx_blocks, k_scale_ptr, v_scale_ptr); + +#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma4_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ + max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ + kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ + max_ctx_blocks, k_scale_ptr, v_scale_ptr); + +#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ + paged_attention_ll4mi_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ + context_lens_ptr, query_start_loc_ptr, max_num_partitions); + +template +void paged_attention_custom_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, const int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + const std::optional& query_start_loc, int max_context_len, + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale) { + int num_seqs = block_tables.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: query start location is optional for V0 decode should not be used. + // If batch contains mix of prefills and decode, prefills should be skipped. + const int* query_start_loc_ptr = + query_start_loc + ? reinterpret_cast(query_start_loc.value().data_ptr()) + : nullptr; + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + OUTT* out_ptr = reinterpret_cast(out.data_ptr()); + + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + + // partition size is fixed at 256 since both mfma4 and mfma16 kernels support + // it mfma4 kernel also supports partition size 512 + constexpr int PARTITION_SIZE = 256; + const int max_num_partitions = + DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int gqa_ratio = num_heads / num_kv_heads; + assert(num_heads % num_kv_heads == 0); + assert(head_size == HEAD_SIZE); + + constexpr int NTHR = 256; + dim3 grid(num_seqs, max_num_partitions, num_kv_heads); + dim3 block(NTHR); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 + switch (gqa_ratio) { + case 1: + LAUNCH_CUSTOM_ATTENTION_MFMA4(1); + break; + case 2: + LAUNCH_CUSTOM_ATTENTION_MFMA4(2); + break; + case 3: + LAUNCH_CUSTOM_ATTENTION_MFMA4(3); + break; + case 4: + LAUNCH_CUSTOM_ATTENTION_MFMA4(4); + break; + case 5: + LAUNCH_CUSTOM_ATTENTION_MFMA16(5); + break; + case 6: + LAUNCH_CUSTOM_ATTENTION_MFMA16(6); + break; + case 7: + LAUNCH_CUSTOM_ATTENTION_MFMA16(7); + break; + case 8: LAUNCH_CUSTOM_ATTENTION_MFMA16(8); break; case 9: @@ -1735,13 +2483,185 @@ void paged_attention_custom_launcher( } } -#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \ - ALIBI_ENABLED) \ - paged_attention_custom_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ - max_context_len, alibi_slopes, k_scale, v_scale); +template +void paged_attention_custom_launcher_navi( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, const int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + int max_context_len, const std::optional& alibi_slopes, + torch::Tensor& k_scale, torch::Tensor& v_scale) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: Navi does not support alibi_slopes. + const float* alibi_slopes_ptr = nullptr; + + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + OUTT* out_ptr = reinterpret_cast(out.data_ptr()); + + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + + constexpr int PARTITION_SIZE = 256; + const int max_num_partitions = + DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int gqa_ratio = num_heads / num_kv_heads; + assert(num_heads % num_kv_heads == 0); + assert(head_size == HEAD_SIZE); + + constexpr int NTHR = 256; + dim3 grid(num_seqs, max_num_partitions, num_kv_heads); + dim3 block(NTHR); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (gqa_ratio) { + case 1: + LAUNCH_CUSTOM_ATTENTION_MFMA16(1); + break; + case 2: + LAUNCH_CUSTOM_ATTENTION_MFMA16(2); + break; + case 3: + LAUNCH_CUSTOM_ATTENTION_MFMA16(3); + break; + case 4: + LAUNCH_CUSTOM_ATTENTION_MFMA16(4); + break; + case 5: + LAUNCH_CUSTOM_ATTENTION_MFMA16(5); + break; + case 6: + LAUNCH_CUSTOM_ATTENTION_MFMA16(6); + break; + case 7: + LAUNCH_CUSTOM_ATTENTION_MFMA16(7); + break; + case 8: + LAUNCH_CUSTOM_ATTENTION_MFMA16(8); + break; + case 9: + LAUNCH_CUSTOM_ATTENTION_MFMA16(9); + break; + case 10: + LAUNCH_CUSTOM_ATTENTION_MFMA16(10); + break; + case 11: + LAUNCH_CUSTOM_ATTENTION_MFMA16(11); + break; + case 12: + LAUNCH_CUSTOM_ATTENTION_MFMA16(12); + break; + case 13: + LAUNCH_CUSTOM_ATTENTION_MFMA16(13); + break; + case 14: + LAUNCH_CUSTOM_ATTENTION_MFMA16(14); + break; + case 15: + LAUNCH_CUSTOM_ATTENTION_MFMA16(15); + break; + case 16: + LAUNCH_CUSTOM_ATTENTION_MFMA16(16); + break; + default: + TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); + break; + } + + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(head_size); + const int warp_size = 32; + const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, warp_size); + // reduction kernel supports upto 16 NPAR_loops * 32 (warp_size) * 256 + // (partition size) = 128K context length + switch (npar_loops) { + case 1: + LAUNCH_CUSTOM_REDUCTION(1); + break; + case 2: + LAUNCH_CUSTOM_REDUCTION(2); + break; + case 3: + LAUNCH_CUSTOM_REDUCTION(3); + break; + case 4: + LAUNCH_CUSTOM_REDUCTION(4); + break; + case 5: + LAUNCH_CUSTOM_REDUCTION(5); + break; + case 6: + LAUNCH_CUSTOM_REDUCTION(6); + break; + case 7: + LAUNCH_CUSTOM_REDUCTION(7); + break; + case 8: + LAUNCH_CUSTOM_REDUCTION(8); + break; + case 9: + LAUNCH_CUSTOM_REDUCTION(9); + break; + case 10: + LAUNCH_CUSTOM_REDUCTION(10); + break; + case 11: + LAUNCH_CUSTOM_REDUCTION(11); + break; + case 12: + LAUNCH_CUSTOM_REDUCTION(12); + break; + case 13: + LAUNCH_CUSTOM_REDUCTION(13); + break; + case 14: + LAUNCH_CUSTOM_REDUCTION(14); + break; + case 15: + LAUNCH_CUSTOM_REDUCTION(15); + break; + case 16: + LAUNCH_CUSTOM_REDUCTION(16); + break; + default: + TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); + break; + } +} + +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \ + ALIBI_ENABLED) \ + if (!is_navi) { \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes, k_scale, v_scale); \ + } else { \ + paged_attention_custom_launcher_navi( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes, k_scale, v_scale); \ + } #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ PSIZE) \ @@ -1794,7 +2714,7 @@ void paged_attention( int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale) { + torch::Tensor& v_scale, bool is_navi) { // clang-format on const int head_size = query.size(2); if (kv_cache_dtype == "auto") { diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index b90cfdc617af..05f8fd2bce49 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -21,4 +21,4 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale); + torch::Tensor& v_scale, bool is_navi); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 4ac6fd1e9940..397ddddb2b41 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " int max_context_len," " Tensor? alibi_slopes," " str kv_cache_dtype," - " Tensor k_scale, Tensor v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale," + " bool is_navi) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index e5650136f258..558bb4d1597d 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -148,6 +148,16 @@ def test_paged_attention( or (version == "rocm" and head_size not in (64, 128))): pytest.skip() + is_rocm_navi = False + if current_platform.is_rocm(): + is_rocm_navi = "gfx1" in torch.cuda.get_device_properties( + "cuda").gcnArchName + + if (version == "rocm" and is_rocm_navi + and (kv_cache_dtype == "fp8" or head_size != 128 + or block_size != 16 or use_alibi)): + pytest.skip() + global PARTITION_SIZE current_platform.seed_everything(seed) @@ -281,13 +291,14 @@ def test_paged_attention( kv_cache_dtype, k_scale, v_scale, + is_rocm_navi, ) opcheck(torch.ops._rocm_C.paged_attention, (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale), + kv_cache_dtype, k_scale, v_scale, is_rocm_navi), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4c577c1c47e7..d56eb8452b1d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -117,13 +117,14 @@ def paged_attention_rocm( kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, + is_navi: bool = False, ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, query_start_loc, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, - v_scale) + v_scale, is_navi) def mla_decode_kvcache_cpu( diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 8076c4791d3c..6ae0ab6482c8 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -861,7 +861,8 @@ def forward( gqa_ratio = num_heads // self.num_kv_heads use_custom = use_rocm_custom_paged_attention( decode_query.dtype, head_size, block_size, gqa_ratio, - decode_meta.max_decode_seq_len, self.sliding_window) + decode_meta.max_decode_seq_len, self.sliding_window, + self.kv_cache_dtype, self.alibi_slopes) if use_custom: max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type != AttentionType.ENCODER_DECODER else @@ -907,6 +908,7 @@ def forward( self.kv_cache_dtype, layer._k_scale, layer._v_scale, + _ON_NAVI, ) else: output[num_prefill_tokens:] = paged_attn.forward_decode( From 7c494875ef6e47a3cb295ff4385b79385f6b8456 Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Wed, 5 Mar 2025 15:57:35 -0500 Subject: [PATCH 2/7] add navi4x support for custom paged attention kernel Signed-off-by: Hosang Yoon --- .../kernels/benchmark_paged_attention.py | 5 +- csrc/rocm/attention.cu | 724 +++++++++++++++++- 2 files changed, 721 insertions(+), 8 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 2625239b08ef..002950770a42 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -17,6 +17,8 @@ NUM_BLOCKS = 128 * 1024 PARTITION_SIZE = 512 PARTITION_SIZE_ROCM = 256 +GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName +ON_NAVI = "gfx1" in GPU_ARCH @torch.inference_mode() @@ -86,7 +88,7 @@ def main( if version == "v2": if current_platform.is_rocm(): global PARTITION_SIZE - if not args.custom_paged_attn: + if not args.custom_paged_attn and not ON_NAVI: PARTITION_SIZE = 1024 else: PARTITION_SIZE = PARTITION_SIZE_ROCM @@ -172,6 +174,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: kv_cache_dtype, k_scale, v_scale, + ON_NAVI, ) else: raise ValueError(f"Invalid version: {version}") diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 7f447d7ec855..ecc06e5bf3a4 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -33,6 +33,10 @@ #define __HIP__NAVI3__ #endif +#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__)) + #define __HIP__NAVI4__ +#endif + #if defined(NDEBUG) #undef NDEBUG #include @@ -1486,12 +1490,6 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #elif defined(__HIP__NAVI3__) using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float; -using float16x4 = - __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; -typedef float16x4 _Half4; -typedef struct _Half8 { - _Half4 xy[2]; -} _Half8; using bit16_t = uint16_t; using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; @@ -1587,7 +1585,7 @@ template __global__ -__launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( +__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] @@ -2227,6 +2225,718 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( out_ptr[threadIdx.x] = from_float(acc); } +#elif defined(__HIP__NAVI4__) + +using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float; + +using bit16_t = uint16_t; +using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; +typedef bit16x4 _B16x4; + +using bit16x8 = __attribute__((__vector_size__(8 * sizeof(uint16_t)))) uint16_t; +union b16x8_u { + bit16x8 u16x8; + _B16x4 xy[2]; +}; +typedef b16x8_u _B16x8; + +using _B8x8 = uint2; +using bit8_t = uint8_t; + +typedef struct _B8x16 { + _B8x8 xy[2]; +} _B8x16; + +template +__device__ __forceinline__ floatx8 gcn_wmma16x16x16_instr(const bit16x8& inpA, + const bit16x8& inpB, + const floatx8& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(inpA, inpB, inpC); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(inpA, inpB, inpC); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float(const T& inp) { + if constexpr (std::is_same::value) { + return (float)inp; + } else if constexpr (std::is_same::value) { + return __bfloat162float(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float_b16(const bit16_t& inp) { + union tmpcvt { + bit16_t u; + _Float16 f; + __hip_bfloat16 b; + } t16; + t16.u = inp; + if constexpr (std::is_same::value) { + return (float)t16.f; + } else if constexpr (std::is_same::value) { + return __bfloat162float(t16.b); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ T from_float(const float& inp) { + if constexpr (std::is_same::value) { + return (_Float16)inp; + } else if constexpr (std::is_same::value) { + return __float2bfloat16(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) { + if constexpr (std::is_same::value) { + union h2cvt { + __half2 h2[4]; + _B16x8 b16x8; + } u; + u.h2[0] = __float22half2_rn(make_float2(inp[0], inp[1])); + u.h2[1] = __float22half2_rn(make_float2(inp[2], inp[3])); + u.h2[2] = __float22half2_rn(make_float2(inp[4], inp[5])); + u.h2[3] = __float22half2_rn(make_float2(inp[6], inp[7])); + return u.b16x8; + } else if constexpr (std::is_same::value) { + union b2cvt { + __hip_bfloat162 b2[4]; + _B16x8 b16x8; + } u; + + u.b2[0] = __float22bfloat162_rn(make_float2(inp[0], inp[1])); + u.b2[1] = __float22bfloat162_rn(make_float2(inp[2], inp[3])); + u.b2[2] = __float22bfloat162_rn(make_float2(inp[4], inp[5])); + u.b2[3] = __float22bfloat162_rn(make_float2(inp[6], inp[7])); + + return u.b16x8; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__global__ +__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11 + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane2id = laneid % 2; + const int lane4id = laneid % 4; + const int lane16id = laneid % 16; + const int rowid = laneid / 16; + + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + + constexpr int T_PAR_SIZE = 256; // token partition size set to 256 + + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; // length of a seq + + const int partition_start_token_idx = partition_idx * T_PAR_SIZE; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + + constexpr int GQA_RATIO2 = DIVIDE_ROUND_UP(GQA_RATIO, 2); + + __shared__ float shared_qk_max[NWARPS][16 + 1]; + __shared__ float shared_exp_sum[NWARPS][16 + 1]; + // shared_logits is used for multiple purposes + __shared__ _B16x8 shared_logits[NWARPS][2][16][2]; + + // for QK wmma16x16_gfx12, layout is QHead/Tokenx16 across every 16 lanes, + // 16 Bytes HeadElements in each lane, 2x16B HeadElements across 2 rows of + // warp + constexpr int ROWS_PER_WARP = + WARP_SIZE / 16; // rows refers to 16 lanes; refer dpp terminology + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = + 16 / sizeof(cache_t); // 8 for 16 bit cache type, 16 for 8 bit types + constexpr int QKHE_PER_FETCH = + CONTIGUOUS_KV_ELEMS_16B_LOAD * + ROWS_PER_WARP; // each fetch across a warp fetches these many elements + constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 2xQKHE_16B across + // warp + + _B16x8 Qlocal[QKHELOOP]; // note that 16 contiguous elements of Q should + // be fetched per lane for 16 bit cache types + + constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); + + constexpr int TOKENS_PER_WARP = + T_PAR_SIZE / + NWARPS; // sub partition of tokens per warp for qk calculation + constexpr int TLOOP = + TOKENS_PER_WARP / + 16; // each wmma16x16x16 instruction processes 16 tokens + + _B16x8 Klocal[TLOOP] + [QKHELOOP]; // can be interpreted as B8x16 for 8 bit types + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + const int total_num_heads = gridDim.z * GQA_RATIO; + + // for QK wmma, tokens in multiples of TOKENS_PER_WARP are spread across warps + // each wmma takes QH16xT16x16HE across warp + // repeat wmma across QKHELOOP dimension + // output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens + // across 2 rows x 8 tokens per lane + + if (GQA_RATIO == 1) { + const int local_qhead_idx = lane16id % GQA_RATIO; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const int64_t seq_idx64 = static_cast(seq_idx); + const scalar_t* q_ptr = q + seq_idx64 * q_stride + + global_qhead_idx * HEAD_SIZE + + rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + if (lane16id < GQA_RATIO) { + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * QKHE_PER_FETCH; + const _B16x8* q_fetch_ptr_16B = + reinterpret_cast(q_fetch_ptr); + Qlocal[qkhe_depth] = *q_fetch_ptr_16B; + } + } + } else { + // fetch Q in shared across warps and then write to registers + const int local_qhead_idx = 2 * warpid + rowid; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const int64_t seq_idx64 = static_cast(seq_idx); + const scalar_t* q_ptr = + q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; + + const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; + if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) { + const scalar_t* q_fetch_ptr = q_ptr + qhead_element; + const _B16x8* q_fetch_ptr_16B = + reinterpret_cast(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; + + const int offset1 = + lane16id / + 2; // 16 contiguous chunks of head elems are spread across 8x2lanes + shared_logits[offset1][lane2id][local_qhead_idx][0] = tmp; + } + + __syncthreads(); + + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + Qlocal[qkhe_depth] = + shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][0]; + } + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; + + int kphysical_block_number[TLOOP]; + + // fetch k physical block numbers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kblock_idx = (kglobal_token_idx < context_len) + ? kglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; + } + + constexpr int KX = 16 / sizeof(cache_t); + const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + + const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int64_t kblock_number = + static_cast(kphysical_block_number[token_depth]); + const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; + const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; + + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; + const _B16x8* k_fetch_ptr_16B = + reinterpret_cast(k_fetch_ptr); + Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; + } + } + + constexpr int VTOKENS_PER_LANE = + TOKENS_PER_WARP / ROWS_PER_WARP; // 32/2 = 16 vtokens per lane + constexpr int VBLOCKS_PER_LANE = 1; // assumes block size >=16 + constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps + constexpr int VTLANELOOP = DIVIDE_ROUND_UP( + VTOKENS_PER_LANE, + CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes + // minimum block size is 16 + constexpr int VHELOOP = + HEAD_SIZE / 16 / NWARPS; // head_size distributed across warps; each wmma + // instr works on 16 head elements + + int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + + // fetch v physical block numbers + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; + vblock_depth++) { + const int vlocal_token_idx = + vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP + + rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE; + const int vglobal_token_idx = + partition_start_token_idx + vlocal_token_idx; + const int vblock_idx = (vglobal_token_idx < context_len) + ? vglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + vphysical_block_number[vtoken_depth][vblock_depth] = + block_table_seq[vblock_idx]; + } + } + + _B16x8 Vlocal[VTLOOP][VHELOOP] + [VTLANELOOP]; // this can be interpreted as B8x16 too + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + + ((rowid * VTOKENS_PER_LANE) % BLOCK_SIZE); + + // v fetches are 16head elems across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id; + const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int vblock_depth = 0; + const int64_t vblock_number = static_cast( + vphysical_block_number[vtoken_depth][vblock_depth]); + const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); + + const cache_t* v_fetch_ptr = + v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const _B16x8* v_fetch_ptr_16B = + reinterpret_cast(v_fetch_ptr); + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; + } + } + } + + floatx8 dout[TLOOP]; + // qk wmma + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] = {0}; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + dout[token_depth] = gcn_wmma16x16x16_instr( + Klocal[token_depth][qkhe_depth].u16x8, Qlocal[qkhe_depth].u16x8, + dout[token_depth]); + } + dout[token_depth] *= scale; + } + + // calculate qk_max and exp_sum per warp and write to shared memory + float qk_max = -FLT_MAX; + float exp_sum = 0.0f; + const int qkout_token_idx = + partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 8; + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 8; i++) { + const float tmp = + (local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX; + qk_max = fmaxf(qk_max, tmp); + } + } + + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, 16)); + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 8; i++) { + const float tmp = (local_token_idx + i < context_len) + ? __expf(dout[token_depth][i] - qk_max) + : 0.0f; + dout[token_depth][i] = tmp; + exp_sum += tmp; + } + } + + exp_sum += __shfl_xor(exp_sum, 16); + + __syncthreads(); + + if (laneid < 16) { + shared_qk_max[warpid][lane16id] = qk_max; + shared_exp_sum[warpid][lane16id] = exp_sum; + } + + __syncthreads(); + + // calculate partition qk_max and exp_sum + float partition_qk_max = -FLT_MAX; + float warp_qk_max_exp[NWARPS]; + float partition_exp_sum = 0.0f; + + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = shared_qk_max[w][lane16id]; + partition_qk_max = fmaxf(partition_qk_max, warp_qk_max_exp[w]); + } + + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max); + partition_exp_sum += shared_exp_sum[w][lane16id] * warp_qk_max_exp[w]; + } + + const float inv_sum_scale = + __fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid]; + + __syncthreads(); + + // write logits to shared mem + #pragma unroll + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] *= inv_sum_scale; + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx8(dout[token_depth]); + } + + // write out partition max_logits and exp_sum + if (threadIdx.x < GQA_RATIO) { + const int qhead_idx = lane16id; + const int offset = seq_idx * total_num_heads * max_num_partitions + + (wg_start_head_idx + qhead_idx) * max_num_partitions + + partition_idx; + max_logits[offset] = partition_qk_max; + exp_sums[offset] = partition_exp_sum; + } + + __syncthreads(); + + _B16x8 outelems[VHELOOP]; + // Softmax V wmma + // v layout: 16he across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + floatx8 tmp_out = {0}; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int offset = rowid * VTLANELOOP + vfetch_depth; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // if output format is 16 qheads across 16 lanes, 16 head elems spread + // across rows + tmp_out = gcn_wmma16x16x16_instr( + Vlocal[vtoken_depth][vhe_depth][vfetch_depth].u16x8, + shared_logits[vtoken_depth][offset2][lane16id][offset1].u16x8, + tmp_out); + } + } + outelems[vhe_depth] = from_floatx8(tmp_out); + } + + __syncthreads(); + + #pragma unroll + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + shared_logits[warpid][vhe_depth][lane16id][rowid] = + outelems[vhe_depth]; // lane16 id head dimension; rowid head element + // dimension + } + + __syncthreads(); + + // write to tmp_out with coalesced writes after reading from shared mem + if (warpid == 0) { + _B16x8 vout[GQA_RATIO2]; + // each lane writes out 16Bytes of tmp_out along head elem dimension + const int head_elem_idx = lane16id * 8; + if (head_elem_idx < HEAD_SIZE) { + for (int h = 0; h < GQA_RATIO2; h++) { + const int local_head_idx = 2 * h + rowid; + const int offset1 = (head_elem_idx / 16) % NWARPS; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 8) % 2; // num_he % num_row + vout[h] = shared_logits[offset1][offset2][local_head_idx][offset3]; + } + + const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; + scalar_t* out_ptr = out + seq_idx * total_num_heads * hsz_maxp_mult + + partition_idx * HEAD_SIZE; + for (int h = 0; h < GQA_RATIO2; h++) { + const int local_head_idx = 2 * h + rowid; + if (local_head_idx < GQA_RATIO) { + const int out_head_idx = wg_start_head_idx + local_head_idx; + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); + *out_ptr_B16x8 = vout[h]; + } + } + } + } +} + +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + + __shared__ float shared_global_exp_sum; + // max num partitions supported is warp_size * NPAR_LOOPS + __shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE]; + + if (warpid == 0) { + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + // valid partition is the last valid partition in case threadid > num + // partitions + int valid_partition[NPAR_LOOPS]; + float reg_max_logit[NPAR_LOOPS]; + const int last_valid_partition = num_partitions - 1; + + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + valid_partition[i] = + (partition_no < num_partitions) ? partition_no : last_valid_partition; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + reg_max_logit[i] = max_logits_ptr[valid_partition[i]]; + } + float max_logit = reg_max_logit[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + max_logit = fmaxf(max_logit, reg_max_logit[i]); + } + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); + } + + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + float rescaled_exp_sum[NPAR_LOOPS]; + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + rescaled_exp_sum[i] *= (partition_no < num_partitions) + ? expf(reg_max_logit[i] - max_logit) + : 0.0f; + } + float global_exp_sum = rescaled_exp_sum[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + global_exp_sum += rescaled_exp_sum[i]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + shared_exp_sums[partition_no] = rescaled_exp_sum[i]; + } + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + global_exp_sum += __shfl_xor(global_exp_sum, mask); + } + if (threadIdx.x == 0) { + shared_global_exp_sum = global_exp_sum; + } + } // warpid == 0 + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; + constexpr int MAX_NPAR = 32; + scalar_t tmps[MAX_NPAR]; + const float dzero = 0.0f; + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + tmps[j] = from_float(dzero); + } + const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; + const int num_partition_offset = (num_partitions)*HEAD_SIZE; + int idx = 0; + + constexpr int JCHUNK = 16; + + #pragma unroll + for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + __syncthreads(); + + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + } + } // num_partitions > JCHUNK + + // Aggregate tmp_out to out. + float acc = 0.0f; + #pragma unroll + for (int j = 0; j < JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK; j < 2 * JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + } + } + + for (int p = 1; p < NPAR_LOOPS; p++) { + if (num_partitions > p * MAX_NPAR) { + idx = 0; + #pragma unroll + for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR]; + } + } + } + + const float inv_global_exp_sum = + __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + acc *= inv_global_exp_sum; + OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + out_ptr[threadIdx.x] = from_float(acc); +} + #else // clang-format off From 7af074dddfb6cfef1f5bb2b02bf3eb390eb00ce4 Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Mon, 21 Apr 2025 16:47:23 -0400 Subject: [PATCH 3/7] Enable AMD Radeon GPU Custom Paged Attention on v1 Signed-off-by: Hosang Yoon --- .../kernels/benchmark_paged_attention.py | 6 +- csrc/rocm/attention.cu | 111 +++++++++++++----- tests/kernels/attention/test_attention.py | 13 +- vllm/_custom_ops.py | 25 +++- vllm/attention/backends/rocm_flash_attn.py | 1 - .../ops/chunked_prefill_paged_decode.py | 3 +- vllm/platforms/rocm.py | 53 ++++++--- 7 files changed, 150 insertions(+), 62 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 002950770a42..116b227b6cb9 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -17,8 +17,6 @@ NUM_BLOCKS = 128 * 1024 PARTITION_SIZE = 512 PARTITION_SIZE_ROCM = 256 -GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName -ON_NAVI = "gfx1" in GPU_ARCH @torch.inference_mode() @@ -88,7 +86,7 @@ def main( if version == "v2": if current_platform.is_rocm(): global PARTITION_SIZE - if not args.custom_paged_attn and not ON_NAVI: + if not args.custom_paged_attn and not current_platform.is_navi(): PARTITION_SIZE = 1024 else: PARTITION_SIZE = PARTITION_SIZE_ROCM @@ -168,13 +166,13 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: scale, block_tables, seq_lens, + None, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, - ON_NAVI, ) else: raise ValueError(f"Invalid version: {version}") diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index ecc06e5bf3a4..245a06aec543 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -1581,6 +1581,7 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) { } } +// clang-format off template @@ -1594,6 +1595,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -1604,6 +1606,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( // head_size] OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] int max_ctx_blocks, const float* k_scale, const float* v_scale) { + // clang-format on constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11 const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -1613,6 +1616,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int rowid = laneid / 16; const int seq_idx = blockIdx.x; + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx]) != 1) { + return; + } + const int partition_idx = blockIdx.y; constexpr int T_PAR_SIZE = 256; // token partition size set to 256 @@ -1671,12 +1681,14 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( // output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens // across 2 rows x 8 tokens per lane + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); + if (GQA_RATIO == 1) { const int local_qhead_idx = lane16id % GQA_RATIO; const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; - const int64_t seq_idx64 = static_cast(seq_idx); const scalar_t* q_ptr = - q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; + q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE; if (lane16id < GQA_RATIO) { #pragma unroll for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) { @@ -1690,9 +1702,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( // fetch Q in shared across warps and then write to registers const int local_qhead_idx = 2 * warpid + rowid; const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; - const int64_t seq_idx64 = static_cast(seq_idx); const scalar_t* q_ptr = - q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; + q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE; const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) { @@ -2024,6 +2035,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -2050,15 +2062,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_partitions) { - const int num_heads = gridDim.x; - const int head_idx = blockIdx.x; - const int seq_idx = blockIdx.y; + const auto num_heads = gridDim.x; + const auto head_idx = blockIdx.x; + const auto seq_idx = blockIdx.y; + + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) { + return; + } + const int context_len = context_lens[seq_idx]; const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + [[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; - const int laneid = threadIdx.x % WARP_SIZE; + [[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE; __shared__ float shared_global_exp_sum; // max num partitions supported is warp_size * NPAR_LOOPS @@ -2221,7 +2242,11 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); acc *= inv_global_exp_sum; - OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); + OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE + + static_cast(head_idx) * HEAD_SIZE; out_ptr[threadIdx.x] = from_float(acc); } @@ -2328,6 +2353,7 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) { } } +// clang-format off template @@ -2341,6 +2367,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -2351,6 +2378,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( // head_size] OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] int max_ctx_blocks, const float* k_scale, const float* v_scale) { + // clang-format on constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11 const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -2360,6 +2388,12 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int rowid = laneid / 16; const int seq_idx = blockIdx.x; + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) { + return; + } const int partition_idx = blockIdx.y; constexpr int T_PAR_SIZE = 256; // token partition size set to 256 @@ -2419,11 +2453,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( // output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens // across 2 rows x 8 tokens per lane + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); + if (GQA_RATIO == 1) { const int local_qhead_idx = lane16id % GQA_RATIO; const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; - const int64_t seq_idx64 = static_cast(seq_idx); - const scalar_t* q_ptr = q + seq_idx64 * q_stride + + const scalar_t* q_ptr = q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE + rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; if (lane16id < GQA_RATIO) { @@ -2439,9 +2475,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( // fetch Q in shared across warps and then write to registers const int local_qhead_idx = 2 * warpid + rowid; const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; - const int64_t seq_idx64 = static_cast(seq_idx); const scalar_t* q_ptr = - q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; + q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE; const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) { @@ -2736,6 +2771,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -2762,15 +2798,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_partitions) { - const int num_heads = gridDim.x; - const int head_idx = blockIdx.x; - const int seq_idx = blockIdx.y; + const auto num_heads = gridDim.x; + const auto head_idx = blockIdx.x; + const auto seq_idx = blockIdx.y; + + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) { + return; + } + const int context_len = context_lens[seq_idx]; const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + [[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; - const int laneid = threadIdx.x % WARP_SIZE; + [[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE; __shared__ float shared_global_exp_sum; // max num partitions supported is warp_size * NPAR_LOOPS @@ -2933,7 +2978,11 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); acc *= inv_global_exp_sum; - OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); + OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE + + static_cast(head_idx) * HEAD_SIZE; out_ptr[threadIdx.x] = from_float(acc); } @@ -3201,9 +3250,10 @@ void paged_attention_custom_launcher_navi( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, const int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, - int max_context_len, const std::optional& alibi_slopes, - torch::Tensor& k_scale, torch::Tensor& v_scale) { - int num_seqs = query.size(0); + const std::optional& query_start_loc, int max_context_len, + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale) { + int num_seqs = block_tables.size(0); int num_heads = query.size(1); int head_size = query.size(2); int max_num_blocks_per_seq = block_tables.size(1); @@ -3211,6 +3261,13 @@ void paged_attention_custom_launcher_navi( int kv_block_stride = key_cache.stride(0); int kv_head_stride = key_cache.stride(1); + // NOTE: query start location is optional for V0 decode should not be used. + // If batch contains mix of prefills and decode, prefills should be skipped. + const int* query_start_loc_ptr = + query_start_loc + ? reinterpret_cast(query_start_loc.value().data_ptr()) + : nullptr; + // NOTE: Navi does not support alibi_slopes. const float* alibi_slopes_ptr = nullptr; @@ -3363,14 +3420,14 @@ void paged_attention_custom_launcher_navi( paged_attention_custom_launcher( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, max_context_len, \ - alibi_slopes, k_scale, v_scale); \ + num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ + max_context_len, alibi_slopes, k_scale, v_scale); \ } else { \ paged_attention_custom_launcher_navi( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, max_context_len, \ - alibi_slopes, k_scale, v_scale); \ + num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ + max_context_len, alibi_slopes, k_scale, v_scale); \ } #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index 558bb4d1597d..d9f956fbc7c0 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -148,12 +148,7 @@ def test_paged_attention( or (version == "rocm" and head_size not in (64, 128))): pytest.skip() - is_rocm_navi = False - if current_platform.is_rocm(): - is_rocm_navi = "gfx1" in torch.cuda.get_device_properties( - "cuda").gcnArchName - - if (version == "rocm" and is_rocm_navi + if (version == "rocm" and current_platform.is_navi() and (kv_cache_dtype == "fp8" or head_size != 128 or block_size != 16 or use_alibi)): pytest.skip() @@ -285,20 +280,20 @@ def test_paged_attention( scale, block_tables, seq_lens, + None, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, - is_rocm_navi, ) opcheck(torch.ops._rocm_C.paged_attention, (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, is_rocm_navi), + seq_lens, None, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d56eb8452b1d..7756885d6be8 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -119,12 +119,25 @@ def paged_attention_rocm( v_scale: torch.Tensor, is_navi: bool = False, ) -> None: - torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, - key_cache, value_cache, num_kv_heads, - scale, block_tables, seq_lens, - query_start_loc, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, k_scale, - v_scale, is_navi) + torch.ops._rocm_C.paged_attention(out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + query_start_loc, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + is_navi=current_platform.is_navi()) def mla_decode_kvcache_cpu( diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 6ae0ab6482c8..abcb68911a8b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -908,7 +908,6 @@ def forward( self.kv_cache_dtype, layer._k_scale, layer._v_scale, - _ON_NAVI, ) else: output[num_prefill_tokens:] = paged_attn.forward_decode( diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 1b47581641b0..e5f7a580bb0a 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -283,7 +283,8 @@ def chunked_prefill_paged_decode( use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, block_size, num_queries_per_kv, - max_seq_len, sliding_window) + max_seq_len, sliding_window, + kv_cache_dtype, alibi_slopes) if use_custom: _PARTITION_SIZE_ROCM = 256 max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ba8f49ca9150..0b2f375bee3c 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -100,24 +100,45 @@ def on_mi250_mi300() -> bool: return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"]) +def on_navi3_navi4() -> bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) + + @cache -def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, - block_size: int, gqa_ratio: int, - max_seq_len: int, - sliding_window: int) -> bool: +def use_rocm_custom_paged_attention( + qtype: torch.dtype, + head_size: int, + block_size: int, + gqa_ratio: int, + max_seq_len: int, + sliding_window: int, + kv_cache_dtype: str, + alibi_slopes: Optional[torch.Tensor] = None) -> bool: - # rocm custom page attention not support on gfx1* # custom paged attn always supported on V0. On V1, requires sliding window # disabled due to observed numerical discrepancy. - return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0 - or sliding_window == (-1, -1)) - and (qtype == torch.half or qtype == torch.bfloat16) - and (head_size == 64 or head_size == 128) - and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 - and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) - and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN - and envs.VLLM_ROCM_USE_AITER)) + if on_mi250_mi300(): + return ((not envs.VLLM_USE_V1 or sliding_window == 0 + or sliding_window == (-1, -1)) + and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) + and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN + and envs.VLLM_ROCM_USE_AITER)) + + else: + return (on_navi3_navi4() + and (not envs.VLLM_USE_V1 or sliding_window == 0 + or sliding_window == (-1, -1)) + and (qtype == torch.half or qtype == torch.bfloat16) + and head_size == 128 and block_size == 16 + and (gqa_ratio >= 3 and gqa_ratio <= 16) + and max_seq_len <= 32768 and alibi_slopes is None + and kv_cache_dtype == "auto" + and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) class RocmPlatform(Platform): @@ -344,3 +365,7 @@ def use_custom_allreduce(cls) -> bool: def get_cu_count(cls, device_id: int = 0) -> int: return torch.cuda.get_device_properties( device_id).multi_processor_count + + @classmethod + def is_navi(cls) -> bool: + return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName From 6c1ec1082b58a67ec1fb4d259a4c76e30240253c Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Fri, 2 May 2025 12:09:19 -0400 Subject: [PATCH 4/7] remove unnecessary arguments Signed-off-by: Hosang Yoon --- csrc/rocm/attention.cu | 10 +++++++++- csrc/rocm/ops.h | 2 +- csrc/rocm/torch_bindings.cpp | 3 +-- vllm/_custom_ops.py | 26 ++++++-------------------- 4 files changed, 17 insertions(+), 24 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index f023925ff080..b78a0668108c 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -3465,6 +3465,10 @@ void paged_attention_custom_launcher_navi( break; \ } +bool is_navi_gpu(const std::string& arch) { + return arch.find("gfx11") == 0 || arch.find("gfx12") == 0; +} + // clang-format off void paged_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] @@ -3482,8 +3486,12 @@ void paged_attention( int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, bool is_navi) { + torch::Tensor& v_scale) { // clang-format on + hipDeviceProp_t deviceProp; + hipGetDeviceProperties(&deviceProp, 0); + bool is_navi = is_navi_gpu(deviceProp.gcnArchName); + const int head_size = query.size(2); if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Half) { diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 05f8fd2bce49..b90cfdc617af 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -21,4 +21,4 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, bool is_navi); + torch::Tensor& v_scale); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 397ddddb2b41..4ac6fd1e9940 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -47,8 +47,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " int max_context_len," " Tensor? alibi_slopes," " str kv_cache_dtype," - " Tensor k_scale, Tensor v_scale," - " bool is_navi) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 15fef7c48b81..7bb01507ac2c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -117,27 +117,13 @@ def paged_attention_rocm( kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, - is_navi: bool = False, ) -> None: - torch.ops._rocm_C.paged_attention(out, - exp_sum, - max_logits, - tmp_out, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - query_start_loc, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - is_navi=current_platform.is_navi()) + torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, + key_cache, value_cache, num_kv_heads, + scale, block_tables, seq_lens, + query_start_loc, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, k_scale, + v_scale) def mla_decode_kvcache_cpu( From 87c7c3ea26854496cbffc69b3e359cb5e102e383 Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Fri, 2 May 2025 12:19:16 -0400 Subject: [PATCH 5/7] chage gpu naming convention Signed-off-by: Hosang Yoon --- csrc/rocm/attention.cu | 8 ++++---- vllm/platforms/rocm.py | 11 +++-------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index b78a0668108c..fb9c3e688e0b 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -31,11 +31,11 @@ #endif #if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__)) - #define __HIP__NAVI3__ + #define __HIP__GFX11__ #endif #if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__)) - #define __HIP__NAVI4__ + #define __HIP__GFX12__ #endif #if defined(NDEBUG) @@ -1488,7 +1488,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } } -#elif defined(__HIP__NAVI3__) +#elif defined(__HIP__GFX11__) using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float; @@ -2251,7 +2251,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( out_ptr[threadIdx.x] = from_float(acc); } -#elif defined(__HIP__NAVI4__) +#elif defined(__HIP__GFX12__) using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float; diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 5281db0930ed..2933f40618a9 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -100,11 +100,6 @@ def on_mi250_mi300() -> bool: return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"]) -def on_navi3_navi4() -> bool: - GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName - return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) - - @cache def use_rocm_custom_paged_attention( qtype: torch.dtype, @@ -118,6 +113,7 @@ def use_rocm_custom_paged_attention( GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) + ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) # custom paged attn always supported on V0. On V1, requires sliding window # disabled due to observed numerical discrepancy. @@ -133,9 +129,8 @@ def use_rocm_custom_paged_attention( and envs.VLLM_ROCM_USE_AITER)) else: - return (on_navi3_navi4() - and (not envs.VLLM_USE_V1 or sliding_window == 0 - or sliding_window == (-1, -1)) + return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0 + or sliding_window == (-1, -1)) and (qtype == torch.half or qtype == torch.bfloat16) and head_size == 128 and block_size == 16 and (gqa_ratio >= 3 and gqa_ratio <= 16) From b5fdeb575029783133daceb57513c3367a8b9bcb Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Mon, 5 May 2025 15:39:30 -0400 Subject: [PATCH 6/7] cache is_navi_gpu result Signed-off-by: Hosang Yoon --- csrc/rocm/attention.cu | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index fb9c3e688e0b..93dd00cfdcd1 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -3465,8 +3465,22 @@ void paged_attention_custom_launcher_navi( break; \ } -bool is_navi_gpu(const std::string& arch) { - return arch.find("gfx11") == 0 || arch.find("gfx12") == 0; +bool is_navi_gpu() { + static bool is_cached = false; + static bool result; + + if (!is_cached) { + int device_id; + hipDeviceProp_t deviceProp; + hipGetDevice(&device_id); + hipGetDeviceProperties(&deviceProp, device_id); + + std::string arch = deviceProp.gcnArchName; + result = arch.find("gfx11") == 0 || arch.find("gfx12") == 0; + is_cached = true; + } + + return result; } // clang-format off @@ -3488,9 +3502,7 @@ void paged_attention( const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale) { // clang-format on - hipDeviceProp_t deviceProp; - hipGetDeviceProperties(&deviceProp, 0); - bool is_navi = is_navi_gpu(deviceProp.gcnArchName); + bool is_navi = is_navi_gpu(); const int head_size = query.size(2); if (kv_cache_dtype == "auto") { From 95d92dabf516cb99b2ed792e7d49791e2bf3b30a Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Wed, 7 May 2025 23:45:24 -0400 Subject: [PATCH 7/7] fix vheloop to ensure minimum value of 1 Signed-off-by: Hosang Yoon --- csrc/rocm/attention.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index dfe7af51df25..449a517a695a 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -1784,9 +1784,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( VTOKENS_PER_LANE, CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes // minimum block size is 16 - constexpr int VHELOOP = - HEAD_SIZE / 16 / NWARPS; // head_size distributed across warps; each wmma - // instr works on 16 head elements + constexpr int VHELOOP = DIVIDE_ROUND_UP( + (HEAD_SIZE / 16), NWARPS); // head_size distributed across warps; each + // wmma instr works on 16 head elements int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; @@ -2555,9 +2555,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( VTOKENS_PER_LANE, CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes // minimum block size is 16 - constexpr int VHELOOP = - HEAD_SIZE / 16 / NWARPS; // head_size distributed across warps; each wmma - // instr works on 16 head elements + constexpr int VHELOOP = DIVIDE_ROUND_UP( + (HEAD_SIZE / 16), NWARPS); // head_size distributed across warps; each + // wmma instr works on 16 head elements int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE];