From 2c7b41c20d6478de8e501b447556e95726d7beed Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 21 Oct 2025 22:17:46 -0500 Subject: [PATCH 01/13] Initial CDNA3 single prefill kernel using WFMA intrinsic. Ports the SinglePrefillWithKVCacheDevice kernel to HIP along with using CDNA3 MFMA intrinsics. The following kernels have been ported: - `load_q_global_smem` - `produce_kv` - `compute_qk` - `update_mdo_states` - `compute_sfm_v` Unit test source is `/libflashinfer/tests/hip/test_single_prefill.cpp` --- .../generic/default_prefill_params.cuh | 30 +- .../flashinfer/attention/generic/dispatch.cuh | 216 +++ .../flashinfer/attention/generic/page.cuh | 1 - .../attention/generic/permuted_smem.cuh | 49 +- .../flashinfer/attention/generic/prefill.cuh | 1239 +++++++++++------ .../backend/hip/mma_debug_utils_hip.h | 280 +++- .../include/gpu_iface/backend/hip/mma_hip.h | 21 +- libflashinfer/include/gpu_iface/mma_ops.hpp | 13 +- .../tests/hip/test_single_prefill.cpp | 619 ++++++++ libflashinfer/utils/cpu_reference_hip.h | 15 +- .../utils/flashinfer_prefill_ops.hip.h | 122 ++ libflashinfer/utils/utils_hip.h | 37 +- 12 files changed, 2084 insertions(+), 558 deletions(-) create mode 100644 libflashinfer/include/flashinfer/attention/generic/dispatch.cuh create mode 100644 libflashinfer/tests/hip/test_single_prefill.cpp create mode 100644 libflashinfer/utils/flashinfer_prefill_ops.hip.h diff --git a/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh index 65007cb94f..e8014f199e 100644 --- a/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh @@ -1,27 +1,15 @@ -/* - * Copyright (c) 2024 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// SPDX - FileCopyrightText : 2023-2025 FlashInfer team. +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache - 2.0 #ifndef FLASHINFER_PREFILL_PARAMS_CUH_ #define FLASHINFER_PREFILL_PARAMS_CUH_ -#include - #include #include -#include "../page.cuh" +#include "gpu_iface/gpu_runtime_compat.hpp" +#include "page.cuh" namespace flashinfer { @@ -39,10 +27,10 @@ struct SinglePrefillParams { float* lse; float* maybe_alibi_slopes; uint_fastdiv group_size; - uint32_t qo_len; - uint32_t kv_len; uint32_t num_qo_heads; uint32_t num_kv_heads; + uint32_t qo_len; + uint32_t kv_len; uint32_t q_stride_n; uint32_t q_stride_h; uint32_t k_stride_n; @@ -388,4 +376,4 @@ struct BatchPrefillPagedParams { } // namespace flashinfer -#endif // FLASHINFER_DECODE_PARAMS_CUH_ +#endif // FLASHINFER_PREFILL_PARAMS_CUH_ diff --git a/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh b/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh new file mode 100644 index 0000000000..abe0a3020e --- /dev/null +++ b/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh @@ -0,0 +1,216 @@ +// SPDX - FileCopyrightText : 2023-2035 FlashInfer team. +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache 2.0 + +#pragma once + +#include "gpu_iface/enums.hpp" +#include "gpu_iface/exception.h" + +#define DISPATCH_USE_FP16_QK_REDUCTION(use_fp16_qk_reduction, USE_FP16_QK_REDUCTION, ...) \ + if (use_fp16_qk_reduction) { \ + FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \ + } else { \ + constexpr bool USE_FP16_QK_REDUCTION = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_NUM_MMA_Q(num_mma_q, NUM_MMA_Q, ...) \ + if (num_mma_q == 1) { \ + constexpr size_t NUM_MMA_Q = 1; \ + __VA_ARGS__ \ + } else if (num_mma_q == 2) { \ + constexpr size_t NUM_MMA_Q = 2; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported num_mma_q: " << num_mma_q; \ + FLASHINFER_ERROR(err_msg.str()); \ + } + +#define DISPATCH_NUM_MMA_KV(max_mma_kv, NUM_MMA_KV, ...) \ + if (max_mma_kv >= 8) { \ + constexpr size_t NUM_MMA_KV = 8; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 4) { \ + constexpr size_t NUM_MMA_KV = 4; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 2) { \ + constexpr size_t NUM_MMA_KV = 2; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 1) { \ + constexpr size_t NUM_MMA_KV = 1; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported max_mma_kv: " << max_mma_kv; \ + FLASHINFER_ERROR(err_msg.str()); \ + } + +#define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \ + switch (cta_tile_q) { \ + case 128: { \ + constexpr uint32_t CTA_TILE_Q = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 64: { \ + constexpr uint32_t CTA_TILE_Q = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 16: { \ + constexpr uint32_t CTA_TILE_Q = 16; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ + } else if (group_size == 3) { \ + constexpr size_t GROUP_SIZE = 3; \ + __VA_ARGS__ \ + } else if (group_size == 4) { \ + constexpr size_t GROUP_SIZE = 4; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported group_size: " << group_size; \ + FLASHINFER_ERROR(err_msg.str()); \ + } + +#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ + switch (mask_mode) { \ + case MaskMode::kNone: { \ + constexpr MaskMode MASK_MODE = MaskMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCausal: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCustom: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCustom; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported mask_mode: " << int(mask_mode); \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +// convert head_dim to compile-time constant +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + switch (head_dim) { \ + case 64: { \ + constexpr size_t HEAD_DIM = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 128: { \ + constexpr size_t HEAD_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 256: { \ + constexpr size_t HEAD_DIM = 256; \ + __VA_ARGS__ \ + break; \ + } \ + case 512: { \ + constexpr size_t HEAD_DIM = 512; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported head_dim: " << head_dim; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \ + switch (pos_encoding_mode) { \ + case PosEncodingMode::kNone: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case PosEncodingMode::kRoPELlama: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kRoPELlama; \ + __VA_ARGS__ \ + break; \ + } \ + case PosEncodingMode::kALiBi: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ + switch (aligned_vec_size) { \ + case 16: { \ + constexpr size_t ALIGNED_VEC_SIZE = 16; \ + __VA_ARGS__ \ + break; \ + } \ + case 8: { \ + constexpr size_t ALIGNED_VEC_SIZE = 8; \ + __VA_ARGS__ \ + break; \ + } \ + case 4: { \ + constexpr size_t ALIGNED_VEC_SIZE = 4; \ + __VA_ARGS__ \ + break; \ + } \ + case 2: { \ + constexpr size_t ALIGNED_VEC_SIZE = 2; \ + __VA_ARGS__ \ + break; \ + } \ + case 1: { \ + constexpr size_t ALIGNED_VEC_SIZE = 1; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +#define DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, ...) \ + if (compute_capacity.first >= 8) { \ + constexpr uint32_t NUM_STAGES_SMEM = 2; \ + __VA_ARGS__ \ + } else { \ + constexpr uint32_t NUM_STAGES_SMEM = 1; \ + __VA_ARGS__ \ + } diff --git a/libflashinfer/include/flashinfer/attention/generic/page.cuh b/libflashinfer/include/flashinfer/attention/generic/page.cuh index 28ed38bebc..14df54be80 100644 --- a/libflashinfer/include/flashinfer/attention/generic/page.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/page.cuh @@ -262,7 +262,6 @@ __global__ void AppendPagedKVCacheKernel(paged_kv_t paged_kv, size_t append_k_stride_n, size_t append_k_stride_h, size_t append_v_stride_n, size_t append_v_stride_h) { uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t num_heads = paged_kv.num_heads; uint32_t head_idx = ty; uint32_t cta_id = blockIdx.x; uint32_t num_ctas = gridDim.x; diff --git a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh index 2fd12b924f..1045a0bda4 100644 --- a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh @@ -10,12 +10,6 @@ #include "gpu_iface/mma_ops.hpp" #include "gpu_iface/platform.hpp" -#if 0 -#include - -#include "mma.cuh" -#endif - namespace gpu_mem = flashinfer::gpu_iface::memory; namespace flashinfer { @@ -138,6 +132,7 @@ struct smem_t { #endif } +#if defined(PLATFORM_HIP_DEVICE) /*! * \brief Loads a fragment from shared memory and performs an in-register transpose across a quad. * \details This function is designed to prepare the B-matrix operand for a CDNA3 MFMA @@ -172,14 +167,11 @@ struct smem_t { * \param frag A pointer to the thread's local registers to store the resulting column fragment. */ template - __device__ __forceinline__ void load_fragment_and_quad_transpose(uint32_t offset, T* frag) { -#if defined(PLATFORM_HIP_DEVICE) - auto smem_t_ptr = reinterpret_cast(base + offset); - flashinfer::gpu_iface::mma::load_quad_transposed_fragment(frag, smem_t_ptr); -#else - static_assert(sizeof(T) == 0, "Not supported on current platform"); -#endif + __device__ __forceinline__ void load_matrix_m16n16_trans(uint32_t offset, T* frag) { + load_fragment(offset, frag); + gpu_iface::mma::transpose_mma_tile(frag); } +#endif template __device__ __forceinline__ void store_fragment(uint32_t offset, const T* frag) { @@ -191,41 +183,42 @@ struct smem_t { #endif } +#if defined(PLATFORM_CUDA_DEVICE) __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t offset, uint32_t* R) { - // b128_t *smem_ptr = base + offset; - // mma::ldmatrix_m8n8x4(R, smem_ptr); + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4(R, smem_ptr); } __device__ __forceinline__ void ldmatrix_m8n8x4_left_half(uint32_t offset, uint32_t* R) { - // b128_t *smem_ptr = base + offset; - // mma::ldmatrix_m8n8x4_left_half(R, smem_ptr); + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_left_half(R, smem_ptr); } __device__ __forceinline__ void ldmatrix_m8n8x4_right_half(uint32_t offset, uint32_t* R) { - // b128_t *smem_ptr = base + offset; - // mma::ldmatrix_m8n8x4_right_half(R, smem_ptr); + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_right_half(R, smem_ptr); } __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t offset, uint32_t* R) { - // b128_t *smem_ptr = base + offset; - // mma::stmatrix_m8n8x4(R, smem_ptr); + b128_t* smem_ptr = base + offset; + mma::stmatrix_m8n8x4(R, smem_ptr); } __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t offset, uint32_t* R) { - // b128_t *smem_ptr = base + offset; - // mma::ldmatrix_m8n8x4_trans(R, smem_ptr); + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_trans(R, smem_ptr); } __device__ __forceinline__ void ldmatrix_m8n8x4_trans_left_half(uint32_t offset, uint32_t* R) { - // b128_t *smem_ptr = base + offset; - // mma::ldmatrix_m8n8x4_trans_left_half(R, smem_ptr); + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_trans_left_half(R, smem_ptr); } __device__ __forceinline__ void ldmatrix_m8n8x4_trans_right_half(uint32_t offset, uint32_t* R) { - // b128_t *smem_ptr = base + offset; - // mma::ldmatrix_m8n8x4_trans_right_half(R, smem_ptr); + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_trans_right_half(R, smem_ptr); } - +#endif template __device__ __forceinline__ void load_128b_async(uint32_t offset, const T* gptr, bool predicate) { b128_t* smem_ptr = base + offset; diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 19200b3792..b6d4d8e87a 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1,52 +1,44 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// SPDX - FileCopyrightText : 2023-2025 FlashInfer team. +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache - 2.0 #ifndef FLASHINFER_PREFILL_CUH_ #define FLASHINFER_PREFILL_CUH_ -#include -#include -#include -#include -#include +#include "gpu_iface/cooperative_groups.h" +#include "gpu_iface/fastdiv.cuh" +#include "gpu_iface/math_ops.hpp" +#include "gpu_iface/memory_ops.hpp" +#include "gpu_iface/mma_ops.hpp" +#include "gpu_iface/platform.hpp" +#include "gpu_iface/utils.cuh" -#include "../cp_async.cuh" -#include "../fastdiv.cuh" #ifdef FP16_QK_REDUCTION_SUPPORTED -#include "../fp16.h" +#include "../../fp16.h" #endif -#include "../frag_layout_swizzle.cuh" -#include "../math.cuh" -#include "../mma.cuh" -#include "../page.cuh" -#include "../permuted_smem.cuh" -#include "../pos_enc.cuh" -#include "../utils.cuh" +#include + #include "cascade.cuh" -#include "mask.cuh" +#include "dispatch.cuh" +#include "frag_layout_swizzle.cuh" +#include "page.cuh" +#include "permuted_smem.cuh" +#include "pos_enc.cuh" #include "variants.cuh" + namespace flashinfer { DEFINE_HAS_MEMBER(maybe_q_rope_offset) DEFINE_HAS_MEMBER(maybe_k_rope_offset) -namespace cg = cooperative_groups; -using cp_async::SharedMemFillMode; +namespace cg = flashinfer::gpu_iface::cg; +namespace memory = flashinfer::gpu_iface::memory; +namespace mma = gpu_iface::mma; + +using gpu_iface::vec_dtypes::vec_cast; using mma::MMAMode; -constexpr uint32_t WARP_SIZE = 32; +constexpr uint32_t WARP_SIZE = gpu_iface::kWarpSize; constexpr uint32_t get_num_warps_q(const uint32_t cta_tile_q) { if (cta_tile_q > 16) { @@ -101,23 +93,13 @@ struct KernelTraits { static constexpr uint32_t NUM_MMA_D_VO = NUM_MMA_D_VO_; static constexpr uint32_t NUM_WARPS_Q = NUM_WARPS_Q_; static constexpr uint32_t NUM_WARPS_KV = NUM_WARPS_KV_; - static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * WARP_SIZE; static constexpr uint32_t NUM_WARPS = NUM_WARPS_Q * NUM_WARPS_KV; static constexpr uint32_t HEAD_DIM_QK = NUM_MMA_D_QK * 16; static constexpr uint32_t HEAD_DIM_VO = NUM_MMA_D_VO * 16; - static constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM_QK / upcast_size(); - static constexpr uint32_t UPCAST_STRIDE_K = HEAD_DIM_QK / upcast_size(); - static constexpr uint32_t UPCAST_STRIDE_V = HEAD_DIM_VO / upcast_size(); - static constexpr uint32_t UPCAST_STRIDE_O = HEAD_DIM_VO / upcast_size(); static constexpr uint32_t CTA_TILE_Q = CTA_TILE_Q_; static constexpr uint32_t CTA_TILE_KV = NUM_MMA_KV * NUM_WARPS_KV * 16; - - static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::k128B; - static constexpr SwizzleMode SWIZZLE_MODE_KV = - (sizeof(DTypeKV_) == 1 && HEAD_DIM_VO == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; - static constexpr uint32_t KV_THR_LAYOUT_ROW = SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 8; - static constexpr uint32_t KV_THR_LAYOUT_COL = SWIZZLE_MODE_KV == SwizzleMode::k128B ? 8 : 4; static constexpr PosEncodingMode POS_ENCODING_MODE = POS_ENCODING_MODE_; + using DTypeQ = DTypeQ_; using DTypeKV = DTypeKV_; using DTypeO = DTypeO_; @@ -125,6 +107,72 @@ struct KernelTraits { using IdType = IdType_; using AttentionVariant = AttentionVariant_; +#if defined(PLATFORM_HIP_DEVICE) + static_assert(sizeof(DTypeKV_) != 1, "8-bit types not supported for CDNA3"); + + using SmemBasePtrTy = uint2; + static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 64; + static constexpr uint32_t WARP_THREAD_ROWS = 4; + static constexpr uint32_t WARP_THREAD_COLS = 16; + static constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; + static constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; + static constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; + // FIXME: Update with a proper swizzle pattern. Linear is used primarily + // for intial testing. + static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::kLinear; + static constexpr SwizzleMode SWIZZLE_MODE_KV = SwizzleMode::kLinear; + + // Presently we use 16x4 thread layout for all cases. + static constexpr uint32_t KV_THR_LAYOUT_ROW = WARP_THREAD_ROWS; + static constexpr uint32_t KV_THR_LAYOUT_COL = WARP_THREAD_COLS; + // FIXME: [The comment is not correct] The constant is defined based on the + // matrix layout of the "D/C" accumulator matrix in a D = A*B+C computation. + // On CDNA3 the D/C matrices are distributed as four 4x16 bands across the + // 64 threads. Each thread owns one element from four different rows. + static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 4; + // Number of threads that collaboratively handle the same set of matrix rows + // in attention score computation and cross-warp synchronization. + // CUDA: 4 threads (each thread handles 2 elements from same row group) + // CDNA3: 16 threads (each thread handles 1 element from same row group) + static constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = 16; + // controls the indexing stride used in logits-related functions + // (logits_transform, logits_mask, and LSE writing). + static constexpr uint32_t LOGITS_INDEX_STRIDE = 4; +#else + using SmemBasePtrTy = uint4; + static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 32; + constexpr uint32_t WARP_THREAD_ROWS = 4; + constexpr uint32_t WARP_THREAD_COLS = 8; + constexpr uint32_t HALF_ELEMS_PER_THREAD = 8; + constexpr uint32_t INT32_ELEMS_PER_THREAD = 4; + constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; + + static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::k128B; + static constexpr SwizzleMode SWIZZLE_MODE_KV = + (sizeof(DTypeKV_) == 1 && HEAD_DIM_VO == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + static constexpr uint32_t KV_THR_LAYOUT_ROW = + SWIZZLE_MODE_KV == SwizzleMode::k128B ? WARP_THREAD_ROWS : WARP_THREAD_COLS; + static constexpr uint32_t KV_THR_LAYOUT_COL = SWIZZLE_MODE_KV == SwizzleMode::k128B ? 8 : 4; + + // The constant is defined based on the matrix layout of the "D/C" + // accumulator matrix in a D = A*B+C computation. On CUDA for + // m16n8k16 mma ops the D/C matrix is distributed as 4 8x8 block and each + // thread stores eight elements from two different rows. + // Refer: + // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-i8-f8 + static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 2; + static constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = 4; + static constexpr uint32_t LOGITS_INDEX_STRIDE = 8; +#endif + static constexpr uint32_t UPCAST_STRIDE_Q = + HEAD_DIM_QK / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_K = + HEAD_DIM_QK / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_V = + HEAD_DIM_VO / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_O = + HEAD_DIM_VO / upcast_size(); + static constexpr bool IsInvalid() { return ((NUM_MMA_D_VO < 4) || (NUM_MMA_D_VO == 4 && NUM_MMA_KV % 2 == 1) || (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && NUM_MMA_D_VO > 4 && @@ -140,9 +188,9 @@ struct KernelTraits { template static constexpr DT getNegInf() { if constexpr (std::is_same::value) { - return std::bit_cast(fp16_ieee_from_fp32_value(-math::inf)); + return std::bit_cast(fp16_ieee_from_fp32_value(-gpu_iface::math::inf)); } else { - return static_cast(-math::inf); + return static_cast(-gpu_iface::math::inf); } } @@ -153,7 +201,7 @@ struct KernelTraits { "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " "then recompile to support fp16 reduction"); static constexpr DTypeQKAccum MaskFillValue = - AttentionVariant::use_softmax ? DTypeQKAccum(-math::inf) : DTypeQKAccum(0.f); + AttentionVariant::use_softmax ? DTypeQKAccum(-gpu_iface::math::inf) : DTypeQKAccum(0.f); #endif }; @@ -194,18 +242,23 @@ __device__ __forceinline__ uint32_t get_warp_idx(const uint32_t tid_y = threadId * \note The sin/cos computation is slow, especially for A100 GPUs which has low * non tensor-ops flops, will optimize in the future. */ -template +template __device__ __forceinline__ void k_frag_apply_llama_rope(T* x_first_half, T* x_second_half, const float* rope_freq, const uint32_t kv_offset) { static_assert(sizeof(T) == 2); #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { float cos, sin, tmp; // 0 1 | 2 3 // --------- // 4 5 | 6 7 + +#if defined(PLATFORM_HIP_DEVICE) + uint32_t i = reg_id / 2, j = reg_id % 2; +#else uint32_t i = reg_id / 4, j = (reg_id % 4) / 2; +#endif __sincosf(float(kv_offset + 8 * i) * rope_freq[2 * j + reg_id % 2], &sin, &cos); tmp = x_first_half[reg_id]; x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); @@ -213,27 +266,33 @@ __device__ __forceinline__ void k_frag_apply_llama_rope(T* x_first_half, T* x_se } } -template +template __device__ __forceinline__ void q_frag_apply_llama_rope(T* x_first_half, T* x_second_half, const float* rope_freq, const uint32_t qo_packed_offset, const uint_fastdiv group_size) { #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { float cos, sin, tmp; +#if defined(PLATFORM_HIP_DEVICE) + uint32_t freq_idx = reg_id; + uint32_t position = qo_packed_offset; +#else // 0 1 | 4 5 // --------- // 2 3 | 6 7 uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); - __sincosf(float((qo_packed_offset + 8 * i) / group_size) * rope_freq[2 * j + reg_id % 2], &sin, - &cos); + uint32_t freq_idx = 2 * j + reg_id % 2; + uint32_t position = qo_packed_offset + 8 * i; +#endif + __sincosf(float(position / group_size) * rope_freq[freq_idx], &sin, &cos); tmp = x_first_half[reg_id]; x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin); } } -template +template __device__ __forceinline__ void q_frag_apply_llama_rope_with_pos(T* x_first_half, T* x_second_half, const float* rope_freq, const uint32_t qo_packed_offset, @@ -242,12 +301,18 @@ __device__ __forceinline__ void q_frag_apply_llama_rope_with_pos(T* x_first_half float pos[2] = {static_cast(q_rope_offset[qo_packed_offset / group_size]), static_cast(q_rope_offset[(qo_packed_offset + 8) / group_size])}; #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { float cos, sin, tmp; // 0 1 | 4 5 // --------- // 2 3 | 6 7 - uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); +#if defined(PLATFORM_HIP_DEVICE) + const uint32_t i = reg_id / 2; + const uint32_t j = reg_id % 2; +#else + const uint32_t i = (reg_id % 4) / 2; + const uint32_t j = reg_id / 4; +#endif __sincosf(pos[i] * rope_freq[2 * j + reg_id % 2], &sin, &cos); tmp = x_first_half[reg_id]; x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); @@ -255,54 +320,41 @@ __device__ __forceinline__ void q_frag_apply_llama_rope_with_pos(T* x_first_half } } -/*! - * \brief Produce k/v fragments from global memory to shared memory. - * \tparam fill_mode The fill mode of the shared memory. - * \tparam NUM_MMA_D_VO The number of fragments in y dimension. - * \tparam NUM_MMA_KV The number of fragments in z dimension. - * \tparam num_warps The number of warps in the threadblock. - * \tparam T The data type of the input tensor. - * \param smem The shared memory to store kv fragments. - * \param gptr The global memory pointer. - * \param kv_idx_base The base kv index. - * \param kv_len The length of kv tensor. - */ -template -__device__ __forceinline__ void produce_kv(smem_t smem, - uint32_t* smem_offset, typename KTraits::DTypeKV** gptr, - const uint32_t stride_n, const uint32_t kv_idx_base, - const uint32_t kv_len, const dim3 tid = threadIdx) { - // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment +template +__device__ __forceinline__ void produce_kv_impl_cuda_( + uint32_t warp_idx, uint32_t lane_idx, + smem_t smem, uint32_t* smem_offset, + typename KTraits::DTypeKV** gptr, const uint32_t stride_n, const uint32_t kv_idx_base, + const uint32_t kv_len) { using DTypeKV = typename KTraits::DTypeKV; - constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; constexpr uint32_t UPCAST_STRIDE = produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; - const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; - // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / - // num_warps + // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / num_warps static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { - smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); - *gptr += 8 * upcast_size(); + *gptr += 8 * upcast_size(); } kv_idx += NUM_WARPS * 4; *smem_offset = smem.template advance_offset_by_row(*smem_offset) - sizeof(DTypeKV) * NUM_MMA_D; - *gptr += NUM_WARPS * 4 * stride_n - sizeof(DTypeKV) * NUM_MMA_D * upcast_size(); + *gptr += NUM_WARPS * 4 * stride_n - + sizeof(DTypeKV) * NUM_MMA_D * upcast_size(); } - *smem_offset -= CTA_TILE_KV * UPCAST_STRIDE; + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; } else { uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / @@ -310,7 +362,7 @@ __device__ __forceinline__ void produce_kv(smem_t smem static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { - smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_row(*smem_offset); kv_idx += NUM_WARPS * 8; @@ -320,15 +372,81 @@ __device__ __forceinline__ void produce_kv(smem_t smem } } +template +__device__ __forceinline__ void produce_kv_impl_cdna3_( + uint32_t warp_idx, uint32_t lane_idx, + smem_t smem, uint32_t* smem_offset, + typename KTraits::DTypeKV** gptr, const uint32_t stride_n, const uint32_t kv_idx_base, + const uint32_t kv_len) { + static_assert(KTraits::SWIZZLE_MODE_KV == SwizzleMode::kLinear); + using DTypeKV = typename KTraits::DTypeKV; + constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; // 16 + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + // NOTE: NUM_MMA_KV*4/NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / KV_THR_LAYOUT_COL; + +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { +#pragma unroll + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { + smem.template load_vector_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = smem.template advance_offset_by_column<16>(*smem_offset, j); + *gptr += 16 * upcast_size(); + } + kv_idx += NUM_WARPS * 4; + *smem_offset = smem.template advance_offset_by_row(*smem_offset) - + (sizeof(DTypeKV) * NUM_MMA_D * 2); + *gptr += NUM_WARPS * 4 * stride_n - + sizeof(DTypeKV) * NUM_MMA_D * 2 * upcast_size(); + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; +} + +/*! + * \brief Produce k/v fragments from global memory to shared memory. + * \tparam fill_mode The fill mode of the shared memory. + * \tparam NUM_MMA_D_VO The number of fragments in y dimension. + * \tparam NUM_MMA_KV The number of fragments in z dimension. + * \tparam num_warps The number of warps in the threadblock. + * \tparam T The data type of the input tensor. + * \param smem The shared memory to store kv fragments. + * \param gptr The global memory pointer. + * \param kv_idx_base The base kv index. + * \param kv_len The length of kv tensor. + */ +template +__device__ __forceinline__ void produce_kv( + smem_t smem, uint32_t* smem_offset, + typename KTraits::DTypeKV** gptr, const uint32_t stride_n, const uint32_t kv_idx_base, + const uint32_t kv_len, const dim3 tid = threadIdx) { + // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment + const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; + +#if defined(PLATFORM_HIP_DEVICE) + produce_kv_impl_cdna3_(warp_idx, lane_idx, smem, smem_offset, gptr, + stride_n, kv_idx_base, kv_len); +#elif defined(PLATFORM_CUDA_DEVICE) + produce_kv_impl_cuda_(warp_idx, lane_idx, smem, smem_offset, gptr, + stride_n, kv_idx_base, kv_len); +#endif +} + template __device__ __forceinline__ void page_produce_kv( - smem_t smem, uint32_t* smem_offset, + smem_t smem, uint32_t* smem_offset, const paged_kv_t& paged_kv, const uint32_t kv_idx_base, const size_t* thr_local_kv_offset, const uint32_t kv_len, const dim3 tid = threadIdx) { // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment using DType = typename KTraits::DTypeKV; - using IdType = typename KTraits::IdType; constexpr SharedMemFillMode fill_mode = produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill; constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; @@ -337,11 +455,12 @@ __device__ __forceinline__ void page_produce_kv( constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; constexpr uint32_t UPCAST_STRIDE = produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; - // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / - // num_warps + // NOTE: NUM_MMA_KV * 4/NUM_WARPS_Q=NUM_WARPS_KV*NUM_MMA_KV*4/num_warps static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { @@ -349,9 +468,9 @@ __device__ __forceinline__ void page_produce_kv( : paged_kv.k_data + thr_local_kv_offset[i]; #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DType)); ++j) { - smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); + smem.template load_vector_async(*smem_offset, gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); - gptr += 8 * upcast_size(); + gptr += 8 * upcast_size(); } kv_idx += NUM_WARPS * 4; *smem_offset = @@ -361,14 +480,13 @@ __device__ __forceinline__ void page_produce_kv( *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; } else { uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; - // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / - // num_warps + // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / num_warps static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { DType* gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] : paged_kv.k_data + thr_local_kv_offset[i]; - smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); + smem.template load_vector_async(*smem_offset, gptr, kv_idx < kv_len); kv_idx += NUM_WARPS * 8; *smem_offset = smem.template advance_offset_by_row(*smem_offset); @@ -377,36 +495,71 @@ __device__ __forceinline__ void page_produce_kv( } } +template +__device__ __forceinline__ uint32_t get_feature_index(uint32_t mma_d, uint32_t lane_idx, + uint32_t j) { +#if defined(PLATFORM_HIP_DEVICE) + // CDNA3 A-matrix MMA tile to thread mapping for a 64-thread wavefront: + // Each group of 16 threads handles the same four consecutive features for + // different sequences: + // T0-T15: Features [0,1,2,3] for sequences 0-15 respectively + // T16-T31: Features [4,5,6,7] for sequences 0-15 respectively + // T32-T47: Features [8,9,10,11] for sequences 0-15 respectively + // T48-T63: Features [12,13,14,15] for sequences 0-15 respectively + // + uint32_t feature_index = (mma_d * 16 + (lane_idx / 4) + j) % (HEAD_DIM / 2); +#elif defined(PLATFORM_CUDA_DEVICE) + // CUDA A-matrix MMA tile to thread mapping for a 32 thread warp: + // Each group of four consecutive threads map four different features for + // the same sequence. + // T0: {0,1,8,9}, T1: {2,3,10,11}, T2: {4,5,12,13}, T3: {6,7,14,15} + // + // The pattern repeats across 8 rows with each row mapped to a set of four + // consecutive threads. + // row 0 --> T0, T1, T2, T3 + // row 1 --> T4, T5, T6, T7 + // ... + // row 7 --> T28, T29, T30, T31 + // The full data to thread mapping repeats again for the next set of 16 + // rows. Thereby, forming a 16x16 MMA tile dubdivided into four 8x8 + // quadrants. + uint32_t feature_index = + ((mma_d * 16 + (j / 2) * 8 + (lane_idx % 4) * 2 + (j % 2)) % (HEAD_DIM / 2)); +#endif + return feature_index; +} + template __device__ __forceinline__ void init_rope_freq(float (*rope_freq)[4], const float rope_rcp_scale, const float rope_rcp_theta, const uint32_t tid_x = threadIdx.x) { constexpr uint32_t HEAD_DIM = KTraits::NUM_MMA_D_QK * 16; const uint32_t lane_idx = tid_x; + #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { #pragma unroll for (uint32_t j = 0; j < 4; ++j) { - rope_freq[mma_d][j] = - rope_rcp_scale * - __powf(rope_rcp_theta, - float(2 * ((mma_d * 16 + (j / 2) * 8 + (lane_idx % 4) * 2 + (j % 2)) % - (HEAD_DIM / 2))) / - float(HEAD_DIM)); + uint32_t feature_index = get_feature_index(mma_d, lane_idx, j); + float freq_base = float(2 * feature_index) / float(HEAD_DIM); + rope_freq[mma_d][j] = rope_rcp_scale * __powf(rope_rcp_theta, freq_base); } } } template -__device__ __forceinline__ void init_states(typename KTraits::AttentionVariant variant, - float (*o_frag)[KTraits::NUM_MMA_D_VO][8], - typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2]) { +__device__ __forceinline__ void init_states( + typename KTraits::AttentionVariant variant, + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { + constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = KTraits::NUM_ACCUM_ROWS_PER_THREAD; #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { o_frag[mma_q][mma_d][reg_id] = 0.f; } } @@ -416,8 +569,8 @@ __device__ __forceinline__ void init_states(typename KTraits::AttentionVariant v #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - m[mma_q][j] = typename KTraits::DTypeQKAccum(-math::inf); + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { + m[mma_q][j] = typename KTraits::DTypeQKAccum(-gpu_iface::math::inf); d[mma_q][j] = 1.f; } } @@ -428,35 +581,51 @@ template __device__ __forceinline__ void load_q_global_smem( uint32_t packed_offset, const uint32_t qo_upper_bound, typename KTraits::DTypeQ* q_ptr_base, const uint32_t q_stride_n, const uint32_t q_stride_h, const uint_fastdiv group_size, - smem_t* q_smem, const dim3 tid = threadIdx) { + smem_t* q_smem, + const dim3 tid = threadIdx) { using DTypeQ = typename KTraits::DTypeQ; + constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; + constexpr uint32_t WARP_THREAD_ROWS = KTraits::WARP_THREAD_ROWS; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + +#if defined(PLATFORM_HIP_DEVICE) + constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D_QK / 4) * WARP_THREAD_COLS; +#else + constexpr uint32_t COLUMN_RESET_OFFSET = 2 * KTraits::NUM_MMA_D_QK; +#endif + const uint32_t lane_idx = tid.x, warp_idx_x = get_warp_idx_q(tid.y); + uint32_t row = lane_idx / WARP_THREAD_COLS; + uint32_t col = lane_idx % WARP_THREAD_COLS; if (get_warp_idx_kv(tid.z) == 0) { - uint32_t q_smem_offset_w = q_smem->get_permuted_offset( - warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); + uint32_t q_smem_offset_w = q_smem->template get_permuted_offset( + warp_idx_x * KTraits::NUM_MMA_Q * 16 + row, col); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2 * 2; ++j) { uint32_t q, r; - group_size.divmod(packed_offset + lane_idx / 8 + mma_q * 16 + j * 4, q, r); + group_size.divmod(packed_offset + row + mma_q * 16 + j * 4, q, r); const uint32_t q_idx = q; - DTypeQ* q_ptr = - q_ptr_base + q * q_stride_n + r * q_stride_h + (lane_idx % 8) * upcast_size(); + DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h + + col * upcast_size(); #pragma unroll for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; ++mma_do) { // load q fragment from gmem to smem - q_smem->load_128b_async(q_smem_offset_w, q_ptr, - q_idx < qo_upper_bound); - q_smem_offset_w = q_smem->template advance_offset_by_column<8>(q_smem_offset_w, mma_do); - q_ptr += 8 * upcast_size(); + q_smem->template load_vector_async(q_smem_offset_w, q_ptr, + q_idx < qo_upper_bound); + q_smem_offset_w = + q_smem->template advance_offset_by_column(q_smem_offset_w, mma_do); + q_ptr += HALF_ELEMS_PER_THREAD * upcast_size(); } - q_smem_offset_w = - q_smem->template advance_offset_by_row<4, UPCAST_STRIDE_Q>(q_smem_offset_w) - - 2 * KTraits::NUM_MMA_D_QK; + q_smem_offset_w = q_smem->template advance_offset_by_row( + q_smem_offset_w) - + COLUMN_RESET_OFFSET; } } } @@ -465,65 +634,82 @@ __device__ __forceinline__ void load_q_global_smem( template __device__ __forceinline__ void q_smem_inplace_apply_rotary( const uint32_t q_packed_idx, const uint32_t qo_len, const uint32_t kv_len, - const uint_fastdiv group_size, smem_t* q_smem, + const uint_fastdiv group_size, + smem_t* q_smem, uint32_t* q_smem_offset_r, float (*rope_freq)[4], const dim3 tid = threadIdx) { - if (get_warp_idx_kv(tid.z) == 0) { - constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; - const uint32_t lane_idx = tid.x; - uint32_t q_frag_local[2][4]; - static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D_QK must be a multiple of 4"); + if (get_warp_idx_kv(tid.z) != 0) return; + + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + constexpr uint32_t COL_ADVANCE_TO_NEXT = 16 / KTraits::HALF_ELEMS_PER_THREAD; + +#if defined(PLATFORM_HIP_DEVICE) + constexpr uint32_t COL_ADVANCE_TO_LAST_HALF = KTraits::NUM_MMA_D_QK * 2; +#elif defined(PLATFORM_CUDA_DEVICE) + constexpr uint32_t COL_ADVANCE_TO_LAST_HALF = KTraits::NUM_MMA_D_QK; +#endif + + const uint32_t lane_idx = tid.x; + uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; + static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D_QK must be a multiple of 4"); #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; +#if defined(PLATFORM_HIP_DEVICE) + const uint32_t seq_id = + q_packed_idx + kv_len * group_size - qo_len * group_size + mma_q * 16 + lane_idx % 16; +#elif defined(PLATFORM_CUDA_DEVICE) + const uint32_t seq_id = + q_packed_idx + kv_len * group_size - qo_len * group_size + mma_q * 16 + lane_idx / 4; +#endif #pragma unroll - for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; ++mma_di) { - q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); - uint32_t q_smem_offset_r_last_half = - q_smem->template advance_offset_by_column( - q_smem_offset_r_first_half, 0); - q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); - q_frag_apply_llama_rope( - (typename KTraits::DTypeQ*)q_frag_local[0], (typename KTraits::DTypeQ*)q_frag_local[1], - rope_freq[mma_di], - q_packed_idx + kv_len * group_size - qo_len * group_size + mma_q * 16 + lane_idx / 4, - group_size); - q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); - q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); - q_smem_offset_r_first_half = - q_smem->template advance_offset_by_column<2>(q_smem_offset_r_first_half, mma_di); - } - *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; + for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; ++mma_di) { + q_smem->load_fragment(q_smem_offset_r_first_half, q_frag_local[0]); + uint32_t q_smem_offset_r_last_half = + q_smem->template advance_offset_by_column( + q_smem_offset_r_first_half, 0); + q_smem->load_fragment(q_smem_offset_r_last_half, q_frag_local[1]); + q_frag_apply_llama_rope( + (typename KTraits::DTypeQ*)q_frag_local[0], (typename KTraits::DTypeQ*)q_frag_local[1], + rope_freq[mma_di], seq_id, group_size); + q_smem->store_fragment(q_smem_offset_r_last_half, q_frag_local[1]); + q_smem->store_fragment(q_smem_offset_r_first_half, q_frag_local[0]); + q_smem_offset_r_first_half = q_smem->template advance_offset_by_column( + q_smem_offset_r_first_half, mma_di); } - *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; } + *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; } template __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( const uint32_t q_packed_idx_base, const typename KTraits::IdType* q_rope_offset, - smem_t* q_smem, const uint_fastdiv group_size, - uint32_t* q_smem_offset_r, float (*rope_freq)[4], const dim3 tid = threadIdx) { + smem_t* q_smem, + const uint_fastdiv group_size, uint32_t* q_smem_offset_r, float (*rope_freq)[4], + const dim3 tid = threadIdx) { if (get_warp_idx_kv(tid.z) == 0) { constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; const uint32_t lane_idx = tid.x; - uint32_t q_frag_local[2][4]; + uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D_QK must be a multiple of 4"); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; #pragma unroll for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; ++mma_di) { - q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); + q_smem->load_fragment(q_smem_offset_r_first_half, q_frag_local[0]); uint32_t q_smem_offset_r_last_half = q_smem->template advance_offset_by_column( q_smem_offset_r_first_half, 0); - q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); - q_frag_apply_llama_rope_with_pos( + q_smem->load_fragment(q_smem_offset_r_last_half, q_frag_local[1]); + q_frag_apply_llama_rope_with_pos( (typename KTraits::DTypeQ*)q_frag_local[0], (typename KTraits::DTypeQ*)q_frag_local[1], - rope_freq[mma_di], q_packed_idx_base + mma_q * 16 + lane_idx / 4, group_size, - q_rope_offset); - q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); - q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); + rope_freq[mma_di], + q_packed_idx_base + mma_q * 16 + lane_idx / KTraits::THREADS_PER_BMATRIX_ROW_SET, + group_size, q_rope_offset); + q_smem->store_fragment(q_smem_offset_r_last_half, q_frag_local[1]); + q_smem->store_fragment(q_smem_offset_r_first_half, q_frag_local[0]); q_smem_offset_r_first_half = q_smem->template advance_offset_by_column<2>(q_smem_offset_r_first_half, mma_di); } @@ -535,12 +721,15 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( template __device__ __forceinline__ void k_smem_inplace_apply_rotary( - const uint32_t kv_idx_base, smem_t* k_smem, uint32_t* k_smem_offset_r, - float (*rope_freq)[4], const dim3 tid = threadIdx) { + const uint32_t kv_idx_base, + smem_t* k_smem, + uint32_t* k_smem_offset_r, float (*rope_freq)[4], const dim3 tid = threadIdx) { using DTypeKV = typename KTraits::DTypeKV; static_assert(sizeof(DTypeKV) == 2); constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; - uint32_t k_frag_local[2][4]; + constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + uint32_t k_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; const uint32_t lane_idx = tid.x; if constexpr (KTraits::NUM_MMA_D_QK == 4 && KTraits::NUM_WARPS_Q == 4) { static_assert(KTraits::NUM_WARPS_KV == 1); @@ -552,21 +741,21 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( // | 16-32 | warp_idx=2 | warp_idx=3 | warp_idx=2 | warp_idx=3 | static_assert(KTraits::NUM_MMA_KV % 2 == 0, "when NUM_MMA_D_QK == 4, NUM_MMA_KV must be a multiple of 2"); - uint32_t kv_idx = kv_idx_base + (warp_idx / 2) * 16 + lane_idx / 4; + uint32_t kv_idx = kv_idx_base + (warp_idx / 2) * 16 + lane_idx / THREADS_PER_BMATRIX_ROW_SET; *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) + (warp_idx / 2) * 16 * UPCAST_STRIDE_K; #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_MMA_KV / 2; ++i) { uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; uint32_t mma_di = (warp_idx % 2); - k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); + k_smem->load_fragment(k_smem_offset_r_first_half, k_frag_local[0]); uint32_t k_smem_offset_r_last_half = k_smem->template advance_offset_by_column<4>(k_smem_offset_r_first_half, 0); - k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); - k_frag_apply_llama_rope((DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], - rope_freq[mma_di], kv_idx); - k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); - k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); + k_smem->load_fragment(k_smem_offset_r_last_half, k_frag_local[1]); + k_frag_apply_llama_rope( + (DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], rope_freq[mma_di], kv_idx); + k_smem->store_fragment(k_smem_offset_r_last_half, k_frag_local[1]); + k_smem->store_fragment(k_smem_offset_r_first_half, k_frag_local[0]); *k_smem_offset_r += 32 * UPCAST_STRIDE_K; kv_idx += 32; } @@ -578,12 +767,12 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( static_assert(KTraits::NUM_MMA_D_QK % (2 * KTraits::NUM_WARPS_Q) == 0); // horizontal axis: y // vertical axis: z - // | (warp_idx_z, warp_idx_x) | 1-16 | 16-32 | 32-48 | 48-64 - // | ... | 1-16*NUM_MMA_KV | (0, 0) | (0, 1) | (0, 2) | - // (0, 3) | ... | 16*NUM_MMA_KV-32*NUM_MMA_KV | (1, 0) | (1, 1) | (1, - // 2) | (1, 3) | ... - // ... - uint32_t kv_idx = kv_idx_base + (warp_idx_z * KTraits::NUM_MMA_KV * 16) + lane_idx / 4; + // | (warp_idx_z, warp_idx_x) | 1-16 | 16-32 | 32-48 | 48-64 + // | ... | 1-16*NUM_MMA_KV | (0, 0) | (0, 1) | (0, 2) | (0, 3) + // | ... | 16*NUM_MMA_KV-32*NUM_MMA_KV | (1, 0) | (1, 1) | (1, 2) | (1, 3) + // | ... ... + uint32_t kv_idx = kv_idx_base + (warp_idx_z * KTraits::NUM_MMA_KV * 16) + + lane_idx / THREADS_PER_BMATRIX_ROW_SET; *k_smem_offset_r = *k_smem_offset_r ^ (0x2 * warp_idx_x); #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_MMA_KV; ++i) { @@ -591,15 +780,15 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( #pragma unroll for (uint32_t j = 0; j < KTraits::NUM_MMA_D_QK / (2 * KTraits::NUM_WARPS_Q); ++j) { uint32_t mma_di = warp_idx_x + j * KTraits::NUM_WARPS_Q; - k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); + k_smem->load_fragment(k_smem_offset_r_first_half, k_frag_local[0]); uint32_t k_smem_offset_r_last_half = k_smem->template advance_offset_by_column( k_smem_offset_r_first_half, 0); - k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); - k_frag_apply_llama_rope((DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], - rope_freq[mma_di], kv_idx); - k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); - k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); + k_smem->load_fragment(k_smem_offset_r_last_half, k_frag_local[1]); + k_frag_apply_llama_rope( + (DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], rope_freq[mma_di], kv_idx); + k_smem->store_fragment(k_smem_offset_r_last_half, k_frag_local[1]); + k_smem->store_fragment(k_smem_offset_r_first_half, k_frag_local[0]); k_smem_offset_r_first_half = k_smem->template advance_offset_by_column<2 * KTraits::NUM_WARPS_Q>( k_smem_offset_r_first_half, mma_di); @@ -614,28 +803,37 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( template __device__ __forceinline__ void compute_qk( - smem_t* q_smem, uint32_t* q_smem_offset_r, - smem_t* k_smem, uint32_t* k_smem_offset_r, - typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8]) { + smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD]) { constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; - uint32_t a_frag[KTraits::NUM_MMA_Q][4], b_frag[4]; + constexpr uint32_t QK_SMEM_COLUMN_ADVANCE = 16 / KTraits::HALF_ELEMS_PER_THREAD; + + uint32_t a_frag[KTraits::NUM_MMA_Q][KTraits::INT32_ELEMS_PER_THREAD], + b_frag[KTraits::INT32_ELEMS_PER_THREAD]; // compute q*k^T #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_QK; ++mma_d) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[mma_q]); + q_smem->load_fragment(*q_smem_offset_r, a_frag[mma_q]); *q_smem_offset_r = q_smem->template advance_offset_by_row<16, UPCAST_STRIDE_Q>(*q_smem_offset_r); } - *q_smem_offset_r = q_smem->template advance_offset_by_column<2>(*q_smem_offset_r, mma_d) - - KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + *q_smem_offset_r = + q_smem->template advance_offset_by_column(*q_smem_offset_r, mma_d) - + KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { +#if defined(PLATFORM_HIP_DEVICE) + static_assert(false, "FP8 support not yet implemented for CDNA3"); +#else uint32_t b_frag_f8[2]; if (mma_d % 2 == 0) { k_smem->ldmatrix_m8n8x4_left_half(*k_smem_offset_r, b_frag_f8); @@ -644,11 +842,13 @@ __device__ __forceinline__ void compute_qk( } b_frag_f8[0] = frag_layout_swizzle_16b_to_8b(b_frag_f8[0]); b_frag_f8[1] = frag_layout_swizzle_16b_to_8b(b_frag_f8[1]); - vec_cast::cast<8>( + vec_cast::template cast<8>( (typename KTraits::DTypeQ*)b_frag, (typename KTraits::DTypeKV*)b_frag_f8); +#endif } else { - k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + k_smem->load_fragment(*k_smem_offset_r, b_frag); } + *k_smem_offset_r = k_smem->template advance_offset_by_row<16, UPCAST_STRIDE_K>(*k_smem_offset_r); @@ -662,7 +862,12 @@ __device__ __forceinline__ void compute_qk( mma::mma_sync_m16n16k16_row_col_f16f16f32( s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } - } else if (std::is_same_v) { + } + + else if (std::is_same_v) { +#if defined(PLATFORM_HIP_DEVICE) + static_assert(false, "FP16 DTypeQKAccum not yet implemented for CDNA3"); +#else if (mma_d == 0) { mma::mma_sync_m16n16k16_row_col_f16f16f16( (uint32_t*)s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); @@ -670,22 +875,29 @@ __device__ __forceinline__ void compute_qk( mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } +#endif } } } if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { if (mma_d % 2 == 1) { - *k_smem_offset_r = - k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, mma_d / 2); + *k_smem_offset_r = k_smem->template advance_offset_by_column( + *k_smem_offset_r, mma_d / 2); } *k_smem_offset_r -= KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; } else { - *k_smem_offset_r = k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, mma_d) - + *k_smem_offset_r = k_smem->template advance_offset_by_column( + *k_smem_offset_r, mma_d) - KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; } } - *q_smem_offset_r -= KTraits::NUM_MMA_D_QK * 2; + *q_smem_offset_r -= KTraits::NUM_MMA_D_QK * QK_SMEM_COLUMN_ADVANCE; + +#if defined(PLATFORM_HIP_DEVICE) + *k_smem_offset_r -= KTraits::NUM_MMA_D_QK * (QK_SMEM_COLUMN_ADVANCE); +#elif defined(PLATFORM_CUDA_DEVICE) *k_smem_offset_r -= KTraits::NUM_MMA_D_QK * sizeof(typename KTraits::DTypeKV); +#endif } template @@ -693,17 +905,21 @@ __device__ __forceinline__ void logits_transform( const Params& params, typename KTraits::AttentionVariant variant, const uint32_t batch_idx, const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, - DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], const dim3 tid = threadIdx, - const uint32_t kv_head_idx = blockIdx.z) { + DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + const dim3 tid = threadIdx, const uint32_t kv_head_idx = blockIdx.z) { + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; + const uint32_t lane_idx = tid.x; - uint32_t q[KTraits::NUM_MMA_Q][2], r[KTraits::NUM_MMA_Q][2]; + uint32_t q[KTraits::NUM_MMA_Q][NAPTR], r[KTraits::NUM_MMA_Q][NAPTR]; float logits = 0., logitsTransformed = 0.; #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / 4 + 8 * j, q[mma_q][j], + for (uint32_t j = 0; j < NAPTR; ++j) { + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + LIS * j, q[mma_q][j], r[mma_q][j]); } } @@ -713,11 +929,17 @@ __device__ __forceinline__ void logits_transform( #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { +#if defined(PLATFORM_HIP_DEVICE) + const uint32_t q_idx = q[mma_q][reg_id % NAPTR]; + const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][reg_id % NAPTR]; + const uint32_t kv_idx = kv_idx_base + mma_kv * 16 + (lane_idx % TPR); +#else const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + 2 * (lane_idx % 4) + 8 * (reg_id / 4) + reg_id % 2; const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; +#endif #ifdef FP16_QK_REDUCTION_SUPPORTED if constexpr (std::is_same::value) { @@ -753,19 +975,22 @@ __device__ __forceinline__ void logits_mask( const Params& params, typename KTraits::AttentionVariant variant, const uint32_t batch_idx, const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, const uint32_t kv_len, const uint32_t chunk_end, const uint_fastdiv group_size, - typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], const dim3 tid = threadIdx, - const uint32_t kv_head_idx = blockIdx.z) { + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + const dim3 tid = threadIdx, const uint32_t kv_head_idx = blockIdx.z) { const uint32_t lane_idx = tid.x; constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - using DTypeQKAccum = typename KTraits::DTypeQKAccum; constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; - uint32_t q[NUM_MMA_Q][2], r[NUM_MMA_Q][2]; + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; + + uint32_t q[NUM_MMA_Q][NAPTR], r[NUM_MMA_Q][NAPTR]; #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / 4 + 8 * j, q[mma_q][j], + for (uint32_t j = 0; j < NAPTR; ++j) { + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + LIS * j, q[mma_q][j], r[mma_q][j]); } } @@ -775,11 +1000,17 @@ __device__ __forceinline__ void logits_mask( #pragma unroll for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { +#if defined(PLATFORM_HIP_DEVICE) + const uint32_t q_idx = q[mma_q][(reg_id % NAPTR)]; + const uint32_t kv_idx = kv_idx_base + mma_kv * 16 + (lane_idx % TPR); + const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % NAPTR)]; +#else const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + - 2 * (lane_idx % 4) + + 2 * (lane_idx % TPR) + 8 * (reg_id / 4) + reg_id % 2; const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; +#endif const bool mask = (!(MASK_MODE == MaskMode::kCausal ? (kv_idx + qo_len > kv_len + q_idx || (kv_idx >= chunk_end)) @@ -795,11 +1026,13 @@ __device__ __forceinline__ void logits_mask( template __device__ __forceinline__ void update_mdo_states( typename KTraits::AttentionVariant variant, - typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], - float (*o_frag)[KTraits::NUM_MMA_D_VO][8], typename KTraits::DTypeQKAccum (*m)[2], - float (*d)[2]) { + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { using DTypeQKAccum = typename KTraits::DTypeQKAccum; using AttentionVariant = typename KTraits::AttentionVariant; + constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = KTraits::NUM_ACCUM_ROWS_PER_THREAD; constexpr bool use_softmax = AttentionVariant::use_softmax; if constexpr (use_softmax) { @@ -808,8 +1041,31 @@ __device__ __forceinline__ void update_mdo_states( #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { float m_prev = m[mma_q][j]; +#if defined(PLATFORM_HIP_DEVICE) +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + m[mma_q][j] = max(m[mma_q][j], s_frag[mma_q][mma_kv][j]); + } + // Butterfly reduction across all threads in the band + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x8)); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x4)); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); + float o_scale = gpu_iface::math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); + d[mma_q][j] *= o_scale; + +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + o_frag[mma_q][mma_d][j] *= o_scale; + } +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + s_frag[mma_q][mma_kv][j] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j] * sm_scale - m[mma_q][j] * sm_scale); + } +#elif (PLATFORM_CUDA_DEVICE) #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { float m_local = @@ -817,10 +1073,10 @@ __device__ __forceinline__ void update_mdo_states( max(s_frag[mma_q][mma_kv][j * 2 + 4], s_frag[mma_q][mma_kv][j * 2 + 5])); m[mma_q][j] = max(m[mma_q][j], m_local); } - m[mma_q][j] = max(m[mma_q][j], math::shfl_xor_sync(m[mma_q][j], 0x2)); - m[mma_q][j] = max(m[mma_q][j], math::shfl_xor_sync(m[mma_q][j], 0x1)); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); - float o_scale = math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); + float o_scale = gpu_iface::math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); d[mma_q][j] *= o_scale; #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { @@ -831,18 +1087,22 @@ __device__ __forceinline__ void update_mdo_states( } #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - s_frag[mma_q][mma_kv][j * 2 + 0] = math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 0] = gpu_iface::math::ptx_exp2( s_frag[mma_q][mma_kv][j * 2 + 0] * sm_scale - m[mma_q][j] * sm_scale); - s_frag[mma_q][mma_kv][j * 2 + 1] = math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 1] = gpu_iface::math::ptx_exp2( s_frag[mma_q][mma_kv][j * 2 + 1] * sm_scale - m[mma_q][j] * sm_scale); - s_frag[mma_q][mma_kv][j * 2 + 4] = math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 4] = gpu_iface::math::ptx_exp2( s_frag[mma_q][mma_kv][j * 2 + 4] * sm_scale - m[mma_q][j] * sm_scale); - s_frag[mma_q][mma_kv][j * 2 + 5] = math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 5] = gpu_iface::math::ptx_exp2( s_frag[mma_q][mma_kv][j * 2 + 5] * sm_scale - m[mma_q][j] * sm_scale); } +#endif } } } else if constexpr (std::is_same_v) { +#if defined(PLATFORM_HIP_DEVICE) + static_assert(false, "Half precision accumulator not yet implemented for AMD"); +#else const half2 sm_scale = __float2half2_rn(variant.sm_scale_log2); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { @@ -852,18 +1112,19 @@ __device__ __forceinline__ void update_mdo_states( m_prev[j] = m[mma_q][j]; #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - half2 m_local = __hmax2(*(half2*)&s_frag[mma_q][mma_kv][j * 2], - *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4]); + half2 m_local = gpu_iface::math::hmax2(*(half2*)&s_frag[mma_q][mma_kv][j * 2], + *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4]); m[mma_q][j] = __hmax(m[mma_q][j], __hmax(m_local.x, m_local.y)); } } - *(half2*)&m[mma_q] = - __hmax2(*(half2*)&m[mma_q], math::shfl_xor_sync(*(half2*)&m[mma_q], 0x2)); - *(half2*)&m[mma_q] = - __hmax2(*(half2*)&m[mma_q], math::shfl_xor_sync(*(half2*)&m[mma_q], 0x1)); + *(half2*)&m[mma_q] = gpu_iface::math::hmax2( + *(half2*)&m[mma_q], gpu_iface::math::shfl_xor_sync(*(half2*)&m[mma_q], 0x2)); + *(half2*)&m[mma_q] = gpu_iface::math::hmax2( + *(half2*)&m[mma_q], gpu_iface::math::shfl_xor_sync(*(half2*)&m[mma_q], 0x1)); #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - float o_scale = math::ptx_exp2(float(m_prev[j] * sm_scale.x - m[mma_q][j] * sm_scale.x)); + float o_scale = + gpu_iface::math::ptx_exp2(float(m_prev[j] * sm_scale.x - m[mma_q][j] * sm_scale.x)); d[mma_q][j] *= o_scale; #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { @@ -875,36 +1136,55 @@ __device__ __forceinline__ void update_mdo_states( half2 m2 = make_half2(m[mma_q][j], m[mma_q][j]); #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - *(half2*)&s_frag[mma_q][mma_kv][j * 2] = - math::ptx_exp2(*(half2*)&s_frag[mma_q][mma_kv][j * 2] * sm_scale - m2 * sm_scale); - *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] = math::ptx_exp2( + *(half2*)&s_frag[mma_q][mma_kv][j * 2] = gpu_iface::math::ptx_exp2( + *(half2*)&s_frag[mma_q][mma_kv][j * 2] * sm_scale - m2 * sm_scale); + *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] = gpu_iface::math::ptx_exp2( *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] * sm_scale - m2 * sm_scale); } } } +#endif } } } template __device__ __forceinline__ void compute_sfm_v( - smem_t* v_smem, uint32_t* v_smem_offset_r, - typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], - float (*o_frag)[KTraits::NUM_MMA_D_VO][8], float (*d)[2]) { + smem_t* v_smem, + uint32_t* v_smem_offset_r, + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], const dim3 tid = threadIdx, + typename KTraits::DTypeQKAccum* qk_scratch = nullptr) { constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t INT32_ELEMS_PER_THREAD = KTraits::INT32_ELEMS_PER_THREAD; + constexpr uint32_t V_SMEM_COLUMN_ADVANCE = 16 / KTraits::HALF_ELEMS_PER_THREAD; + typename KTraits::DTypeQ s_frag_f16[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV] + [HALF_ELEMS_PER_THREAD]; - typename KTraits::DTypeQ s_frag_f16[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV][8]; if constexpr (std::is_same_v) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - vec_cast::cast<8>(s_frag_f16[mma_q][mma_kv], - s_frag[mma_q][mma_kv]); + vec_cast::template cast( + s_frag_f16[mma_q][mma_kv], s_frag[mma_q][mma_kv]); } } } +#if defined(PLATFORM_HIP_DEVICE) +// In-place transposition of the s_frag MMA tile to get the data into CDNA3 A-matrix layout. +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + mma::transpose_mma_tile(reinterpret_cast(s_frag_f16[mma_q][mma_kv])); + } + } +#endif + if constexpr (KTraits::AttentionVariant::use_softmax) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { @@ -913,7 +1193,12 @@ __device__ __forceinline__ void compute_sfm_v( if constexpr (std::is_same_v) { mma::m16k16_rowsum_f16f16f32(d[mma_q], s_frag_f16[mma_q][mma_kv]); } else { +#if defined(PLATFORM_HIP_DEVICE) + static_assert(!std::is_same_v, + "FP16 reduction path not implemented for CDNA3"); +#else mma::m16k16_rowsum_f16f16f32(d[mma_q], s_frag[mma_q][mma_kv]); +#endif } } } @@ -923,8 +1208,11 @@ __device__ __forceinline__ void compute_sfm_v( for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - uint32_t b_frag[4]; + uint32_t b_frag[INT32_ELEMS_PER_THREAD]; if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { +#if defined(PLATFORM_HIP_DEVICE) + static_assert(false, "FP8 V path not implemented for CDNA3 yet"); +#else uint32_t b_frag_f8[2]; if (mma_d % 2 == 0) { v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, b_frag_f8); @@ -933,11 +1221,16 @@ __device__ __forceinline__ void compute_sfm_v( } b_frag_f8[0] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]); b_frag_f8[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]); - vec_cast::cast<8>( + vec_cast::template cast<8>( (typename KTraits::DTypeQ*)b_frag, (typename KTraits::DTypeKV*)b_frag_f8); swap(b_frag[1], b_frag[2]); +#endif } else { +#if defined(PLATFORM_HIP_DEVICE) + v_smem->load_matrix_m16n16_trans(*v_smem_offset_r, b_frag); +#else v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); +#endif } #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { @@ -951,33 +1244,44 @@ __device__ __forceinline__ void compute_sfm_v( } if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { if (mma_d % 2 == 1) { - *v_smem_offset_r = - v_smem->template advance_offset_by_column<2>(*v_smem_offset_r, mma_d / 2); + *v_smem_offset_r = v_smem->template advance_offset_by_column( + *v_smem_offset_r, mma_d / 2); } } else { - *v_smem_offset_r = v_smem->template advance_offset_by_column<2>(*v_smem_offset_r, mma_d); + *v_smem_offset_r = v_smem->template advance_offset_by_column( + *v_smem_offset_r, mma_d); } } +#if defined(PLATFORM_CUDA_DEVICE) *v_smem_offset_r = v_smem->template advance_offset_by_row<16, UPCAST_STRIDE_V>(*v_smem_offset_r) - sizeof(typename KTraits::DTypeKV) * KTraits::NUM_MMA_D_VO; +#elif defined(PLATFORM_HIP_DEVICE) + *v_smem_offset_r = + v_smem->template advance_offset_by_row<16, UPCAST_STRIDE_V>(*v_smem_offset_r) - + V_SMEM_COLUMN_ADVANCE * KTraits::NUM_MMA_D_VO; +#endif } *v_smem_offset_r -= 16 * KTraits::NUM_MMA_KV * UPCAST_STRIDE_V; } template -__device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], - typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2]) { +__device__ __forceinline__ void normalize_d( + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { using AttentionVariant = typename KTraits::AttentionVariant; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + if constexpr (AttentionVariant::use_softmax) { - float d_rcp[KTraits::NUM_MMA_Q][2]; + float d_rcp[KTraits::NUM_MMA_Q][KTraits::NUM_ACCUM_ROWS_PER_THREAD]; // compute reciprocal of d #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - d_rcp[mma_q][j] = (m[mma_q][j] != typename KTraits::DTypeQKAccum(-math::inf)) - ? math::ptx_rcp(d[mma_q][j]) + for (uint32_t j = 0; j < KTraits::NUM_ACCUM_ROWS_PER_THREAD; ++j) { + d_rcp[mma_q][j] = (m[mma_q][j] != typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) + ? gpu_iface::math::ptx_rcp(d[mma_q][j]) : 0.f; } } @@ -987,9 +1291,14 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_V #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { +#if defined(PLATFORM_HIP_DEVICE) + o_frag[mma_q][mma_d][reg_id] = + o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][reg_id % NAPTR]; +#else o_frag[mma_q][mma_d][reg_id] = o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id % 4) / 2]; +#endif } } } @@ -997,14 +1306,15 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_V } template -__device__ __forceinline__ void finalize_m(typename KTraits::AttentionVariant variant, - typename KTraits::DTypeQKAccum (*m)[2]) { +__device__ __forceinline__ void finalize_m( + typename KTraits::AttentionVariant variant, + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { if constexpr (variant.use_softmax) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - if (m[mma_q][j] != typename KTraits::DTypeQKAccum(-math::inf)) { + for (uint32_t j = 0; j < KTraits::NUM_ACCUM_ROWS_PER_THREAD; ++j) { + if (m[mma_q][j] != typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) { m[mma_q][j] *= variant.sm_scale_log2; } } @@ -1013,29 +1323,41 @@ __device__ __forceinline__ void finalize_m(typename KTraits::AttentionVariant va } /*! - * \brief Synchronize the states of the MDO kernel across the threadblock along - * threadIdx.z. + * \brief Synchronize the states of the MDO kernel across the threadblock along threadIdx.z. */ template __device__ __forceinline__ void threadblock_sync_mdo_states( - float (*o_frag)[KTraits::NUM_MMA_D_VO][8], typename KTraits::SharedStorage* smem_storage, - typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2], const uint32_t warp_idx, + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::SharedStorage* smem_storage, + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], const uint32_t warp_idx, const uint32_t lane_idx, const dim3 tid = threadIdx) { + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t NARPT = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + + static_assert(WARP_SIZE % TPR == 0, "THREADS_PER_BMATRIX_ROW_SET must divide WARP_SIZE"); + constexpr uint32_t GROUPS_PER_WARP = WARP_SIZE / TPR; + const uint32_t lane_group_idx = lane_idx / TPR; + // only necessary when blockDim.z > 1 if constexpr (KTraits::NUM_WARPS_KV > 1) { float* smem_o = smem_storage->cta_sync_o_smem; float2* smem_md = smem_storage->cta_sync_md_smem; - // o: [num_warps, NUM_MMA_Q, NUM_MMA_D_VO, WARP_SIZE(32), 8] + // o: [num_warps, + // NUM_MMA_Q, + // NUM_MMA_D_VO, + // WARP_SIZE, + // HALF_ELEMS_PER_THREAD] // md: [num_warps, NUM_MMA_Q, 16, 2 (m/d)] #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - vec_t::memcpy( + vec_t::memcpy( smem_o + (((warp_idx * KTraits::NUM_MMA_Q + mma_q) * KTraits::NUM_MMA_D_VO + mma_d) * WARP_SIZE + lane_idx) * - 8, + KTraits::HALF_ELEMS_PER_THREAD, o_frag[mma_q][mma_d]); } } @@ -1044,9 +1366,9 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - smem_md[((warp_idx * KTraits::NUM_MMA_Q + mma_q) * 2 + j) * 8 + lane_idx / 4] = - make_float2(float(m[mma_q][j]), d[mma_q][j]); + for (uint32_t j = 0; j < NARPT; ++j) { + smem_md[((warp_idx * KTraits::NUM_MMA_Q + mma_q) * NARPT + j) * GROUPS_PER_WARP + + lane_group_idx] = make_float2(float(m[mma_q][j]), d[mma_q][j]); } } @@ -1054,22 +1376,23 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( __syncthreads(); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - float o_scale[2][KTraits::NUM_WARPS_KV]; + float o_scale[NARPT][KTraits::NUM_WARPS_KV]; #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - float m_new = -math::inf, d_new = 1.f; + for (uint32_t j = 0; j < NARPT; ++j) { + float m_new = -gpu_iface::math::inf, d_new = 1.f; #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * KTraits::NUM_MMA_Q + mma_q) * - 2 + + NARPT + j) * - 8 + - lane_idx / 4]; + GROUPS_PER_WARP + + lane_group_idx]; float m_prev = m_new, d_prev = d_new; m_new = max(m_new, md.x); - d_new = d_prev * math::ptx_exp2(m_prev - m_new) + md.y * math::ptx_exp2(md.x - m_new); + d_new = d_prev * gpu_iface::math::ptx_exp2(m_prev - m_new) + + md.y * gpu_iface::math::ptx_exp2(md.x - m_new); } #pragma unroll @@ -1077,12 +1400,12 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * KTraits::NUM_MMA_Q + mma_q) * - 2 + + NARPT + j) * - 8 + - lane_idx / 4]; + GROUPS_PER_WARP + + lane_group_idx]; float mi = md.x; - o_scale[j][i] = math::ptx_exp2(float(mi - m_new)); + o_scale[j][i] = gpu_iface::math::ptx_exp2(float(mi - m_new)); } m[mma_q][j] = typename KTraits::DTypeQKAccum(m_new); d[mma_q][j] = d_new; @@ -1090,11 +1413,11 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - vec_t o_new; + vec_t o_new; o_new.fill(0.f); #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { - vec_t oi; + vec_t oi; oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * KTraits::NUM_MMA_Q + mma_q) * @@ -1102,11 +1425,17 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( mma_d) * WARP_SIZE + lane_idx) * - 8); + KTraits::HALF_ELEMS_PER_THREAD); #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { +#if defined(PLATFORM_HIP_DEVICE) + // CDNA3: Direct mapping - each reg_id corresponds to one accumulator row + o_new[reg_id] += oi[reg_id] * o_scale[reg_id][i]; +#else + // CUDA: Grouped mapping - 2 elements per accumulator row o_new[reg_id] += oi[reg_id] * o_scale[(reg_id % 4) / 2][i]; +#endif } } o_new.store(o_frag[mma_q][mma_d]); @@ -1119,11 +1448,11 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - vec_t o_new; + vec_t o_new; o_new.fill(0.f); #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { - vec_t oi; + vec_t oi; oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * KTraits::NUM_MMA_Q + mma_q) * @@ -1131,9 +1460,9 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( mma_d) * WARP_SIZE + lane_idx) * - 8); + KTraits::HALF_ELEMS_PER_THREAD); #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { o_new[reg_id] += oi[reg_id]; } } @@ -1146,12 +1475,19 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( template __device__ __forceinline__ void write_o_reg_gmem( - float (*o_frag)[KTraits::NUM_MMA_D_VO][8], smem_t* o_smem, + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + smem_t* o_smem, typename KTraits::DTypeO* o_ptr_base, const uint32_t o_packed_idx_base, const uint32_t qo_upper_bound, const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv group_size, const dim3 tid = threadIdx) { using DTypeO = typename KTraits::DTypeO; constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + const uint32_t warp_idx_x = get_warp_idx_q(tid.y); const uint32_t lane_idx = tid.x; @@ -1159,19 +1495,24 @@ __device__ __forceinline__ void write_o_reg_gmem( #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { + for (uint32_t j = 0; j < NAPTR; ++j) { uint32_t q, r; - group_size.divmod(o_packed_idx_base + lane_idx / 4 + mma_q * 16 + j * 8, q, r); + group_size.divmod(o_packed_idx_base + lane_idx / TPR + mma_q * 16 + j * 8, q, r); const uint32_t o_idx = q; #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { if (o_idx < qo_upper_bound) { - *reinterpret_cast(o_ptr_base + q * o_stride_n + r * o_stride_h + mma_d * 16 + - (lane_idx % 4) * 2) = + auto base_addr = o_ptr_base + q * o_stride_n + r * o_stride_h + mma_d * 16; + auto col_offset = lane_idx % 16; +#if defined(PLATFORM_HIP_DEVICE) + *(base_addr + col_offset) = o_frag[mma_q][mma_d][j]; +#else + *reinterpret_cast(base_addr + col_offset * 2) = *reinterpret_cast(&o_frag[mma_q][mma_d][j * 2]); - *reinterpret_cast(o_ptr_base + q * o_stride_n + r * o_stride_h + mma_d * 16 + - 8 + (lane_idx % 4) * 2) = - *reinterpret_cast(&o_frag[mma_q][mma_d][4 + j * 2]); + + *reinterpret_cast(base_addr + 8 + col_offset * 2) = + *reinterpret_cast(&o_frag[mma_q][mma_d][$ + j * 2]); +#endif } } } @@ -1182,46 +1523,56 @@ __device__ __forceinline__ void write_o_reg_gmem( for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - uint32_t o_frag_f16[8 / 2]; - vec_cast::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_q][mma_d]); + uint32_t o_frag_f16[HALF_ELEMS_PER_THREAD / 2]; + vec_cast::template cast((DTypeO*)o_frag_f16, + o_frag[mma_q][mma_d]); #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED - uint32_t o_smem_offset_w = o_smem->get_permuted_offset( + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16, mma_d * 2 + lane_idx / 16); o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); #else - uint32_t o_smem_offset_w = o_smem->get_permuted_offset( - (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / 4, mma_d * 2); - ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( + (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / TPR, mma_d * 2); +#if defined(PLATFORM_HIP_DEVICE) + ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % TPR] = o_frag_f16[0]; + uint32_t offset_2 = o_smem_offset_w + 2; + ((uint32_t*)(o_smem->base + offset_2))[lane_idx % 16] = o_frag_f16[1]; +#else + ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % TPR] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = o_frag_f16[1]; - ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % TPR] = o_frag_f16[2]; ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = o_frag_f16[3]; +#endif #endif } } - uint32_t o_smem_offset_w = o_smem->get_permuted_offset( - warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( + warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / WARP_THREAD_COLS, + lane_idx % WARP_THREAD_COLS); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2 * 2; ++j) { uint32_t q, r; - group_size.divmod(o_packed_idx_base + lane_idx / 8 + mma_q * 16 + j * 4, q, r); + group_size.divmod(o_packed_idx_base + lane_idx / WARP_THREAD_COLS + mma_q * 16 + j * 4, q, + r); const uint32_t o_idx = q; - DTypeO* o_ptr = - o_ptr_base + q * o_stride_n + r * o_stride_h + (lane_idx % 8) * upcast_size(); + DTypeO* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h + + (lane_idx % WARP_THREAD_COLS) * upcast_size(); #pragma unroll for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) { if (o_idx < qo_upper_bound) { - o_smem->store_128b(o_smem_offset_w, o_ptr); + o_smem->store_vector(o_smem_offset_w, o_ptr); } - o_ptr += 8 * upcast_size(); - o_smem_offset_w = o_smem->template advance_offset_by_column<8>(o_smem_offset_w, mma_do); + o_ptr += WARP_THREAD_COLS * upcast_size(); + o_smem_offset_w = o_smem->template advance_offset_by_column( + o_smem_offset_w, mma_do); } o_smem_offset_w = o_smem->template advance_offset_by_row<4, UPCAST_STRIDE_O>(o_smem_offset_w) - @@ -1264,8 +1615,8 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( const uint32_t kv_head_idx = blockIdx.z, const uint32_t num_chunks = gridDim.y, const uint32_t num_kv_heads = gridDim.z) { using DTypeQ = typename Params::DTypeQ; -#if (__CUDA_ARCH__ < 800) - if constexpr (std::is_same_v) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + if constexpr (std::is_same_v) { FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); } else { #endif @@ -1292,6 +1643,13 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + [[maybe_unused]] constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + [[maybe_unused]] constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = + KTraits::NUM_ACCUM_ROWS_PER_THREAD; + [[maybe_unused]] constexpr uint32_t LOGITS_INDEX_STRIDE = KTraits::LOGITS_INDEX_STRIDE; + [[maybe_unused]] constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = + KTraits::THREADS_PER_BMATRIX_ROW_SET; + [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; DTypeQ* q = params.q; DTypeKV* k = params.k; @@ -1307,7 +1665,6 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( const uint32_t k_stride_h = params.k_stride_h; const uint32_t v_stride_n = params.v_stride_n; const uint32_t v_stride_h = params.v_stride_h; - const int32_t maybe_window_left = params.window_left; const uint_fastdiv& group_size = params.group_size; static_assert(sizeof(DTypeQ) == 2); @@ -1325,10 +1682,10 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( AttentionVariant variant(params, /*batch_idx=*/0, smem); const uint32_t window_left = variant.window_left; - DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; - alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; - DTypeQKAccum m[NUM_MMA_Q][2]; - float d[NUM_MMA_Q][2]; + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD]; + DTypeQKAccum m[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; + float d[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; float rope_freq[NUM_MMA_D_QK / 2][4]; if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { const float rope_rcp_scale = params.rope_rcp_scale; @@ -1340,29 +1697,30 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( // cooperative fetch q fragment from gmem to reg const uint32_t qo_packed_idx_base = (bx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; - smem_t qo_smem(smem_storage.q_smem); + smem_t qo_smem(smem_storage.q_smem); const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; DTypeQ* q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; DTypeO* o_ptr_base = partition_kv ? o + chunk_idx * o_stride_n + (kv_head_idx * group_size) * o_stride_h : o + (kv_head_idx * group_size) * o_stride_h; - uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( - get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); - load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem, tid); - cp_async::commit_group(); + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); + + memory::commit_group(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { - cp_async::wait_group<0>(); + memory::wait_group<0>(); block.sync(); q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq, tid); block.sync(); } - smem_t k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem); + smem_t k_smem(smem_storage.k_smem); + smem_t v_smem(smem_storage.v_smem); const uint32_t num_iterations = ceil_div(MASK_MODE == MaskMode::kCausal @@ -1387,19 +1745,28 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( DTypeKV* k_ptr = k + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * k_stride_n + - kv_head_idx * k_stride_h + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + kv_head_idx * k_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); DTypeKV* v_ptr = v + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * v_stride_n + - kv_head_idx * v_stride_h + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + kv_head_idx * v_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); +#if defined(PLATFORM_HIP_DEVICE) uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + - lane_idx % 8, - (lane_idx % 16) / 8), - v_smem_offset_r = v_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), - k_smem_offset_w = k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, (lane_idx / 16)); + uint32_t v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16); +#elif defined(PLATFORM_CUDA_DEVICE) + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, + (lane_idx % 16) / 8); + uint32_t v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + (lane_idx % 4) + 4 * (lane_idx / 16), + lane_idx / 4); +#endif + uint32_t k_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, lane_idx % KV_THR_LAYOUT_COL), v_smem_offset_w = v_smem.template get_permuted_offset( @@ -1407,16 +1774,15 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( lane_idx % KV_THR_LAYOUT_COL); produce_kv(k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); - cp_async::commit_group(); + memory::commit_group(); produce_kv(v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); - cp_async::commit_group(); + memory::commit_group(); #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { - cp_async::wait_group<1>(); + memory::wait_group<1>(); block.sync(); - if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { k_smem_inplace_apply_rotary(chunk_start + iter * CTA_TILE_KV, &k_smem, &k_smem_offset_r, rope_freq, tid); @@ -1425,12 +1791,11 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( // compute attention score compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); - + // Apply logits transformation logits_transform( params, variant, /*batch_idx=*/0, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); - // apply mask if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { logits_mask( @@ -1438,42 +1803,36 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); } - // compute m,d states in online softmax update_mdo_states(variant, s_frag, o_frag, m, d); - block.sync(); produce_kv( k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); - cp_async::commit_group(); - cp_async::wait_group<1>(); + memory::commit_group(); + memory::wait_group<1>(); block.sync(); // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); - + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d, tid); block.sync(); produce_kv( v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); - cp_async::commit_group(); + memory::commit_group(); } - cp_async::wait_group<0>(); + memory::wait_group<0>(); block.sync(); finalize_m(variant, m); // threadblock synchronization threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, lane_idx, tid); - // normalize d normalize_d(o_frag, m, d); - // write back write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, /*o_stride_n=*/ partition_kv ? num_chunks * o_stride_n : o_stride_n, /*o_stride_h=*/o_stride_h, group_size, tid); - // write lse if constexpr (variant.use_softmax) { if (lse != nullptr || partition_kv) { @@ -1481,18 +1840,20 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { uint32_t q, r; - group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16, q, r); + group_size.divmod(qo_packed_idx_base + lane_idx / THREADS_PER_BMATRIX_ROW_SET + + j * LOGITS_INDEX_STRIDE + mma_q * 16, + q, r); const uint32_t qo_head_idx = kv_head_idx * group_size + r; const uint32_t qo_idx = q; if (qo_idx < qo_len) { if (partition_kv) { lse[(qo_idx * num_chunks + chunk_idx) * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } else { lse[qo_idx * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } } } @@ -1500,7 +1861,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( } } } -#if (__CUDA_ARCH__ < 800) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) } #endif } @@ -1516,8 +1877,8 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCache template -cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::DTypeO* tmp, - cudaStream_t stream) { +gpuError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::DTypeO* tmp, + gpuStream_t stream) { using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; @@ -1528,8 +1889,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { std::ostringstream err_msg; err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be " - "greater than or equal " - "to qo_len, got kv_len" + "greater than or equal to qo_len, got kv_len" << kv_len << " and qo_len " << qo_len; FLASHINFER_ERROR(err_msg.str()); } @@ -1550,10 +1910,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D float>::type; int dev_id = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - int max_smem_per_sm = 0; - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( - &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); + FI_GPU_CALL(gpuGetDevice(&dev_id)); + int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); // we expect each sm execute two threadblocks const int num_ctas_per_sm = max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + @@ -1594,14 +1952,13 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D constexpr uint32_t num_threads = (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE; auto kernel = SinglePrefillWithKVCacheKernel; size_t smem_size = sizeof(typename KTraits::SharedStorage); - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FI_GPU_CALL( + gpuFuncSetAttribute(kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, smem_size)); int num_blocks_per_sm = 0; int num_sm = 0; - FLASHINFER_CUDA_CALL( - cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel, num_threads, smem_size)); + FI_GPU_CALL(gpuDeviceGetAttribute(&num_sm, gpuDevAttrMultiProcessorCount, dev_id)); + FI_GPU_CALL(gpuOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); uint32_t max_num_kv_chunks = (num_blocks_per_sm * num_sm) / (num_kv_heads * ceil_div(qo_len * group_size, CTA_TILE_Q)); uint32_t num_chunks; @@ -1617,9 +1974,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D params.partition_kv = false; void* args[] = {(void*)¶ms}; dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, num_kv_heads); - dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } else { // Use cooperative groups to increase occupancy params.partition_kv = true; @@ -1630,29 +1986,28 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D params.lse = tmp_lse; void* args[] = {(void*)¶ms}; dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), num_chunks, num_kv_heads); - dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); if constexpr (AttentionVariant::use_softmax) { - FLASHINFER_CUDA_CALL(MergeStates(tmp, tmp_lse, o, lse, num_chunks, qo_len, num_qo_heads, - HEAD_DIM_VO, stream)); + FI_GPU_CALL(MergeStates(tmp, tmp_lse, o, lse, num_chunks, qo_len, num_qo_heads, + HEAD_DIM_VO, stream)); } else { - FLASHINFER_CUDA_CALL( + FI_GPU_CALL( AttentionSum(tmp, o, num_chunks, qo_len, num_qo_heads, HEAD_DIM_VO, stream)); } } } }) }); - return cudaSuccess; + return gpuSuccess; } template __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKVCacheKernel( const __grid_constant__ Params params) { using DTypeQ = typename Params::DTypeQ; -#if (__CUDA_ARCH__ < 800) - if constexpr (std::is_same_v) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + if constexpr (std::is_same_v) { FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); } else { #endif @@ -1680,6 +2035,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; DTypeQ* q = params.q; IdType* request_indices = params.request_indices; @@ -1700,7 +2056,6 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV const uint32_t k_stride_h = params.k_stride_h; const uint32_t v_stride_n = params.v_stride_n; const uint32_t v_stride_h = params.v_stride_h; - const int32_t maybe_window_left = params.window_left; const uint_fastdiv& group_size = params.group_size; static_assert(sizeof(DTypeQ) == 2); @@ -1756,16 +2111,16 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV : o + o_indptr[request_idx] * o_stride_n + (kv_head_idx * group_size) * o_stride_h; - uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); load_q_global_smem(qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem, tid); - cp_async::commit_group(); + memory::commit_group(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { - cp_async::wait_group<0>(); + memory::wait_group<0>(); block.sync(); IdType* q_rope_offset = nullptr; @@ -1824,24 +2179,24 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV lane_idx / KV_THR_LAYOUT_COL) * k_stride_n + kv_head_idx * k_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); DTypeKV* v_ptr = v + (kv_indptr[request_idx] + chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * v_stride_n + kv_head_idx * v_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); produce_kv(k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); - cp_async::commit_group(); + memory::commit_group(); produce_kv(v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); - cp_async::commit_group(); + memory::commit_group(); #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { - cp_async::wait_group<1>(); + memory::wait_group<1>(); block.sync(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { @@ -1878,8 +2233,8 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV block.sync(); produce_kv( k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); - cp_async::commit_group(); - cp_async::wait_group<1>(); + memory::commit_group(); + memory::wait_group<1>(); block.sync(); // compute sfm*v @@ -1888,9 +2243,9 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV block.sync(); produce_kv( v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); - cp_async::commit_group(); + memory::commit_group(); } - cp_async::wait_group<0>(); + memory::wait_group<0>(); block.sync(); finalize_m(variant, m); @@ -1925,10 +2280,10 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV if (partition_kv) { lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * num_qo_heads + - qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + qo_head_idx] = gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } else { lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } } } @@ -1936,7 +2291,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV } } } -#if (__CUDA_ARCH__ < 800) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) } #endif } @@ -1947,8 +2302,8 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( const uint32_t bx = blockIdx.x, const uint32_t kv_head_idx = blockIdx.z, const uint32_t num_kv_heads = gridDim.z) { using DTypeQ = typename Params::DTypeQ; -#if (__CUDA_ARCH__ < 800) - if constexpr (std::is_same_v) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + if constexpr (std::is_same_v) { FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); } else { #endif @@ -1976,6 +2331,7 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; IdType* request_indices = params.request_indices; IdType* qo_tile_indices = params.qo_tile_indices; @@ -1988,7 +2344,6 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( bool* block_valid_mask = params.block_valid_mask; const paged_kv_t& paged_kv = params.paged_kv; const bool partition_kv = params.partition_kv; - const int32_t maybe_window_left = params.window_left; const uint_fastdiv& group_size = params.group_size; static_assert(sizeof(DTypeQ) == 2); @@ -2040,16 +2395,16 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( (kv_head_idx * group_size) * o_stride_h : o + o_indptr[request_idx] * o_stride_n + (kv_head_idx * group_size) * o_stride_h; - uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); load_q_global_smem(qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem, tid); - cp_async::commit_group(); + memory::commit_group(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { - cp_async::wait_group<0>(); + memory::wait_group<0>(); block.sync(); IdType* q_rope_offset = nullptr; if constexpr (has_maybe_q_rope_offset_v) { @@ -2095,14 +2450,14 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( page_iter, entry_idx); thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( page_iter, kv_head_idx, entry_idx, - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), last_indptr); + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), last_indptr); } page_produce_kv(k_smem, &k_smem_offset_w, paged_kv, 0, thr_local_kv_offset, chunk_size, tid); - cp_async::commit_group(); + memory::commit_group(); page_produce_kv(v_smem, &v_smem_offset_w, paged_kv, 0, thr_local_kv_offset, chunk_size, tid); - cp_async::commit_group(); + memory::commit_group(); const uint32_t num_iterations = ceil_div( (MASK_MODE == MaskMode::kCausal @@ -2138,9 +2493,9 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( page_iter, entry_idx); thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( page_iter, kv_head_idx, entry_idx, - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), last_indptr); + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), last_indptr); } - cp_async::wait_group<1>(); + memory::wait_group<1>(); block.sync(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { @@ -2173,8 +2528,8 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( block.sync(); page_produce_kv(k_smem, &k_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV, thr_local_kv_offset, chunk_size, tid); - cp_async::commit_group(); - cp_async::wait_group<1>(); + memory::commit_group(); + memory::wait_group<1>(); block.sync(); // compute sfm*v @@ -2183,9 +2538,9 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( block.sync(); page_produce_kv(v_smem, &v_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV, thr_local_kv_offset, chunk_size, tid); - cp_async::commit_group(); + memory::commit_group(); } - cp_async::wait_group<0>(); + memory::wait_group<0>(); block.sync(); finalize_m(variant, m); @@ -2220,10 +2575,10 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( if (partition_kv) { lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * num_qo_heads + - qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + qo_head_idx] = gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } else { lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } } } @@ -2231,7 +2586,7 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( } } } -#if (__CUDA_ARCH__ < 800) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) } #endif } @@ -2247,8 +2602,8 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVC template -cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, - float* tmp_s, cudaStream_t stream) { +gpuError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, + float* tmp_s, gpuStream_t stream) { using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; @@ -2263,11 +2618,11 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para // No request, skip // this won't happen in CUDAGraph mode because we fixed the // padded_batch_size - return cudaSuccess; + return gpuSuccess; } dim3 nblks(padded_batch_size, 1, num_kv_heads); - dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; using DTypeQKAccum = @@ -2275,10 +2630,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para float>::type; int dev_id = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - int max_smem_per_sm = 0; - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm, - cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); + FI_GPU_CALL(gpuGetDevice(&dev_id)); + int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); // we expect each sm execute two threadblocks const int num_ctas_per_sm = max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + @@ -2316,14 +2669,13 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para } else { size_t smem_size = sizeof(typename KTraits::SharedStorage); auto kernel = BatchPrefillWithRaggedKVCacheKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FI_GPU_CALL( + gpuFuncSetAttribute(kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, smem_size)); if (tmp_v == nullptr) { // do not partition kv params.partition_kv = false; void* args[] = {(void*)¶ms}; - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } else { // partition kv params.partition_kv = true; @@ -2332,28 +2684,27 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para params.o = tmp_v; params.lse = tmp_s; void* args[] = {(void*)¶ms}; - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); if constexpr (AttentionVariant::use_softmax) { - FLASHINFER_CUDA_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, params.merge_indptr, o, lse, params.max_total_num_rows, - params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); + FI_GPU_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); } else { - FLASHINFER_CUDA_CALL( - VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, params.max_total_num_rows, - params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); + FI_GPU_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); } } } }); - return cudaSuccess; + return gpuSuccess; } template -cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, - float* tmp_s, cudaStream_t stream) { +gpuError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, + float* tmp_s, gpuStream_t stream) { using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; @@ -2368,11 +2719,11 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Param // No request, skip // this won't happen in CUDAGraph mode because we fixed the // padded_batch_size - return cudaSuccess; + return gpuSuccess; } dim3 nblks(padded_batch_size, 1, num_kv_heads); - dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; @@ -2381,10 +2732,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Param float>::type; int dev_id = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - int max_smem_per_sm = 0; - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm, - cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); + FI_GPU_CALL(gpuGetDevice(&dev_id)); + int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); // we expect each sm execute two threadblocks const int num_ctas_per_sm = max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + @@ -2422,14 +2771,13 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Param } else { size_t smem_size = sizeof(typename KTraits::SharedStorage); auto kernel = BatchPrefillWithPagedKVCacheKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FI_GPU_CALL( + gpuFuncSetAttribute(kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, smem_size)); if (tmp_v == nullptr) { // do not partition kv params.partition_kv = false; void* args[] = {(void*)¶ms}; - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } else { params.partition_kv = true; auto o = params.o; @@ -2437,21 +2785,20 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Param params.o = tmp_v; params.lse = tmp_s; void* args[] = {(void*)¶ms}; - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); if constexpr (AttentionVariant::use_softmax) { - FLASHINFER_CUDA_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, params.merge_indptr, o, lse, params.max_total_num_rows, - params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); + FI_GPU_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); } else { - FLASHINFER_CUDA_CALL( - VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, params.max_total_num_rows, - params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); + FI_GPU_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); } } } }); - return cudaSuccess; + return gpuSuccess; } } // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h index f8bc7dd1d2..20bb87b0c2 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h @@ -67,33 +67,49 @@ __device__ void load_bmatrix_layout(T* arr, uint32_t* R, uint32_t dimY) { mma_impl::hip::load_quad_transposed_fragment<__half>(R, &arr[b_idx]); } -/// @brief Prints the four `half` values held in a thread's registers. -/// @tparam T The data type to interpret the registers as, must be `__half`. -/// @param R Pointer to the thread's registers (uint32_t[2]). -template -__device__ void print_register(uint32_t* R) { - static_assert(std::is_same_v, "Only supported for __half types"); - auto values = reinterpret_cast<__half*>(R); - printf("[%5.1f %5.1f %5.1f %5.1f]\n", __half2float(values[0]), __half2float(values[1]), - __half2float(values[2]), __half2float(values[3])); +/// @brief Prints a single MMA fragment (typically 4 or 8 elements). +/// @details Simple low-level printer for a single [ELEMS_PER_FRAGMENT] array. +/// Works for both A-matrix layout (row strip) and B-matrix layout (column strip). +/// @tparam T The data type of the fragment (e.g., float, __half). +/// @tparam ELEMS_PER_FRAGMENT The number of elements per fragment (typically 4 or 8). +/// @param values Pointer to the fragment values. +template +__device__ void debug_print_frag(const T* values) { + printf("["); + for (uint32_t i = 0; i < ELEMS_PER_FRAGMENT; ++i) { + printf("%10.6f", float(values[i])); + if (i < ELEMS_PER_FRAGMENT - 1) printf(", "); + } + printf("]"); } -/// @brief Prints the full s_frag array from a single thread's registers. -template -__device__ void print_s_frag_register(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAGMENT], - const dim3 tid = threadIdx) { - if (tid.x == 0 && tid.y == 0 && tid.z == 0) { - printf("Thread (0,0,0) s_frag registers:\n"); - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { - for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { - const T* values = s_frag[mma_q][mma_kv]; - printf(" frag[%u][%u]: [%8.3f, %8.3f, %8.3f, %8.3f]\n", mma_q, mma_kv, float(values[0]), - float(values[1]), float(values[2]), float(values[3])); +/// @brief Prints all MMA fragments from a thread's registers. +/// @details Loops over [NUM_MMA_ROW][NUM_MMA_COL][ELEMS_PER_FRAGMENT] array. +/// Works for all fragment types: Q, K (A-matrix), S, O (B-matrix), etc. +/// @tparam T The data type of the fragments (e.g., float, __half). +/// @tparam NUM_MMA_ROW Number of MMA tiles in the row dimension. +/// @tparam NUM_MMA_COL Number of MMA tiles in the column dimension. +/// @tparam ELEMS_PER_FRAGMENT The number of elements per fragment (typically 4 or 8). +/// @param frag The 3D fragment array from the thread's registers. +/// @param frag_name A string name to identify which fragment is being printed. +/// @param tidx The x component of the thread to print from. +/// @param tidy The y component of the thread to print from. +/// @param tidz The z component of the thread to print from. +template +__device__ void debug_print_frag_registers(const T (*frag)[NUM_MMA_COL][ELEMS_PER_FRAGMENT], + const char* frag_name = "frag", const uint32_t tidx = 0, + const uint32_t tidy = 0, const uint32_t tidz = 0) { + if (threadIdx.x == tidx && threadIdx.y == tidy && threadIdx.z == tidz) { + printf("Thread (%u,%u,%u) %s registers:\n", tidx, tidy, tidz, frag_name); + for (uint32_t mma_row = 0; mma_row < NUM_MMA_ROW; ++mma_row) { + for (uint32_t mma_col = 0; mma_col < NUM_MMA_COL; ++mma_col) { + printf(" %s[%u][%u]: ", frag_name, mma_row, mma_col); + debug_print_frag(frag[mma_row][mma_col]); + printf("\n"); } } printf("\n"); } - __syncthreads(); } /// @brief Prints a 2D LDS array to the console from a single thread. @@ -109,7 +125,7 @@ __device__ void print_lds_array(T* lds_array, uint32_t dimY, uint32_t dimX, printf("%s (%dx%d):\n", title, dimX, dimY); for (int y = 0; y < dimY; ++y) { for (int x = 0; x < dimX; ++x) { - printf("%5.1f ", __half2float(lds_array[y * dimX + x])); + printf("%5.1f,", __half2float(lds_array[y * dimX + x])); } printf("\n"); } @@ -122,10 +138,14 @@ __device__ void print_lds_array(T* lds_array, uint32_t dimY, uint32_t dimX, __device__ void print_lds_array(float* lds_array, uint32_t dimY, uint32_t dimX, const char* title = "LDS Array (float)") { if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { - printf("%s (%dx%d):\n", title, dimX, dimY); + printf("%s (%dx%d):\n", title, dimY, dimX); for (int y = 0; y < dimY; ++y) { for (int x = 0; x < dimX; ++x) { - printf("%8.3f ", lds_array[y * dimX + x]); + if (x == dimX - 1) { + printf("%10.6f", lds_array[y * dimX + x]); + } else { + printf("%10.6f ", lds_array[y * dimX + x]); + } } printf("\n"); } @@ -134,36 +154,106 @@ __device__ void print_lds_array(float* lds_array, uint32_t dimY, uint32_t dimX, __syncthreads(); } -/// @brief Materializes a 2D array of accumulator fragments from each thread's registers into a -/// 2D shared memory array. -/// @details This function is the inverse of the hardware's distribution of accumulator results. -/// It reconstructs a logical tile of the S = Q * K^T matrix in shared memory, -/// accounting for the partitioning of work across multiple warps. +/// @brief Prints a 1D LDS array of floats to the console from a single thread. +/// @details Useful for printing row-wise statistics like m or d values. +__device__ void print_lds_array_1d(float* lds_array, uint32_t dim, + const char* title = "LDS Array 1D (float)") { + if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + printf("%s (%d elements):\n", title, dim); + for (int i = 0; i < dim; ++i) { + printf("%10.6f ", lds_array[i]); + if ((i + 1) % 16 == 0) printf("\n"); // Line break every 16 elements + } + if (dim % 16 != 0) printf("\n"); + printf("\n"); + } + __syncthreads(); +} + +/// @brief Writes an A-matrix fragment from registers to shared memory. +/// @details In the A-matrix layout, each thread owns a row slice of a 16x16 fragment. +/// Thread T_(16*c + r) owns row r, columns [4*c : 4*c+3]. +/// This function reconstructs the full logical tile from distributed row fragments. /// @tparam T The data type of the fragments and LDS array (e.g., float or half). -/// @tparam NUM_MMA_Q The number of fragments along the Q dimension (rows) per thread. -/// @tparam NUM_MMA_KV The number of fragments along the KV dimension (columns) per thread. -/// @tparam ELEMS_PER_FRAGMENT The number of elements per fragment (typically 4 for float/half). -/// @param s_frag The 3D fragment array from the thread's registers. +/// @tparam NUM_MMA_ROW The number of fragments along the rows dimension per thread. +/// @tparam NUM_MMA_COL The number of fragments along the column dimension per thread. +/// @tparam ELEMS_PER_FRAGMENT The number of elements per fragment (typically 4). +/// @param frag The 3D fragment array from the thread's registers. /// @param lds_scratchpad Pointer to the shared memory array. -/// @param lds_stride The width/stride of the lds_scratchpad (e.g., CTA_TILE_KV). +/// @param lds_stride The width/stride of the lds_scratchpad. /// @param tid The thread's index within the block (threadIdx). -template -__device__ void write_s_frag_to_lds(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAGMENT], - T* lds_scratchpad, const uint32_t lds_stride, - const dim3 tid = threadIdx) { +template +__device__ void write_amatrix_frag_to_lds(const T (*frag)[NUM_MMA_COL][ELEMS_PER_FRAGMENT], + T* lds_scratchpad, const uint32_t lds_stride, + const dim3 tid = threadIdx) { const int lane_id = tid.x % 64; const int warp_idx_q = tid.y; // Calculate the starting row in the LDS tile for this entire warp. - const uint32_t warp_base_row = warp_idx_q * NUM_MMA_Q * MMA_COLS; + const uint32_t warp_base_row = warp_idx_q * NUM_MMA_ROW * MMA_COLS; #pragma unroll - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_row = 0; mma_row < NUM_MMA_ROW; ++mma_row) { #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + for (uint32_t mma_col = 0; mma_col < NUM_MMA_COL; ++mma_col) { + // -- Calculate the top-left corner of the 16x16 fragment this thread contributes to -- + const uint32_t frag_row_offset = mma_row * MMA_COLS; + const uint32_t frag_col_offset = mma_col * MMA_COLS; + + // -- Calculate the specific 1x4 element strip this thread writes within that fragment -- + // A-matrix layout: each thread handles a row strip. + // Thread lane_id = 16*c + r owns row r, columns [4*c : 4*c+3] + const uint32_t thread_row_in_frag = lane_id % MMA_COLS; + const uint32_t thread_start_col_in_frag = (lane_id / MMA_COLS) * MMA_ROWS_PER_THREAD; + + // -- Combine all offsets and write the 1x4 row strip to LDS -- + const T* values = frag[mma_row][mma_col]; + + // The row is fixed for all 4 elements in the strip. + const uint32_t final_row = warp_base_row + frag_row_offset + thread_row_in_frag; + + for (int i = 0; i < MMA_ROWS_PER_THREAD; ++i) { + // The column for this element is the thread's starting column + the element's index. + const uint32_t final_col = frag_col_offset + thread_start_col_in_frag + i; + + // Calculate destination and write the value. + T* dest = lds_scratchpad + final_row * lds_stride + final_col; + *dest = values[i]; + } + } + } +} + +/// @brief Generic function to materialize 2D fragment arrays into shared memory. +/// @details Works for both s_frag (attention scores) and o_frag (output accumulator). +/// Reconstructs a logical tile from distributed register fragments. +/// @tparam T The data type of the fragments and LDS array (e.g., float or half). +/// @tparam NUM_MMA_ROW The number of fragments along the rows dimension per thread. +/// @tparam NUM_MMA_COL The number of fragments along the column dimension per thread. +/// For s_frag: NUM_MMA_KV (KV sequence length) +/// For o_frag: NUM_MMA_D_VO (head dimension) +/// @tparam ELEMS_PER_FRAGMENT The number of elements per fragment (typically 4). +/// @param frag The 3D fragment array from the thread's registers. +/// @param lds_scratchpad Pointer to the shared memory array. +/// @param lds_stride The width/stride of the lds_scratchpad. +/// @param tid The thread's index within the block (threadIdx). +template +__device__ void write_frag_to_lds(const T (*frag)[NUM_MMA_COL][ELEMS_PER_FRAGMENT], + T* lds_scratchpad, const uint32_t lds_stride, + const dim3 tid = threadIdx) { + const int lane_id = tid.x % 64; + const int warp_idx_q = tid.y; + + // Calculate the starting row in the LDS tile for this entire warp. + const uint32_t warp_base_row = warp_idx_q * NUM_MMA_ROW * MMA_COLS; + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_ROW; ++mma_q) { +#pragma unroll + for (uint32_t mma_col = 0; mma_col < NUM_MMA_COL; ++mma_col) { // -- Calculate the top-left corner of the 16x16 fragment this thread contributes to -- const uint32_t frag_row_offset = mma_q * MMA_COLS; - const uint32_t frag_col_offset = mma_kv * MMA_COLS; + const uint32_t frag_col_offset = mma_col * MMA_COLS; // -- Calculate the specific 4x1 element strip this thread writes within that fragment -- // This logic correctly materializes a B-layout fragment (column strip). @@ -174,7 +264,7 @@ __device__ void write_s_frag_to_lds(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAG const uint32_t thread_col_in_frag = (lane_id % MMA_COLS); // -- Combine all offsets and write the 4x1 column strip to LDS -- - const T* values = s_frag[mma_q][mma_kv]; + const T* values = frag[mma_q][mma_col]; for (int i = 0; i < MMA_ROWS_PER_THREAD; ++i) { // The row for this element is the thread's starting row + the element's index in the strip. const uint32_t final_row = warp_base_row + frag_row_offset + thread_start_row_in_frag + i; @@ -189,13 +279,40 @@ __device__ void write_s_frag_to_lds(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAG } } +/// @brief Convenience wrapper for s_frag (attention scores). +template +__device__ void write_s_frag_to_lds(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAGMENT], + T* lds_scratchpad, const uint32_t lds_stride, + const dim3 tid = threadIdx) { + write_frag_to_lds(s_frag, lds_scratchpad, + lds_stride, tid); +} + +/// @brief Convenience wrapper for o_frag (output accumulator). +template +__device__ void write_o_frag_to_lds(const T (*o_frag)[NUM_MMA_D_VO][ELEMS_PER_FRAGMENT], + T* lds_scratchpad, const uint32_t lds_stride, + const dim3 tid = threadIdx) { + write_frag_to_lds(o_frag, lds_scratchpad, + lds_stride, tid); +} + +/// @brief Generic function to materialize 1D row-wise values (m or d) into shared memory. +/// @details Writes row-wise statistics (like max or denominator) from register arrays +/// to a 1D shared memory array, with one value per row. +/// @tparam T The data type (typically float). +/// @tparam NUM_MMA_Q The number of fragments along the Q dimension per thread. +/// @tparam NUM_ACCUM_ROWS_PER_THREAD The number of accumulator rows per thread (typically 4). +/// @param values The 2D array from registers [NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]. +/// @param lds_scratchpad Pointer to the 1D shared memory array. +/// @param tid The thread's index within the block (threadIdx). template -__device__ void write_m_new_to_lds(const T (*m)[NUM_ACCUM_ROWS_PER_THREAD], T* lds_scratchpad, - const dim3 tid = threadIdx) { +__device__ void write_row_values_to_lds(const T (*values)[NUM_ACCUM_ROWS_PER_THREAD], + T* lds_scratchpad, const dim3 tid = threadIdx) { const int lane_idx = tid.x; const int warp_idx_q = tid.y; - // Each group of 16 threads (a "row group") computes the max for 4 rows. + // Each group of 16 threads (a "row group") handles 4 rows. // We only need one thread from each group to write the results. if (lane_idx % MMA_COLS == 0) { // Base row index for this warp's Q tile @@ -212,13 +329,80 @@ __device__ void write_m_new_to_lds(const T (*m)[NUM_ACCUM_ROWS_PER_THREAD], T* l // e.g., lane 0 is in group 0, lane 16 is in group 1, etc. const uint32_t row_group_offset = (lane_idx / MMA_COLS) * NUM_ACCUM_ROWS_PER_THREAD; - // The final row index in the logical S matrix + // The final row index in the logical matrix const uint32_t final_row_idx = warp_base_row + mma_base_row + row_group_offset + j; - lds_scratchpad[final_row_idx] = m[mma_q][j]; + lds_scratchpad[final_row_idx] = values[mma_q][j]; } } } } +/// @brief Convenience wrapper for m (row-wise max) values. +template +__device__ void write_m_to_lds(const T (*m)[NUM_ACCUM_ROWS_PER_THREAD], T* lds_scratchpad, + const dim3 tid = threadIdx) { + write_row_values_to_lds(m, lds_scratchpad, tid); +} + +/// @brief Convenience wrapper for d (denominator) values. +template +__device__ void write_d_to_lds(const T (*d)[NUM_ACCUM_ROWS_PER_THREAD], T* lds_scratchpad, + const dim3 tid = threadIdx) { + write_row_values_to_lds(d, lds_scratchpad, tid); +} + +// Legacy alias for backward compatibility +template +__device__ void write_m_new_to_lds(const T (*m)[NUM_ACCUM_ROWS_PER_THREAD], T* lds_scratchpad, + const dim3 tid = threadIdx) { + write_m_to_lds(m, lds_scratchpad, tid); +} + +/// @brief Reads O matrix from global memory and prints it. +/// @details This function reads back the O matrix that was written to global memory +/// by write_o_reg_gmem and prints it for validation. +/// @tparam DTypeO The data type of the O matrix in global memory (typically __half). +/// @param o_ptr_base Pointer to the base of the O matrix in global memory. +/// @param o_stride_n Stride between consecutive queries (sequence dimension). +/// @param o_stride_h Stride between consecutive heads. +/// @param num_rows Number of rows to read (typically CTA_TILE_Q = 128). +/// @param num_cols Number of columns to read (typically HEAD_DIM = 64). +/// @param qo_packed_idx_base Base index for query packing (for GQA). +/// @param group_size Group size for grouped query attention. +/// @param kv_head_idx The KV head index. +/// @param header_text Optional header text to print before the matrix. +/// @param tid Thread index. +template +__device__ void debug_print_o_from_gmem(DTypeO* o_ptr_base, const uint32_t o_stride_n, + const uint32_t o_stride_h, const uint32_t num_rows, + const uint32_t num_cols, const uint32_t qo_packed_idx_base, + const uint_fastdiv group_size, const uint32_t kv_head_idx, + const char* header_text = "O from global memory", + const dim3 tid = threadIdx) { + if (tid.x == 0 && tid.y == 0 && tid.z == 0) { + printf("\n%s (%dx%d):\n", header_text, num_rows, num_cols); + + for (uint32_t row = 0; row < num_rows; ++row) { + // Compute the q and r indices for GQA + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + row, q, r); + const uint32_t qo_head_idx = kv_head_idx * group_size + r; + + // Print row values + for (uint32_t col = 0; col < num_cols; ++col) { + DTypeO* ptr = o_ptr_base + q * o_stride_n + qo_head_idx * o_stride_h + col; + float val = float(*ptr); + printf("%10.6f", val); + if (col < num_cols - 1) { + printf(" "); + } + } + printf("\n"); + } + printf("\n"); + } + __syncthreads(); +} + } // namespace flashinfer::gpu_iface::debug_utils::hip diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index 1316e454ce..eda659c107 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -111,6 +111,24 @@ __device__ __forceinline__ void transpose_inter_quad_fragments(uint32_t* R) { R[1] = __shfl_xor(R[1], xor_mask, 64); } +/// @brief Performs a full 16x16 in-register matrix transpose by combining intra-quad and +/// inter-quad fragment transpositions. +/// @details This function converts between A-matrix layout (row-major) and B/C/D-matrix layout +/// (column-major) for CDNA3 MFMA operations. It applies both +/// transpose_intra_quad_fragments and transpose_inter_quad_fragments to fully transpose a +/// 16x16 tile distributed across 64 threads. +/// +/// Use cases: +/// - B→A layout: Convert column slices to row slices (e.g., for rowsum where S must be +/// A-matrix) +/// - A→B layout: Convert row slices to column slices (if needed for other operations) +/// +/// @param R Pointer to 2 uint32_t registers containing the fragment data +__device__ __forceinline__ void transpose_mma_tile(uint32_t* R) { + transpose_intra_quad_fragments(R); + transpose_inter_quad_fragments(R); +} + // Single unified load function for all fragment types /// @param R [in] pointer to the register file to load the fragment into /// @param smem_ptr [in] pointer to the shared memory to load the fragment from @@ -141,8 +159,6 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u f16x4 A_fp16 = reinterpret_cast(A)[0]; f32x4 C_fp32 = reinterpret_cast(C)[0]; - // Perform MMA operation directly with fragments - if constexpr (std::is_same_v) { C_fp32 = __builtin_amdgcn_mfma_f32_16x16x16f16(A_fp16, B_fp16, C_fp32, 0, 0, 0); } else if constexpr (std::is_same_v) { @@ -191,7 +207,6 @@ __device__ __forceinline__ void load_quad_transposed_fragment(uint32_t* R, const template __device__ __forceinline__ void m16k16_rowsum_f16f16f32(float* d, DType* s_frag) { static_assert(sizeof(DType) == 2, "DType must be 16-bit type"); - transpose_intra_quad_fragments(reinterpret_cast(s_frag)); f16x4 a = reinterpret_cast(s_frag)[0]; f16x4 b = {f16(1.0f), f16(1.0f), f16(1.0f), f16(1.0f)}; f32x4 c = {d[0], d[1], d[2], d[3]}; diff --git a/libflashinfer/include/gpu_iface/mma_ops.hpp b/libflashinfer/include/gpu_iface/mma_ops.hpp index b015b116a7..57d73cb4c1 100644 --- a/libflashinfer/include/gpu_iface/mma_ops.hpp +++ b/libflashinfer/include/gpu_iface/mma_ops.hpp @@ -34,11 +34,14 @@ __device__ __forceinline__ void load_fragment(uint32_t* R, const T* smem_ptr) { } #if defined(PLATFORM_HIP_DEVICE) -template -__device__ __forceinline__ void load_quad_transposed_fragment(uint32_t* R, const T* smem_ptr) { - static_assert(std::is_same::value, - "Only __half is supported for load_quad_transposed_fragment"); - mma_detail::load_quad_transposed_fragment(R, smem_ptr); +/*! + * \brief Performs a full 16x16 in-register matrix transpose for CDNA3 MFMA tiles + * \details Converts between A-matrix layout (row-major) and B/C/D-matrix layout (column-major) + * by combining intra-quad and inter-quad fragment transpositions. + * \param R Pointer to 2 uint32_t registers containing the fragment data + */ +__device__ __forceinline__ void transpose_mma_tile(uint32_t* R) { + mma_detail::transpose_mma_tile(R); } #endif diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp new file mode 100644 index 0000000000..f02557d80c --- /dev/null +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -0,0 +1,619 @@ +// SPDX - FileCopyrightText : 2023 - 2025 Flashinfer team +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache 2.0 + +#include + +#include + +#include "../../utils/cpu_reference_hip.h" +#include "../../utils/flashinfer_prefill_ops.hip.h" +#include "../../utils/utils_hip.h" +#include "flashinfer/attention/generic/prefill.cuh" +#include "gpu_iface/gpu_runtime_compat.hpp" + +#define HIP_ENABLE_WARP_SYNC_BUILTINS 1 + +using namespace flashinfer; + +#if 0 +template +void _TestComputeQKCorrectness(size_t qo_len, + size_t kv_len, + size_t num_qo_heads, + size_t num_kv_heads, + size_t head_dim, + bool causal, + QKVLayout kv_layout, + PosEncodingMode pos_encoding_mode, + bool use_fp16_qk_reduction, + float rtol = 1e-3, + float atol = 1e-3) +{ + // std::cout << "Testing compute_qk: qo_len=" << qo_len + // << ", kv_len=" << kv_len << ", num_qo_heads=" << num_qo_heads + // << ", num_kv_heads=" << num_kv_heads << ", head_dim=" << head_dim + // << std::endl; + + // Generate test data (same as original test) + std::vector q(qo_len * num_qo_heads * head_dim); + std::vector k(kv_len * num_kv_heads * head_dim); + std::vector v(kv_len * num_kv_heads * + head_dim); // Still needed for params + std::vector o(qo_len * num_qo_heads * + head_dim); // Still needed for params + + utils::vec_normal_(q); + utils::vec_normal_(k); + utils::vec_normal_(v); // Initialize even though we won't use it + utils::vec_zero_(o); + + // GPU memory allocation (same pattern as original) + DTypeQ *q_d; + FI_GPU_CALL(hipMalloc(&q_d, q.size() * sizeof(DTypeQ))); + FI_GPU_CALL(hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), + hipMemcpyHostToDevice)); + + DTypeKV *k_d; + FI_GPU_CALL(hipMalloc(&k_d, k.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), + hipMemcpyHostToDevice)); + + DTypeKV *v_d; + FI_GPU_CALL(hipMalloc(&v_d, v.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), + hipMemcpyHostToDevice)); + + DTypeO *o_d; + FI_GPU_CALL(hipMalloc(&o_d, o.size() * sizeof(DTypeO))); + FI_GPU_CALL(hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), + hipMemcpyHostToDevice)); + + DTypeO *tmp_d; + FI_GPU_CALL(hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO))); + + // Allocate output buffer for QK scores + const size_t qk_output_size = qo_len * kv_len * num_qo_heads; + float *qk_scores_d; + FI_GPU_CALL(hipMalloc(&qk_scores_d, qk_output_size * sizeof(float))); + + // std::cout << "Debug: Kernel launch parameters:" << std::endl; + // std::cout << " qo_len=" << qo_len << ", kv_len=" << kv_len << std::endl; + // std::cout << " num_qo_heads=" << num_qo_heads + // << ", num_kv_heads=" << num_kv_heads << std::endl; + // std::cout << " head_dim=" << head_dim << std::endl; + // std::cout << " qk_output_size=" << qk_output_size << std::endl; + // std::cout << " Launching ComputeQKStubCaller..." << std::endl; + + // Call ComputeQKStubCaller instead of SinglePrefillWithKVCache + hipError_t status = + flashinfer::ComputeQKStubCaller( + q_d, k_d, v_d, o_d, tmp_d, + /*lse=*/nullptr, qk_scores_d, // Add qk_scores_d parameter + num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, + kv_layout, pos_encoding_mode, use_fp16_qk_reduction); + + // std::cout << " Kernel launch status: " << hipGetErrorString(status) + // << std::endl; + EXPECT_EQ(status, hipSuccess) + << "ComputeQKStubCaller kernel launch failed, error message: " + << hipGetErrorString(status); + + // Get GPU QK scores + std::vector gpu_qk_scores(qk_output_size); + FI_GPU_CALL(hipMemcpy(gpu_qk_scores.data(), qk_scores_d, + qk_output_size * sizeof(float), + hipMemcpyDeviceToHost)); + + // Check if GPU output is not empty + bool isEmpty = gpu_qk_scores.empty(); + EXPECT_EQ(isEmpty, false) << "GPU QK scores vector is empty"; + + // Compute CPU reference using our cpu_reference::compute_qk + std::vector cpu_qk_scores = cpu_reference::compute_qk( + q, k, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout); + + // Validate results (same pattern as original test) + size_t num_results_error_atol = 0; + bool nan_detected = false; + + // Compare element-by-element + size_t comparison_size = + std::min(gpu_qk_scores.size(), cpu_qk_scores.size()); + for (size_t i = 0; i < comparison_size; ++i) { + float gpu_val = gpu_qk_scores[i]; + float cpu_val = cpu_qk_scores[i]; + + if (isnan(gpu_val)) { + nan_detected = true; + } + + if (!utils::isclose(cpu_val, gpu_val, rtol, atol)) { + num_results_error_atol++; + if (num_results_error_atol <= 10) + { // Only print first 10 mismatches + std::cout << "QK mismatch at i=" << i << ", cpu_val=" << cpu_val + << ", gpu_val=" << gpu_val << std::endl; + } + } + } +#if 0 + // Calculate and report accuracy + float result_accuracy = + 1.0f - float(num_results_error_atol) / float(comparison_size); + + std::cout << "compute_qk results: num_qo_heads=" << num_qo_heads + << ", num_kv_heads=" << num_kv_heads << ", qo_len=" << qo_len + << ", kv_len=" << kv_len << ", head_dim=" << head_dim + << ", causal=" << causal + << ", kv_layout=" << QKVLayoutToString(kv_layout) + << ", pos_encoding_mode=" + << PosEncodingModeToString(pos_encoding_mode) + << ", qk_accuracy=" << result_accuracy << " (" + << num_results_error_atol << "/" << comparison_size + << " mismatches)" << std::endl; + + // Print some sample values for debugging + std::cout << "Sample QK scores (first 10): GPU vs CPU" << std::endl; + for (size_t i = 0; i < std::min(size_t(10), comparison_size); ++i) { + std::cout << " [" << i << "] GPU=" << gpu_qk_scores[i] + << ", CPU=" << cpu_qk_scores[i] << std::endl; + } + + // Assertions (slightly relaxed for initial testing) + EXPECT_GT(result_accuracy, 0.80) + << "compute_qk accuracy too low"; // Start with 80% + EXPECT_FALSE(nan_detected) << "NaN detected in compute_qk results"; +#endif + // Cleanup + FI_GPU_CALL(hipFree(q_d)); + FI_GPU_CALL(hipFree(k_d)); + FI_GPU_CALL(hipFree(v_d)); + FI_GPU_CALL(hipFree(o_d)); + FI_GPU_CALL(hipFree(tmp_d)); + FI_GPU_CALL(hipFree(qk_scores_d)); +} + +#endif + +template +void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t num_qo_heads, + size_t num_kv_heads, size_t head_dim, bool causal, + QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, + bool use_fp16_qk_reduction, uint32_t debug_thread_id, + uint32_t debug_warp_id, float rtol = 1e-3, + float atol = 1e-3) { + std::vector q(qo_len * num_qo_heads * head_dim); + std::vector k(kv_len * num_kv_heads * head_dim); + std::vector v(kv_len * num_kv_heads * head_dim); + std::vector o(qo_len * num_qo_heads * head_dim); + + utils::vec_normal_(q); + utils::vec_normal_(k); + utils::vec_normal_(v); + // utils::vec_lexicographic_(q); + // utils::vec_lexicographic_(k); + // utils::vec_fill_(v, __float2half(1.0f)); + utils::vec_zero_(o); + + DTypeQ* q_d; + FI_GPU_CALL(hipMalloc(&q_d, q.size() * sizeof(DTypeQ))); + FI_GPU_CALL(hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), hipMemcpyHostToDevice)); + + DTypeKV* k_d; + FI_GPU_CALL(hipMalloc(&k_d, k.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), hipMemcpyHostToDevice)); + + DTypeKV* v_d; + FI_GPU_CALL(hipMalloc(&v_d, v.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), hipMemcpyHostToDevice)); + + DTypeO* o_d; + FI_GPU_CALL(hipMalloc(&o_d, o.size() * sizeof(DTypeO))); + FI_GPU_CALL(hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), hipMemcpyHostToDevice)); + + DTypeO* tmp_d; + FI_GPU_CALL(hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO))); + + hipError_t status = flashinfer::SinglePrefillWithKVCache( + q_d, k_d, v_d, o_d, tmp_d, + /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, kv_layout, + pos_encoding_mode, use_fp16_qk_reduction, debug_thread_id, debug_warp_id); + + EXPECT_EQ(status, hipSuccess) << "SinglePrefillWithKVCache kernel launch failed, error message: " + << hipGetErrorString(status); + + std::vector o_h(o.size()); + FI_GPU_CALL(hipMemcpy(o_h.data(), o_d, o_h.size() * sizeof(DTypeO), hipMemcpyDeviceToHost)); + + // Print the first 10 elements of the output vector for debugging + // std::cout << "Output vector (first 10 elements):"; + // std::cout << "[" << std::endl; + // for (int i = 0; i < 10; ++i) { + // std::cout << fi::con::explicit_casting(o_h[i]) << " + // "; + // } + // std::cout << "]" << std::endl; + + bool isEmpty = o_h.empty(); + EXPECT_EQ(isEmpty, false) << "Output vector is empty"; + + std::vector att_out; + std::vector o_ref = cpu_reference::single_mha( + q, k, v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, causal, kv_layout, + pos_encoding_mode, /*logits_soft_cap=*/8.0f, /*rope_scale=*/1.f, /*rope_theta=*/1e4, + /*use_soft_cap=*/true); + size_t num_results_error_atol = 0; + bool nan_detected = false; + + for (size_t i = 0; i < o_ref.size(); ++i) { + float o_h_val = fi::con::explicit_casting(o_h[i]); + float o_ref_val = fi::con::explicit_casting(o_ref[i]); + + if (isnan(o_h_val)) { + nan_detected = true; + } + + num_results_error_atol += (!utils::isclose(o_ref_val, o_h_val, rtol, atol)); + // if (!utils::isclose(o_ref_val, o_h_val, rtol, atol)) { + // std::cout << "i=" << i << ", o_ref[i]=" << o_ref_val + // << ", o_h[i]=" << o_h_val << std::endl; + // } + } + // std::cout<<"Printing att_out vector:\n"; + // for(auto i: att_out) { + // std::cout << i << "\n"; + // } +#if 1 + float result_accuracy = 1. - float(num_results_error_atol) / float(o_ref.size()); + std::cout << "num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads + << ", qo_len=" << qo_len << ", kv_len=" << kv_len << ", head_dim=" << head_dim + << ", causal=" << causal << ", kv_layout=" << QKVLayoutToString(kv_layout) + << ", pos_encoding_mode=" << PosEncodingModeToString(pos_encoding_mode) + << ", result_accuracy=" << result_accuracy << std::endl; + + EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; + EXPECT_FALSE(nan_detected) << "Nan detected in the result."; +#endif + FI_GPU_CALL(hipFree(q_d)); + FI_GPU_CALL(hipFree(k_d)); + FI_GPU_CALL(hipFree(v_d)); + FI_GPU_CALL(hipFree(o_d)); + FI_GPU_CALL(hipFree(tmp_d)); +} + +// template +// void TestSinglePrefillKernelLongContextCorrectness(bool +// use_fp16_qk_reduction) +// { +// for (size_t qo_len : {1, 31, 63, 127}) { +// for (size_t kv_len : {31717}) { +// for (size_t num_heads : {1}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0, 1}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness< +// DTypeIn, DTypeIn, DTypeO>( +// qo_len, kv_len, num_heads, num_heads, +// head_dim, causal, QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); +// } +// } +// } +// } +// } +// } +// } +// } +//*********************************************************************** +// The following tests are disabled because we dont support fp8 <-> float +// conversions + +// template +// void TestSinglePrefillFP8KernelLongContextCorrectness(bool +// use_fp16_qk_reduction) { +// for (size_t qo_len : {1, 31, 63, 127}) { +// for (size_t kv_len : {31717}) { +// for (size_t num_heads : {1}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness<__half, DTypeKV, __half>( +// qo_len, kv_len, num_heads, num_heads, head_dim, causal, +// QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); +// } +// } +// } +// } +// } +// } +// } +// } + +// template +// void TestSinglePrefillKernelShortContextCorrectness(bool +// use_fp16_qk_reduction) +// { +// float rtol = std::is_same::value ? 1e-2 : 1e-3; +// float atol = std::is_same::value ? 1e-2 : 1e-3; +// for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { +// for (size_t num_qo_heads : {32}) { +// for (size_t num_kv_heads : {4, 8, 32}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0, 1}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness< +// DTypeIn, DTypeIn, DTypeO>( +// qkv_len, qkv_len, num_qo_heads, +// num_kv_heads, head_dim, causal, +// QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction, rtol, atol); +// } +// } +// } +// } +// } +// } +// } +// } + +//*********************************************************************** +// The following tests are disabled because we dont support fp8 <-> float +// conversions + +// template +// void TestSinglePrefillFP8KernelShortContextCorrectness(bool +// use_fp16_qk_reduction) { +// float rtol = 1e-3; +// float atol = 1e-3; +// for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { +// for (size_t num_qo_heads : {32}) { +// for (size_t num_kv_heads : {4, 8, 32}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness<__half, DTypeKV, __half>( +// qkv_len, qkv_len, num_qo_heads, num_kv_heads, head_dim, +// causal, QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction, rtol, atol); +// } +// } +// } +// } +// } +// } +// } +// } + +// template +// void TestSinglePrefillKernelCorrectness(bool use_fp16_qk_reduction) +// { +// for (size_t qo_len : {399, 400, 401}) { +// for (size_t kv_len : {533, 534, 535}) { +// for (size_t num_heads : {12}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0, 1}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness< +// DTypeIn, DTypeIn, DTypeO>( +// qo_len, kv_len, num_heads, num_heads, +// head_dim, causal, QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); +// } +// } +// } +// } +// } +// } +// } +// } + +// template +// void TestSinglePrefillFP8KernelCorrectness(bool use_fp16_qk_reduction) +// { +// for (size_t qo_len : {399, 400, 401}) { +// for (size_t kv_len : {533, 534, 535}) { +// for (size_t num_heads : {12}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness< +// __half, DTypeKV, __half>( +// qo_len, kv_len, num_heads, num_heads, +// head_dim, causal, QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); +// } +// } +// } +// } +// } +// } +// } +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessFP16) +// { +// TestSinglePrefillKernelLongContextCorrectness<__half, __half>(false); +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessFP16QKHalfAccum) +// { +// TestSinglePrefillKernelLongContextCorrectness<__half, __half>(true); +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessFP16) +// { +// TestSinglePrefillKernelShortContextCorrectness<__half, __half>(false); +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessFP16QKHalfAccum) +// { +// TestSinglePrefillKernelShortContextCorrectness<__half, __half>(true); +// } + +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestFP16) +// { +// TestSinglePrefillKernelCorrectness<__half, __half>(false); +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelCorrectnessTestFP16QKHalfAccum) +// { +// TestSinglePrefillKernelCorrectness<__half, __half>(true); +// } + +// #ifdef FLASHINFER_ENABLE_BF16 +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessBF16) +// { +// TestSinglePrefillKernelLongContextCorrectness<__hip_bfloat16, +// __hip_bfloat16>( +// false); +// } +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessBF16) +// { +// TestSinglePrefillKernelShortContextCorrectness<__hip_bfloat16, +// __hip_bfloat16>( +// false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestBF16) +// { +// TestSinglePrefillKernelCorrectness<__hip_bfloat16, +// __hip_bfloat16>(false); +// } +// #endif + +//*********************************************************************** +// The following tests are disabled because we dont support fp8 <-> float +// conversions + +// #ifdef FLASHINFER_ENABLE_FP8_E4M3 +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessE4M3) { +// TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e4m3>(false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE4M3) { +// TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e4m3>(false); +// } +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessE4M3) { +// TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e4m3>(false); +// } +// #endif + +// #ifdef FLASHINFER_ENABLE_FP8_E5M2 +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessE5M2) { +// TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e5m2>(false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE5M2) { +// TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e5m2>(false); +// } +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessE5M2) { +// TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e5m2>(false); +// } +// #endif + +int main(int argc, char** argv) { + // ::testing::InitGoogleTest(&argc, argv); + // return RUN_ALL_TESTS(); + using DTypeIn = __half; + using DTypeO = __half; + uint32_t debug_thread_id = 0; + uint32_t debug_warp_id = 0; + bool use_fp16_qk_reduction = false; + size_t qo_len = 128; + size_t kv_len = 128; + size_t num_heads = 1; + size_t head_dim = 64; + bool causal = false; + size_t pos_encoding_mode = 0; // 1 == kRopeLLama + size_t kv_layout = 0; + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "--thread" && i + 1 < argc) { + debug_thread_id = std::stoi(argv[++i]); + std::cout << "Debug thread ID set to: " << debug_thread_id << std::endl; + } else if (arg == "--warp" && i + 1 < argc) { + debug_warp_id = std::stoi(argv[++i]); + std::cout << "Debug warp ID set to: " << debug_warp_id << std::endl; + } else if (arg == "--qo_len" && i + 1 < argc) { + qo_len = std::stoi(argv[++i]); + } else if (arg == "--kv_len" && i + 1 < argc) { + kv_len = std::stoi(argv[++i]); + } else if (arg == "--heads" && i + 1 < argc) { + num_heads = std::stoi(argv[++i]); + } else if (arg == "--help") { + std::cout << "Usage: " << argv[0] << " [options]\n" + << "Options:\n" + << " --thread Debug thread ID (0-255 for 4 warps)\n" + << " --warp Debug warp ID (0-3 for 4 warps)\n" + << " --qo_len Query/Output length (default: 128)\n" + << " --kv_len Key/Value length (default: 128)\n" + << " --heads Number of heads (default: 1)\n" + << " --help Show this help message\n"; + return 0; + } + } + + _TestSinglePrefillKernelCorrectness( + qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), + PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction, debug_thread_id, debug_warp_id); +} + +// int main(int argc, char **argv) +// { +// // Test compute_qk first with simple parameters +// std::cout << "=== Testing compute_qk function ===" << std::endl; +// using DTypeIn = __half; +// using DTypeO = __half; +// bool use_fp16_qk_reduction = false; +// bool causal = false; +// size_t pos_encoding_mode = 0; +// size_t kv_layout = 0; + +// // Start with small dimensions for easier debugging +// _TestComputeQKCorrectness( +// 16, // qo_len - small for debugging +// 32, // kv_len +// 1, // num_qo_heads - single head +// 1, // num_kv_heads - single head +// 64, // head_dim +// causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); + +// std::cout << "\n=== Testing full single prefill ===" << std::endl; +// // Your existing test... +// size_t qo_len = 399; +// size_t kv_len = 533; +// size_t num_heads = 1; +// size_t head_dim = 64; + +// _TestSinglePrefillKernelCorrectness( +// qo_len, kv_len, num_heads, num_heads, head_dim, causal, +// QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); + +// return 0; +// } diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index 9703f30657..6cc5809754 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -8,6 +8,9 @@ #include #include +#include +#include + #include "flashinfer/attention/generic/page.cuh" #include "flashinfer/attention/generic/pos_enc.cuh" #include "flashinfer/exception.h" @@ -64,7 +67,7 @@ inline std::vector apply_llama_rope(const T* input, size_t D, size_t offs if (std::is_same_v) rst[k] = cos * fi::con::explicit_casting(input[k]) + sin * permuted_input[k]; } - return std::move(rst); + return rst; } template @@ -73,7 +76,8 @@ std::vector single_mha(const std::vector& q, const std::vect size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kHND, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, - float rope_scale = 1.f, float rope_theta = 1e4) { + float logits_soft_cap = 8.0f, float rope_scale = 1.f, + float rope_theta = 1e4, bool use_soft_cap = false) { assert(qo_len <= kv_len); assert(num_qo_heads % num_kv_heads == 0); float sm_scale = 1.f / std::sqrt(float(head_dim)); @@ -81,6 +85,7 @@ std::vector single_mha(const std::vector& q, const std::vect std::vector att(kv_len); std::vector q_rotary_local(head_dim); std::vector k_rotary_local(head_dim); + DISPATCH_head_dim(head_dim, HEAD_DIM, { tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { @@ -120,6 +125,11 @@ std::vector single_mha(const std::vector& q, const std::vect FLASHINFER_ERROR(err_msg.str()); } } + // apply soft cap if enabled + if (use_soft_cap) { + float soft_cap_pre_tanh_scale = sm_scale / logits_soft_cap; + att[kv_idx] = std::tanh(att[kv_idx] / sm_scale * soft_cap_pre_tanh_scale); + } // apply mask if (causal && kv_idx > kv_len + q_idx - qo_len) { att[kv_idx] = -5e4; @@ -132,7 +142,6 @@ std::vector single_mha(const std::vector& q, const std::vect att[kv_idx] = std::exp(att[kv_idx] - max_val); denom += att[kv_idx]; } - // divide by denom for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { att[kv_idx] /= denom; diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops.hip.h new file mode 100644 index 0000000000..b5f6bff293 --- /dev/null +++ b/libflashinfer/utils/flashinfer_prefill_ops.hip.h @@ -0,0 +1,122 @@ +// SPDX - FileCopyrightText : 2023 - 2025 Flashinfer team +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache 2.0 + +#pragma once + +#include + +#include "flashinfer/attention/generic/allocator.h" +#include "flashinfer/attention/generic/default_prefill_params.cuh" +#include "flashinfer/attention/generic/exception.h" +#include "flashinfer/attention/generic/prefill.cuh" +#include "flashinfer/attention/generic/scheduler.cuh" +#include "flashinfer/attention/generic/variants.cuh" +#include "gpu_iface/enums.hpp" +#include "gpu_iface/layout.cuh" +#include "utils_hip.h" + +namespace flashinfer { + +template +hipError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::DTypeO* tmp, + hipStream_t stream); + +template +hipError_t SinglePrefillWithKVCacheCustomMask( + DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeO* o, DTypeO* tmp, float* lse, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, + uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool use_fp16_qk_reduction = false, + std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, + float rope_theta = 1e4, hipStream_t stream = nullptr) { + const float sm_scale = 1.f; + auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = + get_qkv_strides(kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); + DISPATCH_use_fp16_qk_reduction( + static_cast(use_fp16_qk_reduction), USE_FP16_QK_REDUCTION, + {DISPATCH_head_dim( + head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + using Params = SinglePrefillParams; + using AttentionVariant = DefaultAttention< + /*use_custom_mask=*/true, /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; + Params params(q, k, v, custom_mask, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, + qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, head_dim, + /*window_left=*/-1, + /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta); + return SinglePrefillWithKVCacheDispatched(params, tmp, stream); + })})}); + return hipSuccess; +} + +/*! + * \brief FlashAttention prefill hip function for a single request. + * \tparam DTypeIn The data type of input + * \tparam DTypeO The data type of output + * \param q The query tensor. + * \param k The key tensor. + * \param v The value tensor. + * \param o The output tensor. + * \param tmp The temporary storage (only used for cooperative kernel). + * \param lse The logsumexp values. + * \param num_qo_heads The number of query and output heads. + * \param num_kv_heads The number of key and value heads. + * \param qo_len The length of query and output. + * \param kv_len The length of key and value. + * \param head_dim The dimension of each head. + * \param causal Whether to use causal attention. + * \param kv_layout The layout of input and output. + * \param pos_encoding_mode The positional encoding mode. + * \param use_fp16_qk_reduction Whether to allow accumulating q*k^T with fp16. + * \param rope_scale The scaling factor used in RoPE interpolation. + * \param rope_theta The theta used in RoPE. + * \param stream The hip stream to execute the kernel on. + * \return status Indicates whether hip calls are successful + */ +template +hipError_t SinglePrefillWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeO* o, DTypeO* tmp, + float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, + bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + bool use_fp16_qk_reduction = false, + std::optional maybe_sm_scale = std::nullopt, + float rope_scale = 1.f, float rope_theta = 1e4, + hipStream_t stream = nullptr) { + const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = + get_qkv_strides(kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); + DISPATCH_use_fp16_qk_reduction( + static_cast(use_fp16_qk_reduction), USE_FP16_QK_REDUCTION, + {DISPATCH_mask_mode( + mask_mode, MASK_MODE, + {DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + using Params = SinglePrefillParams; + using AttentionVariant = DefaultAttention< + /*use_custom_mask=*/(MASK_MODE == MaskMode::kCustom), + /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/true, /*use_alibi=*/false>; + Params params(q, k, v, /*custom_mask=*/nullptr, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, + qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, head_dim, + /*window_left=*/-1, + /*logits_soft_cap=*/8.f, sm_scale, rope_scale, rope_theta); + return SinglePrefillWithKVCacheDispatched(params, tmp, + stream); + })})})}); + return hipSuccess; +} + +} // namespace flashinfer diff --git a/libflashinfer/utils/utils_hip.h b/libflashinfer/utils/utils_hip.h index bed4b1d533..de6b6956c8 100644 --- a/libflashinfer/utils/utils_hip.h +++ b/libflashinfer/utils/utils_hip.h @@ -51,10 +51,41 @@ namespace utils { +enum Predicate { + Linear, + Ones, + Zeros, +}; + +template +void generate_data(std::vector& vec) { + if constexpr (Pred == Predicate::Linear) { + assert(vec.size() > 0); + for (int i = 0; i < vec.size(); i++) { + vec[i] = fi::con::explicit_casting(static_cast(i)); + } + } + + else if constexpr (Pred == Predicate::Ones) { + vec_fill_(vec, fi::con::explicit_casting(1.0f)); + } + + else if constexpr (Pred == Predicate::Zeros) { + vec_zero_(vec); + } +} + +template +void vec_lexicographic_(std::vector& vec) { + for (int i = 0; i < vec.size(); i++) { + vec[i] = fi::con::explicit_casting(static_cast(i)); + } +} + template void vec_normal_(std::vector& vec, float mean = 0.f, float std = 1.f) { std::random_device rd{}; - std::mt19937 gen{rd()}; + std::mt19937 gen{1234}; std::normal_distribution d{mean, std}; for (size_t i = 0; i < vec.size(); ++i) { float value = static_cast(d(gen)); @@ -65,7 +96,7 @@ void vec_normal_(std::vector& vec, float mean = 0.f, float std = 1.f) { template void vec_uniform_(std::vector& vec, float a = 0.f, float b = 1.f) { std::random_device rd{}; - std::mt19937 gen{rd()}; + std::mt19937 gen{1234}; std::uniform_real_distribution d{a, b}; for (size_t i = 0; i < vec.size(); ++i) { float value = static_cast(d(gen)); @@ -86,7 +117,7 @@ void vec_fill_(std::vector& vec, T val) { template void vec_randint_(std::vector& vec, int low, int high) { std::random_device rd{}; - std::mt19937 gen{rd()}; + std::mt19937 gen{1234}; std::uniform_int_distribution d{low, high}; for (size_t i = 0; i < vec.size(); ++i) { float value = static_cast(d(gen)); From 17b45a8449c48dfd597ecafc6d7c412c37fe7fbc Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 5 Nov 2025 10:26:08 -0600 Subject: [PATCH 02/13] Update libflashinfer/utils/flashinfer_prefill_ops.hip.h Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Diptorup Deb --- libflashinfer/utils/flashinfer_prefill_ops.hip.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops.hip.h index b5f6bff293..53f7fc837a 100644 --- a/libflashinfer/utils/flashinfer_prefill_ops.hip.h +++ b/libflashinfer/utils/flashinfer_prefill_ops.hip.h @@ -1,7 +1,7 @@ -// SPDX - FileCopyrightText : 2023 - 2025 Flashinfer team -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// SPDX-FileCopyrightText: 2023-2025 Flashinfer team +// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc. // -// SPDX - License - Identifier : Apache 2.0 +// SPDX-License-Identifier: Apache-2.0 #pragma once From c99d968f8f2cd19799b9db16c6e49ef3f15fbf7b Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 5 Nov 2025 10:26:19 -0600 Subject: [PATCH 03/13] Update libflashinfer/tests/hip/test_single_prefill.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Diptorup Deb --- libflashinfer/tests/hip/test_single_prefill.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index f02557d80c..e104788910 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -1,7 +1,7 @@ -// SPDX - FileCopyrightText : 2023 - 2025 Flashinfer team -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// SPDX-FileCopyrightText: 2023-2025 Flashinfer team +// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc. // -// SPDX - License - Identifier : Apache 2.0 +// SPDX-License-Identifier: Apache-2.0 #include From 7f1b94d14ea12a1785f7f312f34ac4852182bb45 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 5 Nov 2025 10:49:28 -0600 Subject: [PATCH 04/13] Update libflashinfer/include/flashinfer/attention/generic/dispatch.cuh Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Diptorup Deb --- .../include/flashinfer/attention/generic/dispatch.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh b/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh index abe0a3020e..eed3e1023f 100644 --- a/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh @@ -1,7 +1,7 @@ -// SPDX - FileCopyrightText : 2023-2035 FlashInfer team. -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// SPDX-FileCopyrightText: 2023-2025 FlashInfer team. +// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc. // -// SPDX - License - Identifier : Apache 2.0 +// SPDX-License-Identifier: Apache-2.0 #pragma once From 6caf9cc4a699f0065a521bbd7aa555a7cf63c0d5 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 5 Nov 2025 20:06:35 +0000 Subject: [PATCH 05/13] Address review comments --- .../flashinfer/attention/generic/permuted_smem.cuh | 6 +++--- .../include/flashinfer/attention/generic/prefill.cuh | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh index 1045a0bda4..991bab2992 100644 --- a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh @@ -1,7 +1,7 @@ -// SPDX - FileCopyrightText : 2023-2035 FlashInfer team. -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// SPDX-FileCopyrightText : 2023-2035 FlashInfer team. +// SPDX-FileCopyrightText : 2025 Advanced Micro Devices, Inc. // -// SPDX - License - Identifier : Apache 2.0 +// SPDX-License - Identifier : Apache 2.0 #ifndef FLASHINFER_PERMUTED_SMEM_CUH_ #define FLASHINFER_PERMUTED_SMEM_CUH_ diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index b6d4d8e87a..73bbd8500a 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1,7 +1,7 @@ -// SPDX - FileCopyrightText : 2023-2025 FlashInfer team. -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// SPDX-FileCopyrightText : 2023-2025 FlashInfer team. +// SPDX-FileCopyrightText : 2025 Advanced Micro Devices, Inc. // -// SPDX - License - Identifier : Apache - 2.0 +// SPDX-License-Identifier : Apache - 2.0 #ifndef FLASHINFER_PREFILL_CUH_ #define FLASHINFER_PREFILL_CUH_ @@ -31,8 +31,8 @@ namespace flashinfer { DEFINE_HAS_MEMBER(maybe_q_rope_offset) DEFINE_HAS_MEMBER(maybe_k_rope_offset) -namespace cg = flashinfer::gpu_iface::cg; -namespace memory = flashinfer::gpu_iface::memory; +namespace cg = gpu_iface::cg; +namespace memory = gpu_iface::memory; namespace mma = gpu_iface::mma; using gpu_iface::vec_dtypes::vec_cast; From 9a654ec434ce1a6174fbecdac10e61d7f9f4507e Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 5 Nov 2025 20:07:17 +0000 Subject: [PATCH 06/13] Fix review comments --- .../include/flashinfer/attention/generic/permuted_smem.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh index 991bab2992..a1c7636ec7 100644 --- a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText : 2023-2035 FlashInfer team. // SPDX-FileCopyrightText : 2025 Advanced Micro Devices, Inc. // -// SPDX-License - Identifier : Apache 2.0 +// SPDX-License-Identifier : Apache 2.0 #ifndef FLASHINFER_PERMUTED_SMEM_CUH_ #define FLASHINFER_PERMUTED_SMEM_CUH_ From cb8aa749abd26bd32fcfb24bbd7c882f81a9f35c Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 5 Nov 2025 20:08:34 +0000 Subject: [PATCH 07/13] Fix SPDX header --- .../flashinfer/attention/generic/default_prefill_params.cuh | 6 +++--- libflashinfer/include/flashinfer/attention/generic/page.cuh | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh index e8014f199e..9e28ff3940 100644 --- a/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh @@ -1,7 +1,7 @@ -// SPDX - FileCopyrightText : 2023-2025 FlashInfer team. -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// SPDX-FileCopyrightText : 2023-2025 FlashInfer team. +// SPDX-FileCopyrightText : 2025 Advanced Micro Devices, Inc. // -// SPDX - License - Identifier : Apache - 2.0 +// SPDX-License-Identifier : Apache - 2.0 #ifndef FLASHINFER_PREFILL_PARAMS_CUH_ #define FLASHINFER_PREFILL_PARAMS_CUH_ diff --git a/libflashinfer/include/flashinfer/attention/generic/page.cuh b/libflashinfer/include/flashinfer/attention/generic/page.cuh index 14df54be80..c6cd7e89c6 100644 --- a/libflashinfer/include/flashinfer/attention/generic/page.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/page.cuh @@ -1,7 +1,7 @@ -// SPDX - FileCopyrightText : 2023-2025 FlashInfer team. -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// SPDX-FileCopyrightText : 2023-2025 FlashInfer team. +// SPDX-FileCopyrightText : 2025 Advanced Micro Devices, Inc. // -// SPDX - License - Identifier : Apache - 2.0 +// SPDX-License-Identifier : Apache - 2.0 #pragma once From b604e55a16dc7e17416bf29d65d310ae62e33d07 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 5 Nov 2025 20:11:25 +0000 Subject: [PATCH 08/13] Remove dead code --- .../tests/hip/test_single_prefill.cpp | 470 +----------------- 1 file changed, 2 insertions(+), 468 deletions(-) diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index e104788910..caac4cdc41 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -17,166 +17,6 @@ using namespace flashinfer; -#if 0 -template -void _TestComputeQKCorrectness(size_t qo_len, - size_t kv_len, - size_t num_qo_heads, - size_t num_kv_heads, - size_t head_dim, - bool causal, - QKVLayout kv_layout, - PosEncodingMode pos_encoding_mode, - bool use_fp16_qk_reduction, - float rtol = 1e-3, - float atol = 1e-3) -{ - // std::cout << "Testing compute_qk: qo_len=" << qo_len - // << ", kv_len=" << kv_len << ", num_qo_heads=" << num_qo_heads - // << ", num_kv_heads=" << num_kv_heads << ", head_dim=" << head_dim - // << std::endl; - - // Generate test data (same as original test) - std::vector q(qo_len * num_qo_heads * head_dim); - std::vector k(kv_len * num_kv_heads * head_dim); - std::vector v(kv_len * num_kv_heads * - head_dim); // Still needed for params - std::vector o(qo_len * num_qo_heads * - head_dim); // Still needed for params - - utils::vec_normal_(q); - utils::vec_normal_(k); - utils::vec_normal_(v); // Initialize even though we won't use it - utils::vec_zero_(o); - - // GPU memory allocation (same pattern as original) - DTypeQ *q_d; - FI_GPU_CALL(hipMalloc(&q_d, q.size() * sizeof(DTypeQ))); - FI_GPU_CALL(hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), - hipMemcpyHostToDevice)); - - DTypeKV *k_d; - FI_GPU_CALL(hipMalloc(&k_d, k.size() * sizeof(DTypeKV))); - FI_GPU_CALL(hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), - hipMemcpyHostToDevice)); - - DTypeKV *v_d; - FI_GPU_CALL(hipMalloc(&v_d, v.size() * sizeof(DTypeKV))); - FI_GPU_CALL(hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), - hipMemcpyHostToDevice)); - - DTypeO *o_d; - FI_GPU_CALL(hipMalloc(&o_d, o.size() * sizeof(DTypeO))); - FI_GPU_CALL(hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), - hipMemcpyHostToDevice)); - - DTypeO *tmp_d; - FI_GPU_CALL(hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO))); - - // Allocate output buffer for QK scores - const size_t qk_output_size = qo_len * kv_len * num_qo_heads; - float *qk_scores_d; - FI_GPU_CALL(hipMalloc(&qk_scores_d, qk_output_size * sizeof(float))); - - // std::cout << "Debug: Kernel launch parameters:" << std::endl; - // std::cout << " qo_len=" << qo_len << ", kv_len=" << kv_len << std::endl; - // std::cout << " num_qo_heads=" << num_qo_heads - // << ", num_kv_heads=" << num_kv_heads << std::endl; - // std::cout << " head_dim=" << head_dim << std::endl; - // std::cout << " qk_output_size=" << qk_output_size << std::endl; - // std::cout << " Launching ComputeQKStubCaller..." << std::endl; - - // Call ComputeQKStubCaller instead of SinglePrefillWithKVCache - hipError_t status = - flashinfer::ComputeQKStubCaller( - q_d, k_d, v_d, o_d, tmp_d, - /*lse=*/nullptr, qk_scores_d, // Add qk_scores_d parameter - num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, - kv_layout, pos_encoding_mode, use_fp16_qk_reduction); - - // std::cout << " Kernel launch status: " << hipGetErrorString(status) - // << std::endl; - EXPECT_EQ(status, hipSuccess) - << "ComputeQKStubCaller kernel launch failed, error message: " - << hipGetErrorString(status); - - // Get GPU QK scores - std::vector gpu_qk_scores(qk_output_size); - FI_GPU_CALL(hipMemcpy(gpu_qk_scores.data(), qk_scores_d, - qk_output_size * sizeof(float), - hipMemcpyDeviceToHost)); - - // Check if GPU output is not empty - bool isEmpty = gpu_qk_scores.empty(); - EXPECT_EQ(isEmpty, false) << "GPU QK scores vector is empty"; - - // Compute CPU reference using our cpu_reference::compute_qk - std::vector cpu_qk_scores = cpu_reference::compute_qk( - q, k, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout); - - // Validate results (same pattern as original test) - size_t num_results_error_atol = 0; - bool nan_detected = false; - - // Compare element-by-element - size_t comparison_size = - std::min(gpu_qk_scores.size(), cpu_qk_scores.size()); - for (size_t i = 0; i < comparison_size; ++i) { - float gpu_val = gpu_qk_scores[i]; - float cpu_val = cpu_qk_scores[i]; - - if (isnan(gpu_val)) { - nan_detected = true; - } - - if (!utils::isclose(cpu_val, gpu_val, rtol, atol)) { - num_results_error_atol++; - if (num_results_error_atol <= 10) - { // Only print first 10 mismatches - std::cout << "QK mismatch at i=" << i << ", cpu_val=" << cpu_val - << ", gpu_val=" << gpu_val << std::endl; - } - } - } -#if 0 - // Calculate and report accuracy - float result_accuracy = - 1.0f - float(num_results_error_atol) / float(comparison_size); - - std::cout << "compute_qk results: num_qo_heads=" << num_qo_heads - << ", num_kv_heads=" << num_kv_heads << ", qo_len=" << qo_len - << ", kv_len=" << kv_len << ", head_dim=" << head_dim - << ", causal=" << causal - << ", kv_layout=" << QKVLayoutToString(kv_layout) - << ", pos_encoding_mode=" - << PosEncodingModeToString(pos_encoding_mode) - << ", qk_accuracy=" << result_accuracy << " (" - << num_results_error_atol << "/" << comparison_size - << " mismatches)" << std::endl; - - // Print some sample values for debugging - std::cout << "Sample QK scores (first 10): GPU vs CPU" << std::endl; - for (size_t i = 0; i < std::min(size_t(10), comparison_size); ++i) { - std::cout << " [" << i << "] GPU=" << gpu_qk_scores[i] - << ", CPU=" << cpu_qk_scores[i] << std::endl; - } - - // Assertions (slightly relaxed for initial testing) - EXPECT_GT(result_accuracy, 0.80) - << "compute_qk accuracy too low"; // Start with 80% - EXPECT_FALSE(nan_detected) << "NaN detected in compute_qk results"; -#endif - // Cleanup - FI_GPU_CALL(hipFree(q_d)); - FI_GPU_CALL(hipFree(k_d)); - FI_GPU_CALL(hipFree(v_d)); - FI_GPU_CALL(hipFree(o_d)); - FI_GPU_CALL(hipFree(tmp_d)); - FI_GPU_CALL(hipFree(qk_scores_d)); -} - -#endif - template void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, bool causal, @@ -192,9 +32,6 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu utils::vec_normal_(q); utils::vec_normal_(k); utils::vec_normal_(v); - // utils::vec_lexicographic_(q); - // utils::vec_lexicographic_(k); - // utils::vec_fill_(v, __float2half(1.0f)); utils::vec_zero_(o); DTypeQ* q_d; @@ -227,15 +64,6 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu std::vector o_h(o.size()); FI_GPU_CALL(hipMemcpy(o_h.data(), o_d, o_h.size() * sizeof(DTypeO), hipMemcpyDeviceToHost)); - // Print the first 10 elements of the output vector for debugging - // std::cout << "Output vector (first 10 elements):"; - // std::cout << "[" << std::endl; - // for (int i = 0; i < 10; ++i) { - // std::cout << fi::con::explicit_casting(o_h[i]) << " - // "; - // } - // std::cout << "]" << std::endl; - bool isEmpty = o_h.empty(); EXPECT_EQ(isEmpty, false) << "Output vector is empty"; @@ -256,16 +84,8 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu } num_results_error_atol += (!utils::isclose(o_ref_val, o_h_val, rtol, atol)); - // if (!utils::isclose(o_ref_val, o_h_val, rtol, atol)) { - // std::cout << "i=" << i << ", o_ref[i]=" << o_ref_val - // << ", o_h[i]=" << o_h_val << std::endl; - // } } - // std::cout<<"Printing att_out vector:\n"; - // for(auto i: att_out) { - // std::cout << i << "\n"; - // } -#if 1 + float result_accuracy = 1. - float(num_results_error_atol) / float(o_ref.size()); std::cout << "num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads << ", qo_len=" << qo_len << ", kv_len=" << kv_len << ", head_dim=" << head_dim @@ -275,7 +95,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; EXPECT_FALSE(nan_detected) << "Nan detected in the result."; -#endif + FI_GPU_CALL(hipFree(q_d)); FI_GPU_CALL(hipFree(k_d)); FI_GPU_CALL(hipFree(v_d)); @@ -283,256 +103,6 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu FI_GPU_CALL(hipFree(tmp_d)); } -// template -// void TestSinglePrefillKernelLongContextCorrectness(bool -// use_fp16_qk_reduction) -// { -// for (size_t qo_len : {1, 31, 63, 127}) { -// for (size_t kv_len : {31717}) { -// for (size_t num_heads : {1}) { -// for (size_t head_dim : {64, 128, 256}) { -// for (bool causal : {false, true}) { -// for (size_t pos_encoding_mode : {0, 1}) { -// for (size_t kv_layout : {0, 1}) { -// _TestSinglePrefillKernelCorrectness< -// DTypeIn, DTypeIn, DTypeO>( -// qo_len, kv_len, num_heads, num_heads, -// head_dim, causal, QKVLayout(kv_layout), -// PosEncodingMode(pos_encoding_mode), -// use_fp16_qk_reduction); -// } -// } -// } -// } -// } -// } -// } -// } -//*********************************************************************** -// The following tests are disabled because we dont support fp8 <-> float -// conversions - -// template -// void TestSinglePrefillFP8KernelLongContextCorrectness(bool -// use_fp16_qk_reduction) { -// for (size_t qo_len : {1, 31, 63, 127}) { -// for (size_t kv_len : {31717}) { -// for (size_t num_heads : {1}) { -// for (size_t head_dim : {64, 128, 256}) { -// for (bool causal : {false, true}) { -// for (size_t pos_encoding_mode : {0}) { -// for (size_t kv_layout : {0, 1}) { -// _TestSinglePrefillKernelCorrectness<__half, DTypeKV, __half>( -// qo_len, kv_len, num_heads, num_heads, head_dim, causal, -// QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), -// use_fp16_qk_reduction); -// } -// } -// } -// } -// } -// } -// } -// } - -// template -// void TestSinglePrefillKernelShortContextCorrectness(bool -// use_fp16_qk_reduction) -// { -// float rtol = std::is_same::value ? 1e-2 : 1e-3; -// float atol = std::is_same::value ? 1e-2 : 1e-3; -// for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { -// for (size_t num_qo_heads : {32}) { -// for (size_t num_kv_heads : {4, 8, 32}) { -// for (size_t head_dim : {64, 128, 256}) { -// for (bool causal : {false, true}) { -// for (size_t pos_encoding_mode : {0, 1}) { -// for (size_t kv_layout : {0, 1}) { -// _TestSinglePrefillKernelCorrectness< -// DTypeIn, DTypeIn, DTypeO>( -// qkv_len, qkv_len, num_qo_heads, -// num_kv_heads, head_dim, causal, -// QKVLayout(kv_layout), -// PosEncodingMode(pos_encoding_mode), -// use_fp16_qk_reduction, rtol, atol); -// } -// } -// } -// } -// } -// } -// } -// } - -//*********************************************************************** -// The following tests are disabled because we dont support fp8 <-> float -// conversions - -// template -// void TestSinglePrefillFP8KernelShortContextCorrectness(bool -// use_fp16_qk_reduction) { -// float rtol = 1e-3; -// float atol = 1e-3; -// for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { -// for (size_t num_qo_heads : {32}) { -// for (size_t num_kv_heads : {4, 8, 32}) { -// for (size_t head_dim : {64, 128, 256}) { -// for (bool causal : {false, true}) { -// for (size_t pos_encoding_mode : {0}) { -// for (size_t kv_layout : {0, 1}) { -// _TestSinglePrefillKernelCorrectness<__half, DTypeKV, __half>( -// qkv_len, qkv_len, num_qo_heads, num_kv_heads, head_dim, -// causal, QKVLayout(kv_layout), -// PosEncodingMode(pos_encoding_mode), -// use_fp16_qk_reduction, rtol, atol); -// } -// } -// } -// } -// } -// } -// } -// } - -// template -// void TestSinglePrefillKernelCorrectness(bool use_fp16_qk_reduction) -// { -// for (size_t qo_len : {399, 400, 401}) { -// for (size_t kv_len : {533, 534, 535}) { -// for (size_t num_heads : {12}) { -// for (size_t head_dim : {64, 128, 256}) { -// for (bool causal : {false, true}) { -// for (size_t pos_encoding_mode : {0, 1}) { -// for (size_t kv_layout : {0, 1}) { -// _TestSinglePrefillKernelCorrectness< -// DTypeIn, DTypeIn, DTypeO>( -// qo_len, kv_len, num_heads, num_heads, -// head_dim, causal, QKVLayout(kv_layout), -// PosEncodingMode(pos_encoding_mode), -// use_fp16_qk_reduction); -// } -// } -// } -// } -// } -// } -// } -// } - -// template -// void TestSinglePrefillFP8KernelCorrectness(bool use_fp16_qk_reduction) -// { -// for (size_t qo_len : {399, 400, 401}) { -// for (size_t kv_len : {533, 534, 535}) { -// for (size_t num_heads : {12}) { -// for (size_t head_dim : {64, 128, 256}) { -// for (bool causal : {false, true}) { -// for (size_t pos_encoding_mode : {0}) { -// for (size_t kv_layout : {0, 1}) { -// _TestSinglePrefillKernelCorrectness< -// __half, DTypeKV, __half>( -// qo_len, kv_len, num_heads, num_heads, -// head_dim, causal, QKVLayout(kv_layout), -// PosEncodingMode(pos_encoding_mode), -// use_fp16_qk_reduction); -// } -// } -// } -// } -// } -// } -// } -// } - -// TEST(FlashInferCorrectnessTest, -// TestSinglePrefillKernelLongContextCorrectnessFP16) -// { -// TestSinglePrefillKernelLongContextCorrectness<__half, __half>(false); -// } - -// TEST(FlashInferCorrectnessTest, -// TestSinglePrefillKernelLongContextCorrectnessFP16QKHalfAccum) -// { -// TestSinglePrefillKernelLongContextCorrectness<__half, __half>(true); -// } - -// TEST(FlashInferCorrectnessTest, -// TestSinglePrefillKernelShortContextCorrectnessFP16) -// { -// TestSinglePrefillKernelShortContextCorrectness<__half, __half>(false); -// } - -// TEST(FlashInferCorrectnessTest, -// TestSinglePrefillKernelShortContextCorrectnessFP16QKHalfAccum) -// { -// TestSinglePrefillKernelShortContextCorrectness<__half, __half>(true); -// } - -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestFP16) -// { -// TestSinglePrefillKernelCorrectness<__half, __half>(false); -// } - -// TEST(FlashInferCorrectnessTest, -// TestSinglePrefillKernelCorrectnessTestFP16QKHalfAccum) -// { -// TestSinglePrefillKernelCorrectness<__half, __half>(true); -// } - -// #ifdef FLASHINFER_ENABLE_BF16 -// TEST(FlashInferCorrectnessTest, -// TestSinglePrefillKernelLongContextCorrectnessBF16) -// { -// TestSinglePrefillKernelLongContextCorrectness<__hip_bfloat16, -// __hip_bfloat16>( -// false); -// } -// TEST(FlashInferCorrectnessTest, -// TestSinglePrefillKernelShortContextCorrectnessBF16) -// { -// TestSinglePrefillKernelShortContextCorrectness<__hip_bfloat16, -// __hip_bfloat16>( -// false); -// } -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestBF16) -// { -// TestSinglePrefillKernelCorrectness<__hip_bfloat16, -// __hip_bfloat16>(false); -// } -// #endif - -//*********************************************************************** -// The following tests are disabled because we dont support fp8 <-> float -// conversions - -// #ifdef FLASHINFER_ENABLE_FP8_E4M3 -// TEST(FlashInferCorrectnessTest, -// TestSinglePrefillKernelShortContextCorrectnessE4M3) { -// TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e4m3>(false); -// } -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE4M3) { -// TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e4m3>(false); -// } -// TEST(FlashInferCorrectnessTest, -// TestSinglePrefillKernelLongContextCorrectnessE4M3) { -// TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e4m3>(false); -// } -// #endif - -// #ifdef FLASHINFER_ENABLE_FP8_E5M2 -// TEST(FlashInferCorrectnessTest, -// TestSinglePrefillKernelShortContextCorrectnessE5M2) { -// TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e5m2>(false); -// } -// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE5M2) { -// TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e5m2>(false); -// } -// TEST(FlashInferCorrectnessTest, -// TestSinglePrefillKernelLongContextCorrectnessE5M2) { -// TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e5m2>(false); -// } -// #endif - int main(int argc, char** argv) { // ::testing::InitGoogleTest(&argc, argv); // return RUN_ALL_TESTS(); @@ -581,39 +151,3 @@ int main(int argc, char** argv) { qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction, debug_thread_id, debug_warp_id); } - -// int main(int argc, char **argv) -// { -// // Test compute_qk first with simple parameters -// std::cout << "=== Testing compute_qk function ===" << std::endl; -// using DTypeIn = __half; -// using DTypeO = __half; -// bool use_fp16_qk_reduction = false; -// bool causal = false; -// size_t pos_encoding_mode = 0; -// size_t kv_layout = 0; - -// // Start with small dimensions for easier debugging -// _TestComputeQKCorrectness( -// 16, // qo_len - small for debugging -// 32, // kv_len -// 1, // num_qo_heads - single head -// 1, // num_kv_heads - single head -// 64, // head_dim -// causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), -// use_fp16_qk_reduction); - -// std::cout << "\n=== Testing full single prefill ===" << std::endl; -// // Your existing test... -// size_t qo_len = 399; -// size_t kv_len = 533; -// size_t num_heads = 1; -// size_t head_dim = 64; - -// _TestSinglePrefillKernelCorrectness( -// qo_len, kv_len, num_heads, num_heads, head_dim, causal, -// QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), -// use_fp16_qk_reduction); - -// return 0; -// } From 0c9c6261865c4622c007573b1a5fe522019e2adf Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 5 Nov 2025 20:12:59 +0000 Subject: [PATCH 09/13] Address review comments --- libflashinfer/tests/hip/test_single_prefill.cpp | 2 +- libflashinfer/utils/cpu_reference_hip.h | 3 --- ...ashinfer_prefill_ops.hip.h => flashinfer_prefill_ops_hip.h} | 0 3 files changed, 1 insertion(+), 4 deletions(-) rename libflashinfer/utils/{flashinfer_prefill_ops.hip.h => flashinfer_prefill_ops_hip.h} (100%) diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index caac4cdc41..97647d6f4b 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -8,7 +8,7 @@ #include #include "../../utils/cpu_reference_hip.h" -#include "../../utils/flashinfer_prefill_ops.hip.h" +#include "../../utils/flashinfer_prefill_ops_hip.h" #include "../../utils/utils_hip.h" #include "flashinfer/attention/generic/prefill.cuh" #include "gpu_iface/gpu_runtime_compat.hpp" diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index 6cc5809754..555b1a481d 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -8,9 +8,6 @@ #include #include -#include -#include - #include "flashinfer/attention/generic/page.cuh" #include "flashinfer/attention/generic/pos_enc.cuh" #include "flashinfer/exception.h" diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops_hip.h similarity index 100% rename from libflashinfer/utils/flashinfer_prefill_ops.hip.h rename to libflashinfer/utils/flashinfer_prefill_ops_hip.h From d2a4605151d31c1d9265845411d48391a6a80ba5 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 5 Nov 2025 20:15:24 +0000 Subject: [PATCH 10/13] Apply fix based on copilot review --- libflashinfer/include/flashinfer/attention/generic/prefill.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 73bbd8500a..8008f0de64 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1511,7 +1511,7 @@ __device__ __forceinline__ void write_o_reg_gmem( *reinterpret_cast(&o_frag[mma_q][mma_d][j * 2]); *reinterpret_cast(base_addr + 8 + col_offset * 2) = - *reinterpret_cast(&o_frag[mma_q][mma_d][$ + j * 2]); + *reinterpret_cast(&o_frag[mma_q][mma_d][4 + j * 2]); #endif } } From 161cca6194e6dbf6244cc667faeb3f48343c00ad Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 5 Nov 2025 20:18:34 +0000 Subject: [PATCH 11/13] Address formatting fix --- libflashinfer/utils/utils_hip.h | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/libflashinfer/utils/utils_hip.h b/libflashinfer/utils/utils_hip.h index de6b6956c8..ad82f78221 100644 --- a/libflashinfer/utils/utils_hip.h +++ b/libflashinfer/utils/utils_hip.h @@ -64,13 +64,9 @@ void generate_data(std::vector& vec) { for (int i = 0; i < vec.size(); i++) { vec[i] = fi::con::explicit_casting(static_cast(i)); } - } - - else if constexpr (Pred == Predicate::Ones) { + } else if constexpr (Pred == Predicate::Ones) { vec_fill_(vec, fi::con::explicit_casting(1.0f)); - } - - else if constexpr (Pred == Predicate::Zeros) { + } else if constexpr (Pred == Predicate::Zeros) { vec_zero_(vec); } } From f86432bd4da40ed32d94b86e5d5d9e847fd0a4f7 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 5 Nov 2025 20:20:15 +0000 Subject: [PATCH 12/13] Fix SPDX headers --- .../include/flashinfer/attention/generic/permuted_smem.cuh | 2 +- .../include/flashinfer/attention/generic/prefill.cuh | 4 ++-- libflashinfer/utils/cpu_reference_hip.h | 7 +++---- libflashinfer/utils/utils_hip.h | 6 +++--- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh index a1c7636ec7..d0abe0d492 100644 --- a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText : 2023-2035 FlashInfer team. // SPDX-FileCopyrightText : 2025 Advanced Micro Devices, Inc. // -// SPDX-License-Identifier : Apache 2.0 +// SPDX-License-Identifier : Apache-2.0 #ifndef FLASHINFER_PERMUTED_SMEM_CUH_ #define FLASHINFER_PERMUTED_SMEM_CUH_ diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 8008f0de64..88105add82 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1,7 +1,7 @@ -// SPDX-FileCopyrightText : 2023-2025 FlashInfer team. +// SPDX-FileCopyrightText : 2023-2035 FlashInfer team. // SPDX-FileCopyrightText : 2025 Advanced Micro Devices, Inc. // -// SPDX-License-Identifier : Apache - 2.0 +// SPDX-License-Identifier : Apache-2.0 #ifndef FLASHINFER_PREFILL_CUH_ #define FLASHINFER_PREFILL_CUH_ diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index 555b1a481d..c74985cbd3 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -1,8 +1,7 @@ -// SPDX - FileCopyrightText : 2023 - 2025 Flashinfer team -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// SPDX-FileCopyrightText : 2023-2035 FlashInfer team. +// SPDX-FileCopyrightText : 2025 Advanced Micro Devices, Inc. // -// SPDX - License - Identifier : Apache 2.0 - +// SPDX-License-Identifier : Apache-2.0 #pragma once #include diff --git a/libflashinfer/utils/utils_hip.h b/libflashinfer/utils/utils_hip.h index ad82f78221..6c8bdeb6db 100644 --- a/libflashinfer/utils/utils_hip.h +++ b/libflashinfer/utils/utils_hip.h @@ -1,7 +1,7 @@ -// SPDX - FileCopyrightText : 2023 - 2025 Flashinfer team -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// SPDX-FileCopyrightText : 2023-2035 FlashInfer team. +// SPDX-FileCopyrightText : 2025 Advanced Micro Devices, Inc. // -// SPDX - License - Identifier : Apache 2.0 +// SPDX-License-Identifier : Apache-2.0 #pragma once From 74f8dbde6c03342fa2bcd4a289acf99484e9938b Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 5 Nov 2025 20:21:58 +0000 Subject: [PATCH 13/13] More SPDX fixes --- .../flashinfer/attention/generic/default_decode_params.cuh | 6 +++--- libflashinfer/include/flashinfer/attention/generic/page.cuh | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/default_decode_params.cuh b/libflashinfer/include/flashinfer/attention/generic/default_decode_params.cuh index db6d28091a..8e743ed708 100644 --- a/libflashinfer/include/flashinfer/attention/generic/default_decode_params.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/default_decode_params.cuh @@ -1,7 +1,7 @@ -// SPDX - FileCopyrightText : 2023-2035 FlashInfer team. -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// SPDX-FileCopyrightText : 2023-2035 FlashInfer team. +// SPDX-FileCopyrightText : 2025 Advanced Micro Devices, Inc. // -// SPDX - License - Identifier : Apache 2.0 +// SPDX-License-Identifier : Apache-2.0 #pragma once #ifndef FLASHINFER_DECODE_PARAMS_CUH_ diff --git a/libflashinfer/include/flashinfer/attention/generic/page.cuh b/libflashinfer/include/flashinfer/attention/generic/page.cuh index c6cd7e89c6..ab65b5447c 100644 --- a/libflashinfer/include/flashinfer/attention/generic/page.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/page.cuh @@ -1,7 +1,7 @@ -// SPDX-FileCopyrightText : 2023-2025 FlashInfer team. +// SPDX-FileCopyrightText : 2023-2035 FlashInfer team. // SPDX-FileCopyrightText : 2025 Advanced Micro Devices, Inc. // -// SPDX-License-Identifier : Apache - 2.0 +// SPDX-License-Identifier : Apache-2.0 #pragma once