From d8f094f3d7ff21a962127912db831c6ffbf1ad59 Mon Sep 17 00:00:00 2001 From: jaccob Date: Fri, 21 Nov 2025 00:41:34 -0600 Subject: [PATCH 1/4] Upload experimental pa_ragged kernels and unit test --- csrc/cpp_itfs/pa/pa_kernels.cuh | 1024 +++++++++++++++++++++++ csrc/cpp_itfs/pa/pa_ragged.cpp.jinja | 4 +- csrc/cpp_itfs/pa/pa_ragged.cuh | 18 +- csrc/cpp_itfs/pa/pa_ragged.py | 5 + op_tests/test_pa_ragged_experimental.py | 504 +++++++++++ 5 files changed, 1551 insertions(+), 4 deletions(-) create mode 100644 op_tests/test_pa_ragged_experimental.py diff --git a/csrc/cpp_itfs/pa/pa_kernels.cuh b/csrc/cpp_itfs/pa/pa_kernels.cuh index 6c2cd5df1f..99818f18f3 100644 --- a/csrc/cpp_itfs/pa/pa_kernels.cuh +++ b/csrc/cpp_itfs/pa/pa_kernels.cuh @@ -1110,3 +1110,1027 @@ __inline__ __device__ void _paged_attention_ll4mi_reduce_kernel( out_ptr[threadIdx.x] = from_float(acc); } } + + + + + +// ----------------------------------------------------------------------- +// ----------------------------------------------------------------------- +// ----------------------- Experimental ---------------------------------- +// Configs: head_dim=128, cache_t=bf16 +// Feature: +// 1. continuous threads work together to load K cache into LDS, then each thread save the LDS into registers. +// 2. Double buffer of K cache loading +// 3. NT_KV_LOAD set to true +template +__inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( + const int* block_table_seq, + const int64_t query_loc, + int context_len, + const int partition_start_token_idx, + const scalar_t* q, + const cache_t* k_cache, + const cache_t* v_cache, + const float scale, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + const int kv_seq_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + float logits_soft_cap, + float logits_soft_cap_rcp, + const float* q_scale_ptr, + const float* k_scale_ptr, + const float* v_scale_ptr, + const AttentionVariant* variant, + const int sliding_window = 0) +{ + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + const int kv_head_idx = blockIdx.z; + constexpr int T_PAR_SIZE = 256; + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; + constexpr int HEAD_LOOP = DIVIDE_ROUND_UP(HEAD_SIZE, 128); + constexpr int HEAD_SIZE_PER_LOOP = DIVIDE_ROUND_UP(HEAD_SIZE, HEAD_LOOP); + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane4id = laneid % 4; + const int lane16id = laneid % 16; + const int rowid = laneid / 16; + + const int max_num_partitions = gridDim.y; + constexpr int MAX_ELEMENTS_PER_QUERY = DIVIDE_ROUND_UP(16, GQA_RATIO); + constexpr int MTP_PER_THREAD = DIVIDE_ROUND_UP(MTP, MAX_ELEMENTS_PER_QUERY); + + constexpr int MTP_PARALLEL_THREADS = MTP / MTP_PER_THREAD; + constexpr int GQA_RATIO_LOOP = DIVIDE_ROUND_UP(GQA_RATIO, 16); + constexpr int GQA_RATIO_PER_LOOP = GQA_RATIO / GQA_RATIO_LOOP; + constexpr int GQA_RATIO_MTP_PARALLEL = GQA_RATIO_PER_LOOP * MTP_PARALLEL_THREADS; + constexpr int GQA_RATIO4 = DIVIDE_ROUND_UP(GQA_RATIO_MTP_PARALLEL, 4); + + // shared_logits is used for multiple purposes + __shared__ _B16x4 shared_logits[GQA_RATIO_LOOP][HEAD_LOOP][MTP_PER_THREAD][NWARPS][4][16][4]; + + // for QK mfma16x16, layout is QHead/Tokenx16 across every 16 lanes, 16 Bytes + // HeadElements in each lane, 4x16B HeadElements across 4 rows of warp + + // rows refers to 16 lanes; refer dpp terminology + constexpr int ROWS_PER_WARP = WARP_SIZE / 16; + // 8 for 16 bit cache type, 16 for 8 bit types + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = 16 / sizeof(cache_t); + // each fetch across a warp fetches these many elements + constexpr int QKHE_PER_FETCH = CONTIGUOUS_KV_ELEMS_16B_LOAD * ROWS_PER_WARP; + // 1 for 16bit types, 2 for 8bit types + constexpr int QK_SIZE_RATIO = sizeof(scalar_t) / sizeof(cache_t); + // 4xQKHE_16B across warp + constexpr int QKHELOOP = HEAD_SIZE_PER_LOOP / QKHE_PER_FETCH; + + _B16x8 Qlocal[GQA_RATIO_LOOP][HEAD_LOOP][MTP_PER_THREAD][QKHELOOP] // Jacob: 1x1x1x4x1 + [QK_SIZE_RATIO]; // note that 16 contiguous elements of Q should + // be fetched per lane for 8 bit cache types : + // QK_SIZE_RATIO changes for this + + constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); + + // sub partition of tokens per warp for qk calculation + constexpr int TOKENS_PER_WARP = T_PAR_SIZE / NWARPS; + // constexpr int TLOOP = TOKENS_PER_WARP / 16; // each mfma16x16x16 instruction processes 16 tokens + + const int wg_start_head_idx = kv_head_idx * GQA_RATIO_PER_LOOP; // Jacob: kv_head_idx=0, GQA_RATIO_PER_LOOP=6 + const int wg_start_kv_head_idx = kv_head_idx; + const int total_num_heads = gridDim.z * GQA_RATIO; + + // Jacob: some variables which are dedicated for for Grok1: HEAD_SIZE=128, cache_t=bf16, blockSize 16/64/256 + constexpr int BYTES_PER_WARP_FETCH = WARP_SIZE * 16; // 1024 bytes + constexpr int TOKEN_PER_WARP_FETCH = BYTES_PER_WARP_FETCH / (HEAD_SIZE * sizeof(cache_t)); // 4 token + // constexpr int TLOOP = TOKENS_PER_WARP / TOKEN_PER_WARP_FETCH; // 16 + // 1st Wavefront loads token 1~4 tokens, 17~20 tokens ... + // 2nd Wavefront loads token 5~8 tokens, 21~24 tokens ... + // 3rd Wavefront loads token 9~12 tokens, 25~28 tokens ... + // 4th Wavefront loads token 13~16 tokens, 29~32 tokens ... 61~64 tokens + // The number of iterations to load 64 tokens + constexpr int ITERS_16TK = 64 / (TOKEN_PER_WARP_FETCH * NWARPS); // 4 + // The number of iterations of ITERS_16TK to load 256 tokens + constexpr int TLOOP = 256 / 64; // 4 + constexpr int TOKEN_PER_WG = TOKEN_PER_WARP_FETCH * NWARPS; // 16 tokens per workgroup + constexpr int THREAD_PER_TOKEN = HEAD_SIZE / CONTIGUOUS_KV_ELEMS_16B_LOAD; // if HEAD_SIZE=128, bf16 --> 16 threads load 1 token + + /// NOTICE: We don't support mask for this kernel, so just use a placeholder type/object here. + using Mask = ck_tile::SimplifiedGenericAttentionMask; + const Mask mask{/*seqlen_q=*/1, /*seqlen_k=*/context_len}; + + // Jacob: for QK mfma, tokens in multiples of TOKEN_PER_WG are loaded in each iteration + // each mfma takes 16(KToken)x16(QHead)x32(out of HEAD_SIZE) across warp + // 1 workgroup has 4 wavefronts which load 16 token in 1 TLOOP iteration + // 4 TLOOP iteration load 64 tokens. Then, let each wavefront process 16 tokens mfma + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + // const int last_ctx_page = PAGE_SIZE * ((num_context_blocks - 1) / PAGE_SIZE); // PAGE_SIZE * floor((num_context_blocks - 1) / PAGE_SIZE) + + int kphysical_block_number[TLOOP][ITERS_16TK]; + int kphysical_offset[TLOOP][ITERS_16TK]; + // fetch k physical block numbers + // Jacob: loading order--> Token [0~16, 64~80, 128~144, 192~208], [16~32, 80~96, 144~160, 208~224]... + for(int token_depth = 0; token_depth < TLOOP; token_depth++) // 4 + { + for(int iter_16tk = 0; iter_16tk < ITERS_16TK; iter_16tk++) // 4 + { + // Jacob: block_table_seq has been shifted based on the index of the sequnece + // Jacob: partition_start_token_idx has been set based on the blockIdx.y + const int warp_token_offset = iter_16tk * 64 + token_depth * 16 + warpid * TOKEN_PER_WARP_FETCH; + const int thread_token_offset = rowid; + const int kglobal_token_idx = partition_start_token_idx + warp_token_offset + thread_token_offset; + const int kblock_idx = + (kglobal_token_idx < context_len) ? kglobal_token_idx / BLOCK_SIZE : last_ctx_block; + const int kblock_offset = // % BLOCK_SIZE --> & (BLOCK_SIZE - 1) + kglobal_token_idx & (BLOCK_SIZE - 1); + kphysical_block_number[token_depth][iter_16tk] = block_table_seq[kblock_idx]; + kphysical_offset[token_depth][iter_16tk] = kblock_offset; + + // if (threadIdx.x %16==0 && /*token_depth==0 &&*/ iter_16tk==0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("[HIP] [K block] threadIdx=%3d, token_depth=%d, iter_16tk=%d, " + // "[kblock_idx=%3d, kpage_idx=%3d, kpage_offset=%2d], last_ctx_block=%3d, last_ctx_page=%3d, " + // "block idx=%3d, kblock_offset=%3d, BLOCK_SIZE=%d, PAGE_SIZE=%d\n", + // threadIdx.x, token_depth, iter_16tk, + // kblock_idx, kpage_idx, kpage_offset, last_ctx_block, last_ctx_page, + // block_table_seq[kblock_idx], kblock_offset, BLOCK_SIZE, PAGE_SIZE); + // } + } + __builtin_amdgcn_sched_group_barrier(0x0020, ITERS_16TK, 0); // VMEM read + } + + + // fetch Q in shared across warps and then write to registers + const int warp_mtp_idx = warpid / (4 / MTP_PARALLEL_THREADS); // Jacob: MTP_PARALLEL_THREADS=1, warpid=0, warp_mtp_idx=0 + const int warp_row_idx = warpid % (4 / MTP_PARALLEL_THREADS); // Jacob:warp_row_idx = 0,1,2,3 + + const int local_qhead_idx = 4 * warpid + rowid; // Jacob: rowid=laneid / 16 = 0,1,2,3 + const int local_mtp_qhead_idx = 4 * warp_row_idx + rowid; // Jacob: local_mtp_qhead_idx= 0~15 + const int global_qhead_idx = wg_start_head_idx + local_mtp_qhead_idx; // Jacob: wg_start_head_idx=0, global_qhead_idx=0~15 + const int64_t query_start_off = static_cast(query_loc + warp_mtp_idx); // Jacob: query_loc=sequence idx + constexpr int mtp_loop = MTP_PER_THREAD; + // Jacob: q_stride = GQA_RATIO * HEAD_SIZE = 6*128=768 + // each thread local 8 data, so 16 threads can load the full 128 HEAD_SIZE of Q + // As GQA_RATIO=6, we need 96 threads to load 6x128 q data + // Btw, a block with 256 threads can load 16x128 q data + for(int mtp = 0; mtp < mtp_loop; mtp++) { // 1 + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { // 1 + const scalar_t* q_ptr = + q + (query_start_off + mtp * MTP_PARALLEL_THREADS) * q_stride + (global_qhead_idx + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * HEAD_SIZE; + + for(int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) { // 1 + const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B + head_loop * HEAD_SIZE_PER_LOOP; + if ((local_mtp_qhead_idx < GQA_RATIO_MTP_PARALLEL) && (qhead_element < HEAD_SIZE)) { + const scalar_t* q_fetch_ptr = q_ptr + qhead_element; + const _B16x8* q_fetch_ptr_16B = + reinterpret_cast(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; + + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const int offset1 = + lane16id / + 4; // 16 contiguous chunks of head elems are spread across 4x4lanes + shared_logits[gqa_ratio_loop][head_loop][mtp][offset1][lane4id][local_qhead_idx][0] = tmp.xy[0]; + shared_logits[gqa_ratio_loop][head_loop][mtp][offset1][lane4id][local_qhead_idx][1] = tmp.xy[1]; + } else { + for (int i = 0; i < 2; i++) { + const int head_elem = lane16id * 2 + i; // element id in _B16x4 terms + const int offset3 = head_elem % 4; + const int offset2 = (head_elem / 4) % 4; + const int offset1 = head_elem / 4 / 4; + shared_logits[gqa_ratio_loop][head_loop][mtp][offset1][offset2][local_qhead_idx][offset3] = tmp.xy[i]; + } + } + } + } + } + } + __syncthreads(); + + + + // qk mfma + constexpr bool NT_KV_LOAD = true; + constexpr int KX = 16 / sizeof(cache_t); // vLLM defines x as 16 Bytes of kv cache elements + const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + int curr = 0, next = 1; + __shared__ cache_t Kbuffer_lds[2][HEAD_LOOP][64][HEAD_SIZE]; // curr and next K buffer, each load 64 tokens K cache + // Each warp processes 16x128 and it was divided into 4 mfma 16x(4x32), each thread records 4x CONTIGUOUS_KV_ELEMS_16B_LOAD elems + _B16x8 Kbuffer_reg[2][HEAD_LOOP][QKHELOOP]; + + // qk mfma - define lambda: loading K cache + int DEBUG_TOKEN_DEPTH=0; + constexpr int n_global_load_per_fragment = HEAD_LOOP*ITERS_16TK; + auto load_K_fragment = [&] __device__ ( // Suppose BLOCK_SIZE=1 + const cache_t* k_ptr, int buf_idx, int kpage[ITERS_16TK], int koffset[ITERS_16TK]) { + // return; // Debug: does loading K cache take time? + for (int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) { + for(int iter_16tk = 0; iter_16tk < ITERS_16TK; iter_16tk++){ // 4 + const int64_t kpage_number = static_cast(kpage[iter_16tk]); + const int64_t kpage_offset = static_cast(koffset[iter_16tk]); + const int offset = + kpage_number * kv_block_stride + + kpage_offset * HEAD_SIZE + + lane16id * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const _B16x8* k_ptr_B16x8 = reinterpret_cast(k_ptr + offset); + + // Save to LDS + const int token_row = iter_16tk * 16 + threadIdx.x / 16; + const int cache_offset = lane16id * CONTIGUOUS_KV_ELEMS_16B_LOAD; + if constexpr (NT_KV_LOAD) + *reinterpret_cast<_B16x8*>(&Kbuffer_lds[buf_idx][head_loop][token_row][cache_offset]) = + load_ntmprl_16Byte(k_ptr_B16x8); + else + *reinterpret_cast<_B16x8*>(&Kbuffer_lds[buf_idx][head_loop][token_row][cache_offset]) = + *k_ptr_B16x8; + + // if (iter_16tk<2 && DEBUG_TOKEN_DEPTH<2 && threadIdx.x < 128 && + // blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("[LDS K cache] threadIdx=%3d, iter_16tk=%d, kblock_number=%3ld, " + // "kblock_offset=%3ld, Save to LDS token_row=%3d, LDS cache_offset=%3d, " + // "addr offset=%d, val=%f \n", + // threadIdx.x, iter_16tk, kblock_number, + // kblock_offset, token_row, cache_offset, + // offset, __bfloat162float(*(k_ptr + offset))); + // DEBUG_TOKEN_DEPTH += 1; + // } + // if (DEBUG_TOKEN_DEPTH==0 && token_row == 0 && + // blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("[LDS K cache] threadIdx=%3d, iter_16tk=%d, kblock_number=%3ld, " + // "kblock_offset=%3ld, Save to LDS token_row=%3d, LDS cache_offset=%3d, " + // "addr offset=%d, val=%f \n", + // threadIdx.x, iter_16tk, kblock_number, + // kblock_offset, token_row, cache_offset, + // offset, __bfloat162float(*(k_ptr + offset))); + // DEBUG_TOKEN_DEPTH += 1; + // } + // if (token_row == 0 && cache_offset == 64 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("[LDS K cache] threadIdx=%3d, iter_16tk=%d, kblock_number=%3ld, " + // "kblock_offset=%3ld, LDS token_row=%3d, LDS cache_offset=%3d, " + // "addr offset=%d, val=%f \n", + // threadIdx.x, iter_16tk, kblock_number, + // kblock_offset, token_row, cache_offset, + // offset, __bfloat162float(*(k_ptr + offset))); + // } + } + } + __syncthreads(); + + // Load LDS into registers for mfma16x16x32 A MATRIX + // Wavefront 1 load lds[00:16][HEAD_SIZE] + // Wavefront 2 load lds[16:32][HEAD_SIZE] + // Wavefront 3 load lds[32:48][HEAD_SIZE] + // Wavefront 4 load lds[48:64][HEAD_SIZE] + for (int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) { + for(int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++){ // 4 + const int row_warp_offset = warpid * 16; + const int row_thread_offset = lane16id; + const int row = row_warp_offset + row_thread_offset; + + const int col_16tk_offset = qkhe_depth * 32; // A matrix is 16x32. 32 columns + const int col_thread_offset = CONTIGUOUS_KV_ELEMS_16B_LOAD * rowid; + const int col = col_16tk_offset + col_thread_offset; + Kbuffer_reg[buf_idx][head_loop][qkhe_depth] = + *reinterpret_cast<_B16x8*>(&Kbuffer_lds[buf_idx][head_loop][row][col]); + // _B16x8{{ + // {static_cast(1),1,1,1},{1,1,1,1} + // }}; + + // Check NAN + // for(int x=0; x<2; ++x) + // for(int y=0; y<4; ++y) + // if(isnan((float)Kbuffer_reg[buf_idx][head_loop][qkhe_depth].xy[x][y])) + // printf("Kbuffer_reg is nan\n"); + + // if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("[LDS-->Reg] threadIdx=%3d, qkhe_depth=%d, " + // "Kbuffer_reg load from Kbuffer_lds[%d][%d][%d]\n", + // threadIdx.x, qkhe_depth, + // buf_idx, row, col); + // } + } + } + }; + + // qk mfma - Preload the kphysical_block_number[0] cache + load_K_fragment(k_ptr, curr, kphysical_block_number[0], kphysical_offset[0]); + __builtin_amdgcn_sched_group_barrier(0x0020, n_global_load_per_fragment, 0); // VMEM read + + // qk mfma - Setup alibi_slope + float alibi_slope[GQA_RATIO_LOOP]; + if constexpr(ALIBI_ENABLED) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + const int alibi_head_idx = wg_start_head_idx + lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP; + alibi_slope[gqa_ratio_loop] = (lane16id < GQA_RATIO_PER_LOOP) ? alibi_slopes[alibi_head_idx] : 0.f; + } + } + + // qk mfma - calculate post qk mfma scale + float scale2 = scale; + if constexpr(KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) + { + // multiply by k_scale if fp8 kv cache + scale2 *= *k_scale_ptr; + } + + const auto variant_params = [&] { + if constexpr(AttentionVariant::use_logits_soft_cap) + { + return ck_tile::LogitsSoftCapParams{ + mask, scale2, logits_soft_cap, logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, scale2}; + } + }(); + + // qk mfma - load K cache[iter+1] + mfma[iter] + floatx4 d_out[GQA_RATIO_LOOP][mtp_loop][TLOOP]; + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { // 4 + // Preload the next K cache + if (token_depth + 1 < TLOOP){ + load_K_fragment(k_ptr, next, kphysical_block_number[token_depth+1], kphysical_offset[token_depth+1]); + __builtin_amdgcn_sched_group_barrier(0x0020, n_global_load_per_fragment, 0); // VMEM read + } + + for (int mtp = 0; mtp < mtp_loop; mtp++) { // 1 + for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { // 1 + d_out[gqa_ratio_loop][mtp][token_depth] = {0}; + for (int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) { // 1 + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { // 4 + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + // Load Q from LDS + for (int i = 0; i < 2; i++) + Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth][qkratio].xy[i] = + shared_logits[gqa_ratio_loop][head_loop][mtp][qkhe_depth][rowid] + [lane16id % GQA_RATIO_MTP_PARALLEL][2 * qkratio + i]; + __builtin_amdgcn_sched_group_barrier(0x0100, 2, 0); // LDS read + + // mfma + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + #if defined(__gfx950__) + d_out[gqa_ratio_loop][mtp][token_depth] = gcn_mfma16x16x32_instr( + Kbuffer_reg[curr][head_loop][qkhe_depth], + Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth][qkratio], + d_out[gqa_ratio_loop][mtp][token_depth]); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + #else + for (int i = 0; i < 2; i++) { + d_out[gqa_ratio_loop][mtp][token_depth] = gcn_mfma16x16x16_instr( + Kbuffer_reg[curr][head_loop][qkhe_depth].xy[i], + Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth][qkratio].xy[i], + d_out[gqa_ratio_loop][mtp][token_depth]); + } + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + #endif + + // Check value + // for(int d=0; d<4; ++d) + // if(isnan(d_out[gqa_ratio_loop][mtp][token_depth][d])){ + // for(int x=0; x<2; ++x) + // for(int y=0; y<4; ++y){ + // _B16x4 kdata = Kbuffer_reg[curr][head_loop][qkhe_depth].xy[0]; + // printf("qk_mfma is nan. Kbuffer_reg=%hu %hu %hu %hu\n", + // kdata[0], kdata[1], kdata[2], kdata[3]); + // } + // break; + // } + + } + else { // kv cache dtype fp8 + auto Ktmp = Kbuffer_reg[curr][head_loop][qkhe_depth]; + _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + // Load Q from LDS + + _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; + _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); + #if defined(__gfx950__) + d_out[gqa_ratio_loop][mtp][token_depth] = gcn_mfma16x16x32_instr( + Klocaltmp, + Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth][qkratio], + d_out[gqa_ratio_loop][mtp][token_depth]); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + #else + for (int i = 0; i < 2; i++) { + d_out[gqa_ratio_loop][mtp][token_depth] = gcn_mfma16x16x16_instr( + Klocaltmp.xy[i], Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth][qkratio].xy[i], + d_out[gqa_ratio_loop][mtp][token_depth]); + } + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + #endif + } + } + } + } + } + + // DEBUG: check values + // if (threadIdx.x==1 && token_depth<4 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // floatx4 data = d_out[gqa_ratio_loop][mtp][token_depth]; + // printf("[check d_out] threadIdx.x=%d, BLOCK_SIZE=%d, d_out=%f,%f %f %f \n", + // threadIdx.x, BLOCK_SIZE ,data[0], data[1], data[2], data[3]); + // for(int y=0; yQueryTransform(variant_params, d_out[gqa_ratio_loop][mtp][token_depth][i]); + } + } + } + int tmp = curr; + curr = next; + next = tmp; + } + + const int qkout_token_idx = partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; + // apply alibi + if constexpr (ALIBI_ENABLED) { + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + const int alibi_offset = local_token_idx - context_len + 1; + for (int mtp = 0; mtp < mtp_loop; mtp++) { + for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + for (int i = 0; i < 4; i++) { + d_out[gqa_ratio_loop][mtp][token_depth][i] += alibi_slope[gqa_ratio_loop] * (alibi_offset + i); + } + } + } + } + } + // apply sliding window + if constexpr(SLIDING_WINDOW_ENABLED) + { + for(int token_depth = 0; token_depth < TLOOP; token_depth++) + { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for(int mtp = 0; mtp < mtp_loop; mtp++) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { + for(int i = 0; i < 4; i++) + { + float tmp = d_out[gqa_ratio_loop][mtp][token_depth][i]; + if (local_token_idx + i < context_len - sliding_window) + tmp = -FLT_MAX; + d_out[gqa_ratio_loop][mtp][token_depth][i] = tmp; + } + } + } + } + } + // apply soft-capping to logits + for (int token_depth = 0; token_depth < TLOOP; token_depth++) + { + for (int mtp = 0; mtp < mtp_loop; mtp++) { + for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + for (int i = 0; i < 4; i++) { + d_out[gqa_ratio_loop][mtp][token_depth][i] = + variant->LogitsTransform(variant_params, + d_out[gqa_ratio_loop][mtp][token_depth][i], + /*batch_idx=*/query_start_off + mtp * MTP_PARALLEL_THREADS, + /*qo_head_idx=*/wg_start_head_idx + lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP, + /*kv_head_idx=*/kv_head_idx); + } + + } + } + } + + // Same as golden but the index [warpid, token_depth] are different. + // if (threadIdx.x%16==0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // for (int token_depth = 0; token_depth < TLOOP; token_depth++){ + // floatx4 data = d_out[0][0][token_depth]; + // printf("[Check d_out + soft-capping] warpid=%d, token_depth=%d, threadIdx.x=%3d, d_out[0][0][%d]=%f %f %f %f\n", + // warpid, token_depth, threadIdx.x, token_depth, data[0], data[1], data[2], data[3]); + // __syncthreads(); + // } + // } + + // calculate qk_max and exp_sum per warp and write to shared memory + float qk_max[GQA_RATIO_LOOP][MTP_PER_THREAD] = {-FLT_MAX}; + float exp_sum[GQA_RATIO_LOOP][MTP_PER_THREAD] = {0.0f}; + + for (int mtp = 0; mtp < mtp_loop; mtp++) { + for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + // Step 1.1 Get max qk per thread: + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for(int i=0; i<4; ++i){ + const float tmp = ((local_token_idx + i) < context_len * (warp_mtp_idx + 1)) + ? d_out[gqa_ratio_loop][mtp][token_depth][i] + : -FLT_MAX; + qk_max[gqa_ratio_loop][mtp] = fmaxf(qk_max[gqa_ratio_loop][mtp], tmp); + } + } + + // for(int d=0; d<4; ++d) + // if(isnan(d_out[gqa_ratio_loop][mtp][token_depth][d])){ + // floatx4 ddata = d_out[gqa_ratio_loop][mtp][token_depth]; + // printf("qk_mfma+soft-capping is nan. d_out=%f %f %f %f\n", + // ddata[0], ddata[1], ddata[2], ddata[3]); + // break; + // } + + // Step 1.2 Get max qk along q head under each wavefront + // According to ROCm CDNA4 mfma16x16x32, The output dim of mfma(qk) is 16x16. + // Thread [1, 17, 33, 49] stores 1 column, 16 elements of mfma(K@Q.T). + // Use the following loop can get the max(thread1, thread17, thread33, thread49) + // "mask >= 16" summed to 16 threads as 1 GQA_RATIO_LOOP process 16 q heads + for (int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) { + qk_max[gqa_ratio_loop][mtp] = fmaxf(qk_max[gqa_ratio_loop][mtp], __shfl_xor(qk_max[gqa_ratio_loop][mtp], mask)); + } + + // Step 2.1 Calc exp(d_out-qk_max) per thread + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 4; i++) { + const float tmp = ((local_token_idx + i) < context_len * (warp_mtp_idx + 1)) + ? __expf(d_out[gqa_ratio_loop][mtp][token_depth][i] - qk_max[gqa_ratio_loop][mtp]) + : 0.0f; + d_out[gqa_ratio_loop][mtp][token_depth][i] = tmp; + exp_sum[gqa_ratio_loop][mtp] += tmp; + // if(isnan(tmp)) + // printf("exp(d_out-qk_max) is nan. d_out=%f, qk_max=%f\n", + // d_out[gqa_ratio_loop][mtp][token_depth][i], qk_max[gqa_ratio_loop][mtp]); + } + } + + // Step 2.2 Sum up exp per wavefronts + for (int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) { + exp_sum[gqa_ratio_loop][mtp] += __shfl_xor(exp_sum[gqa_ratio_loop][mtp], mask); + } + } + } + // __syncthreads(); // sync before writing to shared mem // Why need sync here? no LDS ops before this line + + // Step 3. Save qk_max and exp_sum for the entire workgroup + float* shared_mem = reinterpret_cast(shared_logits); + if (laneid < 16) { + for(int mtp = 0; mtp < mtp_loop; mtp++) { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + const int qk_max_offset = + warpid * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + + (lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * MTP_PER_THREAD + + mtp; + shared_mem[qk_max_offset] = qk_max[gqa_ratio_loop][mtp]; + const int exp_sum_offset = + NWARPS * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + qk_max_offset; + shared_mem[exp_sum_offset] = exp_sum[gqa_ratio_loop][mtp]; + + // if (threadIdx.x < 256 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("[qk_max + exp_sum] threadIdx=%3d, qk_max_offset=%3d, exp_sum_offset=%3d, " + // "shared_mem[qk_max_offset]=%f, shared_mem[exp_sum_offset]=%f\n", + // threadIdx.x, qk_max_offset, exp_sum_offset, + // shared_mem[qk_max_offset], shared_mem[exp_sum_offset]); + // } + } + } + } + __syncthreads(); + + // Seg 6.2 + // Get qk_max across wavefronts + // calculate partition qk_max and exp_sum + float inv_sum_scale[GQA_RATIO_LOOP][MTP_PER_THREAD] = {0.0f}; + float partition_qk_max[GQA_RATIO_LOOP][MTP_PER_THREAD] = {-FLT_MAX}; + float partition_exp_sum[GQA_RATIO_LOOP][MTP_PER_THREAD] = {0.0f}; + + for(int mtp = 0; mtp < mtp_loop; mtp++) { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + float warp_qk_max_exp[NWARPS]; + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = shared_mem[w * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + (lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * MTP_PER_THREAD + mtp]; + partition_qk_max[gqa_ratio_loop][mtp] = fmaxf(partition_qk_max[gqa_ratio_loop][mtp], warp_qk_max_exp[w]); + } + + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max[gqa_ratio_loop][mtp]); + partition_exp_sum[gqa_ratio_loop][mtp] += + shared_mem[NWARPS * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + w * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + (lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * MTP_PER_THREAD + mtp] * warp_qk_max_exp[w]; + } + + inv_sum_scale[gqa_ratio_loop][mtp] = + __fdividef(1.f, partition_exp_sum[gqa_ratio_loop][mtp] + 1e-6f) * warp_qk_max_exp[warpid]; + + // if (threadIdx.x < 256 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("threadIdx=%3d, warp_qk_max_exp=%f %f %f %f, partition_qk_max[%d][%d]=%f " + // "partition_exp_sum[%d][%d]=%f, " + // "inv_sum_scale[%d][%d]=%f\n", + // threadIdx.x, + // warp_qk_max_exp[0], warp_qk_max_exp[1], warp_qk_max_exp[2], warp_qk_max_exp[3], gqa_ratio_loop, mtp, partition_qk_max[gqa_ratio_loop][mtp], + // gqa_ratio_loop, mtp, partition_exp_sum[gqa_ratio_loop][mtp], + // gqa_ratio_loop, mtp, inv_sum_scale[gqa_ratio_loop][mtp]); + // } + } + } + + __syncthreads(); // Why need sync here? no LDS ops before this line + + + // disable rtz conversion due to its impact on accuracy. + constexpr bool LOGITS_RTZ_CONVERSION = false; + + // write logits to shared mem + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + for (int mtp = 0; mtp < mtp_loop; mtp++) { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + d_out[gqa_ratio_loop][mtp][token_depth] *= inv_sum_scale[gqa_ratio_loop][mtp]; + if constexpr (LOGITS_RTZ_CONVERSION) { + // use rtz conversion for better performance, with negligible impact on + // accuracy + shared_logits[gqa_ratio_loop][0][mtp][warpid][token_depth][lane16id][rowid] = + from_floatx4_rtz(d_out[gqa_ratio_loop][mtp][token_depth]); + } else { + shared_logits[gqa_ratio_loop][0][mtp][warpid][token_depth][lane16id][rowid] = + from_floatx4(d_out[gqa_ratio_loop][mtp][token_depth]); + } + } + } + } + + + // DEBUG: Get qk_max across blocks + // write out partition max_logits and exp_sum + if (threadIdx.x < GQA_RATIO_MTP_PARALLEL) { + for(int mtp = 0; mtp < mtp_loop; mtp++) { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + const int qhead_idx = lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP; + const int64_t offset = static_cast(seq_idx + mtp * MTP_PARALLEL_THREADS) * + static_cast(total_num_heads) * + static_cast(max_num_partitions) + + (static_cast(wg_start_head_idx) + + static_cast(qhead_idx)) * + static_cast(max_num_partitions) + + static_cast(partition_idx); + max_logits[offset] = partition_qk_max[gqa_ratio_loop][mtp]; + exp_sums[offset] = partition_exp_sum[gqa_ratio_loop][mtp]; + + // if (threadIdx.x < 64 && blockIdx.x == 0 && blockIdx.y == 7 && blockIdx.z == 0) { + // printf("threadIdx=%3d, blockIdx.y=%d, max_logits[%ld]=%f, exp_sums[%ld]=%f \n", + // threadIdx.x, blockIdx.y, offset, max_logits[offset], offset, exp_sums[offset]); + // } + } + } + } + + __syncthreads(); + + // fetch v physical block numbers + constexpr int n_thread_per_warp = (NWARPS * 16) / CONTIGUOUS_KV_ELEMS_16B_LOAD; // 8 + constexpr int k_thread_per_warp = WARP_SIZE / n_thread_per_warp; // 8 + constexpr int n_thread_per_block = n_thread_per_warp; // 8 + constexpr int k_thread_per_block = NWARPS * k_thread_per_warp; // 32 + constexpr int k_repeat = TOKENS_PER_WARP / k_thread_per_block; // 2 + static_assert(BLOCK_SIZE <= k_thread_per_block); + + constexpr int VTOKENS_PER_LANE = + TOKENS_PER_WARP / ROWS_PER_WARP; // 64/4 = 16 contiguous vtokens per lane + constexpr int VBLOCKS_PER_LANE = k_repeat; // assumes block size <= 32 + constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps + constexpr int VTLANELOOP = + DIVIDE_ROUND_UP(VTOKENS_PER_LANE, + CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes + // minimum block size is 16 + constexpr int VHELOOP = HEAD_SIZE / 16 / NWARPS; // head_size distributed across warps; each + // mfma instr works on 16 head elements + int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) + { + for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) + { + const int vlocal_token_idx = vtoken_depth * TOKENS_PER_WARP + + vblock_depth * k_thread_per_block + + threadIdx.x / n_thread_per_block; + const int vglobal_token_idx = partition_start_token_idx + vlocal_token_idx; + const int vblock_idx = + (vglobal_token_idx < context_len) ? vglobal_token_idx / BLOCK_SIZE : last_ctx_block; + vphysical_block_number[vtoken_depth][vblock_depth] = block_table_seq[vblock_idx]; + } + } + + _B16x8 Vlocal[VTLOOP][VHELOOP][VTLANELOOP]; // this can be interpreted as B8x16 too + __shared__ unsigned char vlds_ptr[TOKENS_PER_WARP * n_thread_per_block * 16]; + static_assert(VBLOCKS_PER_LANE == VTLANELOOP, + "make sure we can keep un-shuffled data in Vlocal as well"); + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + + ((threadIdx.x / n_thread_per_block) % BLOCK_SIZE) * kv_seq_stride; + + // v fetches are 16head elems across lanes x 16 tokens per lane + for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) // 2 + { + for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) // 4 + { + for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) // 2 + { + const int vlds_col_idx = laneid % n_thread_per_block; + const int vhead_elem = + vhe_depth * NWARPS * 16 + vlds_col_idx * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const cache_t* v_ptr2 = v_ptr + vhead_elem; + + const int64_t vblock_number = + static_cast(vphysical_block_number[vtoken_depth][vblock_depth]); + const cache_t* v_fetch_ptr = v_ptr2 + (vblock_number * kv_block_stride); + + // Jacob: Non temporal load for large batch size + const _B16x8* v_fetch_ptr_16B = reinterpret_cast(v_fetch_ptr); + if constexpr(NT_KV_LOAD) + { + Vlocal[vtoken_depth][vhe_depth][vblock_depth] = load_ntmprl_16Byte(v_fetch_ptr_16B); + } + else{ + Vlocal[vtoken_depth][vhe_depth][vblock_depth] = + *reinterpret_cast(v_fetch_ptr); + } + } + } + } + + + constexpr int ELEMS8_ELEMS4_RATIO = 8 / 4; + constexpr int ELEMS16_ELEMS8_RATIO = 16 / 8; + + for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) + { + // 1. store data into LDS + for(int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) + { + const int vlds_col_idx = laneid % n_thread_per_block; + const int vlocal_token_idx = + vblock_depth * k_thread_per_block + threadIdx.x / n_thread_per_block; + *reinterpret_cast<_B16x8*>(vlds_ptr + + (/*row=*/vlocal_token_idx * n_thread_per_block + + /*col=*/vlds_col_idx) * + 16) = Vlocal[vtoken_depth][vhe_depth][vblock_depth]; + } + __syncthreads(); + + // 2. load data from LDS (transposed), then do multification + for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++){ + const int vlocal_head_elem = warpid * 16 + lane16id; + + const int vlds_col_idx = vlocal_head_elem / CONTIGUOUS_KV_ELEMS_16B_LOAD; + const int vlds_elem_idx = vlocal_head_elem % CONTIGUOUS_KV_ELEMS_16B_LOAD; + + const int vlocal_token_idx = + rowid * VTOKENS_PER_LANE + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; + + // read data points individually and save them into array + cache_t elems[CONTIGUOUS_KV_ELEMS_16B_LOAD]; + for(int d2 = 0; d2 < CONTIGUOUS_KV_ELEMS_16B_LOAD; ++d2) + { + const cache_t* fetched_elems = reinterpret_cast( + vlds_ptr + (/*row=*/(vlocal_token_idx + d2) * n_thread_per_block + + /*col=*/vlds_col_idx) * + 16); + + elems[d2] = fetched_elems[vlds_elem_idx]; + } + + // copy all the read data points together + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *reinterpret_cast(elems); + } + __syncthreads(); + } + } + + + _B16x4 outelems[GQA_RATIO_LOOP][MTP_PER_THREAD][VHELOOP]; + + // Softmax V mfma + // v layout: 16he across lanes x 16 tokens per lane + for (int mtp = 0; mtp < mtp_loop; mtp++) { + for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + floatx4 tmp_out = {0}; + + for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) + { + if constexpr(KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) + { + for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) + { + #if defined(__gfx950__) + _B16x8 tmp_in; + for(int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) + { + const int offset = rowid * VTLANELOOP * ELEMS8_ELEMS4_RATIO + + vfetch_depth * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + tmp_in.xy[i] = shared_logits[gqa_ratio_loop][0][mtp][vtoken_depth][offset2][lane16id][offset1]; + } + tmp_out = gcn_mfma16x16x32_instr( + Vlocal[vtoken_depth][vhe_depth][vfetch_depth], + tmp_in, + tmp_out); + #else + for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + const int offset = rowid * VTLANELOOP * ELEMS8_ELEMS4_RATIO + + vfetch_depth * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems spread + // across 4 rows + tmp_out = gcn_mfma16x16x16_instr( + Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], + shared_logits[gqa_ratio_loop][0][mtp][vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + #endif + } + } + else + { + for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) + { + _B16x8 Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; + // reinterpret V format as 16 elements of 8bits + _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); + for(int j = 0; j < ELEMS16_ELEMS8_RATIO; j++) + { + _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; + _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); + + #if defined(__gfx950__) + _B16x8 tmp_in; + for(int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) + { + const int offset = + rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + + j * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + tmp_in.xy[i] = shared_logits[gqa_ratio_loop][0][mtp][vtoken_depth][offset2][lane16id][offset1]; + } + tmp_out = gcn_mfma16x16x32_instr( + Vlocaltmp, + tmp_in, + tmp_out); + #else + for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + const int offset = + rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + + j * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems + // spread across 4 rows + tmp_out = gcn_mfma16x16x16_instr( + Vlocaltmp.xy[i], + shared_logits[gqa_ratio_loop][0][mtp][vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + #endif + } + } + } + __syncthreads(); + } + // apply post Softmax V mfma v_scale + if constexpr(KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) + { + tmp_out *= *v_scale_ptr; + } + outelems[gqa_ratio_loop][mtp][vhe_depth] = from_floatx4(tmp_out); + } + } + } + + __syncthreads(); + + // store Softmax-V mfma output to shared mem + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + // lane16 id head dimension; rowid head element dimension + for(int mtp = 0; mtp < mtp_loop; mtp++) { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + shared_logits[gqa_ratio_loop][0][mtp][warpid][vhe_depth][lane16id][rowid] = outelems[gqa_ratio_loop][mtp][vhe_depth]; + + // if (threadIdx.x==0 && vhe_depth==0 && + // blockIdx.x == 1 && blockIdx.y == 0 && blockIdx.z == 0) { + // _B16x4 data = outelems[gqa_ratio_loop][mtp][vhe_depth]; + // uint16_t v0 = data[0]; + // uint16_t v1 = data[1]; + // uint16_t v2 = data[2]; + // uint16_t v3 = data[3]; + // __hip_bfloat16 b0 = *reinterpret_cast<__hip_bfloat16*>(&v0); + // __hip_bfloat16 b1 = *reinterpret_cast<__hip_bfloat16*>(&v1); + // __hip_bfloat16 b2 = *reinterpret_cast<__hip_bfloat16*>(&v2); + // __hip_bfloat16 b3 = *reinterpret_cast<__hip_bfloat16*>(&v3); + // printf("[outelems] threadIdx.x=%d, vhe_depth=%d, outlems=%f %f %f %f \n", + // threadIdx.x, vhe_depth, + // __bfloat162float(b0), + // __bfloat162float(b1), + // __bfloat162float(b2), + // __bfloat162float(b3)); + // } + } + } + } + + __syncthreads(); + + // write to tmp_out with coalesced writes after reading from shared mem + if (warpid == 0) { + for (int mtp = 0; mtp < mtp_loop; mtp++) { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + for(int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) { + _B16x8 vout[GQA_RATIO4]; + // each lane writes out 16Bytes of tmp_out along head elem dimension + const int head_elem_idx = lane16id * 8 + head_loop * HEAD_SIZE_PER_LOOP; + if (head_elem_idx < HEAD_SIZE) { + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + const int offset1 = (head_elem_idx / 16) % 4; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 4) % 4; + for (int i = 0; i < 2; i++) { + vout[h].xy[i] = + shared_logits[gqa_ratio_loop][0][mtp][offset1][offset2][local_head_idx][offset3 + i]; + } + } + + const int64_t hsz_maxp_mult = + static_cast(HEAD_SIZE * max_num_partitions); + + scalar_t* out_ptr = out + (seq_idx + mtp * MTP_PARALLEL_THREADS) * total_num_heads * hsz_maxp_mult + + partition_idx * HEAD_SIZE; + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + if (local_head_idx < GQA_RATIO_MTP_PARALLEL) { + const int64_t out_head_idx = + static_cast(wg_start_head_idx + local_head_idx + gqa_ratio_loop * GQA_RATIO_PER_LOOP); + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); + *out_ptr_B16x8 = vout[h]; + + // if (threadIdx.x<64 && + // blockIdx.x == 59 && blockIdx.y == 6 && blockIdx.z == 0) { + // _B16x8 data_x8 = * out_ptr_B16x8; + // uint16_t v[8]; + // __hip_bfloat16 b[8]; + // float c[8]; + // for(int i=0; i<2; ++i) + // for(int j=0; j<4; ++j){ + // v[i*4+j] = data_x8.xy[i][j]; + // b[i*4+j] = *reinterpret_cast<__hip_bfloat16*>(&v[i*4+j]); + // c[i*4+j] = __bfloat162float(b[i*4+j]); + // } + // printf("[out_ptr] threadIdx.x=%3d, h(GQA_RATIO4)=%d, local_head_idx=%3d, head_elem_idx=%3d, " + // "out=%f %f %f %f, %f %f %f %f \n", + // threadIdx.x, h, local_head_idx, head_elem_idx, + // c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]); + // } + } + } + } + } + } + } + } +} \ No newline at end of file diff --git a/csrc/cpp_itfs/pa/pa_ragged.cpp.jinja b/csrc/cpp_itfs/pa/pa_ragged.cpp.jinja index d817c0ffe9..d711344a74 100644 --- a/csrc/cpp_itfs/pa/pa_ragged.cpp.jinja +++ b/csrc/cpp_itfs/pa/pa_ragged.cpp.jinja @@ -69,8 +69,10 @@ void {{func_name}}(void* out_ptr, constexpr int NTHR = 256; dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); + constexpr int VERSION_ID = {{"0" if version == 'GOLDEN' else "1" if version == 'EXPERIMENTAL' else "0"}}; - paged_attention_ll4mi_QKV_mfma16_kernel<{{dtype}}, + paged_attention_ll4mi_QKV_mfma16_kernel(seq_idx * MTP); const int* block_table_seq = kv_page_indices + kv_indptr[seq_idx]; - _paged_attention_kernel(block_table_seq, query_loc, context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, 0); + + if constexpr (VERSION_ID == 0) // 0: GOLDEN VERSION + { + _paged_attention_kernel + (block_table_seq, query_loc, context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, 0); + } + else // Experimental VERSION: head_dim 128 + { + _paged_attention_kernel_EXPERIMENTAL + (block_table_seq, query_loc, context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, q_scale_ptr, k_scale_ptr, v_scale_ptr, variant, 0); + } } // Grid: (num_heads, num_seqs, mtp). @@ -135,7 +146,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kern #else // !defined(__HIP__MI3XX_MI250__) TODO: Add NAVI support -template torch.dtype: + if isinstance(cache_dtype, str): + if cache_dtype == "auto": + if isinstance(model_dtype, str): + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + elif isinstance(model_dtype, torch.dtype): + torch_dtype = model_dtype + else: + raise ValueError(f"Invalid model dtype: {model_dtype}") + elif cache_dtype in ["half", "bfloat16", "float"]: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + elif cache_dtype == "fp8": + torch_dtype = torch.uint8 + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + elif isinstance(cache_dtype, torch.dtype): + torch_dtype = cache_dtype + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + +def kv_cache_factory_v2( + num_blocks: int, + page_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + + if cache_dtype == "fp8" and head_size % 16: + raise ValueError( + f"Does not support key cache of type fp8 with head_size {head_size}" + ) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + key_cache_shape = (num_blocks, 1, num_heads, head_size) + key_caches: List[torch.Tensor] = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_cache.uniform_(*uniform_range) + else: + raise ValueError(f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, 1, num_heads, head_size) + value_caches: List[torch.Tensor] = [] + for _ in range(num_layers): + value_cache = torch.empty( + size=value_cache_shape, dtype=torch_dtype, device=device + ) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + value_cache.uniform_(*uniform_range) + else: + raise ValueError(f"Does not support value cache of type {cache_dtype}") + value_caches.append(value_cache) + return key_caches, value_caches + +def kv_ptr_factory( + num_seqs: int, + ctx_lens: int, + page_size: int, + )-> Tuple[List[torch.Tensor], List[torch.Tensor]]: + # kv_indptr + num_blocks_list = [ctx_lens] * num_seqs + kv_indptr = torch.tensor([0] + num_blocks_list).cumsum( + dim=0, dtype=torch.int + ) + + # kv_page_indices + padded_ctx_lens = page_size * int(np.ceil(ctx_lens / page_size)) # e.g., ctx_lens=10, page_size=3 --> padded_ctx_lens=12 + index_total = num_seqs * padded_ctx_lens + head_per_row = int(np.ceil(ctx_lens / page_size)) + head_total = num_seqs * int(np.ceil(ctx_lens / page_size)) + + # Generate heads (Start from 0, page_size, 2xpage_size, ...) + all_heads = np.arange(0, index_total, page_size) + np.random.shuffle(all_heads) + row_chunks = all_heads.reshape(num_seqs, head_per_row) + + # Sort the chunks since the page indices are in ascending order. + sorted_row_heads = np.sort(row_chunks, axis=1) + extended_heads = np.repeat(sorted_row_heads, page_size, axis=1) + + # Create Offset matrix with shape (1, length) + offset = np.tile(np.arange(page_size), head_per_row) + print(sorted_row_heads.shape) + + # shape (bs, length) + offset_tile = np.tile(offset, (num_seqs, 1)) + + # Extend sorted_row_heads + + kv_page_indices = extended_heads + offset_tile + kv_page_indices = kv_page_indices[:, :ctx_lens] + kv_page_indices = torch.from_numpy(kv_page_indices).to(device='cuda:0', dtype=torch.int32) + return kv_indptr, kv_page_indices.reshape(-1) + +def run_aiter( + output, + workspace_buffer, + query, + key_cache, + value_cache, + scale, + kv_indptr, + kv_page_indices, + kv_last_page_len, + # page_size, # New args + block_size, + max_num_partitions, + alibi_slopes, + kv_cache_dtype, + kv_cache_layout, + logits_soft_cap, + k_scale, + v_scale, + fp8_out_scale, + _PARTITION_SIZE_ROCM, + version="GOLDEN", +): + os.environ['QKV_VERSION'] = version + torch.ops.aiter.paged_attention_ragged( + output, + workspace_buffer, + query, + key_cache, + value_cache, + scale, + kv_indptr, + kv_page_indices, + kv_last_page_len, + # page_size, # New args + block_size, + max_num_partitions, + alibi_slopes, + kv_cache_dtype, + kv_cache_layout, + logits_soft_cap, + k_scale, + v_scale, + fp8_out_scale, + _PARTITION_SIZE_ROCM, + ) + + return workspace_buffer, output + + + +def test_paged_attention( + in_pt:str, + ctx_lens: int, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + page_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + kv_cache_layout: str, + logits_soft_cap: float, + pa_variant: PAVariant, + quant_cache_dtype: torch.dtype, + seed: int, + device: str, + warmup_iter: int, +) -> None: + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + torch.set_default_device(device) + block_size = 1 + + if in_pt == None: + # Using default kv_scale + k_scale = v_scale = torch.tensor([1.0], dtype=dtypes.fp32) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=dtypes.fp32) + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + max_seq_len = ctx_lens + padded_ctx_lens = page_size * int(np.ceil(max_seq_len / page_size)) # e.g., + num_blocks = padded_ctx_lens * num_seqs + + # prepare inputs & golden output + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + query.uniform_(*uniform_range) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory_v2( + num_blocks, + page_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + device, + ) + key_cache, value_cache = key_caches[0], value_caches[0] + kv_indptr, kv_page_indices = kv_ptr_factory(num_seqs, ctx_lens, page_size) + kv_last_page_len = torch.tensor([block_size for i in range(num_seqs)], dtype=torch.int) + block_size = key_cache.shape[2 if kv_cache_layout == "HND" else 1] + else: # Load from pt + gpu_index = torch.cuda.current_device() + TARGET_DEVICE = torch.device(f'cuda:{gpu_index}') + + data = torch.load(in_pt) + query = data['q'].clone().detach().to(TARGET_DEVICE) + workspace = torch.empty(*data['workspace_buffer_shape']).to(TARGET_DEVICE) + key_cache = torch.empty(*data['k_buffer_shape']).to(TARGET_DEVICE) + value_cache = torch.empty(*data['v_buffer_shape']).to(TARGET_DEVICE) + kv_indptr = data['kv_indptr'].clone().detach().to(TARGET_DEVICE) + kv_page_indices = data['kv_indices'].clone().detach().to(TARGET_DEVICE) + kv_last_page_len = data['kv_last_page_len'].clone().detach().to(TARGET_DEVICE) + page_size = data['page_size'] + block_size = data['block_size'] + max_seq_len = kv_indptr[1] - kv_indptr[0] + kv_cache_dtype = data['kv_cache_dtype'] + kv_cache_layout = data['kv_cache_layout'] + scale = data['scale'] + alibi_slopes = data['alibi_slopes'].to(TARGET_DEVICE) if isinstance(data['alibi_slopes'], torch.Tensor) else data['alibi_slopes'] + logits_soft_cap = data['logits_soft_cap'] + k_scale = data['k_scale'] + v_scale = data['v_scale'] + + + _PARTITION_SIZE_ROCM = 256 + fp8_out_scale = None + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ( + max_seq_len + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM + assert _PARTITION_SIZE_ROCM % block_size == 0 + + # will use single workspace buffer to accommodate following 3 intermediate tensors: + # 1. tmp_output (shape=(num_seqs, num_heads, max_num_partitions, head_size), dtype=output.dtype) + # 2. exp_sums (shape=(num_seqs, num_heads, max_num_partitions), dtype=float32) + # 3. max_logits (shape=(num_seqs, num_heads, max_num_partitions), dtype=float32) + output = torch.empty_like(query) + nbyes_per_qo_elem = torch.finfo(output.dtype).bits // 8 + workspace_buffer = torch.empty( + (num_seqs * num_heads * max_num_partitions * head_size) * nbyes_per_qo_elem + + 2 * (num_seqs * num_heads * max_num_partitions) * 4, + dtype=torch.uint8, + device=output.device, + ) + + cpa_fp8_out = False + if fp8_out_scale is not None: + output = torch.empty_like(output, dtype=dtypes.fp8) + cpa_fp8_out = True + torch.cuda.synchronize() + + # Debug + print(f"[DEBUG pa_unit_test.py] value_cache.is_contiguous()={value_cache.is_contiguous()}") + print(f"[DEBUG] kv_indptr.shape={kv_indptr.shape}, kv_page_indices.shape={kv_page_indices.shape}, kv_last_page_len.shape={kv_last_page_len.shape}") + print(f"[DEBUG] key_cache.shape={key_cache.shape}, value_cache.shape={value_cache.shape}") + print(f"[DEBUG] kv_page_indices={kv_page_indices}") + print(f"[DEBUG] kv_indptr[-10:]={kv_indptr[-10:]}") + print(f"[DEBUG] kv_page_indices.max()={kv_page_indices.max()}, num_seqs*ctx_lens={num_seqs*ctx_lens}") + # print(f"[DEBUG] kv_last_page_len={kv_last_page_len}") + # print(f"kv_indptr={kv_indptr}") + + + ARGS_TUPLE = ( + output, + workspace_buffer, + query, + key_cache.contiguous(), + value_cache.contiguous(), + scale, + kv_indptr, + kv_page_indices, + kv_last_page_len, + # page_size, # New args + block_size, + max_num_partitions, + alibi_slopes, + kv_cache_dtype, + kv_cache_layout, + logits_soft_cap, + k_scale, + v_scale, + fp8_out_scale if cpa_fp8_out else None, + _PARTITION_SIZE_ROCM, + ) + + # Warmup + for i in range(warmup_iter): + _, _ = run_aiter(*ARGS_TUPLE, version='GOLDEN') + _, _ = run_aiter(*ARGS_TUPLE, version='EXPERIMENTAL') + workspace_golden, out_golden = run_aiter(*ARGS_TUPLE, version='GOLDEN') + workspace_experi, out_experi = run_aiter(*ARGS_TUPLE, version='EXPERIMENTAL') + + # Grok1-bf16-TP8 + bs512-ilen2048: + # num_seqs=512, num_heads=6, max_num_partitions=8, head_size=128, nbyes_per_qo_elem=2 + + # workspace_buffer size: from pa_ragged.cpp.jinja + # exp_sums_ptr: = (num_seqs * num_heads * max_num_partitions) * 4 as type is float + # = 512*6*8*4 bytes + # max_logits_ptr:= (num_seqs * num_heads * max_num_partitions) * 4 as type is float + # = 512*6*8*4 bytes + # tmp_out_ptr: = (num_seqs * num_heads * max_num_partitions * head_size) * nbyes_per_qo_elem + # = 512*6*8*128*2 bytes + # output size = torch.empty_like(query), dtype=dtype + num_seqs, num_heads, head_size = query.shape + block_size = key_cache.shape[2 if kv_cache_layout == "HND" else 1] + _PARTITION_SIZE_ROCM = 256 + max_num_partitions = ( + max_seq_len + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM + nbyes_per_qo_elem = torch.finfo(query.dtype).bits // 8 + bytes_sizes = [num_seqs * num_heads * max_num_partitions * 4, + num_seqs * num_heads * max_num_partitions * 4, + num_seqs * num_heads * max_num_partitions * head_size * nbyes_per_qo_elem] + # print(f"[DEBUG] num_seqs={num_seqs}, num_heads={num_heads}, block_size={block_size}, " + # f"max_num_partitions={max_num_partitions}, head_size={head_size}, " + # f"nbyes_per_qo_elem={nbyes_per_qo_elem}, " + # f"_PARTITION_SIZE_ROCM={_PARTITION_SIZE_ROCM}") + + target_dtypes = [torch.float, torch.float, torch.bfloat16] + import itertools + accu_bytes = list(itertools.accumulate(bytes_sizes, initial=0)) + def split_workspace(workspace): + blocks = [] + for i in range(len(bytes_sizes)): + start_byte_idx = accu_bytes[i] + end_byte_idx = accu_bytes[i+1] + byte_slice = workspace[start_byte_idx:end_byte_idx] + block = byte_slice.view(target_dtypes[i]) + blocks.append(block) + return blocks + + def NumericCheck( + golden_tensor: torch.Tensor, + experi_tensor: torch.Tensor, + name: str = "Tensor", + rtol: float = 1e-5, + atol: float = 1e-8, + max_display: int = 5): + golden_tensor = golden_tensor.reshape(-1) + experi_tensor = experi_tensor.reshape(-1) + mismatch_mask = torch.abs(golden_tensor - experi_tensor) > (atol + rtol * torch.abs(experi_tensor)) + mismatch_indices = torch.nonzero(mismatch_mask, as_tuple=False) + mismatch_count = mismatch_mask.sum().item() + if mismatch_count > 0: + num_to_display = min(5, mismatch_count) + print(f"Numeric Check [{name} Failed] Elem count: {exp_sums_golden.numel()}, mismatch_count = {mismatch_count}") + for i in range(num_to_display): + idx = mismatch_indices[i].item() + golden_val = exp_sums_golden[idx].item() + experi_val = exp_sums_experi[idx].item() + abs_diff = abs(golden_val - experi_val) + print(f" Index [{idx}]: Golden={golden_val:.6e}, experi={experi_val:.6e}, Abs Diff={abs_diff:.2e}") + else: + print(f"Numeric Check [{name} Success]") + + exp_sums_golden, max_logits_golden, tmp_out_golden = split_workspace(workspace_golden) + exp_sums_experi, max_logits_experi, tmp_out_experi = split_workspace(workspace_experi) + NumericCheck(exp_sums_golden, exp_sums_experi, "exp_sums") + NumericCheck(max_logits_golden, max_logits_experi, "max_logits") + NumericCheck(tmp_out_golden, tmp_out_experi, "tmp_out") + NumericCheck(out_golden, out_experi, "out") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + description="Test Paged Attention ragged.", + ) + parser.add_argument( + "-c", + "--ctx_len", + type=int, + default=2048, + help="""Context length. + e.g. -c 128""", + ) + parser.add_argument( + "-p", + "--pa_variant", + type=str, + choices=[member.name for member in PAVariant], + default=[PAVariant.Shomy, PAVariant.Asm], + nargs="*", + help="It is not used. Just place an empty str", + ) + parser.add_argument( + "-q", + "--quant_cache_dtype", + type=str, + choices=["none", "fp8", "i8"], + default=["none", "fp8", "i8"], + nargs="*", + help="""Quantization cache dtype. + e.g. -q fp8""", + ) + parser.add_argument( + "-n", + type=int, + default=512, + help="number of seqs", + ) + parser.add_argument( + "--page-size", + type=int, + default=16, + help="block size(page size)", + ) + parser.add_argument( + "--in-pt", + type=str, + default=None, + help="Load data from pt file" + ) + parser.add_argument( + "--warmup", + type=int, + default=5, + help="warmup iterations", + ) + torch.set_printoptions(sci_mode=False) + args = parser.parse_args() + args.quant_cache_dtype = [ + None if i == "none" else dtypes.d_dtypes[i] for i in args.quant_cache_dtype + ] + + ctx_len = args.ctx_len + pa_variant = args.pa_variant + quant_cache_dtype = args.quant_cache_dtype + # print(f"[DEBUG pa_unit_test.py] ctx_len={ctx_len}, pa_variant={pa_variant}, quant_cache_dtype={quant_cache_dtype}") + + page_size = args.page_size # Original block size is 1 + test_paged_attention( + args.in_pt, + ctx_len, + args.n, + (6, 1), # num_heads: query and KV + 128, # head_size + False, # use_alibi + page_size, + dtypes.bf16, # dtype + "auto", # kv_cache_dtype + "NHD", # kv_cache_layout + 30.0, # logits_soft_cap + pa_variant, + quant_cache_dtype, + 0, # seed + "cuda:0", # device + args.warmup + ) + + + +''' +# Even if the input length is 256, I use "context length = 2048" +# since I would like to know the performance of the kernel when +# the KV cache is longer than prefill 256 tokens. +python ~/Grok_SGLang0.4.9/pa_unit_test_v2.py -n 512 -c 2048 --page-size 1 --warmup 0 + +# E2E +RCCL_MSCCL_ENABLE=0 SGLANG_USE_AITER=1 SGLANG_INT4_WEIGHT=1 python -m sglang.bench_one_batch \ + --batch-size 512 --input 256 --output 2048 --tp 8 --quantization fp8 --trust-remote-code \ + --model /data/huggingface/hub/amd/grok-1-W4A8KV8 \ + --tokenizer-path /data/huggingface/hub/Xenova/grok-1-tokenizer \ + --attention-backend aiter + + +''' \ No newline at end of file From 8730a2b91731bf7a16fda48143566336cb315724 Mon Sep 17 00:00:00 2001 From: Jacob Date: Mon, 24 Nov 2025 03:38:45 -0600 Subject: [PATCH 2/4] Set requirements for using EXPERIMENTAL kernel. Update comment in the kernel. --- csrc/cpp_itfs/pa/pa_kernels.cuh | 15 ++++++--------- csrc/cpp_itfs/pa/pa_ragged.py | 3 +++ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/cpp_itfs/pa/pa_kernels.cuh b/csrc/cpp_itfs/pa/pa_kernels.cuh index 99818f18f3..c995009e9e 100644 --- a/csrc/cpp_itfs/pa/pa_kernels.cuh +++ b/csrc/cpp_itfs/pa/pa_kernels.cuh @@ -1118,7 +1118,7 @@ __inline__ __device__ void _paged_attention_ll4mi_reduce_kernel( // ----------------------------------------------------------------------- // ----------------------------------------------------------------------- // ----------------------- Experimental ---------------------------------- -// Configs: head_dim=128, cache_t=bf16 +// Works for: head_dim=128, cache_t=bf16 // Feature: // 1. continuous threads work together to load K cache into LDS, then each thread save the LDS into registers. // 2. Double buffer of K cache loading @@ -1210,20 +1210,17 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // sub partition of tokens per warp for qk calculation constexpr int TOKENS_PER_WARP = T_PAR_SIZE / NWARPS; - // constexpr int TLOOP = TOKENS_PER_WARP / 16; // each mfma16x16x16 instruction processes 16 tokens - const int wg_start_head_idx = kv_head_idx * GQA_RATIO_PER_LOOP; // Jacob: kv_head_idx=0, GQA_RATIO_PER_LOOP=6 const int wg_start_kv_head_idx = kv_head_idx; const int total_num_heads = gridDim.z * GQA_RATIO; - // Jacob: some variables which are dedicated for for Grok1: HEAD_SIZE=128, cache_t=bf16, blockSize 16/64/256 + // HEAD_SIZE=128, cache_t=bf16, blockSize 16/64/256 constexpr int BYTES_PER_WARP_FETCH = WARP_SIZE * 16; // 1024 bytes constexpr int TOKEN_PER_WARP_FETCH = BYTES_PER_WARP_FETCH / (HEAD_SIZE * sizeof(cache_t)); // 4 token - // constexpr int TLOOP = TOKENS_PER_WARP / TOKEN_PER_WARP_FETCH; // 16 - // 1st Wavefront loads token 1~4 tokens, 17~20 tokens ... - // 2nd Wavefront loads token 5~8 tokens, 21~24 tokens ... - // 3rd Wavefront loads token 9~12 tokens, 25~28 tokens ... - // 4th Wavefront loads token 13~16 tokens, 29~32 tokens ... 61~64 tokens + // 1st Wavefront loads token 1~4 tokens, 65~68 tokens ... + // 2nd Wavefront loads token 5~8 tokens, 69~72 tokens ... + // 3rd Wavefront loads token 9~12 tokens, 73~76 tokens ... + // 4th Wavefront loads token 13~16 tokens, 77~80 tokens ... 253~256 tokens // The number of iterations to load 64 tokens constexpr int ITERS_16TK = 64 / (TOKEN_PER_WARP_FETCH * NWARPS); // 4 // The number of iterations of ITERS_16TK to load 256 tokens diff --git a/csrc/cpp_itfs/pa/pa_ragged.py b/csrc/cpp_itfs/pa/pa_ragged.py index 2d88a1b8c6..4cf50583f4 100644 --- a/csrc/cpp_itfs/pa/pa_ragged.py +++ b/csrc/cpp_itfs/pa/pa_ragged.py @@ -27,6 +27,9 @@ def compile( ): import os version = os.getenv('QKV_VERSION', 'GOLDEN') + if version == 'EXPERIMENTAL': + if head_size != 128 or kv_dtype!="__hip_bfloat16": + print("EXPERIMENTAL pa_ragged kernel requires head_size=128 and kv_dtype=bf16. Fallback to original kernel") return compile_template_op( src_template, From e44d5c395a5e100aca4649bcc13b7dcff1f38681 Mon Sep 17 00:00:00 2001 From: Jacob Date: Tue, 25 Nov 2025 07:37:16 +0000 Subject: [PATCH 3/4] Format the files using black --- csrc/cpp_itfs/pa/pa_kernels.cuh | 162 +++++++++--------- csrc/cpp_itfs/pa/pa_ragged.py | 11 +- op_tests/test_pa_ragged_experimental.py | 211 ++++++++++++++---------- 3 files changed, 208 insertions(+), 176 deletions(-) diff --git a/csrc/cpp_itfs/pa/pa_kernels.cuh b/csrc/cpp_itfs/pa/pa_kernels.cuh index c995009e9e..05f4230919 100644 --- a/csrc/cpp_itfs/pa/pa_kernels.cuh +++ b/csrc/cpp_itfs/pa/pa_kernels.cuh @@ -200,7 +200,7 @@ _paged_attention_kernel(const int* block_table_seq, // set to true to enable non temporal kv loads: has some benefit in very high // batch size cases - constexpr bool NT_KV_LOAD = false; + constexpr bool NT_KV_LOAD = true; constexpr int KX = 16 / sizeof(cache_t); // vLLM defines x as 16 Bytes of kv cache elements const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; @@ -1141,7 +1141,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( const int partition_start_token_idx, const scalar_t* q, const cache_t* k_cache, - const cache_t* v_cache, + const cache_t* v_cache, const float scale, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -1164,7 +1164,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( const int seq_idx = blockIdx.x; const int partition_idx = blockIdx.y; const int kv_head_idx = blockIdx.z; - constexpr int T_PAR_SIZE = 256; + constexpr int T_PAR_SIZE = 256; constexpr int NWARPS = NUM_THREADS / WARP_SIZE; constexpr int HEAD_LOOP = DIVIDE_ROUND_UP(HEAD_SIZE, 128); constexpr int HEAD_SIZE_PER_LOOP = DIVIDE_ROUND_UP(HEAD_SIZE, HEAD_LOOP); @@ -1191,15 +1191,15 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // HeadElements in each lane, 4x16B HeadElements across 4 rows of warp // rows refers to 16 lanes; refer dpp terminology - constexpr int ROWS_PER_WARP = WARP_SIZE / 16; + constexpr int ROWS_PER_WARP = WARP_SIZE / 16; // 8 for 16 bit cache type, 16 for 8 bit types - constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = 16 / sizeof(cache_t); + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = 16 / sizeof(cache_t); // each fetch across a warp fetches these many elements - constexpr int QKHE_PER_FETCH = CONTIGUOUS_KV_ELEMS_16B_LOAD * ROWS_PER_WARP; + constexpr int QKHE_PER_FETCH = CONTIGUOUS_KV_ELEMS_16B_LOAD * ROWS_PER_WARP; // 1 for 16bit types, 2 for 8bit types - constexpr int QK_SIZE_RATIO = sizeof(scalar_t) / sizeof(cache_t); + constexpr int QK_SIZE_RATIO = sizeof(scalar_t) / sizeof(cache_t); // 4xQKHE_16B across warp - constexpr int QKHELOOP = HEAD_SIZE_PER_LOOP / QKHE_PER_FETCH; + constexpr int QKHELOOP = HEAD_SIZE_PER_LOOP / QKHE_PER_FETCH; _B16x8 Qlocal[GQA_RATIO_LOOP][HEAD_LOOP][MTP_PER_THREAD][QKHELOOP] // Jacob: 1x1x1x4x1 [QK_SIZE_RATIO]; // note that 16 contiguous elements of Q should @@ -1209,7 +1209,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); // sub partition of tokens per warp for qk calculation - constexpr int TOKENS_PER_WARP = T_PAR_SIZE / NWARPS; + constexpr int TOKENS_PER_WARP = T_PAR_SIZE / NWARPS; const int wg_start_head_idx = kv_head_idx * GQA_RATIO_PER_LOOP; // Jacob: kv_head_idx=0, GQA_RATIO_PER_LOOP=6 const int wg_start_kv_head_idx = kv_head_idx; const int total_num_heads = gridDim.z * GQA_RATIO; @@ -1225,7 +1225,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( constexpr int ITERS_16TK = 64 / (TOKEN_PER_WARP_FETCH * NWARPS); // 4 // The number of iterations of ITERS_16TK to load 256 tokens constexpr int TLOOP = 256 / 64; // 4 - constexpr int TOKEN_PER_WG = TOKEN_PER_WARP_FETCH * NWARPS; // 16 tokens per workgroup + constexpr int TOKEN_PER_WG = TOKEN_PER_WARP_FETCH * NWARPS; // 16 tokens per workgroup constexpr int THREAD_PER_TOKEN = HEAD_SIZE / CONTIGUOUS_KV_ELEMS_16B_LOAD; // if HEAD_SIZE=128, bf16 --> 16 threads load 1 token /// NOTICE: We don't support mask for this kernel, so just use a placeholder type/object here. @@ -1259,19 +1259,19 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( kglobal_token_idx & (BLOCK_SIZE - 1); kphysical_block_number[token_depth][iter_16tk] = block_table_seq[kblock_idx]; kphysical_offset[token_depth][iter_16tk] = kblock_offset; - + // if (threadIdx.x %16==0 && /*token_depth==0 &&*/ iter_16tk==0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("[HIP] [K block] threadIdx=%3d, token_depth=%d, iter_16tk=%d, " // "[kblock_idx=%3d, kpage_idx=%3d, kpage_offset=%2d], last_ctx_block=%3d, last_ctx_page=%3d, " // "block idx=%3d, kblock_offset=%3d, BLOCK_SIZE=%d, PAGE_SIZE=%d\n", - // threadIdx.x, token_depth, iter_16tk, + // threadIdx.x, token_depth, iter_16tk, // kblock_idx, kpage_idx, kpage_offset, last_ctx_block, last_ctx_page, - // block_table_seq[kblock_idx], kblock_offset, BLOCK_SIZE, PAGE_SIZE); + // block_table_seq[kblock_idx], kblock_offset, BLOCK_SIZE, PAGE_SIZE); // } - } + } __builtin_amdgcn_sched_group_barrier(0x0020, ITERS_16TK, 0); // VMEM read } - + // fetch Q in shared across warps and then write to registers const int warp_mtp_idx = warpid / (4 / MTP_PARALLEL_THREADS); // Jacob: MTP_PARALLEL_THREADS=1, warpid=0, warp_mtp_idx=0 @@ -1284,13 +1284,13 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( constexpr int mtp_loop = MTP_PER_THREAD; // Jacob: q_stride = GQA_RATIO * HEAD_SIZE = 6*128=768 // each thread local 8 data, so 16 threads can load the full 128 HEAD_SIZE of Q - // As GQA_RATIO=6, we need 96 threads to load 6x128 q data + // As GQA_RATIO=6, we need 96 threads to load 6x128 q data // Btw, a block with 256 threads can load 16x128 q data for(int mtp = 0; mtp < mtp_loop; mtp++) { // 1 for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { // 1 const scalar_t* q_ptr = q + (query_start_off + mtp * MTP_PARALLEL_THREADS) * q_stride + (global_qhead_idx + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * HEAD_SIZE; - + for(int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) { // 1 const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B + head_loop * HEAD_SIZE_PER_LOOP; if ((local_mtp_qhead_idx < GQA_RATIO_MTP_PARALLEL) && (qhead_element < HEAD_SIZE)) { @@ -1319,19 +1319,19 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( } } __syncthreads(); - + // qk mfma constexpr bool NT_KV_LOAD = true; constexpr int KX = 16 / sizeof(cache_t); // vLLM defines x as 16 Bytes of kv cache elements const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; - const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; int curr = 0, next = 1; __shared__ cache_t Kbuffer_lds[2][HEAD_LOOP][64][HEAD_SIZE]; // curr and next K buffer, each load 64 tokens K cache // Each warp processes 16x128 and it was divided into 4 mfma 16x(4x32), each thread records 4x CONTIGUOUS_KV_ELEMS_16B_LOAD elems - _B16x8 Kbuffer_reg[2][HEAD_LOOP][QKHELOOP]; - + _B16x8 Kbuffer_reg[2][HEAD_LOOP][QKHELOOP]; + // qk mfma - define lambda: loading K cache int DEBUG_TOKEN_DEPTH=0; constexpr int n_global_load_per_fragment = HEAD_LOOP*ITERS_16TK; @@ -1342,49 +1342,49 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( for(int iter_16tk = 0; iter_16tk < ITERS_16TK; iter_16tk++){ // 4 const int64_t kpage_number = static_cast(kpage[iter_16tk]); const int64_t kpage_offset = static_cast(koffset[iter_16tk]); - const int offset = - kpage_number * kv_block_stride + - kpage_offset * HEAD_SIZE + + const int offset = + kpage_number * kv_block_stride + + kpage_offset * HEAD_SIZE + lane16id * CONTIGUOUS_KV_ELEMS_16B_LOAD; const _B16x8* k_ptr_B16x8 = reinterpret_cast(k_ptr + offset); - // Save to LDS + // Save to LDS const int token_row = iter_16tk * 16 + threadIdx.x / 16; - const int cache_offset = lane16id * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const int cache_offset = lane16id * CONTIGUOUS_KV_ELEMS_16B_LOAD; if constexpr (NT_KV_LOAD) - *reinterpret_cast<_B16x8*>(&Kbuffer_lds[buf_idx][head_loop][token_row][cache_offset]) = + *reinterpret_cast<_B16x8*>(&Kbuffer_lds[buf_idx][head_loop][token_row][cache_offset]) = load_ntmprl_16Byte(k_ptr_B16x8); else - *reinterpret_cast<_B16x8*>(&Kbuffer_lds[buf_idx][head_loop][token_row][cache_offset]) = + *reinterpret_cast<_B16x8*>(&Kbuffer_lds[buf_idx][head_loop][token_row][cache_offset]) = *k_ptr_B16x8; - + // if (iter_16tk<2 && DEBUG_TOKEN_DEPTH<2 && threadIdx.x < 128 && // blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("[LDS K cache] threadIdx=%3d, iter_16tk=%d, kblock_number=%3ld, " // "kblock_offset=%3ld, Save to LDS token_row=%3d, LDS cache_offset=%3d, " // "addr offset=%d, val=%f \n", // threadIdx.x, iter_16tk, kblock_number, - // kblock_offset, token_row, cache_offset, - // offset, __bfloat162float(*(k_ptr + offset))); - // DEBUG_TOKEN_DEPTH += 1; - // } - // if (DEBUG_TOKEN_DEPTH==0 && token_row == 0 && + // kblock_offset, token_row, cache_offset, + // offset, __bfloat162float(*(k_ptr + offset))); + // DEBUG_TOKEN_DEPTH += 1; + // } + // if (DEBUG_TOKEN_DEPTH==0 && token_row == 0 && // blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("[LDS K cache] threadIdx=%3d, iter_16tk=%d, kblock_number=%3ld, " // "kblock_offset=%3ld, Save to LDS token_row=%3d, LDS cache_offset=%3d, " // "addr offset=%d, val=%f \n", // threadIdx.x, iter_16tk, kblock_number, - // kblock_offset, token_row, cache_offset, - // offset, __bfloat162float(*(k_ptr + offset))); - // DEBUG_TOKEN_DEPTH += 1; + // kblock_offset, token_row, cache_offset, + // offset, __bfloat162float(*(k_ptr + offset))); + // DEBUG_TOKEN_DEPTH += 1; // } // if (token_row == 0 && cache_offset == 64 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("[LDS K cache] threadIdx=%3d, iter_16tk=%d, kblock_number=%3ld, " // "kblock_offset=%3ld, LDS token_row=%3d, LDS cache_offset=%3d, " // "addr offset=%d, val=%f \n", // threadIdx.x, iter_16tk, kblock_number, - // kblock_offset, token_row, cache_offset, - // offset, __bfloat162float(*(k_ptr + offset))); + // kblock_offset, token_row, cache_offset, + // offset, __bfloat162float(*(k_ptr + offset))); // } } } @@ -1404,12 +1404,12 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( const int col_16tk_offset = qkhe_depth * 32; // A matrix is 16x32. 32 columns const int col_thread_offset = CONTIGUOUS_KV_ELEMS_16B_LOAD * rowid; const int col = col_16tk_offset + col_thread_offset; - Kbuffer_reg[buf_idx][head_loop][qkhe_depth] = + Kbuffer_reg[buf_idx][head_loop][qkhe_depth] = *reinterpret_cast<_B16x8*>(&Kbuffer_lds[buf_idx][head_loop][row][col]); // _B16x8{{ // {static_cast(1),1,1,1},{1,1,1,1} // }}; - + // Check NAN // for(int x=0; x<2; ++x) // for(int y=0; y<4; ++y) @@ -1420,12 +1420,12 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // printf("[LDS-->Reg] threadIdx=%3d, qkhe_depth=%d, " // "Kbuffer_reg load from Kbuffer_lds[%d][%d][%d]\n", // threadIdx.x, qkhe_depth, - // buf_idx, row, col); + // buf_idx, row, col); // } } } }; - + // qk mfma - Preload the kphysical_block_number[0] cache load_K_fragment(k_ptr, curr, kphysical_block_number[0], kphysical_offset[0]); __builtin_amdgcn_sched_group_barrier(0x0020, n_global_load_per_fragment, 0); // VMEM read @@ -1469,19 +1469,19 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( __builtin_amdgcn_sched_group_barrier(0x0020, n_global_load_per_fragment, 0); // VMEM read } - for (int mtp = 0; mtp < mtp_loop; mtp++) { // 1 + for (int mtp = 0; mtp < mtp_loop; mtp++) { // 1 for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { // 1 d_out[gqa_ratio_loop][mtp][token_depth] = {0}; for (int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) { // 1 for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { // 4 for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { // Load Q from LDS - for (int i = 0; i < 2; i++) + for (int i = 0; i < 2; i++) Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth][qkratio].xy[i] = shared_logits[gqa_ratio_loop][head_loop][mtp][qkhe_depth][rowid] [lane16id % GQA_RATIO_MTP_PARALLEL][2 * qkratio + i]; __builtin_amdgcn_sched_group_barrier(0x0100, 2, 0); // LDS read - + // mfma if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { #if defined(__gfx950__) @@ -1499,7 +1499,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( } __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA #endif - + // Check value // for(int d=0; d<4; ++d) // if(isnan(d_out[gqa_ratio_loop][mtp][token_depth][d])){ @@ -1511,7 +1511,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // } // break; // } - + } else { // kv cache dtype fp8 auto Ktmp = Kbuffer_reg[curr][head_loop][qkhe_depth]; @@ -1540,12 +1540,12 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( } } } - + // DEBUG: check values // if (threadIdx.x==1 && token_depth<4 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // floatx4 data = d_out[gqa_ratio_loop][mtp][token_depth]; // printf("[check d_out] threadIdx.x=%d, BLOCK_SIZE=%d, d_out=%f,%f %f %f \n", - // threadIdx.x, BLOCK_SIZE ,data[0], data[1], data[2], data[3]); + // threadIdx.x, BLOCK_SIZE ,data[0], data[1], data[2], data[3]); // for(int y=0; yQueryTransform(variant_params, d_out[gqa_ratio_loop][mtp][token_depth][i]); @@ -1631,7 +1631,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( /*qo_head_idx=*/wg_start_head_idx + lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP, /*kv_head_idx=*/kv_head_idx); } - + } } } @@ -1645,14 +1645,14 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // __syncthreads(); // } // } - + // calculate qk_max and exp_sum per warp and write to shared memory float qk_max[GQA_RATIO_LOOP][MTP_PER_THREAD] = {-FLT_MAX}; float exp_sum[GQA_RATIO_LOOP][MTP_PER_THREAD] = {0.0f}; - + for (int mtp = 0; mtp < mtp_loop; mtp++) { - for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { - // Step 1.1 Get max qk per thread: + for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + // Step 1.1 Get max qk per thread: for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; for(int i=0; i<4; ++i){ @@ -1662,18 +1662,18 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( qk_max[gqa_ratio_loop][mtp] = fmaxf(qk_max[gqa_ratio_loop][mtp], tmp); } } - + // for(int d=0; d<4; ++d) // if(isnan(d_out[gqa_ratio_loop][mtp][token_depth][d])){ // floatx4 ddata = d_out[gqa_ratio_loop][mtp][token_depth]; // printf("qk_mfma+soft-capping is nan. d_out=%f %f %f %f\n", - // ddata[0], ddata[1], ddata[2], ddata[3]); + // ddata[0], ddata[1], ddata[2], ddata[3]); // break; // } // Step 1.2 Get max qk along q head under each wavefront // According to ROCm CDNA4 mfma16x16x32, The output dim of mfma(qk) is 16x16. - // Thread [1, 17, 33, 49] stores 1 column, 16 elements of mfma(K@Q.T). + // Thread [1, 17, 33, 49] stores 1 column, 16 elements of mfma(K@Q.T). // Use the following loop can get the max(thread1, thread17, thread33, thread49) // "mask >= 16" summed to 16 threads as 1 GQA_RATIO_LOOP process 16 q heads for (int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) { @@ -1699,21 +1699,21 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( for (int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) { exp_sum[gqa_ratio_loop][mtp] += __shfl_xor(exp_sum[gqa_ratio_loop][mtp], mask); } - } + } } // __syncthreads(); // sync before writing to shared mem // Why need sync here? no LDS ops before this line - + // Step 3. Save qk_max and exp_sum for the entire workgroup float* shared_mem = reinterpret_cast(shared_logits); if (laneid < 16) { for(int mtp = 0; mtp < mtp_loop; mtp++) { - for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { - const int qk_max_offset = - warpid * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + - (lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * MTP_PER_THREAD + + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + const int qk_max_offset = + warpid * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + + (lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * MTP_PER_THREAD + mtp; shared_mem[qk_max_offset] = qk_max[gqa_ratio_loop][mtp]; - const int exp_sum_offset = + const int exp_sum_offset = NWARPS * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + qk_max_offset; shared_mem[exp_sum_offset] = exp_sum[gqa_ratio_loop][mtp]; @@ -1721,7 +1721,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // printf("[qk_max + exp_sum] threadIdx=%3d, qk_max_offset=%3d, exp_sum_offset=%3d, " // "shared_mem[qk_max_offset]=%f, shared_mem[exp_sum_offset]=%f\n", // threadIdx.x, qk_max_offset, exp_sum_offset, - // shared_mem[qk_max_offset], shared_mem[exp_sum_offset]); + // shared_mem[qk_max_offset], shared_mem[exp_sum_offset]); // } } } @@ -1748,18 +1748,18 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( partition_exp_sum[gqa_ratio_loop][mtp] += shared_mem[NWARPS * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + w * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + (lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * MTP_PER_THREAD + mtp] * warp_qk_max_exp[w]; } - + inv_sum_scale[gqa_ratio_loop][mtp] = __fdividef(1.f, partition_exp_sum[gqa_ratio_loop][mtp] + 1e-6f) * warp_qk_max_exp[warpid]; - + // if (threadIdx.x < 256 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("threadIdx=%3d, warp_qk_max_exp=%f %f %f %f, partition_qk_max[%d][%d]=%f " - // "partition_exp_sum[%d][%d]=%f, " + // "partition_exp_sum[%d][%d]=%f, " // "inv_sum_scale[%d][%d]=%f\n", - // threadIdx.x, + // threadIdx.x, // warp_qk_max_exp[0], warp_qk_max_exp[1], warp_qk_max_exp[2], warp_qk_max_exp[3], gqa_ratio_loop, mtp, partition_qk_max[gqa_ratio_loop][mtp], - // gqa_ratio_loop, mtp, partition_exp_sum[gqa_ratio_loop][mtp], - // gqa_ratio_loop, mtp, inv_sum_scale[gqa_ratio_loop][mtp]); + // gqa_ratio_loop, mtp, partition_exp_sum[gqa_ratio_loop][mtp], + // gqa_ratio_loop, mtp, inv_sum_scale[gqa_ratio_loop][mtp]); // } } } @@ -1787,9 +1787,9 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( } } } - - // DEBUG: Get qk_max across blocks + + // DEBUG: Get qk_max across blocks // write out partition max_logits and exp_sum if (threadIdx.x < GQA_RATIO_MTP_PARALLEL) { for(int mtp = 0; mtp < mtp_loop; mtp++) { @@ -1935,7 +1935,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( } } - + _B16x4 outelems[GQA_RATIO_LOOP][MTP_PER_THREAD][VHELOOP]; // Softmax V mfma @@ -1965,7 +1965,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( Vlocal[vtoken_depth][vhe_depth][vfetch_depth], tmp_in, tmp_out); - #else + #else for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { const int offset = rowid * VTLANELOOP * ELEMS8_ELEMS4_RATIO + vfetch_depth * ELEMS8_ELEMS4_RATIO + i; @@ -2047,7 +2047,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { shared_logits[gqa_ratio_loop][0][mtp][warpid][vhe_depth][lane16id][rowid] = outelems[gqa_ratio_loop][mtp][vhe_depth]; - // if (threadIdx.x==0 && vhe_depth==0 && + // if (threadIdx.x==0 && vhe_depth==0 && // blockIdx.x == 1 && blockIdx.y == 0 && blockIdx.z == 0) { // _B16x4 data = outelems[gqa_ratio_loop][mtp][vhe_depth]; // uint16_t v0 = data[0]; @@ -2093,7 +2093,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( const int64_t hsz_maxp_mult = static_cast(HEAD_SIZE * max_num_partitions); - + scalar_t* out_ptr = out + (seq_idx + mtp * MTP_PARALLEL_THREADS) * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE; for (int h = 0; h < GQA_RATIO4; h++) { @@ -2106,7 +2106,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); *out_ptr_B16x8 = vout[h]; - // if (threadIdx.x<64 && + // if (threadIdx.x<64 && // blockIdx.x == 59 && blockIdx.y == 6 && blockIdx.z == 0) { // _B16x8 data_x8 = * out_ptr_B16x8; // uint16_t v[8]; @@ -2130,4 +2130,4 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( } } } -} \ No newline at end of file +} diff --git a/csrc/cpp_itfs/pa/pa_ragged.py b/csrc/cpp_itfs/pa/pa_ragged.py index 4cf50583f4..7349457e4f 100644 --- a/csrc/cpp_itfs/pa/pa_ragged.py +++ b/csrc/cpp_itfs/pa/pa_ragged.py @@ -26,10 +26,13 @@ def compile( func_name: str = None, ): import os - version = os.getenv('QKV_VERSION', 'GOLDEN') - if version == 'EXPERIMENTAL': - if head_size != 128 or kv_dtype!="__hip_bfloat16": - print("EXPERIMENTAL pa_ragged kernel requires head_size=128 and kv_dtype=bf16. Fallback to original kernel") + + version = os.getenv("QKV_VERSION", "GOLDEN") + if version == "EXPERIMENTAL": + if head_size != 128 or kv_dtype != "__hip_bfloat16": + print( + "EXPERIMENTAL pa_ragged kernel requires head_size=128 and kv_dtype=bf16. Fallback to original kernel" + ) return compile_template_op( src_template, diff --git a/op_tests/test_pa_ragged_experimental.py b/op_tests/test_pa_ragged_experimental.py index a0123ceea4..2114c5ddc9 100644 --- a/op_tests/test_pa_ragged_experimental.py +++ b/op_tests/test_pa_ragged_experimental.py @@ -15,6 +15,8 @@ from aiter import paged_attention_ragged uniform_range = (-1, 1) + + class PAVariant(Enum): Shomy = 1 Asm = 2 @@ -45,6 +47,7 @@ def get_kv_cache_torch_dtype( raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") return torch_dtype + def kv_cache_factory_v2( num_blocks: int, page_size: int, @@ -86,19 +89,20 @@ def kv_cache_factory_v2( value_caches.append(value_cache) return key_caches, value_caches + def kv_ptr_factory( - num_seqs: int, - ctx_lens: int, - page_size: int, - )-> Tuple[List[torch.Tensor], List[torch.Tensor]]: + num_seqs: int, + ctx_lens: int, + page_size: int, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: # kv_indptr num_blocks_list = [ctx_lens] * num_seqs - kv_indptr = torch.tensor([0] + num_blocks_list).cumsum( - dim=0, dtype=torch.int - ) + kv_indptr = torch.tensor([0] + num_blocks_list).cumsum(dim=0, dtype=torch.int) # kv_page_indices - padded_ctx_lens = page_size * int(np.ceil(ctx_lens / page_size)) # e.g., ctx_lens=10, page_size=3 --> padded_ctx_lens=12 + padded_ctx_lens = page_size * int( + np.ceil(ctx_lens / page_size) + ) # e.g., ctx_lens=10, page_size=3 --> padded_ctx_lens=12 index_total = num_seqs * padded_ctx_lens head_per_row = int(np.ceil(ctx_lens / page_size)) head_total = num_seqs * int(np.ceil(ctx_lens / page_size)) @@ -107,25 +111,28 @@ def kv_ptr_factory( all_heads = np.arange(0, index_total, page_size) np.random.shuffle(all_heads) row_chunks = all_heads.reshape(num_seqs, head_per_row) - + # Sort the chunks since the page indices are in ascending order. sorted_row_heads = np.sort(row_chunks, axis=1) extended_heads = np.repeat(sorted_row_heads, page_size, axis=1) - + # Create Offset matrix with shape (1, length) offset = np.tile(np.arange(page_size), head_per_row) print(sorted_row_heads.shape) - + # shape (bs, length) offset_tile = np.tile(offset, (num_seqs, 1)) - - # Extend sorted_row_heads - + + # Extend sorted_row_heads + kv_page_indices = extended_heads + offset_tile kv_page_indices = kv_page_indices[:, :ctx_lens] - kv_page_indices = torch.from_numpy(kv_page_indices).to(device='cuda:0', dtype=torch.int32) + kv_page_indices = torch.from_numpy(kv_page_indices).to( + device="cuda:0", dtype=torch.int32 + ) return kv_indptr, kv_page_indices.reshape(-1) + def run_aiter( output, workspace_buffer, @@ -149,7 +156,7 @@ def run_aiter( _PARTITION_SIZE_ROCM, version="GOLDEN", ): - os.environ['QKV_VERSION'] = version + os.environ["QKV_VERSION"] = version torch.ops.aiter.paged_attention_ragged( output, workspace_buffer, @@ -174,11 +181,10 @@ def run_aiter( ) return workspace_buffer, output - - + def test_paged_attention( - in_pt:str, + in_pt: str, ctx_lens: int, num_seqs: int, num_heads: Tuple[int, int], @@ -198,7 +204,7 @@ def test_paged_attention( torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) - torch.set_default_device(device) + torch.set_default_device(device) block_size = 1 if in_pt == None: @@ -212,9 +218,9 @@ def test_paged_attention( assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads max_seq_len = ctx_lens - padded_ctx_lens = page_size * int(np.ceil(max_seq_len / page_size)) # e.g., + padded_ctx_lens = page_size * int(np.ceil(max_seq_len / page_size)) # e.g., num_blocks = padded_ctx_lens * num_seqs - + # prepare inputs & golden output query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) query.uniform_(*uniform_range) @@ -233,31 +239,36 @@ def test_paged_attention( ) key_cache, value_cache = key_caches[0], value_caches[0] kv_indptr, kv_page_indices = kv_ptr_factory(num_seqs, ctx_lens, page_size) - kv_last_page_len = torch.tensor([block_size for i in range(num_seqs)], dtype=torch.int) + kv_last_page_len = torch.tensor( + [block_size for i in range(num_seqs)], dtype=torch.int + ) block_size = key_cache.shape[2 if kv_cache_layout == "HND" else 1] - else: # Load from pt + else: # Load from pt gpu_index = torch.cuda.current_device() - TARGET_DEVICE = torch.device(f'cuda:{gpu_index}') + TARGET_DEVICE = torch.device(f"cuda:{gpu_index}") data = torch.load(in_pt) - query = data['q'].clone().detach().to(TARGET_DEVICE) - workspace = torch.empty(*data['workspace_buffer_shape']).to(TARGET_DEVICE) - key_cache = torch.empty(*data['k_buffer_shape']).to(TARGET_DEVICE) - value_cache = torch.empty(*data['v_buffer_shape']).to(TARGET_DEVICE) - kv_indptr = data['kv_indptr'].clone().detach().to(TARGET_DEVICE) - kv_page_indices = data['kv_indices'].clone().detach().to(TARGET_DEVICE) - kv_last_page_len = data['kv_last_page_len'].clone().detach().to(TARGET_DEVICE) - page_size = data['page_size'] - block_size = data['block_size'] + query = data["q"].clone().detach().to(TARGET_DEVICE) + workspace = torch.empty(*data["workspace_buffer_shape"]).to(TARGET_DEVICE) + key_cache = torch.empty(*data["k_buffer_shape"]).to(TARGET_DEVICE) + value_cache = torch.empty(*data["v_buffer_shape"]).to(TARGET_DEVICE) + kv_indptr = data["kv_indptr"].clone().detach().to(TARGET_DEVICE) + kv_page_indices = data["kv_indices"].clone().detach().to(TARGET_DEVICE) + kv_last_page_len = data["kv_last_page_len"].clone().detach().to(TARGET_DEVICE) + page_size = data["page_size"] + block_size = data["block_size"] max_seq_len = kv_indptr[1] - kv_indptr[0] - kv_cache_dtype = data['kv_cache_dtype'] - kv_cache_layout = data['kv_cache_layout'] - scale = data['scale'] - alibi_slopes = data['alibi_slopes'].to(TARGET_DEVICE) if isinstance(data['alibi_slopes'], torch.Tensor) else data['alibi_slopes'] - logits_soft_cap = data['logits_soft_cap'] - k_scale = data['k_scale'] - v_scale = data['v_scale'] - + kv_cache_dtype = data["kv_cache_dtype"] + kv_cache_layout = data["kv_cache_layout"] + scale = data["scale"] + alibi_slopes = ( + data["alibi_slopes"].to(TARGET_DEVICE) + if isinstance(data["alibi_slopes"], torch.Tensor) + else data["alibi_slopes"] + ) + logits_soft_cap = data["logits_soft_cap"] + k_scale = data["k_scale"] + v_scale = data["v_scale"] _PARTITION_SIZE_ROCM = 256 fp8_out_scale = None @@ -285,18 +296,25 @@ def test_paged_attention( output = torch.empty_like(output, dtype=dtypes.fp8) cpa_fp8_out = True torch.cuda.synchronize() - + # Debug - print(f"[DEBUG pa_unit_test.py] value_cache.is_contiguous()={value_cache.is_contiguous()}") - print(f"[DEBUG] kv_indptr.shape={kv_indptr.shape}, kv_page_indices.shape={kv_page_indices.shape}, kv_last_page_len.shape={kv_last_page_len.shape}") - print(f"[DEBUG] key_cache.shape={key_cache.shape}, value_cache.shape={value_cache.shape}") + print( + f"[DEBUG pa_unit_test.py] value_cache.is_contiguous()={value_cache.is_contiguous()}" + ) + print( + f"[DEBUG] kv_indptr.shape={kv_indptr.shape}, kv_page_indices.shape={kv_page_indices.shape}, kv_last_page_len.shape={kv_last_page_len.shape}" + ) + print( + f"[DEBUG] key_cache.shape={key_cache.shape}, value_cache.shape={value_cache.shape}" + ) print(f"[DEBUG] kv_page_indices={kv_page_indices}") print(f"[DEBUG] kv_indptr[-10:]={kv_indptr[-10:]}") - print(f"[DEBUG] kv_page_indices.max()={kv_page_indices.max()}, num_seqs*ctx_lens={num_seqs*ctx_lens}") + print( + f"[DEBUG] kv_page_indices.max()={kv_page_indices.max()}, num_seqs*ctx_lens={num_seqs*ctx_lens}" + ) # print(f"[DEBUG] kv_last_page_len={kv_last_page_len}") # print(f"kv_indptr={kv_indptr}") - - + ARGS_TUPLE = ( output, workspace_buffer, @@ -313,7 +331,7 @@ def test_paged_attention( alibi_slopes, kv_cache_dtype, kv_cache_layout, - logits_soft_cap, + logits_soft_cap, k_scale, v_scale, fp8_out_scale if cpa_fp8_out else None, @@ -322,14 +340,14 @@ def test_paged_attention( # Warmup for i in range(warmup_iter): - _, _ = run_aiter(*ARGS_TUPLE, version='GOLDEN') - _, _ = run_aiter(*ARGS_TUPLE, version='EXPERIMENTAL') - workspace_golden, out_golden = run_aiter(*ARGS_TUPLE, version='GOLDEN') - workspace_experi, out_experi = run_aiter(*ARGS_TUPLE, version='EXPERIMENTAL') + _, _ = run_aiter(*ARGS_TUPLE, version="GOLDEN") + _, _ = run_aiter(*ARGS_TUPLE, version="EXPERIMENTAL") + workspace_golden, out_golden = run_aiter(*ARGS_TUPLE, version="GOLDEN") + workspace_experi, out_experi = run_aiter(*ARGS_TUPLE, version="EXPERIMENTAL") - # Grok1-bf16-TP8 + bs512-ilen2048: + # Grok1-bf16-TP8 + bs512-ilen2048: # num_seqs=512, num_heads=6, max_num_partitions=8, head_size=128, nbyes_per_qo_elem=2 - + # workspace_buffer size: from pa_ragged.cpp.jinja # exp_sums_ptr: = (num_seqs * num_heads * max_num_partitions) * 4 as type is float # = 512*6*8*4 bytes @@ -342,56 +360,71 @@ def test_paged_attention( block_size = key_cache.shape[2 if kv_cache_layout == "HND" else 1] _PARTITION_SIZE_ROCM = 256 max_num_partitions = ( - max_seq_len + _PARTITION_SIZE_ROCM - 1 - ) // _PARTITION_SIZE_ROCM + max_seq_len + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM nbyes_per_qo_elem = torch.finfo(query.dtype).bits // 8 - bytes_sizes = [num_seqs * num_heads * max_num_partitions * 4, - num_seqs * num_heads * max_num_partitions * 4, - num_seqs * num_heads * max_num_partitions * head_size * nbyes_per_qo_elem] + bytes_sizes = [ + num_seqs * num_heads * max_num_partitions * 4, + num_seqs * num_heads * max_num_partitions * 4, + num_seqs * num_heads * max_num_partitions * head_size * nbyes_per_qo_elem, + ] # print(f"[DEBUG] num_seqs={num_seqs}, num_heads={num_heads}, block_size={block_size}, " # f"max_num_partitions={max_num_partitions}, head_size={head_size}, " # f"nbyes_per_qo_elem={nbyes_per_qo_elem}, " # f"_PARTITION_SIZE_ROCM={_PARTITION_SIZE_ROCM}") - + target_dtypes = [torch.float, torch.float, torch.bfloat16] import itertools + accu_bytes = list(itertools.accumulate(bytes_sizes, initial=0)) + def split_workspace(workspace): blocks = [] for i in range(len(bytes_sizes)): start_byte_idx = accu_bytes[i] - end_byte_idx = accu_bytes[i+1] + end_byte_idx = accu_bytes[i + 1] byte_slice = workspace[start_byte_idx:end_byte_idx] block = byte_slice.view(target_dtypes[i]) blocks.append(block) return blocks - + def NumericCheck( golden_tensor: torch.Tensor, experi_tensor: torch.Tensor, name: str = "Tensor", rtol: float = 1e-5, atol: float = 1e-8, - max_display: int = 5): + max_display: int = 5, + ): golden_tensor = golden_tensor.reshape(-1) experi_tensor = experi_tensor.reshape(-1) - mismatch_mask = torch.abs(golden_tensor - experi_tensor) > (atol + rtol * torch.abs(experi_tensor)) + mismatch_mask = torch.abs(golden_tensor - experi_tensor) > ( + atol + rtol * torch.abs(experi_tensor) + ) mismatch_indices = torch.nonzero(mismatch_mask, as_tuple=False) mismatch_count = mismatch_mask.sum().item() if mismatch_count > 0: num_to_display = min(5, mismatch_count) - print(f"Numeric Check [{name} Failed] Elem count: {exp_sums_golden.numel()}, mismatch_count = {mismatch_count}") + print( + f"Numeric Check [{name} Failed] Elem count: {exp_sums_golden.numel()}, mismatch_count = {mismatch_count}" + ) for i in range(num_to_display): idx = mismatch_indices[i].item() golden_val = exp_sums_golden[idx].item() experi_val = exp_sums_experi[idx].item() abs_diff = abs(golden_val - experi_val) - print(f" Index [{idx}]: Golden={golden_val:.6e}, experi={experi_val:.6e}, Abs Diff={abs_diff:.2e}") + print( + f" Index [{idx}]: Golden={golden_val:.6e}, experi={experi_val:.6e}, Abs Diff={abs_diff:.2e}" + ) else: print(f"Numeric Check [{name} Success]") - exp_sums_golden, max_logits_golden, tmp_out_golden = split_workspace(workspace_golden) - exp_sums_experi, max_logits_experi, tmp_out_experi = split_workspace(workspace_experi) + exp_sums_golden, max_logits_golden, tmp_out_golden = split_workspace( + workspace_golden + ) + exp_sums_experi, max_logits_experi, tmp_out_experi = split_workspace( + workspace_experi + ) NumericCheck(exp_sums_golden, exp_sums_experi, "exp_sums") NumericCheck(max_logits_golden, max_logits_experi, "max_logits") NumericCheck(tmp_out_golden, tmp_out_experi, "tmp_out") @@ -443,10 +476,7 @@ def NumericCheck( help="block size(page size)", ) parser.add_argument( - "--in-pt", - type=str, - default=None, - help="Load data from pt file" + "--in-pt", type=str, default=None, help="Load data from pt file" ) parser.add_argument( "--warmup", @@ -465,31 +495,30 @@ def NumericCheck( quant_cache_dtype = args.quant_cache_dtype # print(f"[DEBUG pa_unit_test.py] ctx_len={ctx_len}, pa_variant={pa_variant}, quant_cache_dtype={quant_cache_dtype}") - page_size = args.page_size # Original block size is 1 + page_size = args.page_size # Original block size is 1 test_paged_attention( args.in_pt, - ctx_len, + ctx_len, args.n, - (6, 1), # num_heads: query and KV - 128, # head_size - False, # use_alibi + (6, 1), # num_heads: query and KV + 128, # head_size + False, # use_alibi page_size, - dtypes.bf16, # dtype - "auto", # kv_cache_dtype - "NHD", # kv_cache_layout - 30.0, # logits_soft_cap + dtypes.bf16, # dtype + "auto", # kv_cache_dtype + "NHD", # kv_cache_layout + 30.0, # logits_soft_cap pa_variant, quant_cache_dtype, - 0, # seed - "cuda:0", # device - args.warmup + 0, # seed + "cuda:0", # device + args.warmup, ) - -''' +""" # Even if the input length is 256, I use "context length = 2048" -# since I would like to know the performance of the kernel when +# since I would like to know the performance of the kernel when # the KV cache is longer than prefill 256 tokens. python ~/Grok_SGLang0.4.9/pa_unit_test_v2.py -n 512 -c 2048 --page-size 1 --warmup 0 @@ -498,7 +527,7 @@ def NumericCheck( --batch-size 512 --input 256 --output 2048 --tp 8 --quantization fp8 --trust-remote-code \ --model /data/huggingface/hub/amd/grok-1-W4A8KV8 \ --tokenizer-path /data/huggingface/hub/Xenova/grok-1-tokenizer \ - --attention-backend aiter + --attention-backend aiter -''' \ No newline at end of file +""" From c0e968de6f10dd97700aeffa61e4025fe38d7f91 Mon Sep 17 00:00:00 2001 From: Jacob Date: Tue, 25 Nov 2025 03:26:24 -0600 Subject: [PATCH 4/4] Revert NT_KV_LOAD to false --- csrc/cpp_itfs/pa/pa_kernels.cuh | 714 +++++++++++++++++++------------- 1 file changed, 428 insertions(+), 286 deletions(-) diff --git a/csrc/cpp_itfs/pa/pa_kernels.cuh b/csrc/cpp_itfs/pa/pa_kernels.cuh index 05f4230919..44959d0d05 100644 --- a/csrc/cpp_itfs/pa/pa_kernels.cuh +++ b/csrc/cpp_itfs/pa/pa_kernels.cuh @@ -200,7 +200,7 @@ _paged_attention_kernel(const int* block_table_seq, // set to true to enable non temporal kv loads: has some benefit in very high // batch size cases - constexpr bool NT_KV_LOAD = true; + constexpr bool NT_KV_LOAD = false; constexpr int KX = 16 / sizeof(cache_t); // vLLM defines x as 16 Bytes of kv cache elements const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; @@ -454,7 +454,7 @@ _paged_attention_kernel(const int* block_table_seq, for(int i = 0; i < 4; i++) { float tmp = d_out[gqa_ratio_loop][mtp][token_depth][i]; - if (local_token_idx + i < context_len - sliding_window) + if(local_token_idx + i < context_len - sliding_window) tmp = -FLT_MAX; d_out[gqa_ratio_loop][mtp][token_depth][i] = tmp; } @@ -1111,16 +1111,13 @@ __inline__ __device__ void _paged_attention_ll4mi_reduce_kernel( } } - - - - // ----------------------------------------------------------------------- // ----------------------------------------------------------------------- // ----------------------- Experimental ---------------------------------- // Works for: head_dim=128, cache_t=bf16 // Feature: -// 1. continuous threads work together to load K cache into LDS, then each thread save the LDS into registers. +// 1. continuous threads work together to load K cache into LDS, then each thread save the LDS into +// registers. // 2. Double buffer of K cache loading // 3. NT_KV_LOAD set to true template 16 threads load 1 token + constexpr int THREAD_PER_TOKEN = + HEAD_SIZE / + CONTIGUOUS_KV_ELEMS_16B_LOAD; // if HEAD_SIZE=128, bf16 --> 16 threads load 1 token /// NOTICE: We don't support mask for this kernel, so just use a placeholder type/object here. using Mask = ck_tile::SimplifiedGenericAttentionMask; @@ -1238,80 +1239,102 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // 4 TLOOP iteration load 64 tokens. Then, let each wavefront process 16 tokens mfma const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int last_ctx_block = num_context_blocks - 1; - // const int last_ctx_page = PAGE_SIZE * ((num_context_blocks - 1) / PAGE_SIZE); // PAGE_SIZE * floor((num_context_blocks - 1) / PAGE_SIZE) + // const int last_ctx_page = PAGE_SIZE * ((num_context_blocks - 1) / PAGE_SIZE); // + // PAGE_SIZE * floor((num_context_blocks - 1) / PAGE_SIZE) int kphysical_block_number[TLOOP][ITERS_16TK]; int kphysical_offset[TLOOP][ITERS_16TK]; // fetch k physical block numbers - // Jacob: loading order--> Token [0~16, 64~80, 128~144, 192~208], [16~32, 80~96, 144~160, 208~224]... + // Jacob: loading order--> Token [0~16, 64~80, 128~144, 192~208], [16~32, 80~96, 144~160, + // 208~224]... for(int token_depth = 0; token_depth < TLOOP; token_depth++) // 4 { for(int iter_16tk = 0; iter_16tk < ITERS_16TK; iter_16tk++) // 4 { // Jacob: block_table_seq has been shifted based on the index of the sequnece // Jacob: partition_start_token_idx has been set based on the blockIdx.y - const int warp_token_offset = iter_16tk * 64 + token_depth * 16 + warpid * TOKEN_PER_WARP_FETCH; + const int warp_token_offset = + iter_16tk * 64 + token_depth * 16 + warpid * TOKEN_PER_WARP_FETCH; const int thread_token_offset = rowid; - const int kglobal_token_idx = partition_start_token_idx + warp_token_offset + thread_token_offset; + const int kglobal_token_idx = + partition_start_token_idx + warp_token_offset + thread_token_offset; const int kblock_idx = (kglobal_token_idx < context_len) ? kglobal_token_idx / BLOCK_SIZE : last_ctx_block; const int kblock_offset = // % BLOCK_SIZE --> & (BLOCK_SIZE - 1) kglobal_token_idx & (BLOCK_SIZE - 1); kphysical_block_number[token_depth][iter_16tk] = block_table_seq[kblock_idx]; - kphysical_offset[token_depth][iter_16tk] = kblock_offset; + kphysical_offset[token_depth][iter_16tk] = kblock_offset; - // if (threadIdx.x %16==0 && /*token_depth==0 &&*/ iter_16tk==0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // if (threadIdx.x %16==0 && /*token_depth==0 &&*/ iter_16tk==0 && blockIdx.x == 0 && + // blockIdx.y == 0 && blockIdx.z == 0) { // printf("[HIP] [K block] threadIdx=%3d, token_depth=%d, iter_16tk=%d, " - // "[kblock_idx=%3d, kpage_idx=%3d, kpage_offset=%2d], last_ctx_block=%3d, last_ctx_page=%3d, " - // "block idx=%3d, kblock_offset=%3d, BLOCK_SIZE=%d, PAGE_SIZE=%d\n", + // "[kblock_idx=%3d, kpage_idx=%3d, kpage_offset=%2d], last_ctx_block=%3d, + // last_ctx_page=%3d, " "block idx=%3d, kblock_offset=%3d, BLOCK_SIZE=%d, + // PAGE_SIZE=%d\n", // threadIdx.x, token_depth, iter_16tk, // kblock_idx, kpage_idx, kpage_offset, last_ctx_block, last_ctx_page, // block_table_seq[kblock_idx], kblock_offset, BLOCK_SIZE, PAGE_SIZE); // } } - __builtin_amdgcn_sched_group_barrier(0x0020, ITERS_16TK, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x0020, ITERS_16TK, 0); // VMEM read } - // fetch Q in shared across warps and then write to registers - const int warp_mtp_idx = warpid / (4 / MTP_PARALLEL_THREADS); // Jacob: MTP_PARALLEL_THREADS=1, warpid=0, warp_mtp_idx=0 + const int warp_mtp_idx = + warpid / + (4 / MTP_PARALLEL_THREADS); // Jacob: MTP_PARALLEL_THREADS=1, warpid=0, warp_mtp_idx=0 const int warp_row_idx = warpid % (4 / MTP_PARALLEL_THREADS); // Jacob:warp_row_idx = 0,1,2,3 - const int local_qhead_idx = 4 * warpid + rowid; // Jacob: rowid=laneid / 16 = 0,1,2,3 + const int local_qhead_idx = 4 * warpid + rowid; // Jacob: rowid=laneid / 16 = 0,1,2,3 const int local_mtp_qhead_idx = 4 * warp_row_idx + rowid; // Jacob: local_mtp_qhead_idx= 0~15 - const int global_qhead_idx = wg_start_head_idx + local_mtp_qhead_idx; // Jacob: wg_start_head_idx=0, global_qhead_idx=0~15 - const int64_t query_start_off = static_cast(query_loc + warp_mtp_idx); // Jacob: query_loc=sequence idx + const int global_qhead_idx = + wg_start_head_idx + + local_mtp_qhead_idx; // Jacob: wg_start_head_idx=0, global_qhead_idx=0~15 + const int64_t query_start_off = + static_cast(query_loc + warp_mtp_idx); // Jacob: query_loc=sequence idx constexpr int mtp_loop = MTP_PER_THREAD; // Jacob: q_stride = GQA_RATIO * HEAD_SIZE = 6*128=768 // each thread local 8 data, so 16 threads can load the full 128 HEAD_SIZE of Q // As GQA_RATIO=6, we need 96 threads to load 6x128 q data // Btw, a block with 256 threads can load 16x128 q data - for(int mtp = 0; mtp < mtp_loop; mtp++) { // 1 - for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { // 1 + for(int mtp = 0; mtp < mtp_loop; mtp++) + { // 1 + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { // 1 const scalar_t* q_ptr = - q + (query_start_off + mtp * MTP_PARALLEL_THREADS) * q_stride + (global_qhead_idx + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * HEAD_SIZE; + q + (query_start_off + mtp * MTP_PARALLEL_THREADS) * q_stride + + (global_qhead_idx + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * HEAD_SIZE; - for(int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) { // 1 - const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B + head_loop * HEAD_SIZE_PER_LOOP; - if ((local_mtp_qhead_idx < GQA_RATIO_MTP_PARALLEL) && (qhead_element < HEAD_SIZE)) { - const scalar_t* q_fetch_ptr = q_ptr + qhead_element; - const _B16x8* q_fetch_ptr_16B = - reinterpret_cast(q_fetch_ptr); - _B16x8 tmp = *q_fetch_ptr_16B; + for(int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) + { // 1 + const int qhead_element = + lane16id * CONTIGUOUS_SCALAR_ELEMS_16B + head_loop * HEAD_SIZE_PER_LOOP; + if((local_mtp_qhead_idx < GQA_RATIO_MTP_PARALLEL) && (qhead_element < HEAD_SIZE)) + { + const scalar_t* q_fetch_ptr = q_ptr + qhead_element; + const _B16x8* q_fetch_ptr_16B = reinterpret_cast(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; - if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + if constexpr(KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) + { const int offset1 = lane16id / - 4; // 16 contiguous chunks of head elems are spread across 4x4lanes - shared_logits[gqa_ratio_loop][head_loop][mtp][offset1][lane4id][local_qhead_idx][0] = tmp.xy[0]; - shared_logits[gqa_ratio_loop][head_loop][mtp][offset1][lane4id][local_qhead_idx][1] = tmp.xy[1]; - } else { - for (int i = 0; i < 2; i++) { - const int head_elem = lane16id * 2 + i; // element id in _B16x4 terms - const int offset3 = head_elem % 4; - const int offset2 = (head_elem / 4) % 4; - const int offset1 = head_elem / 4 / 4; - shared_logits[gqa_ratio_loop][head_loop][mtp][offset1][offset2][local_qhead_idx][offset3] = tmp.xy[i]; + 4; // 16 contiguous chunks of head elems are spread across 4x4lanes + shared_logits[gqa_ratio_loop][head_loop][mtp][offset1][lane4id] + [local_qhead_idx][0] = tmp.xy[0]; + shared_logits[gqa_ratio_loop][head_loop][mtp][offset1][lane4id] + [local_qhead_idx][1] = tmp.xy[1]; + } + else + { + for(int i = 0; i < 2; i++) + { + const int head_elem = lane16id * 2 + i; // element id in _B16x4 terms + const int offset3 = head_elem % 4; + const int offset2 = (head_elem / 4) % 4; + const int offset1 = head_elem / 4 / 4; + shared_logits[gqa_ratio_loop][head_loop][mtp][offset1][offset2] + [local_qhead_idx][offset3] = tmp.xy[i]; } } } @@ -1320,43 +1343,47 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( } __syncthreads(); - - // qk mfma constexpr bool NT_KV_LOAD = true; constexpr int KX = 16 / sizeof(cache_t); // vLLM defines x as 16 Bytes of kv cache elements const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; int curr = 0, next = 1; - __shared__ cache_t Kbuffer_lds[2][HEAD_LOOP][64][HEAD_SIZE]; // curr and next K buffer, each load 64 tokens K cache - // Each warp processes 16x128 and it was divided into 4 mfma 16x(4x32), each thread records 4x CONTIGUOUS_KV_ELEMS_16B_LOAD elems + __shared__ cache_t Kbuffer_lds[2][HEAD_LOOP][64][HEAD_SIZE]; // curr and next K buffer, each + // load 64 tokens K cache + // Each warp processes 16x128 and it was divided into 4 mfma 16x(4x32), each thread records 4x + // CONTIGUOUS_KV_ELEMS_16B_LOAD elems _B16x8 Kbuffer_reg[2][HEAD_LOOP][QKHELOOP]; // qk mfma - define lambda: loading K cache - int DEBUG_TOKEN_DEPTH=0; - constexpr int n_global_load_per_fragment = HEAD_LOOP*ITERS_16TK; - auto load_K_fragment = [&] __device__ ( // Suppose BLOCK_SIZE=1 - const cache_t* k_ptr, int buf_idx, int kpage[ITERS_16TK], int koffset[ITERS_16TK]) { + int DEBUG_TOKEN_DEPTH = 0; + constexpr int n_global_load_per_fragment = HEAD_LOOP * ITERS_16TK; + auto load_K_fragment = [&] __device__( // Suppose BLOCK_SIZE=1 + const cache_t* k_ptr, + int buf_idx, + int kpage[ITERS_16TK], + int koffset[ITERS_16TK]) { // return; // Debug: does loading K cache take time? - for (int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) { - for(int iter_16tk = 0; iter_16tk < ITERS_16TK; iter_16tk++){ // 4 + for(int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) + { + for(int iter_16tk = 0; iter_16tk < ITERS_16TK; iter_16tk++) + { // 4 const int64_t kpage_number = static_cast(kpage[iter_16tk]); const int64_t kpage_offset = static_cast(koffset[iter_16tk]); - const int offset = - kpage_number * kv_block_stride + - kpage_offset * HEAD_SIZE + - lane16id * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const int offset = kpage_number * kv_block_stride + kpage_offset * HEAD_SIZE + + lane16id * CONTIGUOUS_KV_ELEMS_16B_LOAD; const _B16x8* k_ptr_B16x8 = reinterpret_cast(k_ptr + offset); // Save to LDS - const int token_row = iter_16tk * 16 + threadIdx.x / 16; + const int token_row = iter_16tk * 16 + threadIdx.x / 16; const int cache_offset = lane16id * CONTIGUOUS_KV_ELEMS_16B_LOAD; - if constexpr (NT_KV_LOAD) - *reinterpret_cast<_B16x8*>(&Kbuffer_lds[buf_idx][head_loop][token_row][cache_offset]) = + if constexpr(NT_KV_LOAD) + *reinterpret_cast<_B16x8*>( + &Kbuffer_lds[buf_idx][head_loop][token_row][cache_offset]) = load_ntmprl_16Byte(k_ptr_B16x8); else - *reinterpret_cast<_B16x8*>(&Kbuffer_lds[buf_idx][head_loop][token_row][cache_offset]) = - *k_ptr_B16x8; + *reinterpret_cast<_B16x8*>( + &Kbuffer_lds[buf_idx][head_loop][token_row][cache_offset]) = *k_ptr_B16x8; // if (iter_16tk<2 && DEBUG_TOKEN_DEPTH<2 && threadIdx.x < 128 && // blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { @@ -1378,7 +1405,8 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // offset, __bfloat162float(*(k_ptr + offset))); // DEBUG_TOKEN_DEPTH += 1; // } - // if (token_row == 0 && cache_offset == 64 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // if (token_row == 0 && cache_offset == 64 && blockIdx.x == 0 && blockIdx.y == 0 && + // blockIdx.z == 0) { // printf("[LDS K cache] threadIdx=%3d, iter_16tk=%d, kblock_number=%3ld, " // "kblock_offset=%3ld, LDS token_row=%3d, LDS cache_offset=%3d, " // "addr offset=%d, val=%f \n", @@ -1395,20 +1423,22 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // Wavefront 2 load lds[16:32][HEAD_SIZE] // Wavefront 3 load lds[32:48][HEAD_SIZE] // Wavefront 4 load lds[48:64][HEAD_SIZE] - for (int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) { - for(int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++){ // 4 - const int row_warp_offset = warpid * 16; + for(int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) + { + for(int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) + { // 4 + const int row_warp_offset = warpid * 16; const int row_thread_offset = lane16id; - const int row = row_warp_offset + row_thread_offset; + const int row = row_warp_offset + row_thread_offset; - const int col_16tk_offset = qkhe_depth * 32; // A matrix is 16x32. 32 columns + const int col_16tk_offset = qkhe_depth * 32; // A matrix is 16x32. 32 columns const int col_thread_offset = CONTIGUOUS_KV_ELEMS_16B_LOAD * rowid; - const int col = col_16tk_offset + col_thread_offset; + const int col = col_16tk_offset + col_thread_offset; Kbuffer_reg[buf_idx][head_loop][qkhe_depth] = *reinterpret_cast<_B16x8*>(&Kbuffer_lds[buf_idx][head_loop][row][col]); - // _B16x8{{ - // {static_cast(1),1,1,1},{1,1,1,1} - // }}; + // _B16x8{{ + // {static_cast(1),1,1,1},{1,1,1,1} + // }}; // Check NAN // for(int x=0; x<2; ++x) @@ -1428,15 +1458,18 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // qk mfma - Preload the kphysical_block_number[0] cache load_K_fragment(k_ptr, curr, kphysical_block_number[0], kphysical_offset[0]); - __builtin_amdgcn_sched_group_barrier(0x0020, n_global_load_per_fragment, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x0020, n_global_load_per_fragment, 0); // VMEM read // qk mfma - Setup alibi_slope float alibi_slope[GQA_RATIO_LOOP]; if constexpr(ALIBI_ENABLED) { - for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { - const int alibi_head_idx = wg_start_head_idx + lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP; - alibi_slope[gqa_ratio_loop] = (lane16id < GQA_RATIO_PER_LOOP) ? alibi_slopes[alibi_head_idx] : 0.f; + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { + const int alibi_head_idx = + wg_start_head_idx + lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP; + alibi_slope[gqa_ratio_loop] = + (lane16id < GQA_RATIO_PER_LOOP) ? alibi_slopes[alibi_head_idx] : 0.f; } } @@ -1462,79 +1495,107 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // qk mfma - load K cache[iter+1] + mfma[iter] floatx4 d_out[GQA_RATIO_LOOP][mtp_loop][TLOOP]; - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { // 4 + for(int token_depth = 0; token_depth < TLOOP; token_depth++) + { // 4 // Preload the next K cache - if (token_depth + 1 < TLOOP){ - load_K_fragment(k_ptr, next, kphysical_block_number[token_depth+1], kphysical_offset[token_depth+1]); - __builtin_amdgcn_sched_group_barrier(0x0020, n_global_load_per_fragment, 0); // VMEM read + if(token_depth + 1 < TLOOP) + { + load_K_fragment(k_ptr, + next, + kphysical_block_number[token_depth + 1], + kphysical_offset[token_depth + 1]); + __builtin_amdgcn_sched_group_barrier( + 0x0020, n_global_load_per_fragment, 0); // VMEM read } - for (int mtp = 0; mtp < mtp_loop; mtp++) { // 1 - for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { // 1 + for(int mtp = 0; mtp < mtp_loop; mtp++) + { // 1 + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { // 1 d_out[gqa_ratio_loop][mtp][token_depth] = {0}; - for (int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) { // 1 - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { // 4 - for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for(int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) + { // 1 + for(int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) + { // 4 + for(int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) + { // Load Q from LDS - for (int i = 0; i < 2; i++) + for(int i = 0; i < 2; i++) Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth][qkratio].xy[i] = shared_logits[gqa_ratio_loop][head_loop][mtp][qkhe_depth][rowid] - [lane16id % GQA_RATIO_MTP_PARALLEL][2 * qkratio + i]; + [lane16id % GQA_RATIO_MTP_PARALLEL] + [2 * qkratio + i]; __builtin_amdgcn_sched_group_barrier(0x0100, 2, 0); // LDS read // mfma - if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { - #if defined(__gfx950__) - d_out[gqa_ratio_loop][mtp][token_depth] = gcn_mfma16x16x32_instr( - Kbuffer_reg[curr][head_loop][qkhe_depth], - Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth][qkratio], - d_out[gqa_ratio_loop][mtp][token_depth]); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - #else - for (int i = 0; i < 2; i++) { - d_out[gqa_ratio_loop][mtp][token_depth] = gcn_mfma16x16x16_instr( - Kbuffer_reg[curr][head_loop][qkhe_depth].xy[i], - Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth][qkratio].xy[i], + if constexpr(KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) + { +#if defined(__gfx950__) + d_out[gqa_ratio_loop][mtp][token_depth] = + gcn_mfma16x16x32_instr( + Kbuffer_reg[curr][head_loop][qkhe_depth], + Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth][qkratio], d_out[gqa_ratio_loop][mtp][token_depth]); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA +#else + for(int i = 0; i < 2; i++) + { + d_out[gqa_ratio_loop][mtp][token_depth] = + gcn_mfma16x16x16_instr( + Kbuffer_reg[curr][head_loop][qkhe_depth].xy[i], + Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth] + [qkratio] + .xy[i], + d_out[gqa_ratio_loop][mtp][token_depth]); } - __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA - #endif + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA +#endif // Check value // for(int d=0; d<4; ++d) // if(isnan(d_out[gqa_ratio_loop][mtp][token_depth][d])){ // for(int x=0; x<2; ++x) // for(int y=0; y<4; ++y){ - // _B16x4 kdata = Kbuffer_reg[curr][head_loop][qkhe_depth].xy[0]; - // printf("qk_mfma is nan. Kbuffer_reg=%hu %hu %hu %hu\n", + // _B16x4 kdata = + // Kbuffer_reg[curr][head_loop][qkhe_depth].xy[0]; + // printf("qk_mfma is nan. Kbuffer_reg=%hu %hu %hu + // %hu\n", // kdata[0], kdata[1], kdata[2], kdata[3]); // } // break; // } - } - else { // kv cache dtype fp8 - auto Ktmp = Kbuffer_reg[curr][head_loop][qkhe_depth]; + else + { // kv cache dtype fp8 + auto Ktmp = Kbuffer_reg[curr][head_loop][qkhe_depth]; _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); - for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for(int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) + { // Load Q from LDS - _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; + _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); - #if defined(__gfx950__) - d_out[gqa_ratio_loop][mtp][token_depth] = gcn_mfma16x16x32_instr( - Klocaltmp, - Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth][qkratio], - d_out[gqa_ratio_loop][mtp][token_depth]); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - #else - for (int i = 0; i < 2; i++) { - d_out[gqa_ratio_loop][mtp][token_depth] = gcn_mfma16x16x16_instr( - Klocaltmp.xy[i], Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth][qkratio].xy[i], +#if defined(__gfx950__) + d_out[gqa_ratio_loop][mtp][token_depth] = + gcn_mfma16x16x32_instr( + Klocaltmp, + Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth] + [qkratio], d_out[gqa_ratio_loop][mtp][token_depth]); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA +#else + for(int i = 0; i < 2; i++) + { + d_out[gqa_ratio_loop][mtp][token_depth] = + gcn_mfma16x16x16_instr( + Klocaltmp.xy[i], + Qlocal[gqa_ratio_loop][head_loop][mtp][qkhe_depth] + [qkratio] + .xy[i], + d_out[gqa_ratio_loop][mtp][token_depth]); } - __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA - #endif + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA +#endif } } } @@ -1542,7 +1603,8 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( } // DEBUG: check values - // if (threadIdx.x==1 && token_depth<4 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // if (threadIdx.x==1 && token_depth<4 && blockIdx.x == 0 && blockIdx.y == 0 && + // blockIdx.z == 0) { // floatx4 data = d_out[gqa_ratio_loop][mtp][token_depth]; // printf("[check d_out] threadIdx.x=%d, BLOCK_SIZE=%d, d_out=%f,%f %f %f \n", // threadIdx.x, BLOCK_SIZE ,data[0], data[1], data[2], data[3]); @@ -1571,27 +1633,34 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // } // } - for (int i = 0; i < 4; i++) + for(int i = 0; i < 4; i++) { - d_out[gqa_ratio_loop][mtp][token_depth][i] = variant->QueryTransform(variant_params, d_out[gqa_ratio_loop][mtp][token_depth][i]); + d_out[gqa_ratio_loop][mtp][token_depth][i] = variant->QueryTransform( + variant_params, d_out[gqa_ratio_loop][mtp][token_depth][i]); } } } int tmp = curr; - curr = next; - next = tmp; + curr = next; + next = tmp; } const int qkout_token_idx = partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; // apply alibi - if constexpr (ALIBI_ENABLED) { - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + if constexpr(ALIBI_ENABLED) + { + for(int token_depth = 0; token_depth < TLOOP; token_depth++) + { const int local_token_idx = qkout_token_idx + token_depth * 16; - const int alibi_offset = local_token_idx - context_len + 1; - for (int mtp = 0; mtp < mtp_loop; mtp++) { - for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { - for (int i = 0; i < 4; i++) { - d_out[gqa_ratio_loop][mtp][token_depth][i] += alibi_slope[gqa_ratio_loop] * (alibi_offset + i); + const int alibi_offset = local_token_idx - context_len + 1; + for(int mtp = 0; mtp < mtp_loop; mtp++) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { + for(int i = 0; i < 4; i++) + { + d_out[gqa_ratio_loop][mtp][token_depth][i] += + alibi_slope[gqa_ratio_loop] * (alibi_offset + i); } } } @@ -1610,7 +1679,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( for(int i = 0; i < 4; i++) { float tmp = d_out[gqa_ratio_loop][mtp][token_depth][i]; - if (local_token_idx + i < context_len - sliding_window) + if(local_token_idx + i < context_len - sliding_window) tmp = -FLT_MAX; d_out[gqa_ratio_loop][mtp][token_depth][i] = tmp; } @@ -1619,19 +1688,22 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( } } // apply soft-capping to logits - for (int token_depth = 0; token_depth < TLOOP; token_depth++) + for(int token_depth = 0; token_depth < TLOOP; token_depth++) { - for (int mtp = 0; mtp < mtp_loop; mtp++) { - for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { - for (int i = 0; i < 4; i++) { - d_out[gqa_ratio_loop][mtp][token_depth][i] = - variant->LogitsTransform(variant_params, - d_out[gqa_ratio_loop][mtp][token_depth][i], - /*batch_idx=*/query_start_off + mtp * MTP_PARALLEL_THREADS, - /*qo_head_idx=*/wg_start_head_idx + lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP, - /*kv_head_idx=*/kv_head_idx); + for(int mtp = 0; mtp < mtp_loop; mtp++) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { + for(int i = 0; i < 4; i++) + { + d_out[gqa_ratio_loop][mtp][token_depth][i] = variant->LogitsTransform( + variant_params, + d_out[gqa_ratio_loop][mtp][token_depth][i], + /*batch_idx=*/query_start_off + mtp * MTP_PARALLEL_THREADS, + /*qo_head_idx=*/wg_start_head_idx + lane16id + + gqa_ratio_loop * GQA_RATIO_PER_LOOP, + /*kv_head_idx=*/kv_head_idx); } - } } } @@ -1640,25 +1712,31 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // if (threadIdx.x%16==0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // for (int token_depth = 0; token_depth < TLOOP; token_depth++){ // floatx4 data = d_out[0][0][token_depth]; - // printf("[Check d_out + soft-capping] warpid=%d, token_depth=%d, threadIdx.x=%3d, d_out[0][0][%d]=%f %f %f %f\n", - // warpid, token_depth, threadIdx.x, token_depth, data[0], data[1], data[2], data[3]); + // printf("[Check d_out + soft-capping] warpid=%d, token_depth=%d, threadIdx.x=%3d, + // d_out[0][0][%d]=%f %f %f %f\n", + // warpid, token_depth, threadIdx.x, token_depth, data[0], data[1], data[2], + // data[3]); // __syncthreads(); // } // } // calculate qk_max and exp_sum per warp and write to shared memory - float qk_max[GQA_RATIO_LOOP][MTP_PER_THREAD] = {-FLT_MAX}; + float qk_max[GQA_RATIO_LOOP][MTP_PER_THREAD] = {-FLT_MAX}; float exp_sum[GQA_RATIO_LOOP][MTP_PER_THREAD] = {0.0f}; - for (int mtp = 0; mtp < mtp_loop; mtp++) { - for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + for(int mtp = 0; mtp < mtp_loop; mtp++) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { // Step 1.1 Get max qk per thread: - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + for(int token_depth = 0; token_depth < TLOOP; token_depth++) + { const int local_token_idx = qkout_token_idx + token_depth * 16; - for(int i=0; i<4; ++i){ + for(int i = 0; i < 4; ++i) + { const float tmp = ((local_token_idx + i) < context_len * (warp_mtp_idx + 1)) - ? d_out[gqa_ratio_loop][mtp][token_depth][i] - : -FLT_MAX; + ? d_out[gqa_ratio_loop][mtp][token_depth][i] + : -FLT_MAX; qk_max[gqa_ratio_loop][mtp] = fmaxf(qk_max[gqa_ratio_loop][mtp], tmp); } } @@ -1676,49 +1754,61 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // Thread [1, 17, 33, 49] stores 1 column, 16 elements of mfma(K@Q.T). // Use the following loop can get the max(thread1, thread17, thread33, thread49) // "mask >= 16" summed to 16 threads as 1 GQA_RATIO_LOOP process 16 q heads - for (int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) { - qk_max[gqa_ratio_loop][mtp] = fmaxf(qk_max[gqa_ratio_loop][mtp], __shfl_xor(qk_max[gqa_ratio_loop][mtp], mask)); + for(int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) + { + qk_max[gqa_ratio_loop][mtp] = fmaxf(qk_max[gqa_ratio_loop][mtp], + __shfl_xor(qk_max[gqa_ratio_loop][mtp], mask)); } // Step 2.1 Calc exp(d_out-qk_max) per thread - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + for(int token_depth = 0; token_depth < TLOOP; token_depth++) + { const int local_token_idx = qkout_token_idx + token_depth * 16; - for (int i = 0; i < 4; i++) { + for(int i = 0; i < 4; i++) + { const float tmp = ((local_token_idx + i) < context_len * (warp_mtp_idx + 1)) - ? __expf(d_out[gqa_ratio_loop][mtp][token_depth][i] - qk_max[gqa_ratio_loop][mtp]) - : 0.0f; + ? __expf(d_out[gqa_ratio_loop][mtp][token_depth][i] - + qk_max[gqa_ratio_loop][mtp]) + : 0.0f; d_out[gqa_ratio_loop][mtp][token_depth][i] = tmp; exp_sum[gqa_ratio_loop][mtp] += tmp; // if(isnan(tmp)) // printf("exp(d_out-qk_max) is nan. d_out=%f, qk_max=%f\n", - // d_out[gqa_ratio_loop][mtp][token_depth][i], qk_max[gqa_ratio_loop][mtp]); + // d_out[gqa_ratio_loop][mtp][token_depth][i], + // qk_max[gqa_ratio_loop][mtp]); } } // Step 2.2 Sum up exp per wavefronts - for (int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) { + for(int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) + { exp_sum[gqa_ratio_loop][mtp] += __shfl_xor(exp_sum[gqa_ratio_loop][mtp], mask); } } } - // __syncthreads(); // sync before writing to shared mem // Why need sync here? no LDS ops before this line + // __syncthreads(); // sync before writing to shared mem // Why need sync here? no LDS ops + // before this line // Step 3. Save qk_max and exp_sum for the entire workgroup float* shared_mem = reinterpret_cast(shared_logits); - if (laneid < 16) { - for(int mtp = 0; mtp < mtp_loop; mtp++) { - for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + if(laneid < 16) + { + for(int mtp = 0; mtp < mtp_loop; mtp++) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { const int qk_max_offset = warpid * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + - (lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * MTP_PER_THREAD + - mtp; + (lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * MTP_PER_THREAD + mtp; shared_mem[qk_max_offset] = qk_max[gqa_ratio_loop][mtp]; const int exp_sum_offset = NWARPS * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + qk_max_offset; shared_mem[exp_sum_offset] = exp_sum[gqa_ratio_loop][mtp]; - // if (threadIdx.x < 256 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { - // printf("[qk_max + exp_sum] threadIdx=%3d, qk_max_offset=%3d, exp_sum_offset=%3d, " + // if (threadIdx.x < 256 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + // { + // printf("[qk_max + exp_sum] threadIdx=%3d, qk_max_offset=%3d, + // exp_sum_offset=%3d, " // "shared_mem[qk_max_offset]=%f, shared_mem[exp_sum_offset]=%f\n", // threadIdx.x, qk_max_offset, exp_sum_offset, // shared_mem[qk_max_offset], shared_mem[exp_sum_offset]); @@ -1731,56 +1821,77 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // Seg 6.2 // Get qk_max across wavefronts // calculate partition qk_max and exp_sum - float inv_sum_scale[GQA_RATIO_LOOP][MTP_PER_THREAD] = {0.0f}; - float partition_qk_max[GQA_RATIO_LOOP][MTP_PER_THREAD] = {-FLT_MAX}; + float inv_sum_scale[GQA_RATIO_LOOP][MTP_PER_THREAD] = {0.0f}; + float partition_qk_max[GQA_RATIO_LOOP][MTP_PER_THREAD] = {-FLT_MAX}; float partition_exp_sum[GQA_RATIO_LOOP][MTP_PER_THREAD] = {0.0f}; - for(int mtp = 0; mtp < mtp_loop; mtp++) { - for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + for(int mtp = 0; mtp < mtp_loop; mtp++) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { float warp_qk_max_exp[NWARPS]; - for (int w = 0; w < NWARPS; w++) { - warp_qk_max_exp[w] = shared_mem[w * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + (lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * MTP_PER_THREAD + mtp]; - partition_qk_max[gqa_ratio_loop][mtp] = fmaxf(partition_qk_max[gqa_ratio_loop][mtp], warp_qk_max_exp[w]); + for(int w = 0; w < NWARPS; w++) + { + warp_qk_max_exp[w] = + shared_mem[w * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + + (lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * MTP_PER_THREAD + + mtp]; + partition_qk_max[gqa_ratio_loop][mtp] = + fmaxf(partition_qk_max[gqa_ratio_loop][mtp], warp_qk_max_exp[w]); } - for (int w = 0; w < NWARPS; w++) { - warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max[gqa_ratio_loop][mtp]); + for(int w = 0; w < NWARPS; w++) + { + warp_qk_max_exp[w] = + __expf(warp_qk_max_exp[w] - partition_qk_max[gqa_ratio_loop][mtp]); partition_exp_sum[gqa_ratio_loop][mtp] += - shared_mem[NWARPS * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + w * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + (lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * MTP_PER_THREAD + mtp] * warp_qk_max_exp[w]; + shared_mem[NWARPS * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + + w * 16 * GQA_RATIO_LOOP * MTP_PER_THREAD + + (lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP) * MTP_PER_THREAD + + mtp] * + warp_qk_max_exp[w]; } inv_sum_scale[gqa_ratio_loop][mtp] = - __fdividef(1.f, partition_exp_sum[gqa_ratio_loop][mtp] + 1e-6f) * warp_qk_max_exp[warpid]; + __fdividef(1.f, partition_exp_sum[gqa_ratio_loop][mtp] + 1e-6f) * + warp_qk_max_exp[warpid]; // if (threadIdx.x < 256 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("threadIdx=%3d, warp_qk_max_exp=%f %f %f %f, partition_qk_max[%d][%d]=%f " // "partition_exp_sum[%d][%d]=%f, " // "inv_sum_scale[%d][%d]=%f\n", // threadIdx.x, - // warp_qk_max_exp[0], warp_qk_max_exp[1], warp_qk_max_exp[2], warp_qk_max_exp[3], gqa_ratio_loop, mtp, partition_qk_max[gqa_ratio_loop][mtp], - // gqa_ratio_loop, mtp, partition_exp_sum[gqa_ratio_loop][mtp], - // gqa_ratio_loop, mtp, inv_sum_scale[gqa_ratio_loop][mtp]); + // warp_qk_max_exp[0], warp_qk_max_exp[1], warp_qk_max_exp[2], + // warp_qk_max_exp[3], gqa_ratio_loop, mtp, + // partition_qk_max[gqa_ratio_loop][mtp], gqa_ratio_loop, mtp, + // partition_exp_sum[gqa_ratio_loop][mtp], gqa_ratio_loop, mtp, + // inv_sum_scale[gqa_ratio_loop][mtp]); // } } } __syncthreads(); // Why need sync here? no LDS ops before this line - // disable rtz conversion due to its impact on accuracy. constexpr bool LOGITS_RTZ_CONVERSION = false; // write logits to shared mem - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - for (int mtp = 0; mtp < mtp_loop; mtp++) { - for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + for(int token_depth = 0; token_depth < TLOOP; token_depth++) + { + for(int mtp = 0; mtp < mtp_loop; mtp++) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { d_out[gqa_ratio_loop][mtp][token_depth] *= inv_sum_scale[gqa_ratio_loop][mtp]; - if constexpr (LOGITS_RTZ_CONVERSION) { + if constexpr(LOGITS_RTZ_CONVERSION) + { // use rtz conversion for better performance, with negligible impact on // accuracy shared_logits[gqa_ratio_loop][0][mtp][warpid][token_depth][lane16id][rowid] = from_floatx4_rtz(d_out[gqa_ratio_loop][mtp][token_depth]); - } else { + } + else + { shared_logits[gqa_ratio_loop][0][mtp][warpid][token_depth][lane16id][rowid] = from_floatx4(d_out[gqa_ratio_loop][mtp][token_depth]); } @@ -1788,26 +1899,30 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( } } - // DEBUG: Get qk_max across blocks // write out partition max_logits and exp_sum - if (threadIdx.x < GQA_RATIO_MTP_PARALLEL) { - for(int mtp = 0; mtp < mtp_loop; mtp++) { - for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { + if(threadIdx.x < GQA_RATIO_MTP_PARALLEL) + { + for(int mtp = 0; mtp < mtp_loop; mtp++) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { const int qhead_idx = lane16id + gqa_ratio_loop * GQA_RATIO_PER_LOOP; - const int64_t offset = static_cast(seq_idx + mtp * MTP_PARALLEL_THREADS) * - static_cast(total_num_heads) * - static_cast(max_num_partitions) + - (static_cast(wg_start_head_idx) + - static_cast(qhead_idx)) * - static_cast(max_num_partitions) + - static_cast(partition_idx); + const int64_t offset = + static_cast(seq_idx + mtp * MTP_PARALLEL_THREADS) * + static_cast(total_num_heads) * + static_cast(max_num_partitions) + + (static_cast(wg_start_head_idx) + static_cast(qhead_idx)) * + static_cast(max_num_partitions) + + static_cast(partition_idx); max_logits[offset] = partition_qk_max[gqa_ratio_loop][mtp]; - exp_sums[offset] = partition_exp_sum[gqa_ratio_loop][mtp]; + exp_sums[offset] = partition_exp_sum[gqa_ratio_loop][mtp]; // if (threadIdx.x < 64 && blockIdx.x == 0 && blockIdx.y == 7 && blockIdx.z == 0) { - // printf("threadIdx=%3d, blockIdx.y=%d, max_logits[%ld]=%f, exp_sums[%ld]=%f \n", - // threadIdx.x, blockIdx.y, offset, max_logits[offset], offset, exp_sums[offset]); + // printf("threadIdx=%3d, blockIdx.y=%d, max_logits[%ld]=%f, exp_sums[%ld]=%f + // \n", + // threadIdx.x, blockIdx.y, offset, max_logits[offset], offset, + // exp_sums[offset]); // } } } @@ -1876,9 +1991,11 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( const _B16x8* v_fetch_ptr_16B = reinterpret_cast(v_fetch_ptr); if constexpr(NT_KV_LOAD) { - Vlocal[vtoken_depth][vhe_depth][vblock_depth] = load_ntmprl_16Byte(v_fetch_ptr_16B); + Vlocal[vtoken_depth][vhe_depth][vblock_depth] = + load_ntmprl_16Byte(v_fetch_ptr_16B); } - else{ + else + { Vlocal[vtoken_depth][vhe_depth][vblock_depth] = *reinterpret_cast(v_fetch_ptr); } @@ -1886,11 +2003,11 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( } } - constexpr int ELEMS8_ELEMS4_RATIO = 8 / 4; constexpr int ELEMS16_ELEMS8_RATIO = 16 / 8; - for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) + { for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { // 1. store data into LDS @@ -1900,14 +2017,15 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( const int vlocal_token_idx = vblock_depth * k_thread_per_block + threadIdx.x / n_thread_per_block; *reinterpret_cast<_B16x8*>(vlds_ptr + - (/*row=*/vlocal_token_idx * n_thread_per_block + + (/*row=*/vlocal_token_idx * n_thread_per_block + /*col=*/vlds_col_idx) * - 16) = Vlocal[vtoken_depth][vhe_depth][vblock_depth]; + 16) = Vlocal[vtoken_depth][vhe_depth][vblock_depth]; } __syncthreads(); // 2. load data from LDS (transposed), then do multification - for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++){ + for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) + { const int vlocal_head_elem = warpid * 16 + lane16id; const int vlds_col_idx = vlocal_head_elem / CONTIGUOUS_KV_ELEMS_16B_LOAD; @@ -1923,26 +2041,29 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( const cache_t* fetched_elems = reinterpret_cast( vlds_ptr + (/*row=*/(vlocal_token_idx + d2) * n_thread_per_block + /*col=*/vlds_col_idx) * - 16); + 16); elems[d2] = fetched_elems[vlds_elem_idx]; } // copy all the read data points together - Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *reinterpret_cast(elems); + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = + *reinterpret_cast(elems); } __syncthreads(); } } - _B16x4 outelems[GQA_RATIO_LOOP][MTP_PER_THREAD][VHELOOP]; // Softmax V mfma // v layout: 16he across lanes x 16 tokens per lane - for (int mtp = 0; mtp < mtp_loop; mtp++) { - for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { - for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + for(int mtp = 0; mtp < mtp_loop; mtp++) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { + for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) + { floatx4 tmp_out = {0}; for(int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) @@ -1951,34 +2072,35 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( { for(int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { - #if defined(__gfx950__) +#if defined(__gfx950__) _B16x8 tmp_in; for(int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { const int offset = rowid * VTLANELOOP * ELEMS8_ELEMS4_RATIO + - vfetch_depth * ELEMS8_ELEMS4_RATIO + i; + vfetch_depth * ELEMS8_ELEMS4_RATIO + i; const int offset1 = offset % ROWS_PER_WARP; const int offset2 = offset / ROWS_PER_WARP; - tmp_in.xy[i] = shared_logits[gqa_ratio_loop][0][mtp][vtoken_depth][offset2][lane16id][offset1]; + tmp_in.xy[i] = shared_logits[gqa_ratio_loop][0][mtp][vtoken_depth] + [offset2][lane16id][offset1]; } tmp_out = gcn_mfma16x16x32_instr( - Vlocal[vtoken_depth][vhe_depth][vfetch_depth], - tmp_in, - tmp_out); - #else - for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + Vlocal[vtoken_depth][vhe_depth][vfetch_depth], tmp_in, tmp_out); +#else + for(int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) + { const int offset = rowid * VTLANELOOP * ELEMS8_ELEMS4_RATIO + - vfetch_depth * ELEMS8_ELEMS4_RATIO + i; + vfetch_depth * ELEMS8_ELEMS4_RATIO + i; const int offset1 = offset % ROWS_PER_WARP; const int offset2 = offset / ROWS_PER_WARP; // output format is 16 qheads across 16 lanes, 16 head elems spread // across 4 rows tmp_out = gcn_mfma16x16x16_instr( Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], - shared_logits[gqa_ratio_loop][0][mtp][vtoken_depth][offset2][lane16id][offset1], + shared_logits[gqa_ratio_loop][0][mtp][vtoken_depth][offset2] + [lane16id][offset1], tmp_out); } - #endif +#endif } } else @@ -1993,7 +2115,7 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); - #if defined(__gfx950__) +#if defined(__gfx950__) _B16x8 tmp_in; for(int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { @@ -2002,14 +2124,15 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( j * ELEMS8_ELEMS4_RATIO + i; const int offset1 = offset % ROWS_PER_WARP; const int offset2 = offset / ROWS_PER_WARP; - tmp_in.xy[i] = shared_logits[gqa_ratio_loop][0][mtp][vtoken_depth][offset2][lane16id][offset1]; + tmp_in.xy[i] = + shared_logits[gqa_ratio_loop][0][mtp][vtoken_depth][offset2] + [lane16id][offset1]; } tmp_out = gcn_mfma16x16x32_instr( - Vlocaltmp, - tmp_in, - tmp_out); - #else - for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + Vlocaltmp, tmp_in, tmp_out); +#else + for(int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) + { const int offset = rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + j * ELEMS8_ELEMS4_RATIO + i; @@ -2019,10 +2142,11 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // spread across 4 rows tmp_out = gcn_mfma16x16x16_instr( Vlocaltmp.xy[i], - shared_logits[gqa_ratio_loop][0][mtp][vtoken_depth][offset2][lane16id][offset1], + shared_logits[gqa_ratio_loop][0][mtp][vtoken_depth][offset2] + [lane16id][offset1], tmp_out); } - #endif +#endif } } } @@ -2041,11 +2165,15 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( __syncthreads(); // store Softmax-V mfma output to shared mem - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + for(int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) + { // lane16 id head dimension; rowid head element dimension - for(int mtp = 0; mtp < mtp_loop; mtp++) { - for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { - shared_logits[gqa_ratio_loop][0][mtp][warpid][vhe_depth][lane16id][rowid] = outelems[gqa_ratio_loop][mtp][vhe_depth]; + for(int mtp = 0; mtp < mtp_loop; mtp++) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { + shared_logits[gqa_ratio_loop][0][mtp][warpid][vhe_depth][lane16id][rowid] = + outelems[gqa_ratio_loop][mtp][vhe_depth]; // if (threadIdx.x==0 && vhe_depth==0 && // blockIdx.x == 1 && blockIdx.y == 0 && blockIdx.z == 0) { @@ -2072,39 +2200,51 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( __syncthreads(); // write to tmp_out with coalesced writes after reading from shared mem - if (warpid == 0) { - for (int mtp = 0; mtp < mtp_loop; mtp++) { - for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) { - for(int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) { + if(warpid == 0) + { + for(int mtp = 0; mtp < mtp_loop; mtp++) + { + for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) + { + for(int head_loop = 0; head_loop < HEAD_LOOP; head_loop++) + { _B16x8 vout[GQA_RATIO4]; // each lane writes out 16Bytes of tmp_out along head elem dimension const int head_elem_idx = lane16id * 8 + head_loop * HEAD_SIZE_PER_LOOP; - if (head_elem_idx < HEAD_SIZE) { - for (int h = 0; h < GQA_RATIO4; h++) { + if(head_elem_idx < HEAD_SIZE) + { + for(int h = 0; h < GQA_RATIO4; h++) + { const int local_head_idx = 4 * h + rowid; - const int offset1 = (head_elem_idx / 16) % 4; - const int offset2 = head_elem_idx / 16 / NWARPS; - const int offset3 = (head_elem_idx / 4) % 4; - for (int i = 0; i < 2; i++) { - vout[h].xy[i] = - shared_logits[gqa_ratio_loop][0][mtp][offset1][offset2][local_head_idx][offset3 + i]; + const int offset1 = (head_elem_idx / 16) % 4; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 4) % 4; + for(int i = 0; i < 2; i++) + { + vout[h].xy[i] = shared_logits[gqa_ratio_loop][0][mtp][offset1] + [offset2][local_head_idx][offset3 + i]; } } const int64_t hsz_maxp_mult = static_cast(HEAD_SIZE * max_num_partitions); - scalar_t* out_ptr = out + (seq_idx + mtp * MTP_PARALLEL_THREADS) * total_num_heads * hsz_maxp_mult + + scalar_t* out_ptr = out + + (seq_idx + mtp * MTP_PARALLEL_THREADS) * + total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE; - for (int h = 0; h < GQA_RATIO4; h++) { + for(int h = 0; h < GQA_RATIO4; h++) + { const int local_head_idx = 4 * h + rowid; - if (local_head_idx < GQA_RATIO_MTP_PARALLEL) { + if(local_head_idx < GQA_RATIO_MTP_PARALLEL) + { const int64_t out_head_idx = - static_cast(wg_start_head_idx + local_head_idx + gqa_ratio_loop * GQA_RATIO_PER_LOOP); - scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; - scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + static_cast(wg_start_head_idx + local_head_idx + + gqa_ratio_loop * GQA_RATIO_PER_LOOP); + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); - *out_ptr_B16x8 = vout[h]; + *out_ptr_B16x8 = vout[h]; // if (threadIdx.x<64 && // blockIdx.x == 59 && blockIdx.y == 6 && blockIdx.z == 0) { @@ -2115,10 +2255,12 @@ __inline__ __device__ void _paged_attention_kernel_EXPERIMENTAL( // for(int i=0; i<2; ++i) // for(int j=0; j<4; ++j){ // v[i*4+j] = data_x8.xy[i][j]; - // b[i*4+j] = *reinterpret_cast<__hip_bfloat16*>(&v[i*4+j]); + // b[i*4+j] = + // *reinterpret_cast<__hip_bfloat16*>(&v[i*4+j]); // c[i*4+j] = __bfloat162float(b[i*4+j]); // } - // printf("[out_ptr] threadIdx.x=%3d, h(GQA_RATIO4)=%d, local_head_idx=%3d, head_elem_idx=%3d, " + // printf("[out_ptr] threadIdx.x=%3d, h(GQA_RATIO4)=%d, + // local_head_idx=%3d, head_elem_idx=%3d, " // "out=%f %f %f %f, %f %f %f %f \n", // threadIdx.x, h, local_head_idx, head_elem_idx, // c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]);