Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions libflashinfer/include/flashinfer/attention/generic/prefill.cuh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// SPDX-FileCopyrightText: 2023-2025 FlashInfer team.
// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc.
// SPDX-License-Identifier: Apache-2.0
#ifndef FLASHINFER_PREFILL_CUH_
#define FLASHINFER_PREFILL_CUH_

#pragma once

#include "gpu_iface/cooperative_groups.h"
#include "gpu_iface/fastdiv.cuh"
Expand Down Expand Up @@ -1768,6 +1768,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice(
(chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * k_stride_n +
kv_head_idx * k_stride_h +
(lane_idx % KV_THR_LAYOUT_COL) * upcast_size<DTypeKV, VECTOR_BIT_WIDTH>();

DTypeKV* v_ptr =
v +
(chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * v_stride_n +
Expand Down Expand Up @@ -2057,6 +2058,11 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV
[[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW;
[[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL;
[[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE;
[[maybe_unused]] constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD;
[[maybe_unused]] constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD =
KTraits::NUM_ACCUM_ROWS_PER_THREAD;
[[maybe_unused]] constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET =
KTraits::THREADS_PER_BMATRIX_ROW_SET;
[[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH;

DTypeQ* q = params.q;
Expand Down Expand Up @@ -2093,6 +2099,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV
const uint32_t num_kv_heads = gridDim.z, num_qo_heads = group_size * num_kv_heads;
const uint32_t request_idx = request_indices[bx], qo_tile_idx = qo_tile_indices[bx],
kv_tile_idx = kv_tile_indices[bx];

extern __shared__ uint8_t smem[];
auto& smem_storage = reinterpret_cast<typename KTraits::SharedStorage&>(smem);
AttentionVariant variant(params, /*batch_idx=*/request_idx, smem);
Expand All @@ -2107,10 +2114,10 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV
const uint32_t qo_upper_bound =
min(qo_len, ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size));

DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8];
alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8];
DTypeQKAccum m[NUM_MMA_Q][2];
float d[NUM_MMA_Q][2];
DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD];
alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD];
DTypeQKAccum m[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD];
float d[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD];
float rope_freq[NUM_MMA_D_QK / 2][4];

if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) {
Expand All @@ -2122,7 +2129,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV

const uint32_t qo_packed_idx_base =
(qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q<KTraits>(tid.y)) * NUM_MMA_Q * 16;
smem_t<SWIZZLE_MODE_Q> qo_smem(smem_storage.q_smem);
smem_t<SWIZZLE_MODE_KV, typename KTraits::SmemBasePtrTy> qo_smem(smem_storage.q_smem);
const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO;

DTypeQ* q_ptr_base =
Expand Down Expand Up @@ -2181,15 +2188,24 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV
: chunk_size) /
CTA_TILE_KV;

smem_t<SWIZZLE_MODE_KV> k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem);

smem_t<SWIZZLE_MODE_KV, typename KTraits::SmemBasePtrTy> k_smem(smem_storage.k_smem);
smem_t<SWIZZLE_MODE_KV, typename KTraits::SmemBasePtrTy> v_smem(smem_storage.v_smem);
#if defined(PLATFORM_HIP_DEVICE)
uint32_t k_smem_offset_r = k_smem.template get_permuted_offset<UPCAST_STRIDE_K>(
get_warp_idx_kv<KTraits>(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) +
lane_idx % 8,
(lane_idx % 16) / 8),
v_smem_offset_r = v_smem.template get_permuted_offset<UPCAST_STRIDE_V>(
get_warp_idx_kv<KTraits>(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16),
k_smem_offset_w = k_smem.template get_permuted_offset<UPCAST_STRIDE_K>(
get_warp_idx_kv<KTraits>(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, (lane_idx / 16));

uint32_t v_smem_offset_r = v_smem.template get_permuted_offset<UPCAST_STRIDE_V>(
get_warp_idx_kv<KTraits>(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16);

#elif defined(PLATFORM_CUDA_DEVICE)
uint32_t k_smem_offset_r = k_smem.template get_permuted_offset<UPCAST_STRIDE_K>(
get_warp_idx_kv<KTraits>(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) +
lane_idx % 8,
(lane_idx % 16) / 8),
v_smem_offset_r = v_smem.template get_permuted_offset<UPCAST_STRIDE_V>(
get_warp_idx_kv<KTraits>(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16);
#endif
uint32_t k_smem_offset_w = k_smem.template get_permuted_offset<UPCAST_STRIDE_K>(
warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL,
lane_idx % KV_THR_LAYOUT_COL),
v_smem_offset_w = v_smem.template get_permuted_offset<UPCAST_STRIDE_V>(
Expand All @@ -2214,6 +2230,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV
memory::commit_group();
produce_kv<true, SharedMemFillMode::kFillZero, KTraits>(v_smem, &v_smem_offset_w, &v_ptr,
v_stride_n, 0, chunk_size, tid);

memory::commit_group();

#pragma unroll 1
Expand Down Expand Up @@ -2293,9 +2310,13 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV
#pragma unroll
for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) {
uint32_t q, r;
group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16, q, r);
group_size.divmod(
qo_packed_idx_base +
(lane_idx / THREADS_PER_BMATRIX_ROW_SET) * NUM_ACCUM_ROWS_PER_THREAD + j +
mma_q * 16,
q, r);
const uint32_t qo_head_idx = kv_head_idx * group_size + r;
const uint32_t qo_idx = q;
if (qo_idx < qo_len) {
Expand Down Expand Up @@ -2824,5 +2845,3 @@ gpuError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params
}

} // namespace flashinfer

#endif // FLASHINFER_PREFILL_CUH_
1 change: 1 addition & 0 deletions libflashinfer/tests/hip/test_batch_prefill.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// SPDX-FileCopyrightText: 2023-2025 Flashinfer team
// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc.
//
// SPDX-License-Identifier: Apache 2.0
Expand Down