diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 64a93bdba1..dac5a1c74c 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1,8 +1,8 @@ // SPDX-FileCopyrightText: 2023-2025 FlashInfer team. // SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc. // SPDX-License-Identifier: Apache-2.0 -#ifndef FLASHINFER_PREFILL_CUH_ -#define FLASHINFER_PREFILL_CUH_ + +#pragma once #include "gpu_iface/cooperative_groups.h" #include "gpu_iface/fastdiv.cuh" @@ -1768,6 +1768,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * k_stride_n + kv_head_idx * k_stride_h + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + DTypeKV* v_ptr = v + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * v_stride_n + @@ -2057,6 +2058,11 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + [[maybe_unused]] constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + [[maybe_unused]] constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = + KTraits::NUM_ACCUM_ROWS_PER_THREAD; + [[maybe_unused]] constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = + KTraits::THREADS_PER_BMATRIX_ROW_SET; [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; DTypeQ* q = params.q; @@ -2093,6 +2099,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV const uint32_t num_kv_heads = gridDim.z, num_qo_heads = group_size * num_kv_heads; const uint32_t request_idx = request_indices[bx], qo_tile_idx = qo_tile_indices[bx], kv_tile_idx = kv_tile_indices[bx]; + extern __shared__ uint8_t smem[]; auto& smem_storage = reinterpret_cast(smem); AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); @@ -2107,10 +2114,10 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV const uint32_t qo_upper_bound = min(qo_len, ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size)); - DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; - alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; - DTypeQKAccum m[NUM_MMA_Q][2]; - float d[NUM_MMA_Q][2]; + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD]; + DTypeQKAccum m[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; + float d[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; float rope_freq[NUM_MMA_D_QK / 2][4]; if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { @@ -2122,7 +2129,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV const uint32_t qo_packed_idx_base = (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; - smem_t qo_smem(smem_storage.q_smem); + smem_t qo_smem(smem_storage.q_smem); const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; DTypeQ* q_ptr_base = @@ -2181,15 +2188,24 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV : chunk_size) / CTA_TILE_KV; - smem_t k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem); - + smem_t k_smem(smem_storage.k_smem); + smem_t v_smem(smem_storage.v_smem); +#if defined(PLATFORM_HIP_DEVICE) uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + - lane_idx % 8, - (lane_idx % 16) / 8), - v_smem_offset_r = v_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), - k_smem_offset_w = k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, (lane_idx / 16)); + + uint32_t v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16); + +#elif defined(PLATFORM_CUDA_DEVICE) + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + + lane_idx % 8, + (lane_idx % 16) / 8), + v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16); +#endif + uint32_t k_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, lane_idx % KV_THR_LAYOUT_COL), v_smem_offset_w = v_smem.template get_permuted_offset( @@ -2214,6 +2230,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV memory::commit_group(); produce_kv(v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); + memory::commit_group(); #pragma unroll 1 @@ -2293,9 +2310,13 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { uint32_t q, r; - group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16, q, r); + group_size.divmod( + qo_packed_idx_base + + (lane_idx / THREADS_PER_BMATRIX_ROW_SET) * NUM_ACCUM_ROWS_PER_THREAD + j + + mma_q * 16, + q, r); const uint32_t qo_head_idx = kv_head_idx * group_size + r; const uint32_t qo_idx = q; if (qo_idx < qo_len) { @@ -2824,5 +2845,3 @@ gpuError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params } } // namespace flashinfer - -#endif // FLASHINFER_PREFILL_CUH_ diff --git a/libflashinfer/tests/hip/test_batch_prefill.cpp b/libflashinfer/tests/hip/test_batch_prefill.cpp index 0f195b5184..e20d437f18 100644 --- a/libflashinfer/tests/hip/test_batch_prefill.cpp +++ b/libflashinfer/tests/hip/test_batch_prefill.cpp @@ -1,3 +1,4 @@ +// SPDX-FileCopyrightText: 2023-2025 Flashinfer team // SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc. // // SPDX-License-Identifier: Apache 2.0